azrai99 commited on
Commit
d9b7495
·
verified ·
1 Parent(s): 5f28cd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -197,8 +197,8 @@ def select_model(horizon, model_type, max_steps=50):
197
  elif model_type == 'TimesNet':
198
  return TimesNet(h=horizon,
199
  input_size=horizon * 5,
200
- hidden_size=16,
201
- conv_hidden_size=32,
202
  loss=HuberMQLoss(level=[90]),
203
  scaler_type='standard',
204
  learning_rate=1e-3,
@@ -212,16 +212,16 @@ def select_model(horizon, model_type, max_steps=50):
212
  input_size=horizon * 5,
213
  loss=HuberMQLoss(level=[90]),
214
  scaler_type='standard',
215
- encoder_n_layers=2,
216
- encoder_hidden_size=64,
217
  context_size=10,
218
- decoder_hidden_size=64,
219
- decoder_layers=2,
220
  max_steps=max_steps)
221
  elif model_type == 'TFT':
222
  return TFT(h=horizon,
223
  input_size=horizon*5,
224
- hidden_size=16,
225
  loss=HuberMQLoss(level=[90]),
226
  learning_rate=0.005,
227
  scaler_type='standard',
 
197
  elif model_type == 'TimesNet':
198
  return TimesNet(h=horizon,
199
  input_size=horizon * 5,
200
+ hidden_size=32,
201
+ conv_hidden_size=64,
202
  loss=HuberMQLoss(level=[90]),
203
  scaler_type='standard',
204
  learning_rate=1e-3,
 
212
  input_size=horizon * 5,
213
  loss=HuberMQLoss(level=[90]),
214
  scaler_type='standard',
215
+ encoder_n_layers=3,
216
+ encoder_hidden_size=256,
217
  context_size=10,
218
+ decoder_hidden_size=256,
219
+ decoder_layers=3,
220
  max_steps=max_steps)
221
  elif model_type == 'TFT':
222
  return TFT(h=horizon,
223
  input_size=horizon*5,
224
+ hidden_size=96,
225
  loss=HuberMQLoss(level=[90]),
226
  learning_rate=0.005,
227
  scaler_type='standard',