Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
from chronos import ChronosPipeline | |
class TimeSeriesForecaster: | |
def __init__(self, model_name="amazon/chronos-t5-small"): | |
self.pipeline = ChronosPipeline.from_pretrained( | |
model_name, | |
device_map="cuda" if torch.cuda.is_available() else "cpu", | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
) | |
self.original_series = None | |
self.context = None | |
def preprocess_data(self, df, date_column, value_column, context_length=30, prediction_length=7): | |
""" | |
Prepare time series data from DataFrame | |
""" | |
# Ensure data is sorted by date | |
df = df.sort_values(by=date_column) | |
# Convert date column to datetime | |
df[date_column] = pd.to_datetime(df[date_column]) | |
# Set index to date | |
df.set_index(date_column, inplace=True) | |
# Extract numeric series | |
self.original_series = df[value_column].values | |
# Convert to tensor | |
self.context = torch.tensor(self.original_series[-context_length:], dtype=torch.float32) | |
return self.context, context_length | |
def forecast(self, context, prediction_length=7, num_samples=100): | |
""" | |
Perform time series forecasting | |
""" | |
forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples) | |
return forecasts | |
def visualize_forecast(self, context, forecasts): | |
""" | |
Create visualization of predictions | |
""" | |
plt.figure(figsize=(12, 6)) | |
# Plot original series | |
plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='blue') | |
# Calculate forecast statistics | |
forecast_np = forecasts[0].numpy() | |
low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0) | |
# Plot forecast | |
forecast_index = range(len(self.original_series), len(self.original_series) + len(median)) | |
plt.plot(forecast_index, median, color='red', label='Median Forecast') | |
plt.fill_between(forecast_index, low, high, color='red', alpha=0.3, label='80% Prediction Interval') | |
plt.title('Time Series Forecasting with Amazon Chronos') | |
plt.xlabel('Time Index') | |
plt.ylabel('Value') | |
plt.legend() | |
return plt | |
def main(): | |
st.title('🕰️ Time Series Forecasting with Amazon Chronos') | |
# Sidebar for upload and configuration | |
st.sidebar.header('Forecast Settings') | |
# Upload CSV file | |
uploaded_file = st.sidebar.file_uploader( | |
"Upload CSV File", | |
type=['csv'], | |
help="Ensure CSV file has date and numeric columns" | |
) | |
# Column selection and prediction settings | |
if uploaded_file is not None: | |
# Read CSV | |
df = pd.read_csv(uploaded_file) | |
# Select columns | |
date_column = st.sidebar.selectbox( | |
'Select Date Column', | |
options=df.columns | |
) | |
value_column = st.sidebar.selectbox( | |
'Select Value Column', | |
options=[col for col in df.columns if col != date_column] | |
) | |
# Prediction parameters | |
context_length = st.sidebar.slider( | |
'Context Length', | |
min_value=10, | |
max_value=100, | |
value=30 | |
) | |
prediction_length = st.sidebar.slider( | |
'Prediction Length', | |
min_value=1, | |
max_value=30, | |
value=7 | |
) | |
# Process button | |
if st.sidebar.button('Perform Forecast'): | |
try: | |
# Initialize forecaster | |
forecaster = TimeSeriesForecaster() | |
# Preprocess data | |
context, _ = forecaster.preprocess_data( | |
df, | |
date_column, | |
value_column, | |
context_length, | |
prediction_length | |
) | |
# Perform forecasting | |
forecasts = forecaster.forecast(context, prediction_length) | |
# Visualize results | |
st.subheader('Forecast Visualization') | |
plt = forecaster.visualize_forecast(context, forecasts) | |
st.pyplot(plt) | |
# Display forecast details | |
forecast_np = forecasts[0].numpy() | |
forecast_mean = forecast_np.mean(axis=0) | |
forecast_lower = np.percentile(forecast_np, 10, axis=0) | |
forecast_upper = np.percentile(forecast_np, 90, axis=0) | |
prediction_df = pd.DataFrame({ | |
'Mean Forecast': forecast_mean, | |
'Lower Bound (10%)': forecast_lower, | |
'Upper Bound (90%)': forecast_upper | |
}) | |
st.subheader('Forecast Details') | |
st.dataframe(prediction_df) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == '__main__': | |
main() |