Add feature: support OpenAI dall-e-3 image generation
Browse files- README.md +14 -9
- main.py +21 -10
- models.py +7 -0
- request.py +22 -1
- response.py +33 -39
- utils.py +27 -1
README.md
CHANGED
@@ -12,20 +12,23 @@
|
|
12 |
|
13 |
## Introduction
|
14 |
|
15 |
-
|
16 |
|
17 |
## Features
|
18 |
|
19 |
-
-
|
20 |
-
-
|
21 |
-
-
|
22 |
-
-
|
23 |
-
-
|
24 |
-
-
|
|
|
|
|
|
|
25 |
|
26 |
## Configuration
|
27 |
|
28 |
-
使用api.yaml配置文件,可以配置多个模型,每个模型可以配置多个后端服务,支持负载均衡。下面是 api.yaml 配置文件的示例:
|
29 |
|
30 |
```yaml
|
31 |
providers:
|
@@ -35,6 +38,7 @@ providers:
|
|
35 |
model: # 至少填一个模型
|
36 |
- gpt-4o # 可以使用的模型名称,必填
|
37 |
- claude-3-5-sonnet-20240620: claude-3-5-sonnet # 重命名模型,claude-3-5-sonnet-20240620 是服务商的模型名称,claude-3-5-sonnet 是重命名后的名字,可以使用简洁的名字代替原来复杂的名称,选填
|
|
|
38 |
|
39 |
- provider: anthropic
|
40 |
base_url: https://api.anthropic.com/v1/messages
|
@@ -86,7 +90,7 @@ api_keys:
|
|
86 |
model:
|
87 |
- anthropic/claude-3-5-sonnet # 可以使用的模型名称,仅可以使用名为 anthropic 提供商提供的 claude-3-5-sonnet 模型。其他提供商的 claude-3-5-sonnet 模型不可以使用。
|
88 |
preferences:
|
89 |
-
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true
|
90 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
91 |
```
|
92 |
|
@@ -152,6 +156,7 @@ curl -X POST http://127.0.0.1:8000/v1/chat/completions \
|
|
152 |
-d '{"model": "gpt-4o","messages": [{"role": "user", "content": "Hello"}],"stream": true}'
|
153 |
```
|
154 |
|
|
|
155 |
## Star History
|
156 |
|
157 |
<a href="https://github.com/yym68686/uni-api/stargazers">
|
|
|
12 |
|
13 |
## Introduction
|
14 |
|
15 |
+
如果个人使用的话,one/new-api 过于复杂,有很多个人不需要使用的商用功能,如果你不想要复杂的前端界面,有想要支持的模型多一点,可以试试 uni-api。这是一个统一管理大模型API的项目,可以通过一个统一的API接口调用多个后端服务,统一转换为 OpenAI 格式,支持负载均衡。目前支持的后端服务有:OpenAI、Anthropic、Gemini、Vertex、DeepBricks、OpenRouter 等。
|
16 |
|
17 |
## Features
|
18 |
|
19 |
+
- 无前端,纯配置文件配置 API 渠道。只要写一个文件就能运行起一个属于自己的 API 站,文档有详细的配置指南,小白友好。
|
20 |
+
- 统一管理多个后端服务,支持 OpenAI、Deepseek、DeepBricks、OpenRouter 等其他API 是 OpenAI 格式的提供商。支持 OpenAI Dalle-3 图像生成。
|
21 |
+
- 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
|
22 |
+
- 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
|
23 |
+
- 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
|
24 |
+
- 支持负载均衡,支持 Vertex 区域负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。除了 Vertex 区域负载均衡,所有 API 均支持渠道级负载均衡,提高沉浸式翻译体验。
|
25 |
+
- 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
|
26 |
+
- 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
|
27 |
+
- 支持多个 API Key。
|
28 |
|
29 |
## Configuration
|
30 |
|
31 |
+
使用 api.yaml 配置文件,可以配置多个模型,每个模型可以配置多个后端服务,支持负载均衡。下面是 api.yaml 配置文件的示例:
|
32 |
|
33 |
```yaml
|
34 |
providers:
|
|
|
38 |
model: # 至少填一个模型
|
39 |
- gpt-4o # 可以使用的模型名称,必填
|
40 |
- claude-3-5-sonnet-20240620: claude-3-5-sonnet # 重命名模型,claude-3-5-sonnet-20240620 是服务商的模型名称,claude-3-5-sonnet 是重命名后的名字,可以使用简洁的名字代替原来复杂的名称,选填
|
41 |
+
- dall-e-3
|
42 |
|
43 |
- provider: anthropic
|
44 |
base_url: https://api.anthropic.com/v1/messages
|
|
|
90 |
model:
|
91 |
- anthropic/claude-3-5-sonnet # 可以使用的模型名称,仅可以使用名为 anthropic 提供商提供的 claude-3-5-sonnet 模型。其他提供商的 claude-3-5-sonnet 模型不可以使用。
|
92 |
preferences:
|
93 |
+
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
|
94 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
95 |
```
|
96 |
|
|
|
156 |
-d '{"model": "gpt-4o","messages": [{"role": "user", "content": "Hello"}],"stream": true}'
|
157 |
```
|
158 |
|
159 |
+
|
160 |
## Star History
|
161 |
|
162 |
<a href="https://github.com/yym68686/uni-api/stargazers">
|
main.py
CHANGED
@@ -5,16 +5,16 @@ import secrets
|
|
5 |
from contextlib import asynccontextmanager
|
6 |
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
-
from fastapi import FastAPI, HTTPException, Depends
|
9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
11 |
|
12 |
-
from models import RequestModel
|
13 |
from utils import error_handling_wrapper, get_all_models, post_all_models, load_config
|
14 |
from request import get_payload
|
15 |
from response import fetch_response, fetch_response_stream
|
16 |
|
17 |
-
from typing import List, Dict
|
18 |
from urllib.parse import urlparse
|
19 |
|
20 |
@asynccontextmanager
|
@@ -80,7 +80,7 @@ app.add_middleware(
|
|
80 |
allow_headers=["*"], # 允许所有头部字段
|
81 |
)
|
82 |
|
83 |
-
async def process_request(request: RequestModel, provider: Dict):
|
84 |
url = provider['base_url']
|
85 |
parsed_url = urlparse(url)
|
86 |
# print(parsed_url)
|
@@ -101,6 +101,10 @@ async def process_request(request: RequestModel, provider: Dict):
|
|
101 |
and "gemini" not in provider['model'][request.model]:
|
102 |
engine = "openrouter"
|
103 |
|
|
|
|
|
|
|
|
|
104 |
if provider.get("engine"):
|
105 |
engine = provider["engine"]
|
106 |
|
@@ -122,7 +126,7 @@ async def process_request(request: RequestModel, provider: Dict):
|
|
122 |
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
|
123 |
return StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
124 |
else:
|
125 |
-
return await fetch_response(app.state.client, url, headers, payload)
|
126 |
|
127 |
import asyncio
|
128 |
class ModelRequestHandler:
|
@@ -171,7 +175,7 @@ class ModelRequestHandler:
|
|
171 |
# print(json.dumps(provider, indent=4, ensure_ascii=False))
|
172 |
return provider_list
|
173 |
|
174 |
-
async def request_model(self, request: RequestModel, token: str):
|
175 |
config = app.state.config
|
176 |
# api_keys_db = app.state.api_keys_db
|
177 |
api_list = app.state.api_list
|
@@ -193,9 +197,9 @@ class ModelRequestHandler:
|
|
193 |
if config['api_keys'][api_index]["preferences"].get("AUTO_RETRY") == False:
|
194 |
auto_retry = False
|
195 |
|
196 |
-
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry)
|
197 |
|
198 |
-
async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool, auto_retry: bool):
|
199 |
num_providers = len(providers)
|
200 |
start_index = self.last_provider_index + 1 if use_round_robin else 0
|
201 |
|
@@ -203,7 +207,7 @@ class ModelRequestHandler:
|
|
203 |
self.last_provider_index = (start_index + i) % num_providers
|
204 |
provider = providers[self.last_provider_index]
|
205 |
try:
|
206 |
-
response = await process_request(request, provider)
|
207 |
return response
|
208 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
|
209 |
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
@@ -228,7 +232,7 @@ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
228 |
return token
|
229 |
|
230 |
@app.post("/v1/chat/completions")
|
231 |
-
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
|
232 |
return await model_handler.request_model(request, token)
|
233 |
|
234 |
@app.options("/v1/chat/completions")
|
@@ -251,6 +255,13 @@ async def list_models():
|
|
251 |
"data": models
|
252 |
})
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
@app.get("/generate-api-key")
|
255 |
def generate_api_key():
|
256 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
|
|
5 |
from contextlib import asynccontextmanager
|
6 |
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
+
from fastapi import FastAPI, HTTPException, Depends
|
9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
11 |
|
12 |
+
from models import RequestModel, ImageGenerationRequest
|
13 |
from utils import error_handling_wrapper, get_all_models, post_all_models, load_config
|
14 |
from request import get_payload
|
15 |
from response import fetch_response, fetch_response_stream
|
16 |
|
17 |
+
from typing import List, Dict, Union
|
18 |
from urllib.parse import urlparse
|
19 |
|
20 |
@asynccontextmanager
|
|
|
80 |
allow_headers=["*"], # 允许所有头部字段
|
81 |
)
|
82 |
|
83 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
|
84 |
url = provider['base_url']
|
85 |
parsed_url = urlparse(url)
|
86 |
# print(parsed_url)
|
|
|
101 |
and "gemini" not in provider['model'][request.model]:
|
102 |
engine = "openrouter"
|
103 |
|
104 |
+
if endpoint == "/v1/images/generations":
|
105 |
+
engine = "dalle"
|
106 |
+
request.stream = False
|
107 |
+
|
108 |
if provider.get("engine"):
|
109 |
engine = provider["engine"]
|
110 |
|
|
|
126 |
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
|
127 |
return StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
128 |
else:
|
129 |
+
return await anext(fetch_response(app.state.client, url, headers, payload))
|
130 |
|
131 |
import asyncio
|
132 |
class ModelRequestHandler:
|
|
|
175 |
# print(json.dumps(provider, indent=4, ensure_ascii=False))
|
176 |
return provider_list
|
177 |
|
178 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest], token: str, endpoint=None):
|
179 |
config = app.state.config
|
180 |
# api_keys_db = app.state.api_keys_db
|
181 |
api_list = app.state.api_list
|
|
|
197 |
if config['api_keys'][api_index]["preferences"].get("AUTO_RETRY") == False:
|
198 |
auto_retry = False
|
199 |
|
200 |
+
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
|
201 |
|
202 |
+
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
|
203 |
num_providers = len(providers)
|
204 |
start_index = self.last_provider_index + 1 if use_round_robin else 0
|
205 |
|
|
|
207 |
self.last_provider_index = (start_index + i) % num_providers
|
208 |
provider = providers[self.last_provider_index]
|
209 |
try:
|
210 |
+
response = await process_request(request, provider, endpoint)
|
211 |
return response
|
212 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
|
213 |
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
|
|
232 |
return token
|
233 |
|
234 |
@app.post("/v1/chat/completions")
|
235 |
+
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
|
236 |
return await model_handler.request_model(request, token)
|
237 |
|
238 |
@app.options("/v1/chat/completions")
|
|
|
255 |
"data": models
|
256 |
})
|
257 |
|
258 |
+
@app.post("/v1/images/generations")
|
259 |
+
async def images_generations(
|
260 |
+
request: ImageGenerationRequest,
|
261 |
+
token: str = Depends(verify_api_key)
|
262 |
+
):
|
263 |
+
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
264 |
+
|
265 |
@app.get("/generate-api-key")
|
266 |
def generate_api_key():
|
267 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
models.py
CHANGED
@@ -1,6 +1,13 @@
|
|
1 |
from pydantic import BaseModel, Field
|
2 |
from typing import List, Dict, Optional, Union
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
class FunctionParameter(BaseModel):
|
5 |
type: str
|
6 |
properties: Dict[str, Dict[str, str]]
|
|
|
1 |
from pydantic import BaseModel, Field
|
2 |
from typing import List, Dict, Optional, Union
|
3 |
|
4 |
+
class ImageGenerationRequest(BaseModel):
|
5 |
+
model: str
|
6 |
+
prompt: str
|
7 |
+
n: int
|
8 |
+
size: str
|
9 |
+
stream: bool = False
|
10 |
+
|
11 |
class FunctionParameter(BaseModel):
|
12 |
type: str
|
13 |
properties: Dict[str, Dict[str, str]]
|
request.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import json
|
2 |
from models import RequestModel
|
3 |
-
from utils import c35s, c3s, c3o, c3h, gem,
|
4 |
|
5 |
async def get_image_message(base64_image, engine = None):
|
6 |
if "gpt" == engine:
|
@@ -748,6 +748,25 @@ async def get_claude_payload(request, engine, provider):
|
|
748 |
|
749 |
return url, headers, payload
|
750 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
751 |
async def get_payload(request: RequestModel, engine, provider):
|
752 |
if engine == "gemini":
|
753 |
return await get_gemini_payload(request, engine, provider)
|
@@ -761,5 +780,7 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
761 |
return await get_gpt_payload(request, engine, provider)
|
762 |
elif engine == "openrouter":
|
763 |
return await get_openrouter_payload(request, engine, provider)
|
|
|
|
|
764 |
else:
|
765 |
raise ValueError("Unknown payload")
|
|
|
1 |
import json
|
2 |
from models import RequestModel
|
3 |
+
from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
|
4 |
|
5 |
async def get_image_message(base64_image, engine = None):
|
6 |
if "gpt" == engine:
|
|
|
748 |
|
749 |
return url, headers, payload
|
750 |
|
751 |
+
async def get_dalle_payload(request, engine, provider):
|
752 |
+
model = provider['model'][request.model]
|
753 |
+
headers = {
|
754 |
+
"Content-Type": "application/json",
|
755 |
+
}
|
756 |
+
if provider.get("api"):
|
757 |
+
headers['Authorization'] = f"Bearer {provider['api']}"
|
758 |
+
url = provider['base_url']
|
759 |
+
url = BaseAPI(url).image_url
|
760 |
+
|
761 |
+
payload = {
|
762 |
+
"model": model,
|
763 |
+
"prompt": request.prompt,
|
764 |
+
"n": request.n,
|
765 |
+
"size": request.size
|
766 |
+
}
|
767 |
+
|
768 |
+
return url, headers, payload
|
769 |
+
|
770 |
async def get_payload(request: RequestModel, engine, provider):
|
771 |
if engine == "gemini":
|
772 |
return await get_gemini_payload(request, engine, provider)
|
|
|
780 |
return await get_gpt_payload(request, engine, provider)
|
781 |
elif engine == "openrouter":
|
782 |
return await get_openrouter_payload(request, engine, provider)
|
783 |
+
elif engine == "dalle":
|
784 |
+
return await get_dalle_payload(request, engine, provider)
|
785 |
else:
|
786 |
raise ValueError("Unknown payload")
|
response.py
CHANGED
@@ -36,17 +36,24 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
36 |
|
37 |
return sse_response
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
40 |
timestamp = datetime.timestamp(datetime.now())
|
41 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
error_json = json.loads(error_str)
|
47 |
-
except json.JSONDecodeError:
|
48 |
-
error_json = error_str
|
49 |
-
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
50 |
buffer = ""
|
51 |
revicing_function_call = False
|
52 |
function_full_response = "{"
|
@@ -87,14 +94,11 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
87 |
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
88 |
timestamp = datetime.timestamp(datetime.now())
|
89 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
except json.JSONDecodeError:
|
96 |
-
error_json = error_str
|
97 |
-
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
98 |
buffer = ""
|
99 |
revicing_function_call = False
|
100 |
function_full_response = "{"
|
@@ -138,14 +142,9 @@ async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects
|
|
138 |
while redirect_count < max_redirects:
|
139 |
# logger.info(f"fetch_gpt_response_stream: {url}")
|
140 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
try:
|
145 |
-
error_json = json.loads(error_str)
|
146 |
-
except json.JSONDecodeError:
|
147 |
-
error_json = error_str
|
148 |
-
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
149 |
return
|
150 |
|
151 |
buffer = ""
|
@@ -185,14 +184,10 @@ async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects
|
|
185 |
async def fetch_claude_response_stream(client, url, headers, payload, model):
|
186 |
timestamp = datetime.timestamp(datetime.now())
|
187 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
error_json = json.loads(error_str)
|
193 |
-
except json.JSONDecodeError:
|
194 |
-
error_json = error_str
|
195 |
-
yield {"error": f"fetch_claude_response_stream HTTP Error {response.status_code}", "details": error_json}
|
196 |
buffer = ""
|
197 |
async for chunk in response.aiter_text():
|
198 |
# logger.info(f"chunk: {repr(chunk)}")
|
@@ -241,13 +236,12 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
241 |
yield sse_string
|
242 |
|
243 |
async def fetch_response(client, url, headers, payload):
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
return
|
249 |
-
|
250 |
-
return {"error": f"500", "details": "fetch_response Read Response Timeout"}
|
251 |
|
252 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
253 |
try:
|
|
|
36 |
|
37 |
return sse_response
|
38 |
|
39 |
+
async def check_response(response, error_log):
|
40 |
+
if response.status_code != 200:
|
41 |
+
error_message = await response.aread()
|
42 |
+
error_str = error_message.decode('utf-8', errors='replace')
|
43 |
+
try:
|
44 |
+
error_json = json.loads(error_str)
|
45 |
+
except json.JSONDecodeError:
|
46 |
+
error_json = error_str
|
47 |
+
return {"error": f"{error_log} HTTP Error {response.status_code}", "details": error_json}
|
48 |
+
return None
|
49 |
+
|
50 |
async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
51 |
timestamp = datetime.timestamp(datetime.now())
|
52 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
53 |
+
error_message = await check_response(response, "fetch_gemini_response_stream")
|
54 |
+
if error_message:
|
55 |
+
yield error_message
|
56 |
+
return
|
|
|
|
|
|
|
|
|
57 |
buffer = ""
|
58 |
revicing_function_call = False
|
59 |
function_full_response = "{"
|
|
|
94 |
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
95 |
timestamp = datetime.timestamp(datetime.now())
|
96 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
97 |
+
error_message = await check_response(response, "fetch_vertex_claude_response_stream")
|
98 |
+
if error_message:
|
99 |
+
yield error_message
|
100 |
+
return
|
101 |
+
|
|
|
|
|
|
|
102 |
buffer = ""
|
103 |
revicing_function_call = False
|
104 |
function_full_response = "{"
|
|
|
142 |
while redirect_count < max_redirects:
|
143 |
# logger.info(f"fetch_gpt_response_stream: {url}")
|
144 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
145 |
+
error_message = await check_response(response, "fetch_gpt_response_stream")
|
146 |
+
if error_message:
|
147 |
+
yield error_message
|
|
|
|
|
|
|
|
|
|
|
148 |
return
|
149 |
|
150 |
buffer = ""
|
|
|
184 |
async def fetch_claude_response_stream(client, url, headers, payload, model):
|
185 |
timestamp = datetime.timestamp(datetime.now())
|
186 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
187 |
+
error_message = await check_response(response, "fetch_claude_response_stream")
|
188 |
+
if error_message:
|
189 |
+
yield error_message
|
190 |
+
return
|
|
|
|
|
|
|
|
|
191 |
buffer = ""
|
192 |
async for chunk in response.aiter_text():
|
193 |
# logger.info(f"chunk: {repr(chunk)}")
|
|
|
236 |
yield sse_string
|
237 |
|
238 |
async def fetch_response(client, url, headers, payload):
|
239 |
+
response = await client.post(url, headers=headers, json=payload)
|
240 |
+
error_message = await check_response(response, "fetch_response")
|
241 |
+
if error_message:
|
242 |
+
yield error_message
|
243 |
+
return
|
244 |
+
yield response.json()
|
|
|
245 |
|
246 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
247 |
try:
|
utils.py
CHANGED
@@ -222,4 +222,30 @@ c35s = CircularList(["us-east5", "europe-west1"])
|
|
222 |
c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
|
223 |
c3o = CircularList(["us-east5"])
|
224 |
c3h = CircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
|
225 |
-
gem = CircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
|
223 |
c3o = CircularList(["us-east5"])
|
224 |
c3h = CircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
|
225 |
+
gem = CircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
|
226 |
+
|
227 |
+
class BaseAPI:
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
api_url: str = "https://api.openai.com/v1/chat/completions",
|
231 |
+
):
|
232 |
+
if api_url == "":
|
233 |
+
api_url = "https://api.openai.com/v1/chat/completions"
|
234 |
+
self.source_api_url: str = api_url
|
235 |
+
from urllib.parse import urlparse, urlunparse
|
236 |
+
parsed_url = urlparse(self.source_api_url)
|
237 |
+
if parsed_url.scheme == "":
|
238 |
+
raise Exception("Error: API_URL is not set")
|
239 |
+
if parsed_url.path != '/':
|
240 |
+
before_v1 = parsed_url.path.split("/v1")[0]
|
241 |
+
else:
|
242 |
+
before_v1 = ""
|
243 |
+
self.base_url: str = urlunparse(parsed_url[:2] + (before_v1,) + ("",) * 3)
|
244 |
+
self.v1_url: str = urlunparse(parsed_url[:2]+ (before_v1 + "/v1",) + ("",) * 3)
|
245 |
+
self.v1_models: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/models",) + ("",) * 3)
|
246 |
+
if parsed_url.netloc == "api.deepseek.com":
|
247 |
+
self.chat_url: str = urlunparse(parsed_url[:2] + ("/chat/completions",) + ("",) * 3)
|
248 |
+
else:
|
249 |
+
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
|
250 |
+
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
251 |
+
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|