CyberNative commited on
Commit
ac2d8d3
·
verified ·
1 Parent(s): 19c9ea5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -28
app.py CHANGED
@@ -2,9 +2,8 @@ import gradio as gr
2
  import os
3
  import spaces
4
  from transformers import AutoTokenizer, TextIteratorStreamer
5
- from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
6
  from threading import Thread
7
- import torch
8
 
9
  # Set an environment variable
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -45,15 +44,6 @@ h1 {
45
  }
46
  """
47
 
48
- # Load the tokenizer and model
49
- tokenizer = AutoTokenizer.from_pretrained("CyberNative-AI/Colibri_8b_v0.1_gptq_128_4bit")
50
- model = AutoGPTQForCausalLM.from_quantized("CyberNative-AI/Colibri_8b_v0.1_gptq_128_4bit", dtype=torch.float32, device="cpu")
51
-
52
- terminators = [
53
- tokenizer.eos_token_id,
54
- tokenizer.convert_tokens_to_ids("<|im_end|>")
55
- ]
56
-
57
  @spaces.GPU(duration=120)
58
  def chat_llama3_8b(message: str,
59
  history: list,
@@ -76,24 +66,16 @@ def chat_llama3_8b(message: str,
76
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
77
  conversation.append({"role": "user", "content": message})
78
 
79
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
80
-
81
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
82
-
83
- generate_kwargs = dict(
84
- input_ids= input_ids,
85
- streamer=streamer,
86
- max_new_tokens=max_new_tokens,
87
- do_sample=True,
88
- top_p=0.7,
89
- temperature=temperature,
90
- eos_token_id=terminators,
91
  )
92
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
93
- if temperature == 0:
94
- generate_kwargs['do_sample'] = False
95
-
96
- t = Thread(target=model.generate, kwargs=generate_kwargs)
97
  t.start()
98
 
99
  outputs = []
 
2
  import os
3
  import spaces
4
  from transformers import AutoTokenizer, TextIteratorStreamer
 
5
  from threading import Thread
6
+ from llama_cpp import Llama
7
 
8
  # Set an environment variable
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
44
  }
45
  """
46
 
 
 
 
 
 
 
 
 
 
47
  @spaces.GPU(duration=120)
48
  def chat_llama3_8b(message: str,
49
  history: list,
 
66
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
67
  conversation.append({"role": "user", "content": message})
68
 
69
+ llm = Llama.from_pretrained(
70
+ repo_id="CyberNative-AI/Colibri_8b_v0.1_q5_gguf",
71
+ filename="*Q5_K_M.gguf",
72
+ chat_format="chatml",
73
+ verbose=False,
74
+ max_tokens=max_new_tokens,
75
+ stop=["<|im_end|>"]
 
 
 
 
 
76
  )
77
+
78
+ t = Thread(target=llm.create_chat_completion, messages=conversation, temperature=temperature)
 
 
 
79
  t.start()
80
 
81
  outputs = []