Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
from chronos import ChronosPipeline | |
class TimeSeriesForecaster: | |
def __init__(self, model_name="amazon/chronos-t5-small"): | |
self.pipeline = ChronosPipeline.from_pretrained( | |
model_name, | |
device_map="cuda" if torch.cuda.is_available() else "cpu", | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
) | |
self.original_series = None | |
self.context = None | |
def preprocess_data(self, df, date_column, value_column, context_length=100, prediction_length=30): | |
""" | |
Prepare time series data from DataFrame | |
""" | |
# Validasi panjang prediksi | |
if prediction_length > 30: | |
st.warning("Prediction length dibatasi maksimal 30 langkah. Akan disesuaikan.") | |
prediction_length = min(prediction_length, 30) | |
# Ensure data is sorted by date | |
df = df.sort_values(by=date_column) | |
# Convert date column to datetime | |
df[date_column] = pd.to_datetime(df[date_column]) | |
# Set index to date | |
df.set_index(date_column, inplace=True) | |
# Extract numeric series | |
self.original_series = df[value_column].values | |
# Convert to tensor | |
self.context = torch.tensor(self.original_series[-context_length:], dtype=torch.float32) | |
return self.context, prediction_length | |
def forecast(self, context, prediction_length=30, num_samples=100): | |
""" | |
Perform time series forecasting | |
""" | |
# Pastikan prediksi tidak melebihi 30 langkah | |
prediction_length = min(prediction_length, 30) | |
forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples) | |
return forecasts | |
def visualize_forecast(self, forecasts, original_series): | |
""" | |
Create comprehensive visualization of predictions | |
""" | |
plt.figure(figsize=(16, 8), dpi=100, facecolor='white') | |
# Calculate forecast statistics | |
forecast_np = forecasts[0].numpy() | |
low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0) | |
# Plot original series | |
plt.plot( | |
range(len(original_series)), | |
original_series, | |
label='Historical Data', | |
color='#2C3E50', | |
linewidth=2, | |
alpha=0.7 | |
) | |
# Forecast index | |
forecast_index = range(len(original_series), len(original_series) + len(median)) | |
# Plot median forecast | |
plt.plot( | |
forecast_index, | |
median, | |
color='#3498DB', | |
linewidth=3, | |
label='Median Forecast' | |
) | |
# Plot prediction interval | |
plt.fill_between( | |
forecast_index, | |
low, | |
high, | |
color='#3498DB', | |
alpha=0.2, | |
label='90% Prediction Interval' | |
) | |
plt.title('Advanced Time Series Forecast (Max 30 Steps)', fontsize=18, fontweight='bold', color='#2C3E50') | |
plt.xlabel('Time Steps', fontsize=12, color='#34495E') | |
plt.ylabel('Value', fontsize=12, color='#34495E') | |
plt.legend(frameon=False) | |
plt.grid(True, linestyle='--', color='#BDC3C7', alpha=0.5) | |
# Sophisticated styling | |
plt.tight_layout() | |
return plt | |
def main(): | |
# Page configuration | |
st.set_page_config( | |
page_title="Tempus", | |
page_icon="π", | |
layout="wide" | |
) | |
# Modern, minimalist styling | |
st.markdown(""" | |
<style> | |
.stApp { | |
background-color: #FFFFFF; | |
font-family: 'Inter', 'Roboto', sans-serif; | |
} | |
.stButton>button { | |
background-color: #3498DB; | |
color: white; | |
border: none; | |
border-radius: 8px; | |
padding: 10px 20px; | |
transition: all 0.3s ease; | |
font-weight: 600; | |
} | |
.stButton>button:hover { | |
background-color: #2980B9; | |
transform: scale(1.05); | |
box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
} | |
/* Sleek file uploader */ | |
.stFileUploader { | |
background-color: #F8F9FA; | |
border: 2px solid #3498DB; | |
border-radius: 10px; | |
padding: 20px; | |
text-align: center; | |
transition: all 0.3s ease; | |
} | |
.stFileUploader:hover { | |
border-color: #2980B9; | |
background-color: #EAF2F8; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Elegant title | |
st.markdown( | |
"<h1 style='text-align: center; color: #3498DB; margin-bottom: 30px;'>Tempus</h1>", | |
unsafe_allow_html=True | |
) | |
# File upload | |
uploaded_file = st.file_uploader( | |
"Upload Time Series Data", | |
type=['csv'], | |
help="CSV with timestamp and numeric columns" | |
) | |
if uploaded_file is not None: | |
# Read CSV | |
df = pd.read_csv(uploaded_file) | |
# Column selection | |
col1, col2 = st.columns(2) | |
with col1: | |
date_column = st.selectbox('Date Column', options=df.columns) | |
with col2: | |
value_column = st.selectbox( | |
'Value Column', | |
options=[col for col in df.columns if col != date_column] | |
) | |
# Advanced prediction settings | |
col3, col4 = st.columns(2) | |
with col3: | |
context_length = st.slider( | |
'Historical Context', | |
min_value=30, | |
max_value=500, | |
value=100, | |
help="Number of past data points to analyze" | |
) | |
with col4: | |
prediction_length = st.slider( | |
'Forecast Horizon', | |
min_value=1, | |
max_value=30, # Dibatasi maksimal 30 | |
value=30, # Default 30 | |
help="Number of future time steps to predict (max 30)" | |
) | |
# Forecast generation | |
if st.button('Generate Forecast'): | |
try: | |
# Initialize and run forecaster | |
forecaster = TimeSeriesForecaster() | |
# Preprocess data | |
context, prediction_length = forecaster.preprocess_data( | |
df, | |
date_column, | |
value_column, | |
context_length, | |
prediction_length | |
) | |
# Perform forecasting | |
forecasts = forecaster.forecast(context, prediction_length) | |
# Visualization | |
plt = forecaster.visualize_forecast(forecasts, forecaster.original_series) | |
st.pyplot(plt) | |
# Forecast details | |
forecast_np = forecasts[0].numpy() | |
forecast_mean = forecast_np.mean(axis=0) | |
forecast_lower = np.percentile(forecast_np, 10, axis=0) | |
forecast_upper = np.percentile(forecast_np, 90, axis=0) | |
st.subheader('Forecast Insights') | |
prediction_df = pd.DataFrame({ | |
'Mean Forecast': forecast_mean, | |
'Lower Bound (10%)': forecast_lower, | |
'Upper Bound (90%)': forecast_upper | |
}) | |
st.dataframe(prediction_df) | |
except Exception as e: | |
st.error(f"Forecast generation error: {str(e)}") | |
if __name__ == '__main__': | |
main() |