bstraehle commited on
Commit
1939ff5
·
verified ·
1 Parent(s): 9d8f256

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -18
app.py CHANGED
@@ -79,26 +79,15 @@ def fine_tune_model(base_model_name, dataset_name):
79
  print("###")
80
 
81
  # Configure training arguments
82
-
 
83
  training_args = Seq2SeqTrainingArguments(
84
  output_dir=f"./{FT_MODEL_NAME}",
85
  logging_dir="./logs",
86
  num_train_epochs=1,
87
  max_steps=1, # overwrites num_train_epochs
88
- push_to_hub=True, # only model, also need to push tokenizer
89
- ### TODO ###
90
- #per_device_train_batch_size=16,
91
- #per_device_eval_batch_size=64,
92
- #eval_strategy="steps",
93
- #save_total_limit=2,
94
- #save_steps=500,
95
- #eval_steps=500,
96
- #warmup_steps=500,
97
- #weight_decay=0.01,
98
- #metric_for_best_model="accuracy",
99
- #greater_is_better=True,
100
- #load_best_model_at_end=True,
101
- #save_on_each_node=True,
102
  )
103
 
104
  print("### Training arguments")
@@ -106,14 +95,14 @@ def fine_tune_model(base_model_name, dataset_name):
106
  print("###")
107
 
108
  # Create trainer
109
-
 
110
  trainer = Seq2SeqTrainer(
111
  model=model,
112
  args=training_args,
113
  train_dataset=train_dataset,
114
  eval_dataset=test_dataset,
115
- ### TODO ###
116
- #compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
117
  )
118
 
119
  # Train model
 
79
  print("###")
80
 
81
  # Configure training arguments
82
+
83
+ # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
84
  training_args = Seq2SeqTrainingArguments(
85
  output_dir=f"./{FT_MODEL_NAME}",
86
  logging_dir="./logs",
87
  num_train_epochs=1,
88
  max_steps=1, # overwrites num_train_epochs
89
+ push_to_hub=True, # only pushes model, also need to push tokenizer (see below)
90
+ # TODO
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
92
 
93
  print("### Training arguments")
 
95
  print("###")
96
 
97
  # Create trainer
98
+
99
+ # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
100
  trainer = Seq2SeqTrainer(
101
  model=model,
102
  args=training_args,
103
  train_dataset=train_dataset,
104
  eval_dataset=test_dataset,
105
+ # TODO
 
106
  )
107
 
108
  # Train model