FredZhang7's picture
Create app.py
806fc8c
raw
history blame
1.64 kB
import urllib.request
url = "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-1B5-v2-20231025-ctx4096.pth"
filename = "RWKV-5-World-1B5-v2-20231025-ctx4096.pth"
urllib.request.urlretrieve(url, filename)
import gradio as gr
from rwkv.model import RWKV
import torch
# Load the model
model = RWKV(model='RWKV-5-World-1B5-v2-20231025-ctx4096.pth', strategy='cpu bf16')
def chatbot_model(instruction, input_prompt, temperature, token_count, top_p, presence_penalty, count_penalty):
# Set the parameters
model.temperature = temperature
model.token_count = token_count
model.top_p = top_p
model.presence_penalty = presence_penalty
model.count_penalty = count_penalty
# Generate the output
out, state = model.forward([instruction, input_prompt], None)
return out.detach().cpu().numpy()
# Define the Gradio interface
iface = gr.Interface(
fn=chatbot_model,
inputs=[
gr.inputs.Textbox(lines=2, label="Instruction"),
gr.inputs.Textbox(lines=2, label="Input Prompt"),
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Temperature"),
gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50, label="Token Count"),
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Top P"),
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Presence Penalty"),
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Count Penalty"),
],
outputs=gr.outputs.Textbox(),
theme="dark" # Change to "dark" for dark mode
)
# Launch the interface
iface.launch()