yym68686 commited on
Commit
819dd2f
·
1 Parent(s): ad7d342

Fix the bug of tools search error.

Browse files
Files changed (4) hide show
  1. main.py +30 -4
  2. models.py +1 -0
  3. request.py +24 -14
  4. response.py +90 -83
main.py CHANGED
@@ -1,10 +1,10 @@
1
- import os
2
  import json
3
  import httpx
4
  import logging
5
  import yaml
6
  import secrets
7
  import traceback
 
8
  from contextlib import asynccontextmanager
9
 
10
  from fastapi import FastAPI, Request, HTTPException, Depends
@@ -21,7 +21,7 @@ from urllib.parse import urlparse
21
  @asynccontextmanager
22
  async def lifespan(app: FastAPI):
23
  # 启动时的代码
24
- timeout = httpx.Timeout(connect=10.0, read=30.0, write=30.0, pool=30.0)
25
  app.state.client = httpx.AsyncClient(timeout=timeout)
26
  yield
27
  # 关闭时的代码
@@ -48,7 +48,7 @@ def load_config():
48
  conf['providers'][index] = provider
49
  api_keys_db = conf['api_keys']
50
  api_list = [item["api"] for item in api_keys_db]
51
- print(json.dumps(conf, indent=4, ensure_ascii=False))
52
  return conf, api_keys_db, api_list
53
  except FileNotFoundError:
54
  print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
@@ -59,6 +59,24 @@ def load_config():
59
 
60
  config, api_keys_db, api_list = load_config()
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  async def process_request(request: RequestModel, provider: Dict):
63
  print("provider: ", provider['provider'])
64
  url = provider['base_url']
@@ -84,7 +102,15 @@ async def process_request(request: RequestModel, provider: Dict):
84
 
85
  if request.stream:
86
  model = provider['model'][request.model]
87
- return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload, engine, model), media_type="text/event-stream")
 
 
 
 
 
 
 
 
88
  else:
89
  return await fetch_response(app.state.client, url, headers, payload)
90
 
 
 
1
  import json
2
  import httpx
3
  import logging
4
  import yaml
5
  import secrets
6
  import traceback
7
+ from fastapi.responses import JSONResponse
8
  from contextlib import asynccontextmanager
9
 
10
  from fastapi import FastAPI, Request, HTTPException, Depends
 
21
  @asynccontextmanager
22
  async def lifespan(app: FastAPI):
23
  # 启动时的代码
24
+ timeout = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=30.0)
25
  app.state.client = httpx.AsyncClient(timeout=timeout)
26
  yield
27
  # 关闭时的代码
 
48
  conf['providers'][index] = provider
49
  api_keys_db = conf['api_keys']
50
  api_list = [item["api"] for item in api_keys_db]
51
+ # print(json.dumps(conf, indent=4, ensure_ascii=False))
52
  return conf, api_keys_db, api_list
53
  except FileNotFoundError:
54
  print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
 
59
 
60
  config, api_keys_db, api_list = load_config()
61
 
62
+ async def error_handling_wrapper(generator, status_code=200):
63
+ try:
64
+ first_item = await generator.__anext__()
65
+ if isinstance(first_item, dict) and "error" in first_item:
66
+ # 如果第一个 yield 的项是错误信息,抛出 HTTPException
67
+ raise HTTPException(status_code=status_code, detail=first_item)
68
+
69
+ # 如果不是错误,创建一个新的生成器,首先yield第一个项,然后yield剩余的项
70
+ async def new_generator():
71
+ yield first_item
72
+ async for item in generator:
73
+ yield item
74
+
75
+ return new_generator()
76
+ except StopAsyncIteration:
77
+ # 处理生成器为空的情况
78
+ return []
79
+
80
  async def process_request(request: RequestModel, provider: Dict):
81
  print("provider: ", provider['provider'])
82
  url = provider['base_url']
 
102
 
103
  if request.stream:
104
  model = provider['model'][request.model]
105
+ try:
106
+ generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
107
+ wrapped_generator = await error_handling_wrapper(generator, status_code=500)
108
+ return StreamingResponse(wrapped_generator, media_type="text/event-stream")
109
+ except HTTPException as e:
110
+ return JSONResponse(status_code=e.status_code, content={"error": str(e.detail)})
111
+ except Exception as e:
112
+ # 处理其他异常
113
+ return JSONResponse(status_code=500, content={"error": str(e)})
114
  else:
115
  return await fetch_response(app.state.client, url, headers, payload)
116
 
models.py CHANGED
@@ -28,6 +28,7 @@ class ContentItem(BaseModel):
28
  class Message(BaseModel):
29
  role: str
30
  name: Optional[str] = None
 
31
  content: Union[str, List[ContentItem]]
32
 
33
  class RequestModel(BaseModel):
 
28
  class Message(BaseModel):
29
  role: str
30
  name: Optional[str] = None
31
+ arguments: Optional[str] = None
32
  content: Union[str, List[ContentItem]]
33
 
34
  class RequestModel(BaseModel):
request.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from models import RequestModel
2
 
3
  async def get_image_message(base64_image, engine = None):
@@ -191,6 +192,9 @@ async def get_claude_payload(request, engine, provider):
191
  else:
192
  content = msg.content
193
  name = msg.name
 
 
 
194
  if name:
195
  # messages.append({"role": "assistant", "name": name, "content": content})
196
  messages.append(
@@ -201,7 +205,7 @@ async def get_claude_payload(request, engine, provider):
201
  "type": "tool_use",
202
  "id": "toolu_01RofFmKHUKsEaZvqESG5Hwz",
203
  "name": name,
204
- "input": {"text": messages[-1]["content"][0]["text"]},
205
  }
206
  ]
207
  }
@@ -223,23 +227,30 @@ async def get_claude_payload(request, engine, provider):
223
  elif msg.role == "system":
224
  system_prompt = content
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  model = provider['model'][request.model]
227
  payload = {
228
  "model": model,
229
  "messages": messages,
230
  "system": system_prompt,
231
  }
232
- # json_post = {
233
- # "model": model or self.engine,
234
- # "messages": self.conversation[convo_id] if pass_history else [{
235
- # "role": "user",
236
- # "content": prompt
237
- # }],
238
- # "temperature": kwargs.get("temperature", self.temperature),
239
- # "top_p": kwargs.get("top_p", self.top_p),
240
- # "max_tokens": model_max_tokens,
241
- # "stream": True,
242
- # }
243
 
244
  miss_fields = [
245
  'model',
@@ -258,7 +269,7 @@ async def get_claude_payload(request, engine, provider):
258
  if request.tools:
259
  tools = []
260
  for tool in request.tools:
261
- print("tool", type(tool), tool)
262
 
263
  json_tool = await gpt2claude_tools_json(tool.dict()["function"])
264
  tools.append(json_tool)
@@ -267,7 +278,6 @@ async def get_claude_payload(request, engine, provider):
267
  payload["tool_choice"] = {
268
  "type": "auto"
269
  }
270
- import json
271
  print("payload", json.dumps(payload, indent=2, ensure_ascii=False))
272
 
273
  return url, headers, payload
 
1
+ import json
2
  from models import RequestModel
3
 
4
  async def get_image_message(base64_image, engine = None):
 
192
  else:
193
  content = msg.content
194
  name = msg.name
195
+ arguments = msg.arguments
196
+ if arguments:
197
+ arguments = json.loads(arguments)
198
  if name:
199
  # messages.append({"role": "assistant", "name": name, "content": content})
200
  messages.append(
 
205
  "type": "tool_use",
206
  "id": "toolu_01RofFmKHUKsEaZvqESG5Hwz",
207
  "name": name,
208
+ "input": arguments,
209
  }
210
  ]
211
  }
 
227
  elif msg.role == "system":
228
  system_prompt = content
229
 
230
+ conversation_len = len(messages) - 1
231
+ message_index = 0
232
+ while message_index < conversation_len:
233
+ if messages[message_index]["role"] == messages[message_index + 1]["role"]:
234
+ if messages[message_index].get("content"):
235
+ if isinstance(messages[message_index]["content"], list):
236
+ messages[message_index]["content"].extend(messages[message_index + 1]["content"])
237
+ elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
238
+ content_list = [{"type": "text", "text": messages[message_index]["content"]}]
239
+ content_list.extend(messages[message_index + 1]["content"])
240
+ messages[message_index]["content"] = content_list
241
+ else:
242
+ messages[message_index]["content"] += messages[message_index + 1]["content"]
243
+ messages.pop(message_index + 1)
244
+ conversation_len = conversation_len - 1
245
+ else:
246
+ message_index = message_index + 1
247
+
248
  model = provider['model'][request.model]
249
  payload = {
250
  "model": model,
251
  "messages": messages,
252
  "system": system_prompt,
253
  }
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  miss_fields = [
256
  'model',
 
269
  if request.tools:
270
  tools = []
271
  for tool in request.tools:
272
+ # print("tool", type(tool), tool)
273
 
274
  json_tool = await gpt2claude_tools_json(tool.dict()["function"])
275
  tools.append(json_tool)
 
278
  payload["tool_choice"] = {
279
  "type": "auto"
280
  }
 
281
  print("payload", json.dumps(payload, indent=2, ensure_ascii=False))
282
 
283
  return url, headers, payload
response.py CHANGED
@@ -1,6 +1,7 @@
1
- from datetime import datetime
2
  import json
3
  import httpx
 
 
4
 
5
  async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
6
  sample_data = {
@@ -34,102 +35,108 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
34
  return sse_response
35
 
36
  async def fetch_gemini_response_stream(client, url, headers, payload, model):
37
- try:
38
- timestamp = datetime.timestamp(datetime.now())
39
- async with client.stream('POST', url, headers=headers, json=payload) as response:
40
- buffer = ""
41
- async for chunk in response.aiter_text():
42
- buffer += chunk
43
- while "\n" in buffer:
44
- line, buffer = buffer.split("\n", 1)
45
- print(line)
46
- if line and '\"text\": \"' in line:
47
- try:
48
- json_data = json.loads( "{" + line + "}")
49
- content = json_data.get('text', '')
50
- content = "\n".join(content.split("\\n"))
51
- sse_string = await generate_sse_response(timestamp, model, content)
52
- yield sse_string
53
- except json.JSONDecodeError:
54
- print(f"无法解析JSON: {line}")
55
-
56
- # 处理缓冲区中剩余的内容
57
- if buffer:
58
- # print(buffer)
59
- if '\"text\": \"' in buffer:
60
  try:
61
- json_data = json.loads(buffer)
62
  content = json_data.get('text', '')
63
  content = "\n".join(content.split("\\n"))
64
  sse_string = await generate_sse_response(timestamp, model, content)
65
  yield sse_string
66
  except json.JSONDecodeError:
67
- print(f"无法解析JSON: {buffer}")
68
 
69
- yield "data: [DONE]\n\n"
70
- except httpx.ConnectError as e:
71
- print(f"连接错误: {e}")
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  async def fetch_gpt_response_stream(client, url, headers, payload):
74
- try:
75
- async with client.stream('POST', url, headers=headers, json=payload) as response:
76
- async for chunk in response.aiter_bytes():
77
- print(chunk.decode('utf-8'))
78
- yield chunk
79
- except httpx.ConnectError as e:
80
- print(f"连接错误: {e}")
81
 
82
  async def fetch_claude_response_stream(client, url, headers, payload, model):
83
- try:
84
- timestamp = datetime.timestamp(datetime.now())
85
- async with client.stream('POST', url, headers=headers, json=payload) as response:
86
- buffer = ""
87
- async for chunk in response.aiter_bytes():
88
- buffer += chunk.decode('utf-8')
89
- while "\n" in buffer:
90
- line, buffer = buffer.split("\n", 1)
91
- print(line)
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- if line.startswith("data:"):
94
- print(line)
95
- line = line[6:]
96
- resp: dict = json.loads(line)
97
- message = resp.get("message")
98
- if message:
99
- tokens_use = resp.get("usage")
100
- role = message.get("role")
101
- if role:
102
- sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
103
- yield sse_string
104
- if tokens_use:
105
- total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
106
- # print("\n\rtotal_tokens", total_tokens)
107
- tool_use = resp.get("content_block")
108
- tools_id = None
109
- function_call_name = None
110
- if tool_use and "tool_use" == tool_use['type']:
111
- # print("tool_use", tool_use)
112
- tools_id = tool_use["id"]
113
- if "name" in tool_use:
114
- function_call_name = tool_use["name"]
115
- sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None)
116
- yield sse_string
117
- delta = resp.get("delta")
118
- # print("delta", delta)
119
- if not delta:
120
- continue
121
- if "text" in delta:
122
- content = delta["text"]
123
- sse_string = await generate_sse_response(timestamp, model, content, None, None)
124
  yield sse_string
125
- if "partial_json" in delta:
126
- # {"type":"input_json_delta","partial_json":""}
127
- function_call_content = delta["partial_json"]
128
- sse_string = await generate_sse_response(timestamp, model, None, None, None, function_call_content)
 
 
 
 
 
 
 
 
129
  yield sse_string
130
- yield "data: [DONE]\n\n"
131
- except httpx.ConnectError as e:
132
- print(f"连接错误: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  async def fetch_response(client, url, headers, payload):
135
  response = await client.post(url, headers=headers, json=payload)
 
 
1
  import json
2
  import httpx
3
+ from datetime import datetime
4
+
5
 
6
  async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
7
  sample_data = {
 
35
  return sse_response
36
 
37
  async def fetch_gemini_response_stream(client, url, headers, payload, model):
38
+ timestamp = datetime.timestamp(datetime.now())
39
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
40
+ buffer = ""
41
+ async for chunk in response.aiter_text():
42
+ buffer += chunk
43
+ while "\n" in buffer:
44
+ line, buffer = buffer.split("\n", 1)
45
+ print(line)
46
+ if line and '\"text\": \"' in line:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  try:
48
+ json_data = json.loads( "{" + line + "}")
49
  content = json_data.get('text', '')
50
  content = "\n".join(content.split("\\n"))
51
  sse_string = await generate_sse_response(timestamp, model, content)
52
  yield sse_string
53
  except json.JSONDecodeError:
54
+ print(f"无法解析JSON: {line}")
55
 
56
+ # 处理缓冲区中剩余的内容
57
+ if buffer:
58
+ # print(buffer)
59
+ if '\"text\": \"' in buffer:
60
+ try:
61
+ json_data = json.loads(buffer)
62
+ content = json_data.get('text', '')
63
+ content = "\n".join(content.split("\\n"))
64
+ sse_string = await generate_sse_response(timestamp, model, content)
65
+ yield sse_string
66
+ except json.JSONDecodeError:
67
+ print(f"无法解析JSON: {buffer}")
68
+
69
+ # yield "data: [DONE]\n\n"
70
 
71
  async def fetch_gpt_response_stream(client, url, headers, payload):
72
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
73
+ async for chunk in response.aiter_bytes():
74
+ print(chunk.decode('utf-8'))
75
+ yield chunk
 
 
 
76
 
77
  async def fetch_claude_response_stream(client, url, headers, payload, model):
78
+ timestamp = datetime.timestamp(datetime.now())
79
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
80
+ # response.raise_for_status()
81
+ if response.status_code == 200:
82
+ print("请求成功,状态码是200")
83
+ else:
84
+ print('\033[31m')
85
+ print(f"请求失败,状态码是{response.status_code},错误信息:")
86
+ error_message = await response.aread()
87
+ error_str = error_message.decode('utf-8', errors='replace')
88
+ error_json = json.loads(error_str)
89
+ print(json.dumps(error_json, indent=4, ensure_ascii=False))
90
+ print('\033[0m')
91
+ yield {"error": f"HTTP Error {response.status_code}", "details": error_json}
92
+ # raise HTTPStatusError(f"HTTP Error {response.status_code}", request=response.request, response=response)
93
+ # raise HTTPException(status_code=response.status_code, detail=error_json)
94
+ buffer = ""
95
+ async for chunk in response.aiter_bytes():
96
+ buffer += chunk.decode('utf-8')
97
+ while "\n" in buffer:
98
+ line, buffer = buffer.split("\n", 1)
99
+ print(line)
100
 
101
+ if line.startswith("data:"):
102
+ print(line)
103
+ line = line[6:]
104
+ resp: dict = json.loads(line)
105
+ message = resp.get("message")
106
+ if message:
107
+ tokens_use = resp.get("usage")
108
+ role = message.get("role")
109
+ if role:
110
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  yield sse_string
112
+ if tokens_use:
113
+ total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
114
+ # print("\n\rtotal_tokens", total_tokens)
115
+ tool_use = resp.get("content_block")
116
+ tools_id = None
117
+ function_call_name = None
118
+ if tool_use and "tool_use" == tool_use['type']:
119
+ # print("tool_use", tool_use)
120
+ tools_id = tool_use["id"]
121
+ if "name" in tool_use:
122
+ function_call_name = tool_use["name"]
123
+ sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None)
124
  yield sse_string
125
+ delta = resp.get("delta")
126
+ # print("delta", delta)
127
+ if not delta:
128
+ continue
129
+ if "text" in delta:
130
+ content = delta["text"]
131
+ sse_string = await generate_sse_response(timestamp, model, content, None, None)
132
+ yield sse_string
133
+ if "partial_json" in delta:
134
+ # {"type":"input_json_delta","partial_json":""}
135
+ function_call_content = delta["partial_json"]
136
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, function_call_content)
137
+ yield sse_string
138
+
139
+ # yield "data: [DONE]\n\n"
140
 
141
  async def fetch_response(client, url, headers, payload):
142
  response = await client.post(url, headers=headers, json=payload)