nroggendorff commited on
Commit
b0f4d21
·
verified ·
1 Parent(s): 813d81f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -1,10 +1,26 @@
1
- from transformers import pipeline as pl
2
  import gradio as gr
 
3
 
4
- oracle = pl("stabilityai/stablelm-2-zephyr-1_6b")
 
 
 
 
5
 
6
  def pipe(text: str):
7
- return oracle(text)
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  if __name__ == "__main__":
10
  interface = gr.Interface(pipe, gr.Textbox(label="Prompt"), gr.Textbox(label="Response"), title="Text Completion")
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ tokenizer = AutoTokenizer.from_pretrained('stabilityai/stablelm-2-zephyr-1_6b')
5
+ model = AutoModelForCausalLM.from_pretrained(
6
+ 'stabilityai/stablelm-2-zephyr-1_6b',
7
+ device_map="auto"
8
+ )
9
 
10
  def pipe(text: str):
11
+ tokens = model.generate(
12
+ inputs.to(model.device),
13
+ max_new_tokens=1024,
14
+ temperature=0.5,
15
+ do_sample=True
16
+ )
17
+
18
+ inputs = tokenizer.apply_chat_template(
19
+ text,
20
+ add_generation_prompt=True,
21
+ return_tensors='pt'
22
+ )
23
+ return tokenizer.decode(tokens[0], skip_special_tokens=False)
24
 
25
  if __name__ == "__main__":
26
  interface = gr.Interface(pipe, gr.Textbox(label="Prompt"), gr.Textbox(label="Response"), title="Text Completion")