File size: 2,855 Bytes
47a2526 7360ef0 47a2526 09454cf 7360ef0 09454cf 7360ef0 09454cf 7360ef0 09454cf 47a2526 7360ef0 09454cf 7360ef0 09454cf 7360ef0 09454cf 7360ef0 09454cf 7360ef0 09454cf 20d974d 09454cf 7360ef0 09454cf 7360ef0 09454cf 7360ef0 09454cf 7360ef0 09454cf 7360ef0 09454cf 7360ef0 09454cf 7360ef0 de44ff5 7360ef0 09454cf 7360ef0 47a2526 09454cf 7360ef0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
from gradio_client import Client, handle_file
MODELS = {"Paligemma-10B": "akhaliq/paligemma2-10b-ft-docci-448"}
def create_chat_fn(client, system_prompt, temperature, max_tokens, top_k, rep_penalty, top_p):
def chat(message, history):
text = message.get("text", "")
files = message.get("files", [])
processed_files = [handle_file(f) for f in files]
response = client.predict(
message={"text": text, "files": processed_files},
system_prompt=system_prompt,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
repetition_penalty=rep_penalty,
top_p=top_p,
api_name="/chat",
)
return response
return chat
def set_client_for_session(model_name, request: gr.Request):
headers = {}
if request and hasattr(request, "headers"):
x_ip_token = request.headers.get("x-ip-token")
if x_ip_token:
headers["X-IP-Token"] = x_ip_token
return Client(MODELS[model_name], headers=headers)
def safe_chat_fn(message, history, client, system_prompt, temperature, max_tokens, top_k, rep_penalty, top_p):
if client is None:
return "Error: Client not initialized. Please refresh the page."
try:
return create_chat_fn(client, system_prompt, temperature, max_tokens, top_k, rep_penalty, top_p)(
message, history
)
except Exception as e:
print(f"Error during chat: {e!s}")
return f"Error during chat: {e!s}"
with gr.Blocks() as demo:
client = gr.State()
with gr.Accordion("Advanced Settings", open=False):
system_prompt = gr.Textbox(value="You are a helpful AI assistant.", label="System Prompt")
with gr.Row():
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, label="Top P")
with gr.Row():
top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top K")
rep_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty")
max_tokens = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Tokens")
chat_interface = gr.ChatInterface(
fn=safe_chat_fn,
additional_inputs=[client, system_prompt, temperature, max_tokens, top_k, rep_penalty, top_p],
multimodal=True,
)
# Initialize client on page load with default model
demo.load(fn=set_client_for_session, inputs=[gr.State("Paligemma-10B")], outputs=[client]) # Using default model
# Move the API access check here, after demo is defined
if hasattr(demo, "fns"):
for fn in demo.fns.values():
fn.api_name = False
if __name__ == "__main__":
demo.launch()
|