yym68686 commited on
Commit
2cdadb6
·
1 Parent(s): 93c3f12

Add GPT format test module

Browse files

Fix the bug of decoding error in reading stream response.

Files changed (4) hide show
  1. main.py +1 -2
  2. requirements.txt +2 -1
  3. response.py +8 -7
  4. test/provider_test.py +82 -0
main.py CHANGED
@@ -1,10 +1,9 @@
1
  import json
 
2
  import httpx
3
  import logging
4
- import yaml
5
  import secrets
6
  import traceback
7
- from fastapi.responses import JSONResponse
8
  from contextlib import asynccontextmanager
9
 
10
  from fastapi import FastAPI, Request, HTTPException, Depends
 
1
  import json
2
+ import yaml
3
  import httpx
4
  import logging
 
5
  import secrets
6
  import traceback
 
7
  from contextlib import asynccontextmanager
8
 
9
  from fastapi import FastAPI, Request, HTTPException, Depends
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- fastapi
 
 
1
+ fastapi
2
+ pytest
response.py CHANGED
@@ -77,12 +77,11 @@ async def fetch_gpt_response_stream(client, url, headers, payload):
77
  print(json.dumps(error_json, indent=4, ensure_ascii=False))
78
  yield {"error": f"HTTP Error {response.status_code}", "details": error_json}
79
  buffer = ""
80
- async for chunk in response.aiter_bytes():
81
- # print("chunk.decode('utf-8')", chunk.decode('utf-8'))
82
- buffer += chunk.decode('utf-8')
83
  while "\n" in buffer:
84
  line, buffer = buffer.split("\n", 1)
85
- # print(line)
86
  yield line + "\n"
87
 
88
  async def fetch_claude_response_stream(client, url, headers, payload, model):
@@ -98,14 +97,13 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
98
  print('\033[0m')
99
  yield {"error": f"HTTP Error {response.status_code}", "details": error_json}
100
  buffer = ""
101
- async for chunk in response.aiter_bytes():
102
- buffer += chunk.decode('utf-8')
103
  while "\n" in buffer:
104
  line, buffer = buffer.split("\n", 1)
105
  # print(line)
106
 
107
  if line.startswith("data:"):
108
- print(line)
109
  line = line[6:]
110
  resp: dict = json.loads(line)
111
  message = resp.get("message")
@@ -166,4 +164,7 @@ async def fetch_response_stream(client, url, headers, payload, engine, model):
166
  break
167
  except httpx.ConnectError as e:
168
  print(f"连接错误: {e}")
 
 
 
169
  continue
 
77
  print(json.dumps(error_json, indent=4, ensure_ascii=False))
78
  yield {"error": f"HTTP Error {response.status_code}", "details": error_json}
79
  buffer = ""
80
+ async for chunk in response.aiter_text():
81
+ # print(chunk)
82
+ buffer += chunk
83
  while "\n" in buffer:
84
  line, buffer = buffer.split("\n", 1)
 
85
  yield line + "\n"
86
 
87
  async def fetch_claude_response_stream(client, url, headers, payload, model):
 
97
  print('\033[0m')
98
  yield {"error": f"HTTP Error {response.status_code}", "details": error_json}
99
  buffer = ""
100
+ async for chunk in response.aiter_text():
101
+ buffer += chunk
102
  while "\n" in buffer:
103
  line, buffer = buffer.split("\n", 1)
104
  # print(line)
105
 
106
  if line.startswith("data:"):
 
107
  line = line[6:]
108
  resp: dict = json.loads(line)
109
  message = resp.get("message")
 
164
  break
165
  except httpx.ConnectError as e:
166
  print(f"连接错误: {e}")
167
+ continue
168
+ except httpx.ReadTimeout as e:
169
+ print(f"读取响应超时: {e}")
170
  continue
test/provider_test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ from fastapi.testclient import TestClient
4
+ import sys
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+ from main import app
7
+
8
+ @pytest.fixture
9
+ def test_client():
10
+ with TestClient(app) as client:
11
+ yield client
12
+
13
+ @pytest.fixture
14
+ def api_key():
15
+ return os.environ.get("API")
16
+
17
+ def test_request_model(test_client, api_key, model="gpt-4o"):
18
+ request_data = {
19
+ "model": model,
20
+ "messages": [
21
+ {
22
+ "role": "user",
23
+ "content": "say test"
24
+ }
25
+ ],
26
+ "max_tokens": 4096,
27
+ "stream": True,
28
+ "temperature": 0.5,
29
+ "top_p": 1.0,
30
+ "presence_penalty": 0.0,
31
+ "frequency_penalty": 0.0,
32
+ "n": 1,
33
+ "user": "user",
34
+ "tools": [
35
+ {
36
+ "type": "function",
37
+ "function": {
38
+ "name": "get_search_results",
39
+ "description": "Search Google to enhance knowledge.",
40
+ "parameters": {
41
+ "type": "object",
42
+ "properties": {
43
+ "prompt": {
44
+ "type": "string",
45
+ "description": "The prompt to search."
46
+ }
47
+ },
48
+ "required": ["prompt"]
49
+ }
50
+ }
51
+ },
52
+ {
53
+ "type": "function",
54
+ "function": {
55
+ "name": "get_url_content",
56
+ "description": "Get the webpage content of a URL.",
57
+ "parameters": {
58
+ "type": "object",
59
+ "properties": {
60
+ "url": {
61
+ "type": "string",
62
+ "description": "The URL to request."
63
+ }
64
+ },
65
+ "required": ["url"]
66
+ }
67
+ }
68
+ }
69
+ ]
70
+ }
71
+
72
+ headers = {
73
+ "Authorization": f"Bearer {api_key}"
74
+ }
75
+
76
+ response = test_client.post("/v1/chat/completions", json=request_data, headers=headers)
77
+ for line in response.iter_lines():
78
+ print(line)
79
+ assert response.status_code == 200
80
+
81
+ if __name__ == "__main__":
82
+ pytest.main(["-s", "test/test.py"])