chiayewken's picture
Update model in app.py
d38ce92
raw
history blame
3.11 kB
import re
from typing import Optional, List
import vllm
from fire import Fire
from pydantic import BaseModel
from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM
class ZeroShotChatTemplate:
# This is the default template used in llama-factory for training
texts: List[str] = []
@staticmethod
def make_prompt(prompt: str) -> str:
return f"Human: {prompt}\nAssistant: "
@staticmethod
def get_stopping_words() -> List[str]:
return ["Human:"]
@staticmethod
def extract_answer(text: str) -> str:
filtered = "".join([char for char in text if char.isdigit() or char == " "])
if not filtered.strip():
return text
return re.findall(pattern=r"\d+", string=filtered)[-1]
class VLLMModel(BaseModel, arbitrary_types_allowed=True):
path_model: str
model: vllm.LLM = None
tokenizer: Optional[PreTrainedTokenizer] = None
max_input_length: int = 512
max_output_length: int = 512
stopping_words: Optional[List[str]] = None
def load(self):
if self.model is None:
self.model = vllm.LLM(model=self.path_model, trust_remote_code=True)
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(self.path_model)
def format_prompt(self, prompt: str) -> str:
self.load()
prompt = prompt.rstrip(" ") # Llama is sensitive (eg "Answer:" vs "Answer: ")
return prompt
def make_kwargs(self, do_sample: bool, **kwargs) -> dict:
if self.stopping_words:
kwargs.update(stop=self.stopping_words)
params = vllm.SamplingParams(
temperature=0.5 if do_sample else 0.0,
max_tokens=self.max_output_length,
**kwargs,
)
outputs = dict(sampling_params=params, use_tqdm=False)
return outputs
def run(self, prompt: str) -> str:
prompt = self.format_prompt(prompt)
outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False))
pred = outputs[0].outputs[0].text
pred = pred.split("<|endoftext|>")[0]
return pred
def upload_to_hub(path: str, repo_id: str):
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path)
model.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
def main(
question: str = "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?",
**kwargs,
):
model = VLLMModel(**kwargs)
demo = ZeroShotChatTemplate()
model.stopping_words = demo.get_stopping_words()
prompt = demo.make_prompt(question)
raw_outputs = model.run(prompt)
pred = demo.extract_answer(raw_outputs)
print(dict(question=question, prompt=prompt, raw_outputs=raw_outputs, pred=pred))
"""
p run_demo.py upload_to_hub outputs_paths/gsm8k_paths_llama3_8b_beta_03_rank_128/final chiayewken/llama3-8b-gsm8k-rpo
p run_demo.py main --path_model chiayewken/llama3-8b-gsm8k-rpo
"""
if __name__ == "__main__":
Fire()