bstraehle commited on
Commit
b85865d
·
verified ·
1 Parent(s): a99ba01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -93,6 +93,7 @@ def fine_tune_model(model_name, dataset_name):
93
  args=training_args,
94
  train_dataset=train_dataset,
95
  eval_dataset=test_dataset,
 
96
  compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
97
  )
98
 
@@ -100,9 +101,10 @@ def fine_tune_model(model_name, dataset_name):
100
  print(trainer)
101
  print("###")
102
 
103
- # Train model
104
  #trainer.train()
105
-
 
106
  def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
107
  pipe = pipeline("text-generation",
108
  model=model_name,
 
93
  args=training_args,
94
  train_dataset=train_dataset,
95
  eval_dataset=test_dataset,
96
+ tokenizer=tokenizer,
97
  compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
98
  )
99
 
 
101
  print(trainer)
102
  print("###")
103
 
104
+ # Train and save model
105
  #trainer.train()
106
+ #trainer.save_model()
107
+
108
  def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
109
  pipe = pipeline("text-generation",
110
  model=model_name,