prithivMLmods commited on
Commit
aff3623
·
verified ·
1 Parent(s): c0bf3d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -112
app.py CHANGED
@@ -1,145 +1,127 @@
1
- import spaces
2
  import os
3
- import json
4
- import subprocess
5
- from llama_cpp import Llama
6
- from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
7
- from llama_cpp_agent.providers import LlamaCppPythonProvider
8
- from llama_cpp_agent.chat_history import BasicChatHistory
9
- from llama_cpp_agent.chat_history.messages import Roles
10
  import gradio as gr
11
- from huggingface_hub import hf_hub_download
 
 
12
 
13
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
14
 
15
- hf_hub_download(
16
- repo_id="mradermacher/GWQ-9B-Preview-GGUF",
17
- filename="GWQ-9B-Preview.Q4_K_M.gguf",
18
- local_dir="./models"
19
- )
20
 
21
- hf_hub_download(
22
- repo_id="mradermacher/GWQ-9B-Preview2-GGUF",
23
- filename="GWQ-9B-Preview2.Q4_K_M.gguf",
24
- local_dir="./models"
 
 
25
  )
 
 
26
 
27
- llm = None
28
- llm_model = None
29
 
30
  @spaces.GPU(duration=120)
31
- def respond(
32
- message,
33
- history: list[tuple[str, str]],
34
- model,
35
- system_message,
36
- max_tokens,
37
- temperature,
38
- top_p,
39
- top_k,
40
- repeat_penalty,
41
- ):
42
- chat_template = MessagesFormatterType.GEMMA_2
43
-
44
- global llm
45
- global llm_model
46
-
47
- if llm is None or llm_model != model:
48
- llm = Llama(
49
- model_path=f"models/{model}",
50
- flash_attn=True,
51
- n_gpu_layers=81,
52
- n_batch=1024,
53
- n_ctx=8192,
54
- )
55
- llm_model = model
56
 
57
- provider = LlamaCppPythonProvider(llm)
 
 
 
 
58
 
59
- agent = LlamaCppAgent(
60
- provider,
61
- system_prompt=f"{system_message}",
62
- predefined_messages_formatter_type=chat_template,
63
- debug_output=True
 
 
 
 
 
 
64
  )
65
-
66
- settings = provider.get_provider_default_settings()
67
- settings.temperature = temperature
68
- settings.top_k = top_k
69
- settings.top_p = top_p
70
- settings.max_tokens = max_tokens
71
- settings.repeat_penalty = repeat_penalty
72
- settings.stream = True
73
 
74
- messages = BasicChatHistory()
 
 
 
75
 
76
- for msn in history:
77
- user = {
78
- 'role': Roles.user,
79
- 'content': msn[0]
80
- }
81
- assistant = {
82
- 'role': Roles.assistant,
83
- 'content': msn[1]
84
- }
85
- messages.add_message(user)
86
- messages.add_message(assistant)
87
-
88
- stream = agent.get_chat_response(
89
- message,
90
- llm_sampling_settings=settings,
91
- chat_history=messages,
92
- returns_streaming_generator=True,
93
- print_output=False
94
- )
95
-
96
- outputs = ""
97
- for output in stream:
98
- outputs += output
99
- yield outputs
100
 
101
  demo = gr.ChatInterface(
102
- respond,
103
  additional_inputs=[
104
- gr.Dropdown([
105
- 'GWQ-9B-Preview.Q4_K_M.gguf',
106
- 'GWQ-9B-Preview2.Q4_K_M.gguf'
107
- ],
108
- value="GWQ-9B-Preview.Q4_K_M.gguf",
109
- label="Model"
110
  ),
111
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
112
- gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
113
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
114
  gr.Slider(
 
115
  minimum=0.1,
 
 
 
 
 
 
 
116
  maximum=1.0,
117
- value=0.95,
118
  step=0.05,
119
- label="Top-p",
120
  ),
121
  gr.Slider(
122
- minimum=0,
123
- maximum=100,
124
- value=40,
125
- step=1,
126
  label="Top-k",
 
 
 
 
127
  ),
128
  gr.Slider(
129
- minimum=0.0,
130
- maximum=2.0,
131
- value=1.1,
132
- step=0.1,
133
  label="Repetition penalty",
 
 
 
 
134
  ),
135
  ],
136
- title="GWQ PREV",
137
- chatbot=gr.Chatbot(
138
- scale=1,
139
- show_copy_button=True,
140
- type="messages"
141
- )
 
 
 
 
 
 
 
142
  )
143
 
 
144
  if __name__ == "__main__":
145
- demo.launch()
 
 
1
  import os
2
+ from collections.abc import Iterator
3
+ from threading import Thread
4
+
 
 
 
 
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
+ DESCRIPTION = """
11
+ # GWQ PREV
12
+ """
13
 
14
+ MAX_MAX_NEW_TOKENS = 2048
15
+ DEFAULT_MAX_NEW_TOKENS = 1024
16
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
+
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
 
20
+ model_id = "prithivMLmods/GWQ-9B-Preview"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_id,
24
+ device_map="auto",
25
+ torch_dtype=torch.bfloat16,
26
  )
27
+ model.config.sliding_window = 4096
28
+ model.eval()
29
 
 
 
30
 
31
  @spaces.GPU(duration=120)
32
+ def generate(
33
+ message: str,
34
+ chat_history: list[dict],
35
+ max_new_tokens: int = 1024,
36
+ temperature: float = 0.6,
37
+ top_p: float = 0.9,
38
+ top_k: int = 50,
39
+ repetition_penalty: float = 1.2,
40
+ ) -> Iterator[str]:
41
+ conversation = chat_history.copy()
42
+ conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
45
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
47
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
+ input_ids = input_ids.to(model.device)
49
 
50
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
51
+ generate_kwargs = dict(
52
+ {"input_ids": input_ids},
53
+ streamer=streamer,
54
+ max_new_tokens=max_new_tokens,
55
+ do_sample=True,
56
+ top_p=top_p,
57
+ top_k=top_k,
58
+ temperature=temperature,
59
+ num_beams=1,
60
+ repetition_penalty=repetition_penalty,
61
  )
62
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
63
+ t.start()
 
 
 
 
 
 
64
 
65
+ outputs = []
66
+ for text in streamer:
67
+ outputs.append(text)
68
+ yield "".join(outputs)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  demo = gr.ChatInterface(
72
+ fn=generate,
73
  additional_inputs=[
74
+ gr.Slider(
75
+ label="Max new tokens",
76
+ minimum=1,
77
+ maximum=MAX_MAX_NEW_TOKENS,
78
+ step=1,
79
+ value=DEFAULT_MAX_NEW_TOKENS,
80
  ),
 
 
 
81
  gr.Slider(
82
+ label="Temperature",
83
  minimum=0.1,
84
+ maximum=4.0,
85
+ step=0.1,
86
+ value=0.6,
87
+ ),
88
+ gr.Slider(
89
+ label="Top-p (nucleus sampling)",
90
+ minimum=0.05,
91
  maximum=1.0,
 
92
  step=0.05,
93
+ value=0.9,
94
  ),
95
  gr.Slider(
 
 
 
 
96
  label="Top-k",
97
+ minimum=1,
98
+ maximum=1000,
99
+ step=1,
100
+ value=50,
101
  ),
102
  gr.Slider(
 
 
 
 
103
  label="Repetition penalty",
104
+ minimum=1.0,
105
+ maximum=2.0,
106
+ step=0.05,
107
+ value=1.2,
108
  ),
109
  ],
110
+ stop_btn=None,
111
+ examples=[
112
+ ["Hello there! How are you doing?"],
113
+ ["Can you explain briefly to me what is the Python programming language?"],
114
+ ["Explain the plot of Cinderella in a sentence."],
115
+ ["How many hours does it take a man to eat a Helicopter?"],
116
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
117
+ ],
118
+ cache_examples=False,
119
+ type="messages",
120
+ description=DESCRIPTION,
121
+ css_paths="style.css",
122
+ fill_height=True,
123
  )
124
 
125
+
126
  if __name__ == "__main__":
127
+ demo.queue(max_size=20).launch()