Fix the bug of tools search error.
Browse files- main.py +30 -4
- models.py +1 -0
- request.py +24 -14
- response.py +90 -83
main.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
import os
|
2 |
import json
|
3 |
import httpx
|
4 |
import logging
|
5 |
import yaml
|
6 |
import secrets
|
7 |
import traceback
|
|
|
8 |
from contextlib import asynccontextmanager
|
9 |
|
10 |
from fastapi import FastAPI, Request, HTTPException, Depends
|
@@ -21,7 +21,7 @@ from urllib.parse import urlparse
|
|
21 |
@asynccontextmanager
|
22 |
async def lifespan(app: FastAPI):
|
23 |
# 启动时的代码
|
24 |
-
timeout = httpx.Timeout(connect=
|
25 |
app.state.client = httpx.AsyncClient(timeout=timeout)
|
26 |
yield
|
27 |
# 关闭时的代码
|
@@ -48,7 +48,7 @@ def load_config():
|
|
48 |
conf['providers'][index] = provider
|
49 |
api_keys_db = conf['api_keys']
|
50 |
api_list = [item["api"] for item in api_keys_db]
|
51 |
-
print(json.dumps(conf, indent=4, ensure_ascii=False))
|
52 |
return conf, api_keys_db, api_list
|
53 |
except FileNotFoundError:
|
54 |
print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
|
@@ -59,6 +59,24 @@ def load_config():
|
|
59 |
|
60 |
config, api_keys_db, api_list = load_config()
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
async def process_request(request: RequestModel, provider: Dict):
|
63 |
print("provider: ", provider['provider'])
|
64 |
url = provider['base_url']
|
@@ -84,7 +102,15 @@ async def process_request(request: RequestModel, provider: Dict):
|
|
84 |
|
85 |
if request.stream:
|
86 |
model = provider['model'][request.model]
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
else:
|
89 |
return await fetch_response(app.state.client, url, headers, payload)
|
90 |
|
|
|
|
|
1 |
import json
|
2 |
import httpx
|
3 |
import logging
|
4 |
import yaml
|
5 |
import secrets
|
6 |
import traceback
|
7 |
+
from fastapi.responses import JSONResponse
|
8 |
from contextlib import asynccontextmanager
|
9 |
|
10 |
from fastapi import FastAPI, Request, HTTPException, Depends
|
|
|
21 |
@asynccontextmanager
|
22 |
async def lifespan(app: FastAPI):
|
23 |
# 启动时的代码
|
24 |
+
timeout = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=30.0)
|
25 |
app.state.client = httpx.AsyncClient(timeout=timeout)
|
26 |
yield
|
27 |
# 关闭时的代码
|
|
|
48 |
conf['providers'][index] = provider
|
49 |
api_keys_db = conf['api_keys']
|
50 |
api_list = [item["api"] for item in api_keys_db]
|
51 |
+
# print(json.dumps(conf, indent=4, ensure_ascii=False))
|
52 |
return conf, api_keys_db, api_list
|
53 |
except FileNotFoundError:
|
54 |
print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
|
|
|
59 |
|
60 |
config, api_keys_db, api_list = load_config()
|
61 |
|
62 |
+
async def error_handling_wrapper(generator, status_code=200):
|
63 |
+
try:
|
64 |
+
first_item = await generator.__anext__()
|
65 |
+
if isinstance(first_item, dict) and "error" in first_item:
|
66 |
+
# 如果第一个 yield 的项是错误信息,抛出 HTTPException
|
67 |
+
raise HTTPException(status_code=status_code, detail=first_item)
|
68 |
+
|
69 |
+
# 如果不是错误,创建一个新的生成器,首先yield第一个项,然后yield剩余的项
|
70 |
+
async def new_generator():
|
71 |
+
yield first_item
|
72 |
+
async for item in generator:
|
73 |
+
yield item
|
74 |
+
|
75 |
+
return new_generator()
|
76 |
+
except StopAsyncIteration:
|
77 |
+
# 处理生成器为空的情况
|
78 |
+
return []
|
79 |
+
|
80 |
async def process_request(request: RequestModel, provider: Dict):
|
81 |
print("provider: ", provider['provider'])
|
82 |
url = provider['base_url']
|
|
|
102 |
|
103 |
if request.stream:
|
104 |
model = provider['model'][request.model]
|
105 |
+
try:
|
106 |
+
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
107 |
+
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
|
108 |
+
return StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
109 |
+
except HTTPException as e:
|
110 |
+
return JSONResponse(status_code=e.status_code, content={"error": str(e.detail)})
|
111 |
+
except Exception as e:
|
112 |
+
# 处理其他异常
|
113 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
114 |
else:
|
115 |
return await fetch_response(app.state.client, url, headers, payload)
|
116 |
|
models.py
CHANGED
@@ -28,6 +28,7 @@ class ContentItem(BaseModel):
|
|
28 |
class Message(BaseModel):
|
29 |
role: str
|
30 |
name: Optional[str] = None
|
|
|
31 |
content: Union[str, List[ContentItem]]
|
32 |
|
33 |
class RequestModel(BaseModel):
|
|
|
28 |
class Message(BaseModel):
|
29 |
role: str
|
30 |
name: Optional[str] = None
|
31 |
+
arguments: Optional[str] = None
|
32 |
content: Union[str, List[ContentItem]]
|
33 |
|
34 |
class RequestModel(BaseModel):
|
request.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from models import RequestModel
|
2 |
|
3 |
async def get_image_message(base64_image, engine = None):
|
@@ -191,6 +192,9 @@ async def get_claude_payload(request, engine, provider):
|
|
191 |
else:
|
192 |
content = msg.content
|
193 |
name = msg.name
|
|
|
|
|
|
|
194 |
if name:
|
195 |
# messages.append({"role": "assistant", "name": name, "content": content})
|
196 |
messages.append(
|
@@ -201,7 +205,7 @@ async def get_claude_payload(request, engine, provider):
|
|
201 |
"type": "tool_use",
|
202 |
"id": "toolu_01RofFmKHUKsEaZvqESG5Hwz",
|
203 |
"name": name,
|
204 |
-
"input":
|
205 |
}
|
206 |
]
|
207 |
}
|
@@ -223,23 +227,30 @@ async def get_claude_payload(request, engine, provider):
|
|
223 |
elif msg.role == "system":
|
224 |
system_prompt = content
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
model = provider['model'][request.model]
|
227 |
payload = {
|
228 |
"model": model,
|
229 |
"messages": messages,
|
230 |
"system": system_prompt,
|
231 |
}
|
232 |
-
# json_post = {
|
233 |
-
# "model": model or self.engine,
|
234 |
-
# "messages": self.conversation[convo_id] if pass_history else [{
|
235 |
-
# "role": "user",
|
236 |
-
# "content": prompt
|
237 |
-
# }],
|
238 |
-
# "temperature": kwargs.get("temperature", self.temperature),
|
239 |
-
# "top_p": kwargs.get("top_p", self.top_p),
|
240 |
-
# "max_tokens": model_max_tokens,
|
241 |
-
# "stream": True,
|
242 |
-
# }
|
243 |
|
244 |
miss_fields = [
|
245 |
'model',
|
@@ -258,7 +269,7 @@ async def get_claude_payload(request, engine, provider):
|
|
258 |
if request.tools:
|
259 |
tools = []
|
260 |
for tool in request.tools:
|
261 |
-
print("tool", type(tool), tool)
|
262 |
|
263 |
json_tool = await gpt2claude_tools_json(tool.dict()["function"])
|
264 |
tools.append(json_tool)
|
@@ -267,7 +278,6 @@ async def get_claude_payload(request, engine, provider):
|
|
267 |
payload["tool_choice"] = {
|
268 |
"type": "auto"
|
269 |
}
|
270 |
-
import json
|
271 |
print("payload", json.dumps(payload, indent=2, ensure_ascii=False))
|
272 |
|
273 |
return url, headers, payload
|
|
|
1 |
+
import json
|
2 |
from models import RequestModel
|
3 |
|
4 |
async def get_image_message(base64_image, engine = None):
|
|
|
192 |
else:
|
193 |
content = msg.content
|
194 |
name = msg.name
|
195 |
+
arguments = msg.arguments
|
196 |
+
if arguments:
|
197 |
+
arguments = json.loads(arguments)
|
198 |
if name:
|
199 |
# messages.append({"role": "assistant", "name": name, "content": content})
|
200 |
messages.append(
|
|
|
205 |
"type": "tool_use",
|
206 |
"id": "toolu_01RofFmKHUKsEaZvqESG5Hwz",
|
207 |
"name": name,
|
208 |
+
"input": arguments,
|
209 |
}
|
210 |
]
|
211 |
}
|
|
|
227 |
elif msg.role == "system":
|
228 |
system_prompt = content
|
229 |
|
230 |
+
conversation_len = len(messages) - 1
|
231 |
+
message_index = 0
|
232 |
+
while message_index < conversation_len:
|
233 |
+
if messages[message_index]["role"] == messages[message_index + 1]["role"]:
|
234 |
+
if messages[message_index].get("content"):
|
235 |
+
if isinstance(messages[message_index]["content"], list):
|
236 |
+
messages[message_index]["content"].extend(messages[message_index + 1]["content"])
|
237 |
+
elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
|
238 |
+
content_list = [{"type": "text", "text": messages[message_index]["content"]}]
|
239 |
+
content_list.extend(messages[message_index + 1]["content"])
|
240 |
+
messages[message_index]["content"] = content_list
|
241 |
+
else:
|
242 |
+
messages[message_index]["content"] += messages[message_index + 1]["content"]
|
243 |
+
messages.pop(message_index + 1)
|
244 |
+
conversation_len = conversation_len - 1
|
245 |
+
else:
|
246 |
+
message_index = message_index + 1
|
247 |
+
|
248 |
model = provider['model'][request.model]
|
249 |
payload = {
|
250 |
"model": model,
|
251 |
"messages": messages,
|
252 |
"system": system_prompt,
|
253 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
miss_fields = [
|
256 |
'model',
|
|
|
269 |
if request.tools:
|
270 |
tools = []
|
271 |
for tool in request.tools:
|
272 |
+
# print("tool", type(tool), tool)
|
273 |
|
274 |
json_tool = await gpt2claude_tools_json(tool.dict()["function"])
|
275 |
tools.append(json_tool)
|
|
|
278 |
payload["tool_choice"] = {
|
279 |
"type": "auto"
|
280 |
}
|
|
|
281 |
print("payload", json.dumps(payload, indent=2, ensure_ascii=False))
|
282 |
|
283 |
return url, headers, payload
|
response.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
from datetime import datetime
|
2 |
import json
|
3 |
import httpx
|
|
|
|
|
4 |
|
5 |
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
|
6 |
sample_data = {
|
@@ -34,102 +35,108 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
34 |
return sse_response
|
35 |
|
36 |
async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
if line and '\"text\": \"' in line:
|
47 |
-
try:
|
48 |
-
json_data = json.loads( "{" + line + "}")
|
49 |
-
content = json_data.get('text', '')
|
50 |
-
content = "\n".join(content.split("\\n"))
|
51 |
-
sse_string = await generate_sse_response(timestamp, model, content)
|
52 |
-
yield sse_string
|
53 |
-
except json.JSONDecodeError:
|
54 |
-
print(f"无法解析JSON: {line}")
|
55 |
-
|
56 |
-
# 处理缓冲区中剩余的内容
|
57 |
-
if buffer:
|
58 |
-
# print(buffer)
|
59 |
-
if '\"text\": \"' in buffer:
|
60 |
try:
|
61 |
-
json_data = json.loads(
|
62 |
content = json_data.get('text', '')
|
63 |
content = "\n".join(content.split("\\n"))
|
64 |
sse_string = await generate_sse_response(timestamp, model, content)
|
65 |
yield sse_string
|
66 |
except json.JSONDecodeError:
|
67 |
-
print(f"无法解析JSON: {
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
async def fetch_gpt_response_stream(client, url, headers, payload):
|
74 |
-
|
75 |
-
async
|
76 |
-
|
77 |
-
|
78 |
-
yield chunk
|
79 |
-
except httpx.ConnectError as e:
|
80 |
-
print(f"连接错误: {e}")
|
81 |
|
82 |
async def fetch_claude_response_stream(client, url, headers, payload, model):
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
yield sse_string
|
104 |
-
if tokens_use:
|
105 |
-
total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
|
106 |
-
# print("\n\rtotal_tokens", total_tokens)
|
107 |
-
tool_use = resp.get("content_block")
|
108 |
-
tools_id = None
|
109 |
-
function_call_name = None
|
110 |
-
if tool_use and "tool_use" == tool_use['type']:
|
111 |
-
# print("tool_use", tool_use)
|
112 |
-
tools_id = tool_use["id"]
|
113 |
-
if "name" in tool_use:
|
114 |
-
function_call_name = tool_use["name"]
|
115 |
-
sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None)
|
116 |
-
yield sse_string
|
117 |
-
delta = resp.get("delta")
|
118 |
-
# print("delta", delta)
|
119 |
-
if not delta:
|
120 |
-
continue
|
121 |
-
if "text" in delta:
|
122 |
-
content = delta["text"]
|
123 |
-
sse_string = await generate_sse_response(timestamp, model, content, None, None)
|
124 |
yield sse_string
|
125 |
-
if
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
yield sse_string
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
async def fetch_response(client, url, headers, payload):
|
135 |
response = await client.post(url, headers=headers, json=payload)
|
|
|
|
|
1 |
import json
|
2 |
import httpx
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
|
6 |
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
|
7 |
sample_data = {
|
|
|
35 |
return sse_response
|
36 |
|
37 |
async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
38 |
+
timestamp = datetime.timestamp(datetime.now())
|
39 |
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
40 |
+
buffer = ""
|
41 |
+
async for chunk in response.aiter_text():
|
42 |
+
buffer += chunk
|
43 |
+
while "\n" in buffer:
|
44 |
+
line, buffer = buffer.split("\n", 1)
|
45 |
+
print(line)
|
46 |
+
if line and '\"text\": \"' in line:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
try:
|
48 |
+
json_data = json.loads( "{" + line + "}")
|
49 |
content = json_data.get('text', '')
|
50 |
content = "\n".join(content.split("\\n"))
|
51 |
sse_string = await generate_sse_response(timestamp, model, content)
|
52 |
yield sse_string
|
53 |
except json.JSONDecodeError:
|
54 |
+
print(f"无法解析JSON: {line}")
|
55 |
|
56 |
+
# 处理缓冲区中剩余的内容
|
57 |
+
if buffer:
|
58 |
+
# print(buffer)
|
59 |
+
if '\"text\": \"' in buffer:
|
60 |
+
try:
|
61 |
+
json_data = json.loads(buffer)
|
62 |
+
content = json_data.get('text', '')
|
63 |
+
content = "\n".join(content.split("\\n"))
|
64 |
+
sse_string = await generate_sse_response(timestamp, model, content)
|
65 |
+
yield sse_string
|
66 |
+
except json.JSONDecodeError:
|
67 |
+
print(f"无法解析JSON: {buffer}")
|
68 |
+
|
69 |
+
# yield "data: [DONE]\n\n"
|
70 |
|
71 |
async def fetch_gpt_response_stream(client, url, headers, payload):
|
72 |
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
73 |
+
async for chunk in response.aiter_bytes():
|
74 |
+
print(chunk.decode('utf-8'))
|
75 |
+
yield chunk
|
|
|
|
|
|
|
76 |
|
77 |
async def fetch_claude_response_stream(client, url, headers, payload, model):
|
78 |
+
timestamp = datetime.timestamp(datetime.now())
|
79 |
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
80 |
+
# response.raise_for_status()
|
81 |
+
if response.status_code == 200:
|
82 |
+
print("请求成功,状态码是200")
|
83 |
+
else:
|
84 |
+
print('\033[31m')
|
85 |
+
print(f"请求失败,状态码是{response.status_code},错误信息:")
|
86 |
+
error_message = await response.aread()
|
87 |
+
error_str = error_message.decode('utf-8', errors='replace')
|
88 |
+
error_json = json.loads(error_str)
|
89 |
+
print(json.dumps(error_json, indent=4, ensure_ascii=False))
|
90 |
+
print('\033[0m')
|
91 |
+
yield {"error": f"HTTP Error {response.status_code}", "details": error_json}
|
92 |
+
# raise HTTPStatusError(f"HTTP Error {response.status_code}", request=response.request, response=response)
|
93 |
+
# raise HTTPException(status_code=response.status_code, detail=error_json)
|
94 |
+
buffer = ""
|
95 |
+
async for chunk in response.aiter_bytes():
|
96 |
+
buffer += chunk.decode('utf-8')
|
97 |
+
while "\n" in buffer:
|
98 |
+
line, buffer = buffer.split("\n", 1)
|
99 |
+
print(line)
|
100 |
|
101 |
+
if line.startswith("data:"):
|
102 |
+
print(line)
|
103 |
+
line = line[6:]
|
104 |
+
resp: dict = json.loads(line)
|
105 |
+
message = resp.get("message")
|
106 |
+
if message:
|
107 |
+
tokens_use = resp.get("usage")
|
108 |
+
role = message.get("role")
|
109 |
+
if role:
|
110 |
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
yield sse_string
|
112 |
+
if tokens_use:
|
113 |
+
total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
|
114 |
+
# print("\n\rtotal_tokens", total_tokens)
|
115 |
+
tool_use = resp.get("content_block")
|
116 |
+
tools_id = None
|
117 |
+
function_call_name = None
|
118 |
+
if tool_use and "tool_use" == tool_use['type']:
|
119 |
+
# print("tool_use", tool_use)
|
120 |
+
tools_id = tool_use["id"]
|
121 |
+
if "name" in tool_use:
|
122 |
+
function_call_name = tool_use["name"]
|
123 |
+
sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None)
|
124 |
yield sse_string
|
125 |
+
delta = resp.get("delta")
|
126 |
+
# print("delta", delta)
|
127 |
+
if not delta:
|
128 |
+
continue
|
129 |
+
if "text" in delta:
|
130 |
+
content = delta["text"]
|
131 |
+
sse_string = await generate_sse_response(timestamp, model, content, None, None)
|
132 |
+
yield sse_string
|
133 |
+
if "partial_json" in delta:
|
134 |
+
# {"type":"input_json_delta","partial_json":""}
|
135 |
+
function_call_content = delta["partial_json"]
|
136 |
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, function_call_content)
|
137 |
+
yield sse_string
|
138 |
+
|
139 |
+
# yield "data: [DONE]\n\n"
|
140 |
|
141 |
async def fetch_response(client, url, headers, payload):
|
142 |
response = await client.post(url, headers=headers, json=payload)
|