chiayewken commited on
Commit
d38ce92
·
1 Parent(s): 9b30274

Update model in app.py

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +6 -7
  3. run_demo.py +97 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea/
app.py CHANGED
@@ -7,6 +7,8 @@ import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
 
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
@@ -34,7 +36,7 @@ if not torch.cuda.is_available():
34
 
35
 
36
  if torch.cuda.is_available():
37
- model_id = "meta-llama/Llama-2-7b-chat-hf"
38
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
39
  tokenizer = AutoTokenizer.from_pretrained(model_id)
40
  tokenizer.use_default_system_prompt = False
@@ -51,13 +53,10 @@ def generate(
51
  top_k: int = 50,
52
  repetition_penalty: float = 1.2,
53
  ) -> Iterator[str]:
54
- conversation = []
55
- if system_prompt:
56
- conversation.append({"role": "system", "content": system_prompt})
57
- conversation += chat_history
58
- conversation.append({"role": "user", "content": message})
59
 
60
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
61
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
62
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
63
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
+ from run_demo import ZeroShotChatTemplate
11
+
12
  MAX_MAX_NEW_TOKENS = 2048
13
  DEFAULT_MAX_NEW_TOKENS = 1024
14
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
36
 
37
 
38
  if torch.cuda.is_available():
39
+ model_id = "chiayewken/llama3-8b-gsm8k-rpo"
40
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
41
  tokenizer = AutoTokenizer.from_pretrained(model_id)
42
  tokenizer.use_default_system_prompt = False
 
53
  top_k: int = 50,
54
  repetition_penalty: float = 1.2,
55
  ) -> Iterator[str]:
56
+ demo = ZeroShotChatTemplate()
57
+ prompt = demo.make_prompt(message)
58
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 
 
59
 
 
60
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
run_demo.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Optional, List
3
+
4
+ import vllm
5
+ from fire import Fire
6
+ from pydantic import BaseModel
7
+ from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM
8
+
9
+
10
+ class ZeroShotChatTemplate:
11
+ # This is the default template used in llama-factory for training
12
+ texts: List[str] = []
13
+
14
+ @staticmethod
15
+ def make_prompt(prompt: str) -> str:
16
+ return f"Human: {prompt}\nAssistant: "
17
+
18
+ @staticmethod
19
+ def get_stopping_words() -> List[str]:
20
+ return ["Human:"]
21
+
22
+ @staticmethod
23
+ def extract_answer(text: str) -> str:
24
+ filtered = "".join([char for char in text if char.isdigit() or char == " "])
25
+ if not filtered.strip():
26
+ return text
27
+ return re.findall(pattern=r"\d+", string=filtered)[-1]
28
+
29
+
30
+ class VLLMModel(BaseModel, arbitrary_types_allowed=True):
31
+ path_model: str
32
+ model: vllm.LLM = None
33
+ tokenizer: Optional[PreTrainedTokenizer] = None
34
+ max_input_length: int = 512
35
+ max_output_length: int = 512
36
+ stopping_words: Optional[List[str]] = None
37
+
38
+ def load(self):
39
+ if self.model is None:
40
+ self.model = vllm.LLM(model=self.path_model, trust_remote_code=True)
41
+ if self.tokenizer is None:
42
+ self.tokenizer = AutoTokenizer.from_pretrained(self.path_model)
43
+
44
+ def format_prompt(self, prompt: str) -> str:
45
+ self.load()
46
+ prompt = prompt.rstrip(" ") # Llama is sensitive (eg "Answer:" vs "Answer: ")
47
+ return prompt
48
+
49
+ def make_kwargs(self, do_sample: bool, **kwargs) -> dict:
50
+ if self.stopping_words:
51
+ kwargs.update(stop=self.stopping_words)
52
+ params = vllm.SamplingParams(
53
+ temperature=0.5 if do_sample else 0.0,
54
+ max_tokens=self.max_output_length,
55
+ **kwargs,
56
+ )
57
+
58
+ outputs = dict(sampling_params=params, use_tqdm=False)
59
+ return outputs
60
+
61
+ def run(self, prompt: str) -> str:
62
+ prompt = self.format_prompt(prompt)
63
+ outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False))
64
+ pred = outputs[0].outputs[0].text
65
+ pred = pred.split("<|endoftext|>")[0]
66
+ return pred
67
+
68
+
69
+ def upload_to_hub(path: str, repo_id: str):
70
+ tokenizer = AutoTokenizer.from_pretrained(path)
71
+ model = AutoModelForCausalLM.from_pretrained(path)
72
+ model.push_to_hub(repo_id)
73
+ tokenizer.push_to_hub(repo_id)
74
+
75
+
76
+ def main(
77
+ question: str = "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?",
78
+ **kwargs,
79
+ ):
80
+ model = VLLMModel(**kwargs)
81
+ demo = ZeroShotChatTemplate()
82
+ model.stopping_words = demo.get_stopping_words()
83
+
84
+ prompt = demo.make_prompt(question)
85
+ raw_outputs = model.run(prompt)
86
+ pred = demo.extract_answer(raw_outputs)
87
+ print(dict(question=question, prompt=prompt, raw_outputs=raw_outputs, pred=pred))
88
+
89
+
90
+ """
91
+ p run_demo.py upload_to_hub outputs_paths/gsm8k_paths_llama3_8b_beta_03_rank_128/final chiayewken/llama3-8b-gsm8k-rpo
92
+ p run_demo.py main --path_model chiayewken/llama3-8b-gsm8k-rpo
93
+ """
94
+
95
+
96
+ if __name__ == "__main__":
97
+ Fire()