Spaces:
Runtime error
Runtime error
import urllib.request | |
url = "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-1B5-v2-20231025-ctx4096.pth" | |
filename = "RWKV-5-World-1B5-v2-20231025-ctx4096.pth" | |
urllib.request.urlretrieve(url, filename) | |
import gradio as gr | |
from rwkv.model import RWKV | |
import torch | |
# Load the model | |
model = RWKV(model='RWKV-5-World-1B5-v2-20231025-ctx4096.pth', strategy='cpu bf16') | |
def chatbot_model(instruction, input_prompt, temperature, token_count, top_p, presence_penalty, count_penalty): | |
# Set the parameters | |
model.temperature = temperature | |
model.token_count = token_count | |
model.top_p = top_p | |
model.presence_penalty = presence_penalty | |
model.count_penalty = count_penalty | |
# Generate the output | |
out, state = model.forward([instruction, input_prompt], None) | |
return out.detach().cpu().numpy() | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=chatbot_model, | |
inputs=[ | |
gr.inputs.Textbox(lines=2, label="Instruction"), | |
gr.inputs.Textbox(lines=2, label="Input Prompt"), | |
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Temperature"), | |
gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50, label="Token Count"), | |
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Top P"), | |
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Presence Penalty"), | |
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Count Penalty"), | |
], | |
outputs=gr.outputs.Textbox(), | |
theme="dark" # Change to "dark" for dark mode | |
) | |
# Launch the interface | |
iface.launch() | |