yasserrmd commited on
Commit
608de90
·
verified ·
1 Parent(s): 8323129

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, Request
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.templating import Jinja2Templates
4
+ from fastapi.staticfiles import StaticFiles
5
+ from huggingface_hub import InferenceClient
6
+ import os
7
+ import json
8
+ import asyncio
9
+
10
+ app = FastAPI()
11
+
12
+ # Mount static files directory
13
+ app.mount("/static", StaticFiles(directory="static"), name="static")
14
+
15
+ # Setup Jinja2 templates
16
+ templates = Jinja2Templates(directory="templates")
17
+
18
+ # Initialize the Hugging Face Inference Client
19
+ client = InferenceClient()
20
+
21
+ async def generate_stream_response(prompt_template: str, **kwargs):
22
+ """
23
+ Generate a streaming response using Hugging Face Inference Client
24
+
25
+ Args:
26
+ prompt_template (str): The prompt template to use
27
+ **kwargs: Dynamic arguments to format the prompt
28
+
29
+ Yields:
30
+ str: Streamed content chunks
31
+ """
32
+ # Construct the prompt (you'll need to set up environment variables or a prompt mapping)
33
+ prompt = os.getenv(prompt_template).format(**kwargs)
34
+
35
+ # Prepare messages for the model
36
+ messages = [
37
+ {"role": "user", "content": prompt}
38
+ ]
39
+
40
+ try:
41
+ # Create a stream for the chat completion
42
+ stream = client.chat.completions.create(
43
+ model="Qwen/Qwen2.5-Math-1.5B-Instruct",
44
+ messages=messages,
45
+ temperature=0.7,
46
+ max_tokens=1024,
47
+ top_p=0.8,
48
+ stream=True
49
+ )
50
+
51
+ # Stream the generated content
52
+ for chunk in stream:
53
+ if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
54
+ yield chunk.choices[0].delta.content
55
+
56
+ except Exception as e:
57
+ yield f"Error occurred: {str(e)}"
58
+
59
+ @app.websocket("/ws/{endpoint}")
60
+ async def websocket_endpoint(websocket: WebSocket, endpoint: str):
61
+ """
62
+ WebSocket endpoint for streaming responses
63
+
64
+ Args:
65
+ websocket (WebSocket): The WebSocket connection
66
+ endpoint (str): The specific endpoint/task to process
67
+ """
68
+ await websocket.accept()
69
+
70
+ try:
71
+ # Receive the initial message with parameters
72
+ data = await websocket.receive_json()
73
+
74
+ # Map the endpoint to the appropriate prompt template
75
+ endpoint_prompt_map = {
76
+ "solve": "PROMPT_SOLVE",
77
+ "hint": "PROMPT_HINT",
78
+ "verify": "PROMPT_VERIFY",
79
+ "generate": "PROMPT_GENERATE",
80
+ "explain": "PROMPT_EXPLAIN"
81
+ }
82
+
83
+ # Get the appropriate prompt template
84
+ prompt_template = endpoint_prompt_map.get(endpoint)
85
+ if not prompt_template:
86
+ await websocket.send_json({"error": "Invalid endpoint"})
87
+ return
88
+
89
+ # Stream the response
90
+ full_response = ""
91
+ async for chunk in generate_stream_response(prompt_template, **data):
92
+ full_response += chunk
93
+ await websocket.send_json({"chunk": chunk})
94
+
95
+ # Send a final message to indicate streaming is complete
96
+ await websocket.send_json({"complete": True, "full_response": full_response})
97
+
98
+ except Exception as e:
99
+ await websocket.send_json({"error": str(e)})
100
+ finally:
101
+ await websocket.close()
102
+
103
+ # Existing routes remain the same as in the previous implementation
104
+ @app.get("/", response_class=HTMLResponse)
105
+ async def home(request: Request):
106
+ return templates.TemplateResponse("index.html", {
107
+ "request": request,
108
+ "title": "Mathematical Insight Tutor"
109
+ })