Tonic commited on
Commit
22d5543
·
verified ·
1 Parent(s): 8c00703

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -169
app.py CHANGED
@@ -1,30 +1,19 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- from rwkv.model import RWKV
5
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
6
- from copy import deepcopy
 
 
7
  import requests
8
- import os.path
9
  from tqdm import tqdm
10
 
11
- # Set environment variables
12
- os.environ['RWKV_JIT_ON'] = '1'
13
- os.environ["RWKV_CUDA_ON"] = '0'
14
- os.environ["RWKV_V7_ON"] = '1'
15
-
16
- # Model options
17
- MODELS = {
18
- "0.1B (Smaller)": "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth",
19
- "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
20
- }
21
-
22
- # Download vocab file if not present
23
- VOCAB_FILE = "rwkv_vocab_v20230424.txt"
24
- VOCAB_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/v2/rwkv_vocab_v20230424.txt"
25
 
26
  def download_file(url, filename):
27
- """Generic file downloader with progress bar"""
28
  if not os.path.exists(filename):
29
  print(f"Downloading {filename}...")
30
  response = requests.get(url, stream=True)
@@ -41,195 +30,146 @@ def download_file(url, filename):
41
  size = file.write(data)
42
  pbar.update(size)
43
 
44
- def download_model(model_name):
45
- """Download model if not present"""
46
- if not os.path.exists(model_name):
47
- url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
48
- download_file(url, model_name)
49
-
50
- def ensure_vocab():
51
- """Ensure vocab file is present"""
52
- if not os.path.exists(VOCAB_FILE):
53
- download_file(VOCAB_URL, VOCAB_FILE)
54
 
55
- class ModelManager:
56
- def __init__(self):
57
- self.current_model = None
58
- self.current_model_name = None
59
- self.pipeline = None
60
- ensure_vocab()
61
-
62
- def load_model(self, model_choice):
63
- model_file = MODELS[model_choice]
64
- if model_file != self.current_model_name:
65
- download_model(model_file)
66
- self.current_model = RWKV(
67
- model=model_file,
68
- strategy='cpu fp32'
69
- )
70
- self.pipeline = PIPELINE(self.current_model, VOCAB_FILE)
71
- self.current_model_name = model_file
72
- return self.pipeline
73
 
74
- model_manager = ModelManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- def generate_response(
77
- model_choice,
78
- user_prompt,
79
- system_prompt,
80
- temperature,
81
- top_p,
82
- top_k,
83
- alpha_frequency,
84
- alpha_presence,
85
- alpha_decay,
86
- max_tokens
87
  ):
88
  try:
89
- # Get or load the model
90
- pipeline = model_manager.load_model(model_choice)
91
 
92
- # Prepare the context
93
- if system_prompt.strip():
94
- ctx = f"{system_prompt.strip()}\n\nUser: {user_prompt.strip()}\n\nA:"
95
- else:
96
- ctx = f"User: {user_prompt.strip()}\n\nA:"
97
 
98
- # Prepare generation arguments
99
- args = PIPELINE_ARGS(
 
100
  temperature=temperature,
101
- top_p=top_p,
102
- top_k=top_k,
103
- alpha_frequency=alpha_frequency,
104
- alpha_presence=alpha_presence,
105
- alpha_decay=alpha_decay,
106
- token_ban=[],
107
- token_stop=[],
108
- chunk_len=256
109
  )
110
 
111
- # Generate response
112
- response = ""
113
- def callback(text):
114
- nonlocal response
115
- response += text
116
- return response
117
-
118
- pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback)
119
- return response
120
  except Exception as e:
121
- import traceback
122
- return f"Error: {str(e)}\nStack trace: {traceback.format_exc()}"
123
 
124
- # Create the Gradio interface
125
  with gr.Blocks() as demo:
126
- gr.Markdown("# RWKV-7 Language Model Demo")
127
 
128
  with gr.Row():
129
  with gr.Column():
130
- model_choice = gr.Radio(
131
- choices=list(MODELS.keys()),
132
- value=list(MODELS.keys())[0],
133
- label="Model Selection"
134
  )
135
- system_prompt = gr.Textbox(
136
- label="System Prompt",
137
- placeholder="Optional system prompt to set the context",
138
- lines=3,
139
- value="You are a helpful AI assistant. You provide detailed and accurate responses."
140
- )
141
- user_prompt = gr.Textbox(
142
- label="User Prompt",
143
- placeholder="Enter your prompt here",
144
- lines=3
145
- )
146
- max_tokens = gr.Slider(
147
- minimum=1,
148
- maximum=1000,
149
- value=200,
150
- step=1,
151
- label="Max Tokens"
152
  )
153
 
154
  with gr.Column():
155
- temperature = gr.Slider(
156
  minimum=0.1,
157
  maximum=2.0,
158
  value=1.0,
159
- step=0.1,
160
  label="Temperature"
161
  )
162
- top_p = gr.Slider(
163
- minimum=0.0,
164
  maximum=1.0,
165
- value=0.7,
166
- step=0.05,
167
- label="Top P"
168
  )
169
- top_k = gr.Slider(
170
- minimum=0,
171
- maximum=200,
172
  value=100,
173
- step=1,
174
- label="Top K"
175
- )
176
- alpha_frequency = gr.Slider(
177
- minimum=0.0,
178
- maximum=1.0,
179
- value=0.25,
180
- step=0.05,
181
- label="Alpha Frequency"
182
- )
183
- alpha_presence = gr.Slider(
184
- minimum=0.0,
185
- maximum=1.0,
186
- value=0.25,
187
- step=0.05,
188
- label="Alpha Presence"
189
- )
190
- alpha_decay = gr.Slider(
191
- minimum=0.9,
192
- maximum=1.0,
193
- value=0.996,
194
- step=0.001,
195
- label="Alpha Decay"
196
  )
197
 
198
  generate_button = gr.Button("Generate")
199
- output = gr.Textbox(label="Generated Response", lines=10)
200
 
201
  generate_button.click(
202
- fn=generate_response,
203
  inputs=[
204
- model_choice,
205
- user_prompt,
206
- system_prompt,
207
- temperature,
208
- top_p,
209
- top_k,
210
- alpha_frequency,
211
- alpha_presence,
212
- alpha_decay,
213
- max_tokens
214
  ],
215
- outputs=output
216
  )
217
 
218
  gr.Markdown("""
219
- ## Model Information
220
- - **0.1B Model**: Smaller model, faster but less capable
221
- - **0.4B Model**: Larger model, slower but more capable
222
-
223
- ## Parameter Descriptions
224
- - **Temperature**: Controls randomness in the output (higher = more random)
225
- - **Top P**: Nucleus sampling threshold (lower = more focused)
226
- - **Top K**: Limits the number of tokens considered for each step
227
- - **Alpha Frequency**: Penalizes frequent tokens
228
- - **Alpha Presence**: Penalizes tokens that have appeared before
229
- - **Alpha Decay**: Rate at which penalties decay
230
- - **Max Tokens**: Maximum length of generated response
231
  """)
232
 
233
- # Launch the demo
234
  if __name__ == "__main__":
235
  demo.launch()
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tokenizers import Tokenizer
7
+ import json
8
+ import math
9
  import requests
 
10
  from tqdm import tqdm
11
 
12
+ # Download tokenizer if not present
13
+ TOKENIZER_FILE = "20B_tokenizer.json"
14
+ TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json"
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def download_file(url, filename):
 
17
  if not os.path.exists(filename):
18
  print(f"Downloading {filename}...")
19
  response = requests.get(url, stream=True)
 
30
  size = file.write(data)
31
  pbar.update(size)
32
 
33
+ # Ensure tokenizer exists
34
+ if not os.path.exists(TOKENIZER_FILE):
35
+ download_file(TOKENIZER_URL, TOKENIZER_FILE)
 
 
 
 
 
 
 
36
 
37
+ tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ class RWKV_Model:
40
+ def __init__(self, model_path):
41
+ self.model_path = model_path
42
+ self.model = None
43
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+ def load_model(self):
46
+ if not os.path.exists(self.model_path):
47
+ raise FileNotFoundError(f"Model file {self.model_path} not found")
48
+
49
+ self.model = torch.load(self.model_path, map_location=self.device)
50
+ print("Model loaded successfully")
51
+
52
+ def generate(self, prompt, max_length=100, temperature=1.0, top_p=0.9):
53
+ if self.model is None:
54
+ self.load_model()
55
+
56
+ input_ids = tokenizer.encode(prompt).ids
57
+ input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device)
58
+
59
+ with torch.no_grad():
60
+ output_sequence = []
61
+
62
+ for _ in range(max_length):
63
+ outputs = self.model(input_tensor)
64
+ next_token_logits = outputs[0, -1, :] / temperature
65
+
66
+ # Apply top-p sampling
67
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
68
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
69
+ sorted_indices_to_remove = cumulative_probs > top_p
70
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
71
+ sorted_indices_to_remove[..., 0] = 0
72
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
73
+ next_token_logits[indices_to_remove] = float('-inf')
74
+
75
+ probs = F.softmax(next_token_logits, dim=-1)
76
+ next_token = torch.multinomial(probs, num_samples=1)
77
+
78
+ output_sequence.append(next_token.item())
79
+ input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
80
+
81
+ if next_token.item() == tokenizer.token_to_id("</s>"):
82
+ break
83
+
84
+ return tokenizer.decode(output_sequence)
85
 
86
+ def generate_text(
87
+ prompt,
88
+ temperature=1.0,
89
+ top_p=0.9,
90
+ max_length=100,
91
+ model_size="small"
 
 
 
 
 
92
  ):
93
  try:
94
+ # Select model based on size
95
+ model_path = "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth" if model_size == "small" else "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
96
 
97
+ model = RWKV_Model(model_path)
 
 
 
 
98
 
99
+ generated_text = model.generate(
100
+ prompt=prompt,
101
+ max_length=max_length,
102
  temperature=temperature,
103
+ top_p=top_p
 
 
 
 
 
 
 
104
  )
105
 
106
+ return generated_text
107
+
 
 
 
 
 
 
 
108
  except Exception as e:
109
+ return f"Error: {str(e)}"
 
110
 
111
+ # Create Gradio interface
112
  with gr.Blocks() as demo:
113
+ gr.Markdown("# RWKV-7 Text Generation Demo")
114
 
115
  with gr.Row():
116
  with gr.Column():
117
+ prompt_input = gr.Textbox(
118
+ label="Input Prompt",
119
+ placeholder="Enter your prompt here...",
120
+ lines=5
121
  )
122
+ model_size = gr.Radio(
123
+ choices=["small", "large"],
124
+ label="Model Size",
125
+ value="small"
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  )
127
 
128
  with gr.Column():
129
+ temperature_slider = gr.Slider(
130
  minimum=0.1,
131
  maximum=2.0,
132
  value=1.0,
 
133
  label="Temperature"
134
  )
135
+ top_p_slider = gr.Slider(
136
+ minimum=0.1,
137
  maximum=1.0,
138
+ value=0.9,
139
+ label="Top-p"
 
140
  )
141
+ max_length_slider = gr.Slider(
142
+ minimum=10,
143
+ maximum=500,
144
  value=100,
145
+ step=10,
146
+ label="Maximum Length"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  )
148
 
149
  generate_button = gr.Button("Generate")
150
+ output_text = gr.Textbox(label="Generated Output", lines=10)
151
 
152
  generate_button.click(
153
+ fn=generate_text,
154
  inputs=[
155
+ prompt_input,
156
+ temperature_slider,
157
+ top_p_slider,
158
+ max_length_slider,
159
+ model_size
 
 
 
 
 
160
  ],
161
+ outputs=output_text
162
  )
163
 
164
  gr.Markdown("""
165
+ ## Parameters:
166
+ - **Temperature**: Controls randomness (higher = more random)
167
+ - **Top-p**: Controls diversity (higher = more diverse)
168
+ - **Maximum Length**: Maximum number of tokens to generate
169
+ - **Model Size**:
170
+ - Small (0.1B parameters)
171
+ - Large (0.4B parameters)
 
 
 
 
 
172
  """)
173
 
 
174
  if __name__ == "__main__":
175
  demo.launch()