nroggendorff commited on
Commit
482f6f1
·
verified ·
1 Parent(s): 699a605

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -1,8 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import torch
4
- import spaces
5
- import torch
6
 
7
  torch.set_default_device("cuda")
8
 
@@ -18,7 +16,6 @@ model_id = "cognitivecomputations/dolphin-2.9.3-mistral-7B-32k"
18
  tokenizer = AutoTokenizer.from_pretrained(model_id)
19
  model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
20
 
21
- @spaces.GPU(duration=120)
22
  def predict(input_text, history):
23
  chat = []
24
  for item in history:
@@ -26,12 +23,12 @@ def predict(input_text, history):
26
  if item[1] is not None:
27
  chat.append({"role": "assistant", "content": item[1]})
28
  chat.append({"role": "user", "content": input_text})
29
- conv = tokenizer.apply_chat_template(chat, tokenize=False)
30
 
 
31
  inputs = tokenizer(conv, return_tensors="pt").to("cuda")
32
  outputs = model.generate(**inputs, max_new_tokens=512)
33
 
34
- generated_text = tokenizer.batch_decode(outputs)[0]
35
  return generated_text.split("<|assistant|>")[-1]
36
 
37
  gr.ChatInterface(predict, theme="soft").launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import torch
 
 
4
 
5
  torch.set_default_device("cuda")
6
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
  model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
18
 
 
19
  def predict(input_text, history):
20
  chat = []
21
  for item in history:
 
23
  if item[1] is not None:
24
  chat.append({"role": "assistant", "content": item[1]})
25
  chat.append({"role": "user", "content": input_text})
 
26
 
27
+ conv = tokenizer.apply_chat_template(chat, tokenize=False)
28
  inputs = tokenizer(conv, return_tensors="pt").to("cuda")
29
  outputs = model.generate(**inputs, max_new_tokens=512)
30
 
31
+ generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
32
  return generated_text.split("<|assistant|>")[-1]
33
 
34
  gr.ChatInterface(predict, theme="soft").launch()