Tempus / app.py
vincentiusyoshuac's picture
Update app.py
0dc2bb2 verified
raw
history blame
5.36 kB
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from chronos import ChronosPipeline
class TimeSeriesForecaster:
def __init__(self, model_name="amazon/chronos-t5-small"):
self.pipeline = ChronosPipeline.from_pretrained(
model_name,
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
self.original_series = None
self.context = None
def preprocess_data(self, df, date_column, value_column, context_length=30, prediction_length=7):
"""
Prepare time series data from DataFrame
"""
# Ensure data is sorted by date
df = df.sort_values(by=date_column)
# Convert date column to datetime
df[date_column] = pd.to_datetime(df[date_column])
# Set index to date
df.set_index(date_column, inplace=True)
# Extract numeric series
self.original_series = df[value_column].values
# Convert to tensor
self.context = torch.tensor(self.original_series[-context_length:], dtype=torch.float32)
return self.context, context_length
def forecast(self, context, prediction_length=7, num_samples=100):
"""
Perform time series forecasting
"""
forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples)
return forecasts
def visualize_forecast(self, context, forecasts):
"""
Create visualization of predictions
"""
plt.figure(figsize=(12, 6))
# Plot original series
plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='blue')
# Calculate forecast statistics
forecast_np = forecasts[0].numpy()
low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0)
# Plot forecast
forecast_index = range(len(self.original_series), len(self.original_series) + len(median))
plt.plot(forecast_index, median, color='red', label='Median Forecast')
plt.fill_between(forecast_index, low, high, color='red', alpha=0.3, label='80% Prediction Interval')
plt.title('Time Series Forecasting with Amazon Chronos')
plt.xlabel('Time Index')
plt.ylabel('Value')
plt.legend()
return plt
def main():
st.title('🕰️ Time Series Forecasting with Amazon Chronos')
# Sidebar for upload and configuration
st.sidebar.header('Forecast Settings')
# Upload CSV file
uploaded_file = st.sidebar.file_uploader(
"Upload CSV File",
type=['csv'],
help="Ensure CSV file has date and numeric columns"
)
# Column selection and prediction settings
if uploaded_file is not None:
# Read CSV
df = pd.read_csv(uploaded_file)
# Select columns
date_column = st.sidebar.selectbox(
'Select Date Column',
options=df.columns
)
value_column = st.sidebar.selectbox(
'Select Value Column',
options=[col for col in df.columns if col != date_column]
)
# Prediction parameters
context_length = st.sidebar.slider(
'Context Length',
min_value=10,
max_value=100,
value=30
)
prediction_length = st.sidebar.slider(
'Prediction Length',
min_value=1,
max_value=30,
value=7
)
# Process button
if st.sidebar.button('Perform Forecast'):
try:
# Initialize forecaster
forecaster = TimeSeriesForecaster()
# Preprocess data
context, _ = forecaster.preprocess_data(
df,
date_column,
value_column,
context_length,
prediction_length
)
# Perform forecasting
forecasts = forecaster.forecast(context, prediction_length)
# Visualize results
st.subheader('Forecast Visualization')
plt = forecaster.visualize_forecast(context, forecasts)
st.pyplot(plt)
# Display forecast details
forecast_np = forecasts[0].numpy()
forecast_mean = forecast_np.mean(axis=0)
forecast_lower = np.percentile(forecast_np, 10, axis=0)
forecast_upper = np.percentile(forecast_np, 90, axis=0)
prediction_df = pd.DataFrame({
'Mean Forecast': forecast_mean,
'Lower Bound (10%)': forecast_lower,
'Upper Bound (90%)': forecast_upper
})
st.subheader('Forecast Details')
st.dataframe(prediction_df)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == '__main__':
main()