MathMentor / app.py
yasserrmd's picture
Update app.py
4cfe4c5 verified
from fastapi import FastAPI, WebSocket, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from huggingface_hub import InferenceClient
import os
import json
import asyncio
app = FastAPI()
# Mount static files directory
app.mount("/static", StaticFiles(directory="static"), name="static")
# Setup Jinja2 templates
templates = Jinja2Templates(directory="templates")
# Initialize the Hugging Face Inference Client
client = InferenceClient()
async def generate_stream_response(prompt_template: str, **kwargs):
"""
Generate a streaming response using Hugging Face Inference Client
Args:
prompt_template (str): The prompt template to use
**kwargs: Dynamic arguments to format the prompt
Yields:
str: Streamed content chunks
"""
# Construct the prompt (you'll need to set up environment variables or a prompt mapping)
prompt = os.getenv(prompt_template).format(**kwargs)
# Prepare messages for the model
messages = [
{"role": "user", "content": prompt}
]
try:
# Create a stream for the chat completion
stream = client.chat.completions.create(
model="Qwen/Qwen2.5-Math-1.5B-Instruct",
messages=messages,
temperature=0.7,
max_tokens=1024,
top_p=0.8,
stream=True
)
# Stream the generated content
for chunk in stream:
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
yield f"Error occurred: {str(e)}"
@app.websocket("/ws/{endpoint}")
async def websocket_endpoint(websocket: WebSocket, endpoint: str):
"""
WebSocket endpoint for streaming responses
Args:
websocket (WebSocket): The WebSocket connection
endpoint (str): The specific endpoint/task to process
"""
await websocket.accept()
try:
# Receive the initial message with parameters
data = await websocket.receive_json()
# Map the endpoint to the appropriate prompt template
endpoint_prompt_map = {
"solve": "PROMPT_SOLVE",
"hint": "PROMPT_HINT",
"verify": "PROMPT_VERIFY",
"generate": "PROMPT_GENERATE",
"explain": "PROMPT_EXPLAIN"
}
# Get the appropriate prompt template
prompt_template = endpoint_prompt_map.get(endpoint)
if not prompt_template:
await websocket.send_json({"error": "Invalid endpoint"})
return
# Stream the response
full_response = ""
async for chunk in generate_stream_response(prompt_template, **data):
full_response += chunk
await websocket.send_json({"chunk": chunk})
# Send a final message to indicate streaming is complete
await websocket.send_json({"complete": True, "full_response": full_response})
except Exception as e:
await websocket.send_json({"error": str(e)})
finally:
await websocket.close()
# Existing routes remain the same as in the previous implementation
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return HTMLResponse(open("static/index.html").read())