vilarin commited on
Commit
0961bc7
·
verified ·
1 Parent(s): f663115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -11
app.py CHANGED
@@ -49,7 +49,6 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
49
 
50
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
51
 
52
- # streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
53
 
54
  generate_kwargs = dict(
55
  input_ids=input_ids,
@@ -60,17 +59,9 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
60
  temperature=temperature,
61
  repetition_penalty=1.2,
62
  )
63
- '''
64
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
65
- thread.start()
66
-
67
- buffer = ""
68
- for new_text in streamer:
69
- buffer[-1][1] += new_text
70
- yield buffer
71
- '''
72
  with torch.no_grad():
73
- outputs = model.generate(**inputs, **gen_kwargs)
74
  outputs = outputs[:, inputs['input_ids'].shape[1]:]
75
  results = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
  return results
 
49
 
50
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
51
 
 
52
 
53
  generate_kwargs = dict(
54
  input_ids=input_ids,
 
59
  temperature=temperature,
60
  repetition_penalty=1.2,
61
  )
62
+
 
 
 
 
 
 
 
 
63
  with torch.no_grad():
64
+ outputs = model.generate(**generate_kwargs)
65
  outputs = outputs[:, inputs['input_ids'].shape[1]:]
66
  results = tokenizer.decode(outputs[0], skip_special_tokens=True)
67
  return results