bstraehle commited on
Commit
7f9f34a
·
verified ·
1 Parent(s): 7e05fe4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -4
app.py CHANGED
@@ -18,8 +18,8 @@ system_prompt = "You are a text to SQL query translator. Given a question in Eng
18
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
19
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
20
 
21
- base_model_id = "google/gemma-2-9b-it" # "meta-llama/Meta-Llama-3-8B-Instruct"
22
- dataset = "b-mc2/sql-create-context"
23
 
24
  def prompt_model(model_id, system_prompt, user_prompt, schema):
25
  pipe = pipeline("text-generation",
@@ -64,12 +64,73 @@ def prompt_model(model_id, system_prompt, user_prompt, schema):
64
  # print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
65
 
66
  def fine_tune_model(base_model_id, dataset):
67
- tokenizer = download_model(base_model_id)
 
68
  #prepare_dataset(dataset)
69
  #train_model(base_model_id)
70
- fine_tuned_model_id = upload_model(base_model_id, tokenizer)
71
  return fine_tuned_model_id
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def download_model(base_model_id):
74
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)
75
  model = AutoModelForCausalLM.from_pretrained(base_model_id)
 
18
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
19
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
20
 
21
+ base_model_id = "microsoft/Phi-3-mini-4k-instruct"
22
+ dataset = "gretelai/synthetic_text_to_sql"
23
 
24
  def prompt_model(model_id, system_prompt, user_prompt, schema):
25
  pipe = pipeline("text-generation",
 
64
  # print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
65
 
66
  def fine_tune_model(base_model_id, dataset):
67
+ test(base_model_id, dataset)
68
+ ##tokenizer = download_model(base_model_id)
69
  #prepare_dataset(dataset)
70
  #train_model(base_model_id)
71
+ ##fine_tuned_model_id = upload_model(base_model_id, tokenizer)
72
  return fine_tuned_model_id
73
 
74
+ def test(base_model_id, dataset):
75
+ print("111")
76
+ model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16)
77
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
78
+
79
+ # Load the dataset for fine-tuning
80
+ print("222")
81
+ dataset = load_dataset(dataset, split="train")
82
+
83
+ # Define the formatting function for the prompts
84
+ def formatting_prompts_func(examples):
85
+ convos = examples["conversations"]
86
+ texts = []
87
+ mapper = {"system": "system\n", "human": "\nuser\n", "gpt": "\nassistant\n"}
88
+ end_mapper = {"system": "", "human": "", "gpt": ""}
89
+ for convo in convos:
90
+ text = "".join(f"{mapper[(turn := x['from'])]} {x['value']}\n{end_mapper[turn]}" for x in convo)
91
+ texts.append(f"{text}{tokenizer.eos_token}")
92
+ return {"text": texts}
93
+
94
+ # Apply the formatting function to the dataset
95
+ print("333")
96
+ dataset = dataset.map(formatting_prompts_func, batched=True)
97
+
98
+ # Define the training arguments
99
+ print("444")
100
+ args = TrainingArguments(
101
+ evaluation_strategy="steps",
102
+ per_device_train_batch_size=7,
103
+ gradient_accumulation_steps=4,
104
+ gradient_checkpointing=True,
105
+ learning_rate=1e-4,
106
+ fp16=True,
107
+ max_steps=-1,
108
+ num_train_epochs=3,
109
+ save_strategy="epoch",
110
+ logging_steps=10,
111
+ output_dir=NEW_MODEL_NAME,
112
+ optim="paged_adamw_32bit",
113
+ lr_scheduler_type="linear"
114
+ )
115
+
116
+ # Create the trainer
117
+ print("555")
118
+ trainer = SFTTrainer(
119
+ model=model,
120
+ args=args,
121
+ train_dataset=dataset,
122
+ dataset_text_field="text",
123
+ max_seq_length=128,
124
+ formatting_func=formatting_prompts_func
125
+ )
126
+
127
+ # Start the training process
128
+ print("666")
129
+ trainer.train()
130
+
131
+ print("777")
132
+ trainer.save_model()
133
+
134
  def download_model(base_model_id):
135
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)
136
  model = AutoModelForCausalLM.from_pretrained(base_model_id)