Nefertury commited on
Commit
9d0493a
·
1 Parent(s): 3bb4dff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path as osp
3
+ import random
4
+ from typing import Union
5
+ import os
6
+ import sys
7
+ from typing import List
8
+ import torch
9
+ import transformers
10
+ from datasets import load_dataset
11
+ from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
12
+ import gradio as gr
13
+ import torch.nn as nn
14
+
15
+
16
+ from peft import (
17
+ LoraConfig,
18
+ get_peft_model,
19
+ get_peft_model_state_dict,
20
+ prepare_model_for_int8_training,
21
+ set_peft_model_state_dict,
22
+ PeftModel
23
+ )
24
+ from transformers import LlamaForCausalLM, LlamaTokenizer
25
+
26
+
27
+ base_model='nickypro/tinyllama-15M'
28
+
29
+
30
+ class Prompter(object):
31
+
32
+ def generate_prompt(
33
+ self,
34
+ instruction: str,
35
+ label: Union[None, str] = None,
36
+ ) -> str:
37
+
38
+ res = f"{instruction}\nAnswer: "
39
+
40
+ if label:
41
+ res = f"{res}{label}"
42
+
43
+ return res
44
+
45
+ def get_response(self, output: str) -> str:
46
+ return output.split("Answer:")[1].strip().replace("/", "\u00F7").replace("*", "\u00D7")
47
+
48
+ model = LlamaForCausalLM.from_pretrained(
49
+ base_model,
50
+ torch_dtype=torch.float32,
51
+ device_map="auto",
52
+ )
53
+ model = PeftModel.from_pretrained(
54
+ model,
55
+ f'checkpoint-16000',
56
+ torch_dtype=torch.float32,
57
+ )
58
+
59
+ model.eval()
60
+ if torch.__version__ >= "2" and sys.platform != "win32":
61
+ model = torch.compile(model)
62
+
63
+ tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
64
+ tokenizer.pad_token_id = 0
65
+ tokenizer.padding_side = "left"
66
+
67
+
68
+ def generate_answers(instructions, model, tokenizer):
69
+ prompter = Prompter()
70
+ raw_answers = []
71
+
72
+ for instruction in instructions:
73
+ prompt = prompter.generate_prompt(instruction)
74
+ inputs = tokenizer(prompt, return_tensors="pt")
75
+
76
+ input_ids = inputs["input_ids"]
77
+
78
+ generation_output = model.generate(
79
+ input_ids=input_ids,
80
+ return_dict_in_generate=True,
81
+ output_scores=True,
82
+ pad_token_id=0,
83
+ eos_token_id=tokenizer.eos_token_id,
84
+ max_new_tokens=16
85
+ )
86
+ s = generation_output.sequences[0]
87
+ raw_answers.append(tokenizer.decode(s, skip_special_tokens=True).strip())
88
+
89
+ return raw_answers
90
+
91
+
92
+ def evaluate(instruction):
93
+ return generate_answers([instruction], model, tokenizer)[0]
94
+
95
+
96
+ if __name__ == "__main__":
97
+ gr.Interface(
98
+ fn=evaluate,
99
+ inputs=[
100
+ gr.components.Textbox(
101
+ lines=1,
102
+ label="Arithmetic",
103
+ placeholder="63303235 + 20239503",
104
+ )
105
+ ],
106
+ outputs=[
107
+ gr.Textbox(
108
+ lines=5,
109
+ label="Output",
110
+ )
111
+ ],
112
+ title="Arithmetic LLaMA",
113
+ description="This model is 15M llama model, finetuned on a+b tasks",
114
+ ).queue().launch(server_name="0.0.0.0", share=True)