import urllib.request url = "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-1B5-v2-20231025-ctx4096.pth" filename = "RWKV-5-World-1B5-v2-20231025-ctx4096.pth" urllib.request.urlretrieve(url, filename) import gradio as gr from rwkv.model import RWKV import torch # Load the model model = RWKV(model='RWKV-5-World-1B5-v2-20231025-ctx4096.pth', strategy='cpu bf16') def chatbot_model(instruction, input_prompt, temperature, token_count, top_p, presence_penalty, count_penalty): # Set the parameters model.temperature = temperature model.token_count = token_count model.top_p = top_p model.presence_penalty = presence_penalty model.count_penalty = count_penalty # Generate the output out, state = model.forward([instruction, input_prompt], None) return out.detach().cpu().numpy() # Define the Gradio interface iface = gr.Interface( fn=chatbot_model, inputs=[ gr.inputs.Textbox(lines=2, label="Instruction"), gr.inputs.Textbox(lines=2, label="Input Prompt"), gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Temperature"), gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50, label="Token Count"), gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Top P"), gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Presence Penalty"), gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Count Penalty"), ], outputs=gr.outputs.Textbox(), theme="dark" # Change to "dark" for dark mode ) # Launch the interface iface.launch()