bstraehle commited on
Commit
a9bd106
·
verified ·
1 Parent(s): 6137b7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -21
app.py CHANGED
@@ -1,12 +1,9 @@
1
  # https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#3-create-and-prepare-the-dataset
2
  import gradio as gr
3
- import os#, torch
4
  from datasets import load_dataset
5
  from huggingface_hub import HfApi, login
6
- #from peft import AutoPeftModelForCausalLM, LoraConfig
7
- #from random import randint
8
- from transformers import AutoTokenizer, AutoModelForCausalLM#, BitsAndBytesConfig, TrainingArguments, pipeline
9
- #from trl import SFTTrainer, setup_chat_format
10
 
11
  hf_profile = "bstraehle"
12
 
@@ -20,6 +17,20 @@ schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255),
20
  base_model_id = "codellama/CodeLlama-7b-hf"
21
  dataset = "b-mc2/sql-create-context"
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def prompt_model(model_id, system_prompt, user_prompt, schema):
24
  pipe = pipeline("text-generation",
25
  model=model_id,
@@ -34,12 +45,7 @@ def prompt_model(model_id, system_prompt, user_prompt, schema):
34
  output = pipe(messages)
35
  result = output[0]["generated_text"][-1]["content"]
36
  print(result)
37
- return result
38
-
39
- def fine_tune_model(base_model_id, dataset):
40
- tokenizer = download_model(base_model_id)
41
- fine_tuned_model_id = upload_model(base_model_id, tokenizer)
42
- return fine_tuned_model_id
43
 
44
  def download_model(base_model_id):
45
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)
@@ -51,7 +57,7 @@ def upload_model(base_model_id, tokenizer):
51
  fine_tuned_model_id = replace_hf_profile(base_model_id)
52
  login(token=os.environ["HF_TOKEN"])
53
  api = HfApi()
54
- #api.delete_repo(repo_id=fine_tuned_model_id, repo_type="model")
55
  api.create_repo(repo_id=fine_tuned_model_id)
56
  api.upload_folder(
57
  folder_path=base_model_id,
@@ -64,15 +70,6 @@ def replace_hf_profile(base_model_id):
64
  model_id = base_model_id[base_model_id.rfind('/')+1:]
65
  return f"{hf_profile}/{model_id}"
66
 
67
- def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
68
- #raise gr.Error("Please clone and bring your own credentials.")
69
- if action == action_1:
70
- result = fine_tune_model(base_model_id, dataset)
71
- elif action == action_2:
72
- fine_tuned_model_id = replace_hf_profile(base_model_id)
73
- result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
74
- return result
75
-
76
  demo = gr.Interface(fn=process,
77
  inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
78
  gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
 
1
  # https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#3-create-and-prepare-the-dataset
2
  import gradio as gr
3
+ import os
4
  from datasets import load_dataset
5
  from huggingface_hub import HfApi, login
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
7
 
8
  hf_profile = "bstraehle"
9
 
 
17
  base_model_id = "codellama/CodeLlama-7b-hf"
18
  dataset = "b-mc2/sql-create-context"
19
 
20
+ def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
21
+ #raise gr.Error("Please clone and bring your own credentials.")
22
+ if action == action_1:
23
+ result = fine_tune_model(base_model_id, dataset)
24
+ elif action == action_2:
25
+ fine_tuned_model_id = replace_hf_profile(base_model_id)
26
+ result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
27
+ return result
28
+
29
+ def fine_tune_model(base_model_id, dataset):
30
+ tokenizer = download_model(base_model_id)
31
+ fine_tuned_model_id = upload_model(base_model_id, tokenizer)
32
+ return fine_tuned_model_id
33
+
34
  def prompt_model(model_id, system_prompt, user_prompt, schema):
35
  pipe = pipeline("text-generation",
36
  model=model_id,
 
45
  output = pipe(messages)
46
  result = output[0]["generated_text"][-1]["content"]
47
  print(result)
48
+ return result
 
 
 
 
 
49
 
50
  def download_model(base_model_id):
51
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)
 
57
  fine_tuned_model_id = replace_hf_profile(base_model_id)
58
  login(token=os.environ["HF_TOKEN"])
59
  api = HfApi()
60
+ api.delete_repo(repo_id=fine_tuned_model_id, repo_type="model")
61
  api.create_repo(repo_id=fine_tuned_model_id)
62
  api.upload_folder(
63
  folder_path=base_model_id,
 
70
  model_id = base_model_id[base_model_id.rfind('/')+1:]
71
  return f"{hf_profile}/{model_id}"
72
 
 
 
 
 
 
 
 
 
 
73
  demo = gr.Interface(fn=process,
74
  inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
75
  gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),