Tempus / app.py
vincentiusyoshuac's picture
Update app.py
8a839a2 verified
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from chronos import ChronosPipeline
class TimeSeriesForecaster:
def __init__(self, model_name="amazon/chronos-t5-small"):
self.pipeline = ChronosPipeline.from_pretrained(
model_name,
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
self.original_series = None
self.context = None
def preprocess_data(self, df, date_column, value_column, context_length=100, prediction_length=30):
"""
Prepare time series data from DataFrame
"""
# Validasi panjang prediksi
if prediction_length > 30:
st.warning("Prediction length dibatasi maksimal 30 langkah. Akan disesuaikan.")
prediction_length = min(prediction_length, 30)
# Ensure data is sorted by date
df = df.sort_values(by=date_column)
# Convert date column to datetime
df[date_column] = pd.to_datetime(df[date_column])
# Set index to date
df.set_index(date_column, inplace=True)
# Extract numeric series
self.original_series = df[value_column].values
# Convert to tensor
self.context = torch.tensor(self.original_series[-context_length:], dtype=torch.float32)
return self.context, prediction_length
def forecast(self, context, prediction_length=30, num_samples=100):
"""
Perform time series forecasting
"""
# Pastikan prediksi tidak melebihi 30 langkah
prediction_length = min(prediction_length, 30)
forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples)
return forecasts
def visualize_forecast(self, forecasts, original_series):
"""
Create comprehensive visualization of predictions
"""
plt.figure(figsize=(16, 8), dpi=100, facecolor='white')
# Calculate forecast statistics
forecast_np = forecasts[0].numpy()
low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0)
# Plot original series
plt.plot(
range(len(original_series)),
original_series,
label='Historical Data',
color='#2C3E50',
linewidth=2,
alpha=0.7
)
# Forecast index
forecast_index = range(len(original_series), len(original_series) + len(median))
# Plot median forecast
plt.plot(
forecast_index,
median,
color='#3498DB',
linewidth=3,
label='Median Forecast'
)
# Plot prediction interval
plt.fill_between(
forecast_index,
low,
high,
color='#3498DB',
alpha=0.2,
label='90% Prediction Interval'
)
plt.title('Advanced Time Series Forecast (Max 30 Steps)', fontsize=18, fontweight='bold', color='#2C3E50')
plt.xlabel('Time Steps', fontsize=12, color='#34495E')
plt.ylabel('Value', fontsize=12, color='#34495E')
plt.legend(frameon=False)
plt.grid(True, linestyle='--', color='#BDC3C7', alpha=0.5)
# Sophisticated styling
plt.tight_layout()
return plt
def main():
# Page configuration
st.set_page_config(
page_title="Tempus",
page_icon="πŸ“Š",
layout="wide"
)
# Modern, minimalist styling
st.markdown("""
<style>
.stApp {
background-color: #FFFFFF;
font-family: 'Inter', 'Roboto', sans-serif;
}
.stButton>button {
background-color: #3498DB;
color: white;
border: none;
border-radius: 8px;
padding: 10px 20px;
transition: all 0.3s ease;
font-weight: 600;
}
.stButton>button:hover {
background-color: #2980B9;
transform: scale(1.05);
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
}
/* Sleek file uploader */
.stFileUploader {
background-color: #F8F9FA;
border: 2px solid #3498DB;
border-radius: 10px;
padding: 20px;
text-align: center;
transition: all 0.3s ease;
}
.stFileUploader:hover {
border-color: #2980B9;
background-color: #EAF2F8;
}
</style>
""", unsafe_allow_html=True)
# Elegant title
st.markdown(
"<h1 style='text-align: center; color: #3498DB; margin-bottom: 30px;'>Tempus</h1>",
unsafe_allow_html=True
)
# File upload
uploaded_file = st.file_uploader(
"Upload Time Series Data",
type=['csv'],
help="CSV with timestamp and numeric columns"
)
if uploaded_file is not None:
# Read CSV
df = pd.read_csv(uploaded_file)
# Column selection
col1, col2 = st.columns(2)
with col1:
date_column = st.selectbox('Date Column', options=df.columns)
with col2:
value_column = st.selectbox(
'Value Column',
options=[col for col in df.columns if col != date_column]
)
# Advanced prediction settings
col3, col4 = st.columns(2)
with col3:
context_length = st.slider(
'Historical Context',
min_value=30,
max_value=500,
value=100,
help="Number of past data points to analyze"
)
with col4:
prediction_length = st.slider(
'Forecast Horizon',
min_value=1,
max_value=30, # Dibatasi maksimal 30
value=30, # Default 30
help="Number of future time steps to predict (max 30)"
)
# Forecast generation
if st.button('Generate Forecast'):
try:
# Initialize and run forecaster
forecaster = TimeSeriesForecaster()
# Preprocess data
context, prediction_length = forecaster.preprocess_data(
df,
date_column,
value_column,
context_length,
prediction_length
)
# Perform forecasting
forecasts = forecaster.forecast(context, prediction_length)
# Visualization
plt = forecaster.visualize_forecast(forecasts, forecaster.original_series)
st.pyplot(plt)
# Forecast details
forecast_np = forecasts[0].numpy()
forecast_mean = forecast_np.mean(axis=0)
forecast_lower = np.percentile(forecast_np, 10, axis=0)
forecast_upper = np.percentile(forecast_np, 90, axis=0)
st.subheader('Forecast Insights')
prediction_df = pd.DataFrame({
'Mean Forecast': forecast_mean,
'Lower Bound (10%)': forecast_lower,
'Upper Bound (90%)': forecast_upper
})
st.dataframe(prediction_df)
except Exception as e:
st.error(f"Forecast generation error: {str(e)}")
if __name__ == '__main__':
main()