Daemontatox commited on
Commit
13880c3
·
verified ·
1 Parent(s): d83f798

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -183
app.py CHANGED
@@ -1,26 +1,56 @@
1
- import subprocess
2
- subprocess.run(
3
- 'pip install flash-attn --no-build-isolation',
4
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
5
- shell=True
6
- )
7
-
8
  import os
 
9
  import time
10
- import spaces
11
  import torch
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
13
  import gradio as gr
14
  from threading import Thread
 
 
 
 
 
 
15
 
16
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
- MODEL = "Daemontatox/AetherDrake"
 
 
 
 
 
 
 
 
 
18
 
19
- TITLE = "<h1><center>Tags Reasoner</center></h1>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
21
  PLACEHOLDER = """
22
  <center>
23
- <p>Ask me Anything !!</p>
24
  </center>
25
  """
26
 
@@ -54,87 +84,115 @@ h3 {
54
  padding: 0.2em 0.4em;
55
  font-family: monospace;
56
  }
 
 
 
 
57
  """
58
 
59
- device = "cuda" # for GPU usage or "cpu" for CPU usage
 
 
 
 
 
 
 
60
 
61
- quantization_config = BitsAndBytesConfig(
62
- load_in_4bit=True, # Use 8-bit instead of 4-bit
63
- bnb_4bit_compute_dtype=torch.bfloat16, # bfloat16 for compute
64
- bnb_4bit_use_double_quant=True # Disable double quantization
65
- )
66
 
67
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
68
- model = AutoModelForCausalLM.from_pretrained(
69
- MODEL,
70
- torch_dtype=torch.float16,
71
- device_map="auto",
72
- attn_implementation="flash_attention_2",
73
- quantization_config=quantization_config)
 
74
 
75
- # Ensure `pad_token_id` is set
76
- if tokenizer.pad_token_id is None:
77
- tokenizer.pad_token_id = tokenizer.eos_token_id
78
 
79
  def format_text(text):
80
- """Helper function to format text with proper line breaks and spacing"""
81
- # Replace single newlines with double newlines for paragraph spacing
82
- formatted = text.replace('\n', '\n\n')
83
- # Remove extra spaces between paragraphs
84
- formatted = '\n'.join(line.strip() for line in formatted.split('\n'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  return formatted
86
 
87
  @spaces.GPU()
88
  def stream_chat(
89
- message: str,
90
  history: list,
91
  system_prompt: str,
92
- temperature: float = 1.0,
93
- max_new_tokens: int = 8192,
94
- top_p: float = 1.0,
95
- top_k: int = 20,
96
  penalty: float = 1.2,
97
  ):
98
- print(f'message: {message}')
99
- print(f'history: {history}')
100
-
101
  conversation = [
102
  {"role": "system", "content": system_prompt}
103
  ]
 
 
104
  for prompt, answer in history:
105
  conversation.extend([
106
- {"role": "user", "content": prompt},
107
- {"role": "assistant", "content": answer},
108
  ])
109
-
 
110
  conversation.append({"role": "user", "content": message})
111
-
 
112
  input_ids = tokenizer.apply_chat_template(
113
- conversation,
114
- add_generation_prompt=True,
115
  return_tensors="pt"
116
  ).to(model.device)
117
 
 
118
  streamer = TextIteratorStreamer(
119
- tokenizer,
120
- timeout=60.0,
121
- skip_prompt=True,
122
  skip_special_tokens=True
123
  )
124
 
 
125
  generate_kwargs = dict(
126
- input_ids=input_ids,
127
  max_new_tokens=max_new_tokens,
128
  do_sample=False if temperature == 0 else True,
129
  top_p=top_p,
130
  top_k=top_k,
131
- eos_token_id=tokenizer.eos_token_id,
132
- pad_token_id=tokenizer.pad_token_id,
133
  temperature=temperature,
134
  repetition_penalty=penalty,
135
  streamer=streamer,
136
  )
137
-
 
138
  buffer = ""
139
  current_line = ""
140
 
@@ -142,140 +200,113 @@ def stream_chat(
142
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
143
  thread.start()
144
 
145
- for new_text in streamer:
146
- # Add the new text to both buffers
147
- buffer += new_text
148
- current_line += new_text
149
-
150
- # Check if we have complete lines to process
151
- if '\n' in current_line:
152
- lines = current_line.split('\n')
153
- # The last element might be incomplete, so keep it in current_line
154
- current_line = lines[-1]
155
- # Format the complete text
156
- formatted_buffer = format_text(buffer)
157
- yield formatted_buffer
158
- else:
159
- yield buffer
160
-
161
 
162
- chatbot = gr.Chatbot(
163
- height=600,
164
- placeholder=PLACEHOLDER,
165
- bubble_full_width=False,
166
- show_copy_button=True
167
- )
 
168
 
169
- DEFAULT_SYSTEM_PROMPT = """You are a Sentient Reasoning AI , expert at providing high-quality answers.
170
- Your process involves these steps:
171
- 1. Initial Thought: Use the <Thinking> tag to reason step-by-step and generate your best possible response to the following request: [User's Request Here].
172
- Example:
173
- <Thinking>
174
- Step 1: Understand the request.
175
- Step 2: Analyze potential solutions.
176
- Step 3: Choose the optimal response.
177
- </Thinking>
178
- 2. Self-Critique: Critically evaluate your initial response within <Critique> tags,
179
- focusing on:
180
- Accuracy: Is it factually correct and verifiable?
181
- Clarity: Is it easy to understand and free of ambiguity?
182
- Completeness: Does it fully address the user's request?
183
- Improvement: What specific aspects could be better?
184
- Example:
185
- <Critique>
186
- Accuracy: Verified.
187
- Clarity: Needs simplification.
188
- Completeness: Add examples.
189
- </Critique>
190
- 3. Revision: Based on your critique, use <Revising> tags to refine and improve your response.
191
- Example:
192
- <Revising>
193
- Adjusting for clarity and adding an example to improve understanding.
194
- </Revising>
195
- 4. Final Response: Present your revised answer clearly within <Final> tags.
196
- Example:
197
- <Final>
198
- This is the improved response.
199
- </Final>
200
- 5. Tag Innovation: If necessary, create and define new tags to better structure your reasoning or enhance clarity. Use them consistently.
201
- Example:
202
- <Definition>
203
- This tag defines a new term introduced in the response.
204
- </Definition>
205
- Ensure every part of your thought process and output is properly enclosed in appropriate tags for clarity and organization."""
206
 
207
- with gr.Blocks(css=CSS, theme="soft") as demo:
208
- gr.HTML(TITLE)
209
- gr.DuplicateButton(
210
- value="Duplicate Space for private use",
211
- elem_classes="duplicate-button"
 
 
 
 
 
 
 
212
  )
213
 
214
- gr.ChatInterface(
215
- fn=stream_chat,
216
- chatbot=chatbot,
217
- fill_height=True,
218
- additional_inputs_accordion=gr.Accordion(
219
- label="⚙️ Parameters",
220
- open=False,
221
- render=False
222
- ),
223
- additional_inputs=[
224
- gr.Textbox(
225
- value=DEFAULT_SYSTEM_PROMPT,
226
- label="System Prompt",
227
- lines=5,
228
- render=False,
229
- ),
230
- gr.Slider(
231
- minimum=0,
232
- maximum=1,
233
- step=0.1,
234
- value=0.2,
235
- label="Temperature",
236
- render=False,
237
- ),
238
- gr.Slider(
239
- minimum=128,
240
- maximum=32000,
241
- step=1,
242
- value=8192,
243
- label="Max new tokens",
244
- render=False,
245
- ),
246
- gr.Slider(
247
- minimum=0.0,
248
- maximum=1.0,
249
- step=0.1,
250
- value=1.0,
251
- label="top_p",
252
- render=False,
253
- ),
254
- gr.Slider(
255
- minimum=1,
256
- maximum=20,
257
- step=1,
258
- value=20,
259
- label="top_k",
260
- render=False,
261
- ),
262
- gr.Slider(
263
- minimum=0.0,
264
- maximum=2.0,
265
- step=0.1,
266
- value=1.2,
267
- label="Repetition penalty",
268
- render=False,
269
  ),
270
- ],
271
- examples=[
272
- ["What is meant by a Singularity?"],
273
- ["Explain the theory of Relativity"],
274
- ["Explain your thought process in details"],
275
- ["Explain how mamba2 structure LLMs work and how do they differ from transformers?"],
276
- ],
277
- cache_examples=False,
278
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  if __name__ == "__main__":
 
281
  demo.launch()
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
  import time
 
4
  import torch
5
+ import spaces
6
  import gradio as gr
7
  from threading import Thread
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ BitsAndBytesConfig,
12
+ TextIteratorStreamer
13
+ )
14
 
15
+ # Configuration Constants
16
+ MODEL_ID = "Daemontatox/AetherDrake"
17
+ DEFAULT_SYSTEM_PROMPT = """You are a Sentient Reasoning AI, expert at providing high-quality answers.
18
+ Your process involves these steps:
19
+ 1. Initial Thought: Use the <Thinking> tag to reason step-by-step about any given request.
20
+ Example:
21
+ <Thinking>
22
+ Step 1: Understand the core request
23
+ Step 2: Analyze key components
24
+ Step 3: Formulate comprehensive response
25
+ </Thinking>
26
 
27
+ 2. Self-Critique: Use <Critique> tags to evaluate your response:
28
+ <Critique>
29
+ - Accuracy: Verify facts and logic
30
+ - Clarity: Assess explanation clarity
31
+ - Completeness: Check all points addressed
32
+ - Improvements: Identify enhancement areas
33
+ </Critique>
34
+
35
+ 3. Revision: Use <Revising> tags to refine your response:
36
+ <Revising>
37
+ Making identified improvements...
38
+ Enhancing clarity...
39
+ Adding examples...
40
+ </Revising>
41
+
42
+ 4. Final Response: Present your polished answer in <Final> tags:
43
+ <Final>
44
+ Your complete, refined response goes here.
45
+ </Final>
46
 
47
+ Always organize your responses using these tags for clear reasoning structure."""
48
+
49
+ # UI Configuration
50
+ TITLE = "<h1><center>AI Reasoning Assistant</center></h1>"
51
  PLACEHOLDER = """
52
  <center>
53
+ <p>Ask me anything! I'll think through it step by step.</p>
54
  </center>
55
  """
56
 
 
84
  padding: 0.2em 0.4em;
85
  font-family: monospace;
86
  }
87
+ .custom-tag {
88
+ color: #0066cc;
89
+ font-weight: bold;
90
+ }
91
  """
92
 
93
+ def initialize_model():
94
+ """Initialize the model with appropriate configurations"""
95
+ # Quantization configuration
96
+ quantization_config = BitsAndBytesConfig(
97
+ load_in_4bit=True,
98
+ bnb_4bit_compute_dtype=torch.bfloat16,
99
+ bnb_4bit_use_double_quant=True
100
+ )
101
 
102
+ # Initialize tokenizer
103
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
104
+ if tokenizer.pad_token_id is None:
105
+ tokenizer.pad_token_id = tokenizer.eos_token_id
 
106
 
107
+ # Initialize model
108
+ model = AutoModelForCausalLM.from_pretrained(
109
+ MODEL_ID,
110
+ torch_dtype=torch.float16,
111
+ device_map="auto",
112
+ attn_implementation="flash_attention_2",
113
+ quantization_config=quantization_config
114
+ )
115
 
116
+ return model, tokenizer
 
 
117
 
118
  def format_text(text):
119
+ """Format text with proper spacing and tag highlighting"""
120
+ # Add newlines around tags
121
+ tag_patterns = [
122
+ (r'<Thinking>', '\n<Thinking>\n'),
123
+ (r'</Thinking>', '\n</Thinking>\n'),
124
+ (r'<Critique>', '\n<Critique>\n'),
125
+ (r'</Critique>', '\n</Critique>\n'),
126
+ (r'<Revising>', '\n<Revising>\n'),
127
+ (r'</Revising>', '\n</Revising>\n'),
128
+ (r'<Final>', '\n<Final>\n'),
129
+ (r'</Final>', '\n</Final>\n')
130
+ ]
131
+
132
+ formatted = text
133
+ for pattern, replacement in tag_patterns:
134
+ formatted = re.sub(pattern, replacement, formatted)
135
+
136
+ # Remove extra blank lines
137
+ formatted = '\n'.join(line for line in formatted.split('\n') if line.strip())
138
+
139
  return formatted
140
 
141
  @spaces.GPU()
142
  def stream_chat(
143
+ message: str,
144
  history: list,
145
  system_prompt: str,
146
+ temperature: float = 0.2,
147
+ max_new_tokens: int = 8192,
148
+ top_p: float = 1.0,
149
+ top_k: int = 20,
150
  penalty: float = 1.2,
151
  ):
152
+ """Generate streaming chat responses with proper tag handling"""
153
+ # Format conversation context
 
154
  conversation = [
155
  {"role": "system", "content": system_prompt}
156
  ]
157
+
158
+ # Add conversation history
159
  for prompt, answer in history:
160
  conversation.extend([
161
+ {"role": "user", "content": prompt},
162
+ {"role": "assistant", "content": answer}
163
  ])
164
+
165
+ # Add current message
166
  conversation.append({"role": "user", "content": message})
167
+
168
+ # Prepare input for model
169
  input_ids = tokenizer.apply_chat_template(
170
+ conversation,
171
+ add_generation_prompt=True,
172
  return_tensors="pt"
173
  ).to(model.device)
174
 
175
+ # Configure streamer
176
  streamer = TextIteratorStreamer(
177
+ tokenizer,
178
+ timeout=60.0,
179
+ skip_prompt=True,
180
  skip_special_tokens=True
181
  )
182
 
183
+ # Set generation parameters
184
  generate_kwargs = dict(
185
+ input_ids=input_ids,
186
  max_new_tokens=max_new_tokens,
187
  do_sample=False if temperature == 0 else True,
188
  top_p=top_p,
189
  top_k=top_k,
 
 
190
  temperature=temperature,
191
  repetition_penalty=penalty,
192
  streamer=streamer,
193
  )
194
+
195
+ # Generate and stream response
196
  buffer = ""
197
  current_line = ""
198
 
 
200
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
201
  thread.start()
202
 
203
+ for new_text in streamer:
204
+ buffer += new_text
205
+ current_line += new_text
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ if '\n' in current_line:
208
+ lines = current_line.split('\n')
209
+ current_line = lines[-1]
210
+ formatted_buffer = format_text(buffer)
211
+ yield formatted_buffer
212
+ else:
213
+ yield buffer
214
 
215
+ def create_examples():
216
+ """Create example queries that demonstrate the system's capabilities"""
217
+ return [
218
+ ["Explain how neural networks learn through backpropagation."],
219
+ ["What are the key differences between classical and quantum computing?"],
220
+ ["Analyze the environmental impact of renewable energy sources."],
221
+ ["How does the human memory system work?"],
222
+ ["Explain the concept of ethical AI and its importance."]
223
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ def main():
226
+ """Main function to set up and launch the Gradio interface"""
227
+ # Initialize model and tokenizer
228
+ global model, tokenizer
229
+ model, tokenizer = initialize_model()
230
+
231
+ # Create chatbot interface
232
+ chatbot = gr.Chatbot(
233
+ height=600,
234
+ placeholder=PLACEHOLDER,
235
+ bubble_full_width=False,
236
+ show_copy_button=True
237
  )
238
 
239
+ # Create interface
240
+ with gr.Blocks(css=CSS, theme="soft") as demo:
241
+ gr.HTML(TITLE)
242
+ gr.DuplicateButton(
243
+ value="Duplicate Space for private use",
244
+ elem_classes="duplicate-button"
245
+ )
246
+
247
+ gr.ChatInterface(
248
+ fn=stream_chat,
249
+ chatbot=chatbot,
250
+ fill_height=True,
251
+ additional_inputs_accordion=gr.Accordion(
252
+ label="⚙️ Advanced Settings",
253
+ open=False,
254
+ render=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  ),
256
+ additional_inputs=[
257
+ gr.Textbox(
258
+ value=DEFAULT_SYSTEM_PROMPT,
259
+ label="System Prompt",
260
+ lines=5,
261
+ render=False,
262
+ ),
263
+ gr.Slider(
264
+ minimum=0,
265
+ maximum=1,
266
+ step=0.1,
267
+ value=0.2,
268
+ label="Temperature",
269
+ render=False,
270
+ ),
271
+ gr.Slider(
272
+ minimum=128,
273
+ maximum=32000,
274
+ step=128,
275
+ value=8192,
276
+ label="Max Tokens",
277
+ render=False,
278
+ ),
279
+ gr.Slider(
280
+ minimum=0.1,
281
+ maximum=1.0,
282
+ step=0.1,
283
+ value=1.0,
284
+ label="Top-p",
285
+ render=False,
286
+ ),
287
+ gr.Slider(
288
+ minimum=1,
289
+ maximum=100,
290
+ step=1,
291
+ value=20,
292
+ label="Top-k",
293
+ render=False,
294
+ ),
295
+ gr.Slider(
296
+ minimum=1.0,
297
+ maximum=2.0,
298
+ step=0.1,
299
+ value=1.2,
300
+ label="Repetition Penalty",
301
+ render=False,
302
+ ),
303
+ ],
304
+ examples=create_examples(),
305
+ cache_examples=False,
306
+ )
307
+
308
+ return demo
309
 
310
  if __name__ == "__main__":
311
+ demo = main()
312
  demo.launch()