Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from darts import TimeSeries, concatenate | |
from darts.dataprocessing.transformers import Scaler | |
from darts.utils.timeseries_generation import datetime_attribute_timeseries | |
from darts.models.forecasting.tft_model import TFTModel | |
from darts.metrics import mape | |
from dateutil.relativedelta import relativedelta | |
import warnings | |
warnings.filterwarnings("ignore") | |
import logging | |
logging.disable(logging.CRITICAL) | |
import pandas as pd | |
import numpy as np | |
from typing import Any, List, Optional | |
import plotly.graph_objects as go | |
df_final = pd.read_csv('data/all_afghan.csv',parse_dates=['Date']) | |
df_comtrade_flour = pd.read_csv('data/comtrade_flour.csv',parse_dates=['Date']) | |
df_comtrade_grain = pd.read_csv('data/comtrade_grain.csv',parse_dates=['Date']) | |
series = TimeSeries.from_dataframe(df_final, | |
time_col='Date', | |
value_cols=['price', 'usdprice', 'wheat_grain', 'exchange_rate','common_unit_price','black_sea'] | |
) | |
six_months = df_final['Date'].max() + relativedelta(months=-6) | |
data_series = series['common_unit_price'] | |
train, val = data_series.split_after(six_months) | |
transformer = Scaler() | |
train_transformed = transformer.fit_transform(train) | |
val_transformed = transformer.transform(val) | |
series_transformed = transformer.transform(data_series) | |
# create year, month and integer index covariate series | |
covariates = datetime_attribute_timeseries(series_transformed, attribute="year", one_hot=False) | |
covariates = covariates.stack( | |
datetime_attribute_timeseries(series_transformed, attribute="month", one_hot=True) | |
) | |
covariates = covariates.stack( | |
TimeSeries.from_times_and_values( | |
times=series_transformed.time_index, | |
values=np.arange(len(series_transformed)), | |
) | |
) | |
covariates = covariates.add_holidays(country_code="ES") | |
covariates = covariates.astype(np.float32) | |
scaler_covs = Scaler() | |
cov_train, cov_val = covariates.split_after(six_months) | |
cov_train = scaler_covs.fit_transform(cov_train) | |
cov_val = scaler_covs.transform(cov_val) | |
covariates_transformed = scaler_covs.transform(covariates) | |
grain_series = series['wheat_grain'] | |
grain_scaler = Scaler() | |
grain_train, grain_val = grain_series.split_after(six_months) | |
grain_train = grain_scaler.fit_transform(grain_train) | |
grain_val = grain_scaler.transform(grain_val) | |
grain_series_scaled = grain_scaler.transform(grain_series) | |
pakistan_series = series["price"] | |
pakistan_scaler = Scaler() | |
pakistan_train, pakistan_val = pakistan_series.split_after(six_months) | |
pakistan_train = pakistan_scaler.fit_transform(pakistan_train) | |
pakistan_val = pakistan_scaler.transform(pakistan_val) | |
pakistan_series_scaled = pakistan_scaler.transform(pakistan_series) | |
usd_series = series['usdprice'] | |
usd_scaler = Scaler() | |
usd_train, usd_val = usd_series.split_after(six_months) | |
usd_train = usd_scaler.fit_transform(usd_train) | |
usd_val = usd_scaler.transform(usd_val) | |
usd_series_scaled = usd_scaler.transform(usd_series) | |
erate_series = series['exchange_rate'] | |
erate_scaler = Scaler() | |
erate_train, erate_val = erate_series.split_after(six_months) | |
erate_train_transformed = erate_scaler.fit_transform(erate_train) | |
erate_val_transformed = erate_scaler.transform(erate_val) | |
erate_series_scaled = erate_scaler.transform(erate_series) | |
black_sea = series['black_sea'] | |
black_sea_scaler = Scaler() | |
black_train,black_val = black_sea.split_after(six_months) | |
black_train_transformed = black_sea_scaler.fit_transform(black_train) | |
black_val_transformed = black_sea_scaler.transform(black_val) | |
black_sea_series = black_sea_scaler.transform(black_sea) | |
comtrade_flour_series = TimeSeries.from_dataframe(df_comtrade_flour, | |
time_col="Date") | |
comtrade_grain_series = TimeSeries.from_dataframe(df_comtrade_grain, | |
time_col="Date") | |
from darts import concatenate | |
my_multivariate_series = concatenate( | |
[ | |
grain_series_scaled, | |
pakistan_series_scaled, | |
# usd_series_scaled, | |
erate_series_scaled, | |
black_sea_series, | |
comtrade_flour_series, | |
comtrade_grain_series, | |
covariates_transformed, | |
], | |
axis=1) | |
multivariate_series_train = concatenate( | |
[ | |
grain_train, | |
pakistan_train, | |
# usd_train, | |
erate_train, | |
#russian_train_transformed, | |
# black_train_transformed, | |
cov_train, | |
], | |
axis=1) | |
class FlaggingHandler(gr.FlaggingCallback): | |
def __init__(self): | |
self._csv_logger = gr.CSVLogger() | |
def setup(self, components: List[gr.components.Component], flagging_dir: str): | |
"""Called by Gradio at the beginning of the `Interface.launch()` method. | |
Parameters: | |
components: Set of components that will provide flagged data. | |
flagging_dir: A string, typically containing the path to the directory where | |
the flagging file should be storied (provided as an argument to Interface.__init__()). | |
""" | |
self.components = components | |
self._csv_logger.setup(components=components, flagging_dir=flagging_dir) | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option: Optional[str] = None, | |
# flag_index: Optional[int] = None, | |
username: Optional[str] = None, | |
) -> int: | |
"""Called by Gradio whenver one of the <flag> buttons is clicked. | |
Parameters: | |
interface: The Interface object that is being used to launch the flagging interface. | |
flag_data: The data to be flagged. | |
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. | |
flag_index (optional): The index of the sample that is being flagged. | |
username (optional): The username of the user that is flagging the data, if logged in. | |
Returns: | |
(int) The total number of samples that have been flagged. | |
""" | |
for item in flag_data: | |
print(f"Flagging: {item}") | |
if flag_option: | |
print(f"Flag option: {flag_option}") | |
# if flag_index: | |
# print(f"Flag index: {flag_index}") | |
flagged_count = self._csv_logger.flag( | |
flag_data=flag_data, | |
flag_option=flag_option, | |
# flag_index=flag_index, | |
# username=username, | |
) | |
return flagged_count | |
def get_forecast(period_: str, pred_model: str): | |
# Let the prediction service do its magic. | |
period = int(period_[0]) | |
afgh_model = TFTModel.load("Afghan_w_blacksea_allcomtrade_jun06.pt",map_location=torch.device('cpu')) | |
### afgh model### | |
pred_series = afgh_model.predict(n=period,num_samples=1) | |
preds = transformer.inverse_transform(pred_series) | |
# creating a Dataframe | |
df_= preds.pd_dataframe() | |
df_.rename(columns={'common_unit_price': 'Wheat_Forecast'},inplace=True) | |
# error intervals: | |
# Calculate the 90% and 110% forecast values | |
forecast_90 = preds * 0.9 | |
forecast_110 = preds * 1.1 | |
df_90 = forecast_90.pd_dataframe() | |
df_90.rename(columns={'common_unit_price': 'Lower_Limit'},inplace=True) | |
df_110 = forecast_110.pd_dataframe() | |
df_110.rename(columns={'common_unit_price': 'Upper_Limit'},inplace=True) | |
merged_df = pd.merge(df_90,df_, on=['Date']).merge(df_110, on=['Date']) | |
merged_df = merged_df.reset_index() | |
start=pd.Timestamp("20180131") | |
backtest_series_ = afgh_model.historical_forecasts( | |
series_transformed, | |
past_covariates=my_multivariate_series, | |
start=start, | |
forecast_horizon=period, | |
retrain=False, | |
verbose=False, | |
) | |
series_time = series_transformed[-len(backtest_series_):].time_index | |
series_vals = (transformer.inverse_transform(series_transformed[-len(backtest_series_):])).values() | |
df_series = pd.DataFrame(data={'date': series_time, 'actual_prices': series_vals.ravel() }) | |
vals = (transformer.inverse_transform(backtest_series_)).values() | |
df_backtest = pd.DataFrame(data={'date': backtest_series_.time_index, 'historical_forecasts': vals.ravel() }) | |
# Create figure | |
fig = go.Figure() | |
fig.add_trace( | |
go.Scatter( | |
x=list(df_backtest.date), | |
y=list(df_backtest.historical_forecasts), | |
name='historical forecasts' | |
# x=list(df.Date), y=list(df.High) | |
)) | |
fig.add_trace( | |
go.Scatter( | |
x=list(df_series.date), | |
y=list(df_series.actual_prices), | |
name="actual prices", | |
)) | |
fig.add_trace(go.Scatter( | |
x = list(merged_df.Date), | |
y=list(merged_df.Upper_Limit), | |
name="Upper limit" | |
)) | |
fig.add_trace(go.Scatter( | |
x = list(merged_df.Date), | |
y=list(merged_df.Lower_Limit), | |
name="Lower limit" | |
)) | |
fig.add_trace(go.Scatter( | |
x = list(merged_df.Date), | |
y=list(merged_df.Wheat_Forecast), | |
name=" Wheat Forecast" | |
)) | |
# Set title | |
fig.update_layout( | |
title_text=f"\n Mean Absolute Percentage Error {mape(transformer.inverse_transform(series_transformed), transformer.inverse_transform(backtest_series_)):.2f}%" | |
) | |
# Add range slider | |
fig.update_layout( | |
xaxis=dict( | |
rangeselector=dict( | |
buttons=list([ | |
dict(count=1, | |
label="1m", | |
step="month", | |
stepmode="backward"), | |
dict(count=6, | |
label="6m", | |
step="month", | |
stepmode="todate"), | |
dict(count=1, | |
label="YTD", | |
step="year", | |
stepmode="todate"), | |
# dict(count=1, | |
# label="1y", | |
# step="year", | |
# stepmode="backward"), | |
# dict(step="all") | |
]) | |
), | |
rangeslider=dict( | |
visible=True | |
), | |
type="date" | |
) | |
) | |
return merged_df,fig | |
def main(): | |
flagging_handler = FlaggingHandler() | |
# example_url = "" # noqa: E501 | |
with gr.Blocks() as iface: | |
gr.Markdown( | |
""" | |
**Timeseries Forecasting model Temporal Fusion Transformer(TFT) built on Darts library**. | |
""") | |
commodity = gr.Radio(["Wheat Price Forecasting","Maize Price Forecasting"],label="Commodity to Forecast") | |
period = gr.Radio(['3 months',"6 months"],label="Forecast horizon") | |
# with gr.Row(): | |
# lib = gr.Dropdown(["pandas", "scikit-learn", "torch", "prophet"], label="Library", value="torch") | |
# time = gr.Dropdown(["3 months", "6 months",], label="Downloads over the last...", value="6 months") | |
with gr.Row(): | |
btn = gr.Button("Forecast.") | |
feedback = gr.Textbox(label="Give feedback") | |
gr.CSVLogger() | |
data_points = gr.Textbox(label=f"Forecast values. Lower and upper values include a 10% error rate") | |
plt = gr.Plot(label="Backtesting plot, from 2018").style() | |
btn.click( | |
get_forecast, | |
inputs=[period,commodity], | |
outputs = [data_points,plt] | |
) | |
with gr.Row(): | |
btn_incorrect = gr.Button("Flag as incorrect") | |
btn_other = gr.Button("Flag as other") | |
flagging_handler.setup( | |
components=[commodity, period], | |
flagging_dir="data/flagged", | |
) | |
with gr.Row(): | |
current_wheat = gr.Image('wheat_prices.png') | |
current_maize = gr.Image('maize_prices.png') | |
btn_incorrect.click( | |
lambda *args: flagging_handler.flag( | |
flag_data=args, flag_option="Incorrect" | |
), | |
[commodity, data_points, period,feedback], | |
None, | |
preprocess=False, | |
) | |
btn_other.click( | |
lambda *args: flagging_handler.flag(flag_data=args, flag_option="Other"), | |
[commodity, data_points, period,feedback], | |
None, | |
preprocess=False, | |
) | |
iface.launch(debug=True, inline=False) | |
main() |