Update app.py
Browse files
app.py
CHANGED
@@ -93,6 +93,7 @@ def fine_tune_model(model_name, dataset_name):
|
|
93 |
args=training_args,
|
94 |
train_dataset=train_dataset,
|
95 |
eval_dataset=test_dataset,
|
|
|
96 |
compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
|
97 |
)
|
98 |
|
@@ -100,9 +101,10 @@ def fine_tune_model(model_name, dataset_name):
|
|
100 |
print(trainer)
|
101 |
print("###")
|
102 |
|
103 |
-
# Train model
|
104 |
#trainer.train()
|
105 |
-
|
|
|
106 |
def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
|
107 |
pipe = pipeline("text-generation",
|
108 |
model=model_name,
|
|
|
93 |
args=training_args,
|
94 |
train_dataset=train_dataset,
|
95 |
eval_dataset=test_dataset,
|
96 |
+
tokenizer=tokenizer,
|
97 |
compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
|
98 |
)
|
99 |
|
|
|
101 |
print(trainer)
|
102 |
print("###")
|
103 |
|
104 |
+
# Train and save model
|
105 |
#trainer.train()
|
106 |
+
#trainer.save_model()
|
107 |
+
|
108 |
def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
|
109 |
pipe = pipeline("text-generation",
|
110 |
model=model_name,
|