Kashif Rasul
added forecaster
46a14b8
raw
history blame
1.2 kB
import gradio as gr
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.torch.model.deepar import DeepAREstimator
import matplotlib.pyplot as plt
def fn(upload_data):
df = pd.read_csv(upload_data.name, index_col=0, parse_dates=True)
dataset = PandasDataset(df, target=df.columns[0])
training_data, test_gen = split(dataset, offset=-36)
model = DeepAREstimator(
prediction_length=12,
freq=dataset.freq,
trainer_kwargs=dict(max_epochs=1),
).train(
training_data=training_data,
)
test_data = test_gen.generate_instances(prediction_length=12, windows=3)
forecasts = list(model.predict(test_data.input))
fig = plt.figure()
df["#Passengers"].plot(color="black")
for forecast, color in zip(forecasts, ["green", "blue", "purple"]):
forecast.plot(color=f"tab:{color}")
plt.legend(["True values"], loc="upper left", fontsize="xx-large")
return fig
with gr.Blocks() as demo:
plot = gr.Plot()
upload_btn = gr.UploadButton()
upload_btn.upload(fn, inputs=upload_btn, outputs=plot)
if __name__ == "__main__":
demo.launch()