azrai99 commited on
Commit
f3f15dd
·
verified ·
1 Parent(s): 0e39b3d

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitattributes +34 -35
  2. .gitignore +131 -0
  3. README.md +13 -13
  4. app.py +374 -0
  5. requirements.txt +12 -0
.gitattributes CHANGED
@@ -1,35 +1,34 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
 
.gitignore ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ models/
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Transfer Learning Time Series
3
- emoji:
4
- colorFrom: green
5
- colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.36.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Transfer Learning Time Series
3
+ emoji: 🐠
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.21.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: bsd-3-clause
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ import streamlit as st
8
+ from datasetsforecast.losses import rmse, mae, smape, mse, mape
9
+ from st_aggrid import AgGrid
10
+
11
+ from src.nf import MODELS, forecast_pretrained_model
12
+ from src.model_descriptions import model_cards
13
+
14
+ DATASETS = {
15
+ "Electricity (Ercot COAST)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/ercot_COAST.csv",
16
+ #"Electriciy (ERCOT, multiple markets)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/ercot_multiple_ts.csv",
17
+ "Web Traffic (Peyton Manning)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/peyton_manning.csv",
18
+ "Demand (AirPassengers)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv",
19
+ "Finance (Exchange USD-EUR)": "https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/usdeur.csv",
20
+ }
21
+
22
+
23
+ @st.cache_data
24
+ def convert_df(df):
25
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
26
+ return df.to_csv(index=False).encode("utf-8")
27
+
28
+
29
+ def plot(df, uid, df_forecast, model):
30
+ figs = []
31
+ figs += [
32
+ go.Scatter(
33
+ x=df["ds"],
34
+ y=df["y"],
35
+ mode="lines",
36
+ marker=dict(color="#236796"),
37
+ legendrank=1,
38
+ name=uid,
39
+ ),
40
+ ]
41
+ if df_forecast is not None:
42
+ ds_f = df_forecast["ds"].to_list()
43
+ lo = df_forecast["forecast_lo_90"].to_list()
44
+ hi = df_forecast["forecast_hi_90"].to_list()
45
+ figs += [
46
+ go.Scatter(
47
+ x=ds_f + ds_f[::-1],
48
+ y=hi + lo[::-1],
49
+ fill="toself",
50
+ fillcolor="#E7C4C0",
51
+ mode="lines",
52
+ line=dict(color="#E7C4C0"),
53
+ name="Prediction Intervals (90%)",
54
+ legendrank=5,
55
+ opacity=0.5,
56
+ hoverinfo="skip",
57
+ ),
58
+ go.Scatter(
59
+ x=ds_f,
60
+ y=df_forecast["forecast"],
61
+ mode="lines",
62
+ legendrank=4,
63
+ marker=dict(color="#E7C4C0"),
64
+ name=f"Forecast {uid}",
65
+ ),
66
+ ]
67
+ fig = go.Figure(figs)
68
+ fig.update_layout(
69
+ {"plot_bgcolor": "rgba(0, 0, 0, 0)", "paper_bgcolor": "rgba(0, 0, 0, 0)"}
70
+ )
71
+ fig.update_layout(
72
+ title=f"Forecasts for {uid} using Transfer Learning (from {model})",
73
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
74
+ margin=dict(l=20, b=20),
75
+ xaxis=dict(rangeslider=dict(visible=True)),
76
+ )
77
+ initial_range = [df.tail(200)["ds"].iloc[0], ds_f[-1]]
78
+ fig["layout"]["xaxis"].update(range=initial_range)
79
+ return fig
80
+
81
+
82
+ def st_transfer_learning():
83
+ st.set_page_config(
84
+ page_title="Time Series Visualization",
85
+ page_icon="🔮",
86
+ layout="wide",
87
+ initial_sidebar_state="expanded",
88
+ )
89
+
90
+ st.title(
91
+ "Transfer Learning: Revolutionizing Time Series by Nixtla"
92
+ )
93
+ st.write(
94
+ "<style>div.block-container{padding-top:2rem;}</style>", unsafe_allow_html=True
95
+ )
96
+
97
+ intro = """
98
+ The success of startups like Open AI and Stability highlights the potential for transfer learning (TL) techniques to have a similar impact on the field of time series forecasting.
99
+
100
+ TL can achieve lightning-fast predictions with a fraction of the computational cost by pre-training a flexible model on a large dataset and then using it on another dataset with little to no additional training.
101
+
102
+ In this live demo, you can use pre-trained models by Nixtla (trained on the M4 dataset) to predict your own datasets. You can also see how the models perform on unseen example datasets.
103
+ """
104
+ st.write(intro)
105
+
106
+ required_cols = ["ds", "y"]
107
+
108
+ with st.sidebar.expander("Dataset", expanded=False):
109
+ data_selection = st.selectbox("Select example dataset", DATASETS.keys())
110
+ data_url = DATASETS[data_selection]
111
+ url_json = st.text_input("Data (you can pass your own url here)", data_url)
112
+ st.write(
113
+ "You can also upload a CSV file like [this one](https://github.com/Nixtla/transfer-learning-time-series/blob/main/datasets/air_passengers.csv)."
114
+ )
115
+
116
+ uploaded_file = st.file_uploader("Upload CSV")
117
+ with st.form("Data"):
118
+
119
+ if uploaded_file is not None:
120
+ df = pd.read_csv(uploaded_file)
121
+ cols = df.columns
122
+ timestamp_col = st.selectbox("Timestamp column", options=cols)
123
+ value_col = st.selectbox("Value column", options=cols)
124
+ else:
125
+ timestamp_col = st.text_input("Timestamp column", value="timestamp")
126
+ value_col = st.text_input("Value column", value="value")
127
+ st.write("You must press Submit each time you want to forecast.")
128
+ submitted = st.form_submit_button("Submit")
129
+ if submitted:
130
+ if uploaded_file is None:
131
+ st.write("Please provide a dataframe.")
132
+ if url_json.endswith("json"):
133
+ df = pd.read_json(url_json)
134
+ else:
135
+ df = pd.read_csv(url_json)
136
+ df = df.rename(
137
+ columns=dict(zip([timestamp_col, value_col], required_cols))
138
+ )
139
+ else:
140
+ # df = pd.read_csv(uploaded_file)
141
+ df = df.rename(
142
+ columns=dict(zip([timestamp_col, value_col], required_cols))
143
+ )
144
+ else:
145
+ if url_json.endswith("json"):
146
+ df = pd.read_json(url_json)
147
+ else:
148
+ df = pd.read_csv(url_json)
149
+ cols = df.columns
150
+ if "unique_id" in cols:
151
+ cols = cols[-2:]
152
+ df = df.rename(columns=dict(zip(cols, required_cols)))
153
+
154
+ if "unique_id" not in df:
155
+ df.insert(0, "unique_id", "ts_0")
156
+
157
+ df["ds"] = pd.to_datetime(df["ds"])
158
+ df = df.sort_values(["unique_id", "ds"])
159
+
160
+ with st.sidebar:
161
+ st.write("Define the pretrained model you want to use to forecast your data")
162
+ model_name = st.selectbox("Select your model", tuple(MODELS.keys()))
163
+ model_file = MODELS[model_name]["model"]
164
+ st.write("Choose how many steps you want to forecast")
165
+ fh = st.number_input("Forecast horizon", value=18)
166
+ st.write(
167
+ "Choose for how many steps the pretrained model will be updated using your data (use 0 for fast computation)"
168
+ )
169
+ max_steps = st.number_input("N-shot inference", value=0)
170
+
171
+ # tabs
172
+ tab_fcst, tab_cv, tab_docs, tab_nixtla = st.tabs(
173
+ [
174
+ "📈 Forecast",
175
+ "🔎 Cross Validation",
176
+ "📚 Documentation",
177
+ "🔮 Nixtlaverse",
178
+ ]
179
+ )
180
+
181
+ uids = df["unique_id"].unique()
182
+ fcst_cols = ["forecast_lo_90", "forecast", "forecast_hi_90"]
183
+
184
+ with tab_fcst:
185
+ uid = uids[0]#st.selectbox("Dataset", options=uids)
186
+ col1, col2 = st.columns([2, 4])
187
+ with col1:
188
+ tab_insample, tab_forecast = st.tabs(
189
+ ["Modify input data", "Modify forecasts"]
190
+ )
191
+ with tab_insample:
192
+ df_grid = df.query("unique_id == @uid").drop(columns="unique_id")
193
+ grid_table = AgGrid(
194
+ df_grid,
195
+ editable=True,
196
+ theme="streamlit",
197
+ fit_columns_on_grid_load=True,
198
+ height=360,
199
+ )
200
+ df.loc[df["unique_id"] == uid, "y"] = (
201
+ grid_table["data"].sort_values("ds")["y"].values
202
+ )
203
+ # forecast code
204
+ init = time()
205
+ df_forecast = forecast_pretrained_model(df, model_file, fh, max_steps)
206
+ end = time()
207
+ df_forecast = df_forecast.rename(
208
+ columns=dict(zip(["y_5", "y_50", "y_95"], fcst_cols))
209
+ )
210
+ with tab_forecast:
211
+ df_fcst_grid = df_forecast.query("unique_id == @uid").filter(
212
+ ["ds", "forecast"]
213
+ )
214
+ grid_fcst_table = AgGrid(
215
+ df_fcst_grid,
216
+ editable=True,
217
+ theme="streamlit",
218
+ fit_columns_on_grid_load=True,
219
+ height=360,
220
+ )
221
+ changes = (
222
+ df_forecast.query("unique_id == @uid")["forecast"].values
223
+ - grid_fcst_table["data"].sort_values("ds")["forecast"].values
224
+ )
225
+ for col in fcst_cols:
226
+ df_forecast.loc[df_forecast["unique_id"] == uid, col] = (
227
+ df_forecast.loc[df_forecast["unique_id"] == uid, col] - changes
228
+ )
229
+ with col2:
230
+ st.plotly_chart(
231
+ plot(
232
+ df.query("unique_id == @uid"),
233
+ uid,
234
+ df_forecast.query("unique_id == @uid"),
235
+ model_name,
236
+ ),
237
+ use_container_width=True,
238
+ )
239
+ st.success(f'Done! Approximate inference time CPU: {0.7*(end-init):.2f} seconds.')
240
+
241
+ with tab_cv:
242
+ col_uid, col_n_windows = st.columns(2)
243
+ uid = uids[0]
244
+ #with col_uid:
245
+ # uid = st.selectbox("Time series to analyse", options=uids, key="uid_cv")
246
+ with col_n_windows:
247
+ n_windows = st.number_input("Cross validation windows", value=1)
248
+ df_forecast = []
249
+ for i_window in range(n_windows, 0, -1):
250
+ test = df.groupby("unique_id").tail(i_window * fh)
251
+ df_forecast_w = forecast_pretrained_model(
252
+ df.drop(test.index), model_file, fh, max_steps
253
+ )
254
+ df_forecast_w = df_forecast_w.rename(
255
+ columns=dict(zip(["y_5", "y_50", "y_95"], fcst_cols))
256
+ )
257
+ df_forecast_w.insert(2, "window", i_window)
258
+ df_forecast.append(df_forecast_w)
259
+ df_forecast = pd.concat(df_forecast)
260
+ df_forecast["ds"] = pd.to_datetime(df_forecast["ds"])
261
+ df_forecast = df_forecast.merge(df, how="left", on=["unique_id", "ds"])
262
+ metrics = [mae, mape, rmse, smape]
263
+ evaluation = df_forecast.groupby(["unique_id", "window"]).apply(
264
+ lambda df: [f'{fn(df["y"].values, df["forecast"]):.2f}' for fn in metrics]
265
+ )
266
+ evaluation = evaluation.rename("eval").reset_index()
267
+ evaluation["eval"] = evaluation["eval"].str.join(",")
268
+ evaluation[["MAE", "MAPE", "RMSE", "sMAPE"]] = evaluation["eval"].str.split(
269
+ ",", expand=True
270
+ )
271
+ col_eval, col_plot = st.columns([2, 4])
272
+ with col_eval:
273
+ st.write("Evaluation metrics for each cross validation window")
274
+ st.dataframe(
275
+ evaluation.query("unique_id == @uid")
276
+ .drop(columns=["unique_id", "eval"])
277
+ .set_index("window")
278
+ )
279
+ with col_plot:
280
+ st.plotly_chart(
281
+ plot(
282
+ df.query("unique_id == @uid"),
283
+ uid,
284
+ df_forecast.query("unique_id == @uid").drop(columns="y"),
285
+ model_name,
286
+ ),
287
+ use_container_width=True,
288
+ )
289
+ with tab_docs:
290
+ tab_transfer, tab_desc, tab_ref = st.tabs(
291
+ [
292
+ "🚀 Transfer Learning",
293
+ "🔎 Description of the model",
294
+ "📚 References",
295
+ ]
296
+ )
297
+
298
+ with tab_desc:
299
+ model_card_name = MODELS[model_name]["card"]
300
+ st.subheader("Abstract")
301
+ st.write(f"""{model_cards[model_card_name]['Abstract']}""")
302
+ st.subheader("Intended use")
303
+ st.write(f"""{model_cards[model_card_name]['Intended use']}""")
304
+ st.subheader("Secondary use")
305
+ st.write(f"""{model_cards[model_card_name]['Secondary use']}""")
306
+ st.subheader("Limitations")
307
+ st.write(f"""{model_cards[model_card_name]['Limitations']}""")
308
+ st.subheader("Training data")
309
+ st.write(f"""{model_cards[model_card_name]['Training data']}""")
310
+ st.subheader("BibTex/Citation Info")
311
+ st.code(f"""{model_cards[model_card_name]['Citation Info']}""")
312
+
313
+ with tab_transfer:
314
+ transfer_text = """
315
+ Transfer learning refers to the process of pre-training a flexible model on a large dataset and using it later on other data with little to no training. It is one of the most outstanding 🚀 achievements in Machine Learning 🧠 and has many practical applications.
316
+
317
+ For time series forecasting, the technique allows you to get lightning-fast predictions ⚡ bypassing the tradeoff between accuracy and speed.
318
+
319
+ [This notebook](https://colab.research.google.com/drive/1uFCO2UBpH-5l2fk3KmxfU0oupsOC6v2n?authuser=0&pli=1#cell-5=) shows how to generate a pre-trained model and store it in a checkpoint to make it available for public use to forecast new time series never seen by the model.
320
+ **You can contribute with your pre-trained models by following [this Notebook](https://github.com/Nixtla/transfer-learning-time-series/blob/main/nbs/Transfer_Learning.ipynb) and sending us an email at federico[at]nixtla.io**
321
+
322
+ You can also take a look at list of pretrained models here. Currently we have this ones avaiable in our [API](https://docs.nixtla.io/reference/neural_transfer_neural_transfer_post) or [Demo](http://nixtla.io/transfer-learning/). You can also download the `.ckpt`:
323
+ - [Pretrained N-HiTS M4 Hourly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_hourly.ckpt)
324
+ - [Pretrained N-HiTS M4 Hourly (Tiny)](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_hourly_tiny.ckpt)
325
+ - [Pretrained N-HiTS M4 Daily](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_daily.ckpt)
326
+ - [Pretrained N-HiTS M4 Monthly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_monthly.ckpt)
327
+ - [Pretrained N-HiTS M4 Yearly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nhits_m4_yearly.ckpt)
328
+ - [Pretrained N-BEATS M4 Hourly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_hourly.ckpt)
329
+ - [Pretrained N-BEATS M4 Daily](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_daily.ckpt)
330
+ - [Pretrained N-BEATS M4 Weekly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_weekly.ckpt)
331
+ - [Pretrained N-BEATS M4 Monthly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_monthly.ckpt)
332
+ - [Pretrained N-BEATS M4 Yearly](https://nixtla-public.s3.amazonaws.com/transfer/pretrained_models/nbeats_m4_yearly.ckpt)
333
+ """
334
+ st.write(transfer_text)
335
+
336
+ with tab_ref:
337
+ ref_text = """
338
+ If you are interested in the transfer learning literature applied to time series forecasting, take a look at these papers:
339
+ - [Meta-learning framework with applications to zero-shot time-series forecasting](https://arxiv.org/abs/2002.02887)
340
+ - [N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting](https://arxiv.org/abs/2201.12886)
341
+ """
342
+ st.write(ref_text)
343
+
344
+ with tab_nixtla:
345
+ nixtla_text = """
346
+ Nixtla is a startup that is building forecasting software for Data Scientists and Devs.
347
+
348
+ We have been developing different open source libraries for machine learning, statistical and deep learning forecasting.
349
+
350
+ In our [GitHub repo](https://github.com/Nixtla), you can find the projects that support this APP.
351
+ """
352
+ st.write(nixtla_text)
353
+ st.image(
354
+ "https://files.readme.io/168cdb2-Screen_Shot_2022-09-30_at_10.40.09.png",
355
+ width=800,
356
+ )
357
+
358
+ with st.sidebar:
359
+ st.download_button(
360
+ label="Download historical data as CSV",
361
+ data=convert_df(df),
362
+ file_name="history.csv",
363
+ mime="text/csv",
364
+ )
365
+ st.download_button(
366
+ label="Download forecasts as CSV",
367
+ data=convert_df(df_forecast),
368
+ file_name="forecasts.csv",
369
+ mime="text/csv",
370
+ )
371
+
372
+
373
+ if __name__ == "__main__":
374
+ st_transfer_learning()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasetsforecast
2
+ fire
3
+ neuralforecast==0.1.0
4
+ pandas
5
+ plotly
6
+ python-dotenv
7
+ torch==2.3.0
8
+ pytorch-lightning
9
+ statsforecast
10
+ streamlit
11
+ streamlit-aggrid
12
+ hyperopt