Spaces:
Runtime error
Runtime error
File size: 1,195 Bytes
46dbe9e 46a14b8 46dbe9e 46a14b8 46dbe9e 46a14b8 46dbe9e 46a14b8 46dbe9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import gradio as gr
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.torch.model.deepar import DeepAREstimator
import matplotlib.pyplot as plt
def fn(upload_data):
df = pd.read_csv(upload_data.name, index_col=0, parse_dates=True)
dataset = PandasDataset(df, target=df.columns[0])
training_data, test_gen = split(dataset, offset=-36)
model = DeepAREstimator(
prediction_length=12,
freq=dataset.freq,
trainer_kwargs=dict(max_epochs=1),
).train(
training_data=training_data,
)
test_data = test_gen.generate_instances(prediction_length=12, windows=3)
forecasts = list(model.predict(test_data.input))
fig = plt.figure()
df["#Passengers"].plot(color="black")
for forecast, color in zip(forecasts, ["green", "blue", "purple"]):
forecast.plot(color=f"tab:{color}")
plt.legend(["True values"], loc="upper left", fontsize="xx-large")
return fig
with gr.Blocks() as demo:
plot = gr.Plot()
upload_btn = gr.UploadButton()
upload_btn.upload(fn, inputs=upload_btn, outputs=plot)
if __name__ == "__main__":
demo.launch()
|