Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import pytorch_lightning as pl | |
from neuralforecast.core import NeuralForecast | |
from neuralforecast.models import NHITS, TimesNet, LSTM, TFT | |
from neuralforecast.losses.pytorch import HuberMQLoss | |
from neuralforecast.utils import AirPassengersDF | |
import time | |
from st_aggrid import AgGrid | |
from nixtla import NixtlaClient | |
import os | |
st.set_page_config(layout='wide') | |
def load_model(path, freq): | |
nf = NeuralForecast.load(path=path) | |
return nf | |
def load_all_models(): | |
nhits_paths = { | |
'D': './M4/NHITS/daily', | |
'M': './M4/NHITS/monthly', | |
'H': './M4/NHITS/hourly', | |
'W': './M4/NHITS/weekly', | |
'Y': './M4/NHITS/yearly' | |
} | |
timesnet_paths = { | |
'D': './M4/TimesNet/daily', | |
'M': './M4/TimesNet/monthly', | |
'H': './M4/TimesNet/hourly', | |
'W': './M4/TimesNet/weekly', | |
'Y': './M4/TimesNet/yearly' | |
} | |
lstm_paths = { | |
'D': './M4/LSTM/daily', | |
'M': './M4/LSTM/monthly', | |
'H': './M4/LSTM/hourly', | |
'W': './M4/LSTM/weekly', | |
'Y': './M4/LSTM/yearly' | |
} | |
tft_paths = { | |
'D': './M4/TFT/daily', | |
'M': './M4/TFT/monthly', | |
'H': './M4/TFT/hourly', | |
'W': './M4/TFT/weekly', | |
'Y': './M4/TFT/yearly' | |
} | |
nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()} | |
timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()} | |
lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()} | |
tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()} | |
return nhits_models, timesnet_models, lstm_models, tft_models | |
def generate_forecast(model, df,tag=False): | |
if tag == 'retrain': | |
forecast_df = model.predict() | |
else: | |
forecast_df = model.predict(df=df) | |
return forecast_df | |
def determine_frequency(df): | |
df['ds'] = pd.to_datetime(df['ds']) | |
df = df.drop_duplicates(subset='ds') | |
df = df.set_index('ds') | |
# # Create a complete date range | |
# full_range = pd.date_range(start=df.index.min(), end=df.index.max(),freq=freq) | |
# # Reindex the DataFrame to this full date range | |
# df_full = df.reindex(full_range) | |
# Infer the frequency | |
# freq = pd.infer_freq(df_full.index) | |
freq = pd.infer_freq(df.index) | |
if not freq: | |
st.warning('The forecast will use default Daily forecast due to date inconsistency. Please check your data.',icon="⚠️") | |
freq = 'D' | |
return freq | |
def plot_forecasts_matplotlib(forecast_df, train_df, title): | |
fig, ax = plt.subplots(1, 1, figsize=(20, 7)) | |
plot_df = pd.concat([train_df, forecast_df]).set_index('ds') | |
historical_col = 'y' | |
forecast_col = next((col for col in plot_df.columns if 'median' in col), None) | |
lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None) | |
hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None) | |
if forecast_col is None: | |
raise KeyError("No forecast column found in the data.") | |
plot_df[[historical_col, forecast_col]].plot(ax=ax, linewidth=2, label=['Historical', 'Forecast']) | |
if lo_col and hi_col: | |
ax.fill_between( | |
plot_df.index, | |
plot_df[lo_col], | |
plot_df[hi_col], | |
color='blue', | |
alpha=0.3, | |
label='90% Confidence Interval' | |
) | |
ax.set_title(title, fontsize=22) | |
ax.set_ylabel('Value', fontsize=20) | |
ax.set_xlabel('Timestamp [t]', fontsize=20) | |
ax.legend(prop={'size': 15}) | |
ax.grid() | |
st.pyplot(fig) | |
import plotly.graph_objects as go | |
def plot_forecasts(forecast_df, train_df, title): | |
# Combine historical and forecast data | |
plot_df = pd.concat([train_df, forecast_df]).set_index('ds') | |
# Find relevant columns | |
historical_col = 'y' | |
forecast_col = next((col for col in plot_df.columns if 'median' in col), None) | |
lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None) | |
hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None) | |
if forecast_col is None: | |
raise KeyError("No forecast column found in the data.") | |
# Create Plotly figure | |
fig = go.Figure() | |
# Add historical data | |
fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[historical_col], mode='lines', name='Historical')) | |
# Add forecast data | |
fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[forecast_col], mode='lines', name='Forecast')) | |
# Add confidence interval if available | |
if lo_col and hi_col: | |
fig.add_trace(go.Scatter( | |
x=plot_df.index, | |
y=plot_df[hi_col], | |
mode='lines', | |
line=dict(color='rgba(0,100,80,0.2)'), | |
showlegend=False | |
)) | |
fig.add_trace(go.Scatter( | |
x=plot_df.index, | |
y=plot_df[lo_col], | |
mode='lines', | |
line=dict(color='rgba(0,100,80,0.2)'), | |
fill='tonexty', | |
fillcolor='rgba(0,100,80,0.2)', | |
name='90% Confidence Interval' | |
)) | |
# Update layout | |
fig.update_layout( | |
title=title, | |
xaxis_title='Timestamp [t]', | |
yaxis_title='Value', | |
template='plotly_white' | |
) | |
# Display the plot | |
st.plotly_chart(fig) | |
def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models): | |
if freq == 'D': | |
return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D'] | |
elif freq == 'ME': | |
return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M'] | |
elif freq == 'H': | |
return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H'] | |
elif freq in ['W', 'W-SUN']: | |
return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W'] | |
elif freq in ['Y', 'Y-DEC']: | |
return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y'] | |
else: | |
raise ValueError(f"Unsupported frequency: {freq}") | |
def select_model(horizon, model_type, max_steps=50): | |
if model_type == 'NHITS': | |
return NHITS(input_size=5 * horizon, | |
h=horizon, | |
max_steps=max_steps, | |
stack_types=3*['identity'], | |
n_blocks=3*[1], | |
mlp_units=[[256, 256] for _ in range(3)], | |
n_pool_kernel_size=3*[1], | |
batch_size=32, | |
scaler_type='standard', | |
n_freq_downsample=[12, 4, 1], | |
loss=HuberMQLoss(level=[90])) | |
elif model_type == 'TimesNet': | |
return TimesNet(h=horizon, | |
input_size=horizon * 5, | |
hidden_size=32, | |
conv_hidden_size=64, | |
loss=HuberMQLoss(level=[90]), | |
scaler_type='standard', | |
learning_rate=1e-3, | |
max_steps=max_steps, | |
val_check_steps=200, | |
valid_batch_size=64, | |
windows_batch_size=128, | |
inference_windows_batch_size=512) | |
elif model_type == 'LSTM': | |
return LSTM(h=horizon, | |
input_size=horizon * 5, | |
loss=HuberMQLoss(level=[90]), | |
scaler_type='standard', | |
encoder_n_layers=3, | |
encoder_hidden_size=256, | |
context_size=10, | |
decoder_hidden_size=256, | |
decoder_layers=3, | |
max_steps=max_steps) | |
elif model_type == 'TFT': | |
return TFT(h=horizon, | |
input_size=horizon*5, | |
hidden_size=96, | |
loss=HuberMQLoss(level=[90]), | |
learning_rate=0.005, | |
scaler_type='standard', | |
windows_batch_size=128, | |
max_steps=max_steps, | |
val_check_steps=200, | |
valid_batch_size=64, | |
enable_progress_bar=True) | |
else: | |
raise ValueError(f"Unsupported model type: {model_type}") | |
def model_train(df,model, freq): | |
nf = NeuralForecast(models=[model], freq=freq) | |
df['ds'] = pd.to_datetime(df['ds']) | |
nf.fit(df) | |
return nf | |
def forecast_time_series(df, model_type, horizon, max_steps,y_col): | |
start_time = time.time() # Start timing | |
freq = determine_frequency(df) | |
st.sidebar.write(f"Data frequency: {freq}") | |
selected_model = select_model(horizon, model_type, max_steps) | |
st.spinner(f"Training {model_type} model...") | |
model = model_train(df, selected_model,freq) | |
forecast_results = {} | |
forecast_results[model_type] = generate_forecast(model, df, tag='retrain') | |
st.session_state.forecast_results = forecast_results | |
for model_name, forecast_df in forecast_results.items(): | |
plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}') | |
end_time = time.time() # End timing | |
time_taken = end_time - start_time | |
st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds") | |
if 'forecast_results' in st.session_state: | |
forecast_results = st.session_state.forecast_results | |
st.markdown('You can download Input and Forecast Data below') | |
tab_insample, tab_forecast = st.tabs( | |
["Input data", "Forecast"] | |
) | |
with tab_insample: | |
df_grid = df.drop(columns="unique_id") | |
st.write(df_grid) | |
# grid_table = AgGrid( | |
# df_grid, | |
# theme="alpine", | |
# ) | |
with tab_forecast: | |
if model_type in forecast_results: | |
df_grid = forecast_results[model_type] | |
st.write(df_grid) | |
# grid_table = AgGrid( | |
# df_grid, | |
# theme="alpine", | |
# ) | |
def load_default(): | |
df = AirPassengersDF.copy() | |
return df | |
def transfer_learning_forecasting(): | |
st.title("Zero-shot Forecasting") | |
st.markdown(""" | |
Instant time series forecasting and visualization by using various pre-trained deep neural network-based model trained on M4 data. | |
""") | |
nhits_models, timesnet_models, lstm_models, tft_models = load_all_models() | |
with st.sidebar.expander("Upload and Configure Dataset", expanded=True): | |
if 'uploaded_file' not in st.session_state: | |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
if uploaded_file: | |
df = pd.read_csv(uploaded_file) | |
st.session_state.df = df | |
st.session_state.uploaded_file = uploaded_file | |
else: | |
df = load_default() | |
st.session_state.df = df | |
else: | |
if st.checkbox("Upload a new file (CSV)"): | |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
if uploaded_file: | |
df = pd.read_csv(uploaded_file) | |
st.session_state.df = df | |
st.session_state.uploaded_file = uploaded_file | |
else: | |
df = st.session_state.df | |
else: | |
df = st.session_state.df | |
columns = df.columns.tolist() | |
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0) | |
target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')] | |
y_col = st.selectbox("Select Target column", options=target_columns, index=0) | |
st.session_state.ds_col = ds_col | |
st.session_state.y_col = y_col | |
# Model selection and forecasting | |
st.sidebar.subheader("Model Selection and Forecasting") | |
model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"]) | |
horizon = st.sidebar.number_input("Forecast horizon", value=12) | |
df = df.rename(columns={ds_col: 'ds', y_col: 'y'}) | |
df['unique_id']=1 | |
df = df[['unique_id','ds','y']] | |
# Determine frequency of data | |
frequency = determine_frequency(df) | |
st.sidebar.write(f"Detected frequency: {frequency}") | |
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models) | |
forecast_results = {} | |
if st.sidebar.button("Submit"): | |
start_time = time.time() # Start timing | |
if model_choice == "NHITS": | |
forecast_results['NHITS'] = generate_forecast(nhits_model, df) | |
elif model_choice == "TimesNet": | |
forecast_results['TimesNet'] = generate_forecast(timesnet_model, df) | |
elif model_choice == "LSTM": | |
forecast_results['LSTM'] = generate_forecast(lstm_model, df) | |
elif model_choice == "TFT": | |
forecast_results['TFT'] = generate_forecast(tft_model, df) | |
st.session_state.forecast_results = forecast_results | |
for model_name, forecast_df in forecast_results.items(): | |
plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}') | |
end_time = time.time() # End timing | |
time_taken = end_time - start_time | |
st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds") | |
if 'forecast_results' in st.session_state: | |
forecast_results = st.session_state.forecast_results | |
st.markdown('You can download Input and Forecast Data below') | |
tab_insample, tab_forecast = st.tabs( | |
["Input data", "Forecast"] | |
) | |
with tab_insample: | |
df_grid = df.drop(columns="unique_id") | |
st.write(df_grid) | |
# grid_table = AgGrid( | |
# df_grid, | |
# theme="alpine", | |
# ) | |
with tab_forecast: | |
if model_choice in forecast_results: | |
df_grid = forecast_results[model_choice] | |
st.write(df_grid) | |
# grid_table = AgGrid( | |
# df_grid, | |
# theme="alpine", | |
# ) | |
def dynamic_forecasting(): | |
st.title("Personalized Neural Forecasting") | |
st.markdown(""" | |
Train time series forecasting model from scratch and provide forecasts/visualization by using various deep neural network-based model trained on user data. | |
Forecasting speed depends on CPU/GPU availabilty. | |
""") | |
with st.sidebar.expander("Upload and Configure Dataset", expanded=True): | |
if 'uploaded_file' not in st.session_state: | |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
if uploaded_file: | |
df = pd.read_csv(uploaded_file) | |
st.session_state.df = df | |
st.session_state.uploaded_file = uploaded_file | |
else: | |
df = load_default() | |
st.session_state.df = df | |
else: | |
if st.checkbox("Upload a new file (CSV)"): | |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
if uploaded_file: | |
df = pd.read_csv(uploaded_file) | |
st.session_state.df = df | |
st.session_state.uploaded_file = uploaded_file | |
else: | |
df = st.session_state.df | |
else: | |
df = st.session_state.df | |
columns = df.columns.tolist() | |
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0) | |
target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')] | |
y_col = st.selectbox("Select Target column", options=target_columns, index=0) | |
st.session_state.ds_col = ds_col | |
st.session_state.y_col = y_col | |
df = df.rename(columns={ds_col: 'ds', y_col: 'y'}) | |
df['unique_id']=1 | |
df = df[['unique_id','ds','y']] | |
# Dynamic forecasting | |
st.sidebar.subheader("Dynamic Model Selection and Forecasting") | |
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice") | |
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=12) | |
dynamic_max_steps = st.sidebar.number_input('Max steps', value=20) | |
if st.sidebar.button("Submit"): | |
with st.spinner('Training model. This may take few minutes...'): | |
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col) | |
def timegpt_fcst(): | |
nixtla_token = os.environ.get("NIXTLA_API_KEY") | |
nixtla_client = NixtlaClient( | |
api_key = nixtla_token | |
) | |
st.title("TimeGPT Forecasting") | |
st.markdown(""" | |
Instant time series forecasting and visualization by using the TimeGPT API provided by Nixtla. | |
""") | |
with st.sidebar.expander("Upload and Configure Dataset", expanded=True): | |
if 'uploaded_file' not in st.session_state: | |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
if uploaded_file: | |
df = pd.read_csv(uploaded_file) | |
st.session_state.df = df | |
st.session_state.uploaded_file = uploaded_file | |
else: | |
df = load_default() | |
st.session_state.df = df | |
else: | |
if st.checkbox("Upload a new file (CSV)"): | |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
if uploaded_file: | |
df = pd.read_csv(uploaded_file) | |
st.session_state.df = df | |
st.session_state.uploaded_file = uploaded_file | |
else: | |
df = st.session_state.df | |
else: | |
df = st.session_state.df | |
columns = df.columns.tolist() | |
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0) | |
target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')] | |
y_col = st.selectbox("Select Target column", options=target_columns, index=0) | |
h = st.number_input("Forecast horizon", value=14) | |
df = df.rename(columns={ds_col: 'ds', y_col: 'y'}) | |
id_col = 'ts_test' | |
df['unique_id']=id_col | |
df = df[['unique_id','ds','y']] | |
freq = determine_frequency(df) | |
df = df.drop_duplicates(subset=['ds']).reset_index(drop=True) | |
plot_type = st.sidebar.selectbox("Select Visualization", ["Matplotlib", "Plotly"]) | |
if st.sidebar.button("Submit"): | |
start_time = time.time() | |
forecast_df = nixtla_client.forecast( | |
df=df, | |
h=h, | |
freq=freq, | |
level=[90] | |
) | |
st.session_state.forecast_df = forecast_df | |
if 'forecast_df' in st.session_state: | |
forecast_df = st.session_state.forecast_df | |
if plot_type == "Matplotlib": | |
# Convert the Plotly figure to a Matplotlib figure if needed | |
# Note: You may need to handle this conversion depending on your specific use case | |
# For now, this example assumes that you are using a Matplotlib figure | |
fig = nixtla_client.plot(df, forecast_df, level=[90], engine='matplotlib') | |
st.pyplot(fig) | |
elif plot_type == "Plotly": | |
# Plotly figure directly | |
fig = nixtla_client.plot(df, forecast_df, level=[90], engine='plotly') | |
st.plotly_chart(fig) | |
end_time = time.time() # End timing | |
time_taken = end_time - start_time | |
st.success(f"Time taken for TimeGPT forecast: {time_taken:.2f} seconds") | |
if 'forecast_df' in st.session_state: | |
forecast_df = st.session_state.forecast_df | |
st.markdown('You can download Input and Forecast Data below') | |
tab_insample, tab_forecast = st.tabs( | |
["Input data", "Forecast"] | |
) | |
with tab_insample: | |
df_grid = df.drop(columns="unique_id") | |
st.write(df_grid) | |
# grid_table = AgGrid( | |
# df_grid, | |
# theme="alpine", | |
# ) | |
with tab_forecast: | |
df_grid = forecast_df | |
st.write(df_grid) | |
# grid_table = AgGrid( | |
# df_grid, | |
# theme="alpine", | |
# ) | |
def timegpt_anom(): | |
nixtla_token = os.environ.get("NIXTLA_API_KEY") | |
nixtla_client = NixtlaClient( | |
api_key = nixtla_token | |
) | |
st.title("TimeGPT Anomaly Detection") | |
st.markdown(""" | |
Instant time series anomaly detection and visualization by using the TimeGPT API provided by Nixtla. | |
""") | |
with st.sidebar.expander("Upload and Configure Dataset", expanded=True): | |
if 'uploaded_file' not in st.session_state: | |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
if uploaded_file: | |
df = pd.read_csv(uploaded_file) | |
st.session_state.df = df | |
st.session_state.uploaded_file = uploaded_file | |
else: | |
df = load_default() | |
st.session_state.df = df | |
else: | |
if st.checkbox("Upload a new file (CSV)"): | |
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
if uploaded_file: | |
df = pd.read_csv(uploaded_file) | |
st.session_state.df = df | |
st.session_state.uploaded_file = uploaded_file | |
else: | |
df = st.session_state.df | |
else: | |
df = st.session_state.df | |
columns = df.columns.tolist() | |
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0) | |
target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')] | |
y_col = st.selectbox("Select Target column", options=target_columns, index=0) | |
df = df.rename(columns={ds_col: 'ds', y_col: 'y'}) | |
id_col = 'ts_test' | |
df['unique_id']=id_col | |
df = df[['unique_id','ds','y']] | |
freq = determine_frequency(df) | |
df = df.drop_duplicates(subset=['ds']).reset_index(drop=True) | |
plot_type = st.sidebar.selectbox("Select Visualization", ["Matplotlib", "Plotly"]) | |
if st.sidebar.button("Submit"): | |
start_time=time.time() | |
anom_df = nixtla_client.detect_anomalies( | |
df=df, | |
freq=freq, | |
level=90 | |
) | |
st.session_state.anom_df = anom_df | |
if 'anom_df' in st.session_state: | |
anom_df = st.session_state.anom_df | |
if plot_type == "Matplotlib": | |
# Convert the Plotly figure to a Matplotlib figure if needed | |
# Note: You may need to handle this conversion depending on your specific use case | |
# For now, this example assumes that you are using a Matplotlib figure | |
fig = nixtla_client.plot(df, anom_df, level=[90], engine='matplotlib') | |
st.pyplot(fig) | |
elif plot_type == "Plotly": | |
# Plotly figure directly | |
fig = nixtla_client.plot(df, anom_df, level=[90], engine='plotly') | |
st.plotly_chart(fig) | |
end_time = time.time() # End timing | |
time_taken = end_time - start_time | |
st.success(f"Time taken for TimeGPT forecast: {time_taken:.2f} seconds") | |
st.markdown('You can download Input and Forecast Data below') | |
tab_insample, tab_forecast = st.tabs( | |
["Input data", "Forecast"] | |
) | |
with tab_insample: | |
df_grid = df.drop(columns="unique_id") | |
st.write(df_grid) | |
# grid_table = AgGrid( | |
# df_grid, | |
# theme="alpine", | |
# ) | |
with tab_forecast: | |
df_grid = anom_df | |
st.write(df_grid) | |
# grid_table = AgGrid( | |
# df_grid, | |
# theme="alpine", | |
# ) | |
pg = st.navigation({ | |
"Neuralforecast": [ | |
# Load pages from functions | |
st.Page(transfer_learning_forecasting, title="Zero-shot Forecasting", default=True, icon=":material/query_stats:"), | |
st.Page(dynamic_forecasting, title="Personalized Neural Forecasting", icon=":material/monitoring:"), | |
], | |
"TimeGPT": [ | |
# Load pages from functions | |
st.Page(timegpt_fcst, title="TimeGPT Forecast", icon=":material/smart_toy:"), | |
st.Page(timegpt_anom, title="TimeGPT Anomalies Detection", icon=":material/detector_offline:") | |
] | |
}) | |
pg.run() | |