Spaces:
Running
Running
from fastapi import FastAPI, WebSocket, Request | |
from fastapi.responses import HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
from huggingface_hub import InferenceClient | |
import os | |
import json | |
import asyncio | |
app = FastAPI() | |
# Mount static files directory | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Setup Jinja2 templates | |
templates = Jinja2Templates(directory="templates") | |
# Initialize the Hugging Face Inference Client | |
client = InferenceClient() | |
async def generate_stream_response(prompt_template: str, **kwargs): | |
""" | |
Generate a streaming response using Hugging Face Inference Client | |
Args: | |
prompt_template (str): The prompt template to use | |
**kwargs: Dynamic arguments to format the prompt | |
Yields: | |
str: Streamed content chunks | |
""" | |
# Construct the prompt (you'll need to set up environment variables or a prompt mapping) | |
prompt = os.getenv(prompt_template).format(**kwargs) | |
# Prepare messages for the model | |
messages = [ | |
{"role": "user", "content": prompt} | |
] | |
try: | |
# Create a stream for the chat completion | |
stream = client.chat.completions.create( | |
model="Qwen/Qwen2.5-Math-1.5B-Instruct", | |
messages=messages, | |
temperature=0.7, | |
max_tokens=1024, | |
top_p=0.8, | |
stream=True | |
) | |
# Stream the generated content | |
for chunk in stream: | |
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: | |
yield chunk.choices[0].delta.content | |
except Exception as e: | |
yield f"Error occurred: {str(e)}" | |
async def websocket_endpoint(websocket: WebSocket, endpoint: str): | |
""" | |
WebSocket endpoint for streaming responses | |
Args: | |
websocket (WebSocket): The WebSocket connection | |
endpoint (str): The specific endpoint/task to process | |
""" | |
await websocket.accept() | |
try: | |
# Receive the initial message with parameters | |
data = await websocket.receive_json() | |
# Map the endpoint to the appropriate prompt template | |
endpoint_prompt_map = { | |
"solve": "PROMPT_SOLVE", | |
"hint": "PROMPT_HINT", | |
"verify": "PROMPT_VERIFY", | |
"generate": "PROMPT_GENERATE", | |
"explain": "PROMPT_EXPLAIN" | |
} | |
# Get the appropriate prompt template | |
prompt_template = endpoint_prompt_map.get(endpoint) | |
if not prompt_template: | |
await websocket.send_json({"error": "Invalid endpoint"}) | |
return | |
# Stream the response | |
full_response = "" | |
async for chunk in generate_stream_response(prompt_template, **data): | |
full_response += chunk | |
await websocket.send_json({"chunk": chunk}) | |
# Send a final message to indicate streaming is complete | |
await websocket.send_json({"complete": True, "full_response": full_response}) | |
except Exception as e: | |
await websocket.send_json({"error": str(e)}) | |
finally: | |
await websocket.close() | |
# Existing routes remain the same as in the previous implementation | |
async def home(request: Request): | |
return HTMLResponse(open("static/index.html").read()) | |