Update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
32 |
trust_remote_code=True,
|
33 |
).eval()
|
34 |
|
35 |
-
tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True)
|
36 |
|
37 |
class StopOnTokens(StoppingCriteria):
|
38 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
@@ -56,7 +56,7 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
|
|
56 |
print(f"Conversation is -\n{conversation}")
|
57 |
stop = StopOnTokens()
|
58 |
|
59 |
-
input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(model.device)
|
60 |
#input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
|
61 |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
62 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
@@ -64,8 +64,8 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
|
|
64 |
|
65 |
generate_kwargs = dict(
|
66 |
input_ids=input_ids,
|
67 |
-
max_new_tokens=max_new_tokens,
|
68 |
streamer=streamer,
|
|
|
69 |
do_sample=True,
|
70 |
top_k=1,
|
71 |
temperature=temperature,
|
|
|
32 |
trust_remote_code=True,
|
33 |
).eval()
|
34 |
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True, use_fast=False)
|
36 |
|
37 |
class StopOnTokens(StoppingCriteria):
|
38 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
56 |
print(f"Conversation is -\n{conversation}")
|
57 |
stop = StopOnTokens()
|
58 |
|
59 |
+
input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(next(model.parameters()).device)
|
60 |
#input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
|
61 |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
62 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
|
|
64 |
|
65 |
generate_kwargs = dict(
|
66 |
input_ids=input_ids,
|
|
|
67 |
streamer=streamer,
|
68 |
+
max_new_tokens=max_new_tokens,
|
69 |
do_sample=True,
|
70 |
top_k=1,
|
71 |
temperature=temperature,
|