FredZhang7 commited on
Commit
806fc8c
·
1 Parent(s): c8004c8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+
3
+ url = "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-1B5-v2-20231025-ctx4096.pth"
4
+ filename = "RWKV-5-World-1B5-v2-20231025-ctx4096.pth"
5
+ urllib.request.urlretrieve(url, filename)
6
+
7
+ import gradio as gr
8
+ from rwkv.model import RWKV
9
+ import torch
10
+
11
+ # Load the model
12
+ model = RWKV(model='RWKV-5-World-1B5-v2-20231025-ctx4096.pth', strategy='cpu bf16')
13
+
14
+ def chatbot_model(instruction, input_prompt, temperature, token_count, top_p, presence_penalty, count_penalty):
15
+ # Set the parameters
16
+ model.temperature = temperature
17
+ model.token_count = token_count
18
+ model.top_p = top_p
19
+ model.presence_penalty = presence_penalty
20
+ model.count_penalty = count_penalty
21
+
22
+ # Generate the output
23
+ out, state = model.forward([instruction, input_prompt], None)
24
+ return out.detach().cpu().numpy()
25
+
26
+ # Define the Gradio interface
27
+ iface = gr.Interface(
28
+ fn=chatbot_model,
29
+ inputs=[
30
+ gr.inputs.Textbox(lines=2, label="Instruction"),
31
+ gr.inputs.Textbox(lines=2, label="Input Prompt"),
32
+ gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Temperature"),
33
+ gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50, label="Token Count"),
34
+ gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Top P"),
35
+ gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Presence Penalty"),
36
+ gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Count Penalty"),
37
+ ],
38
+ outputs=gr.outputs.Textbox(),
39
+ theme="dark" # Change to "dark" for dark mode
40
+ )
41
+
42
+ # Launch the interface
43
+ iface.launch()