yym68686 commited on
Commit
edb14b7
·
1 Parent(s): cb6cbda

✨ Feature: Add feature: Support vertex ai API and vertex tool use invocation.

Browse files
Files changed (9) hide show
  1. .gitignore +2 -1
  2. README.md +13 -3
  3. main.py +15 -1
  4. request.py +197 -1
  5. requirements.txt +1 -1
  6. response.py +1 -1
  7. test/provider_test.py +1 -0
  8. test/test_httpx.py +64 -0
  9. utils.py +2 -0
.gitignore CHANGED
@@ -4,4 +4,5 @@ api.yaml
4
  __pycache__
5
  .vscode
6
  node_modules
7
- .wrangler
 
 
4
  __pycache__
5
  .vscode
6
  node_modules
7
+ .wrangler
8
+ .pytest_cache
README.md CHANGED
@@ -12,14 +12,15 @@
12
 
13
  ## Introduction
14
 
15
- 这是一个统一管理大模型API的项目,可以通过一个统一的API接口调用多个后端服务,统一转换为 OpenAI 格式,支持负载均衡。目前支持的后端服务有:OpenAI、Anthropic、DeepBricks、OpenRouter、Gemini等。
16
 
17
  ## Features
18
 
19
  - 统一管理多个后端服务
20
  - 支持负载均衡
 
21
  - 支持多个模型
22
- - 支持多个API Key
23
 
24
  ## Configuration
25
 
@@ -48,9 +49,18 @@ providers:
48
  api: AIzaSyAN2k6IRdgw
49
  model:
50
  - gemini-1.5-pro
51
- - gemini-1.5-flash
52
  tools: false
53
 
 
 
 
 
 
 
 
 
 
54
  - provider: other-provider
55
  base_url: https://api.xxx.com/v1/messages
56
  api: sk-bNnAOJyA-xQw_twAA
 
12
 
13
  ## Introduction
14
 
15
+ 这是一个统一管理大模型API的项目,可以通过一个统一的API接口调用多个后端服务,统一转换为 OpenAI 格式,支持负载均衡。目前支持的后端服务有:OpenAI、Anthropic、DeepBricks、OpenRouter、Gemini、Vertex 等。
16
 
17
  ## Features
18
 
19
  - 统一管理多个后端服务
20
  - 支持负载均衡
21
+ - 支持 OpenAI, Anthropic, Gemini, Vertex 函数调用
22
  - 支持多个模型
23
+ - 支持多个 API Key
24
 
25
  ## Configuration
26
 
 
49
  api: AIzaSyAN2k6IRdgw
50
  model:
51
  - gemini-1.5-pro
52
+ - gemini-1.5-flash-exp-0827: gemini-1.5-flash
53
  tools: false
54
 
55
+ - provider: vertex
56
+ project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
57
+ private_key: "-----BEGIN PRIVATE KEY-----\nxxxxx\n-----END PRIVATE" # 描述: Google Cloud Vertex AI服务账号的私钥。格式: 一个JSON格式的字符串,包含服务账号的私钥信息。获取方式: 在Google Cloud Console中创建服务账号,生成JSON格式的密钥文件,然后将其内容设置为此环境变量的值。
58
+ client_email: [email protected] # 描述: Google Cloud Vertex AI服务账号的电子邮件地址。格式: 通常是形如 "[email protected]" 的字符串。获取方式: 在创建服务账号时生成,也可以在Google Cloud Console的"IAM与管理"部分查看服务账号详情获得。
59
+ model:
60
+ - gemini-1.5-pro
61
+ - gemini-1.5-flash
62
+ tools: true
63
+
64
  - provider: other-provider
65
  base_url: https://api.xxx.com/v1/messages
66
  api: sk-bNnAOJyA-xQw_twAA
main.py CHANGED
@@ -21,7 +21,18 @@ from urllib.parse import urlparse
21
  async def lifespan(app: FastAPI):
22
  # 启动时的代码
23
  timeout = httpx.Timeout(connect=15.0, read=20.0, write=30.0, pool=30.0)
24
- app.state.client = httpx.AsyncClient(timeout=timeout)
 
 
 
 
 
 
 
 
 
 
 
25
  app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
26
  yield
27
  # 关闭时的代码
@@ -45,6 +56,8 @@ async def process_request(request: RequestModel, provider: Dict):
45
  engine = None
46
  if parsed_url.netloc == 'generativelanguage.googleapis.com':
47
  engine = "gemini"
 
 
48
  elif parsed_url.netloc == 'api.anthropic.com' or parsed_url.path.endswith("v1/messages"):
49
  engine = "claude"
50
  elif parsed_url.netloc == 'openrouter.ai':
@@ -59,6 +72,7 @@ async def process_request(request: RequestModel, provider: Dict):
59
 
60
  if provider.get("engine"):
61
  engine = provider["engine"]
 
62
  logger.info(f"provider: {provider['provider']:<10} model: {request.model:<10} engine: {engine}")
63
 
64
  url, headers, payload = await get_payload(request, engine, provider)
 
21
  async def lifespan(app: FastAPI):
22
  # 启动时的代码
23
  timeout = httpx.Timeout(connect=15.0, read=20.0, write=30.0, pool=30.0)
24
+ default_headers = {
25
+ "User-Agent": "curl/7.68.0", # 模拟 curl 的 User-Agent
26
+ "Accept": "*/*", # curl 的默认 Accept 头
27
+ }
28
+ app.state.client = httpx.AsyncClient(
29
+ timeout=timeout,
30
+ headers=default_headers,
31
+ http2=True, # 禁用 HTTP/2
32
+ verify=True, # 保持 SSL 验证(如需禁用,设为 False,但不建议)
33
+ follow_redirects=True, # 自动跟随重定向
34
+ )
35
+ # app.state.client = httpx.AsyncClient(timeout=timeout)
36
  app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
37
  yield
38
  # 关闭时的代码
 
56
  engine = None
57
  if parsed_url.netloc == 'generativelanguage.googleapis.com':
58
  engine = "gemini"
59
+ elif parsed_url.netloc == 'aiplatform.googleapis.com':
60
+ engine = "vertex"
61
  elif parsed_url.netloc == 'api.anthropic.com' or parsed_url.path.endswith("v1/messages"):
62
  engine = "claude"
63
  elif parsed_url.netloc == 'openrouter.ai':
 
72
 
73
  if provider.get("engine"):
74
  engine = provider["engine"]
75
+
76
  logger.info(f"provider: {provider['provider']:<10} model: {request.model:<10} engine: {engine}")
77
 
78
  url, headers, payload = await get_payload(request, engine, provider)
request.py CHANGED
@@ -165,10 +165,204 @@ async def get_gemini_payload(request, engine, provider):
165
 
166
  return url, headers, payload
167
 
168
- async def get_gpt_payload(request, engine, provider):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  headers = {
170
  'Content-Type': 'application/json'
171
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if provider.get("api"):
173
  headers['Authorization'] = f"Bearer {provider['api']}"
174
  url = provider['base_url']
@@ -426,6 +620,8 @@ async def get_claude_payload(request, engine, provider):
426
  async def get_payload(request: RequestModel, engine, provider):
427
  if engine == "gemini":
428
  return await get_gemini_payload(request, engine, provider)
 
 
429
  elif engine == "claude":
430
  return await get_claude_payload(request, engine, provider)
431
  elif engine == "gpt":
 
165
 
166
  return url, headers, payload
167
 
168
+ import time
169
+ import httpx
170
+ import base64
171
+ from cryptography.hazmat.primitives import hashes
172
+ from cryptography.hazmat.primitives.asymmetric import padding
173
+ from cryptography.hazmat.primitives.serialization import load_pem_private_key
174
+
175
+ def create_jwt(client_email, private_key):
176
+ # JWT Header
177
+ header = json.dumps({
178
+ "alg": "RS256",
179
+ "typ": "JWT"
180
+ }).encode()
181
+
182
+ # JWT Payload
183
+ now = int(time.time())
184
+ payload = json.dumps({
185
+ "iss": client_email,
186
+ "scope": "https://www.googleapis.com/auth/cloud-platform",
187
+ "aud": "https://oauth2.googleapis.com/token",
188
+ "exp": now + 3600,
189
+ "iat": now
190
+ }).encode()
191
+
192
+ # Encode header and payload
193
+ segments = [
194
+ base64.urlsafe_b64encode(header).rstrip(b'='),
195
+ base64.urlsafe_b64encode(payload).rstrip(b'=')
196
+ ]
197
+
198
+ # Create signature
199
+ signing_input = b'.'.join(segments)
200
+ private_key = load_pem_private_key(private_key.encode(), password=None)
201
+ signature = private_key.sign(
202
+ signing_input,
203
+ padding.PKCS1v15(),
204
+ hashes.SHA256()
205
+ )
206
+
207
+ segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
208
+ return b'.'.join(segments).decode()
209
+
210
+ def get_access_token(client_email, private_key):
211
+ jwt = create_jwt(client_email, private_key)
212
+
213
+ with httpx.Client() as client:
214
+ response = client.post(
215
+ "https://oauth2.googleapis.com/token",
216
+ data={
217
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
218
+ "assertion": jwt
219
+ },
220
+ headers={'Content-Type': "application/x-www-form-urlencoded"}
221
+ )
222
+ response.raise_for_status()
223
+ return response.json()["access_token"]
224
+
225
+ async def get_vertex_payload(request, engine, provider):
226
  headers = {
227
  'Content-Type': 'application/json'
228
  }
229
+ if provider.get("client_email") and provider.get("private_key"):
230
+ access_token = get_access_token(provider['client_email'], provider['private_key'])
231
+ headers['Authorization'] = f"Bearer {access_token}"
232
+ model = provider['model'][request.model]
233
+ if request.stream:
234
+ gemini_stream = "streamGenerateContent"
235
+ if provider.get("project_id"):
236
+ project_id = provider.get("project_id")
237
+ url = "https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}".format(PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream)
238
+
239
+ messages = []
240
+ systemInstruction = None
241
+ function_arguments = None
242
+ for msg in request.messages:
243
+ if msg.role == "assistant":
244
+ msg.role = "model"
245
+ tool_calls = None
246
+ if isinstance(msg.content, list):
247
+ content = []
248
+ for item in msg.content:
249
+ if item.type == "text":
250
+ text_message = await get_text_message(msg.role, item.text, engine)
251
+ content.append(text_message)
252
+ elif item.type == "image_url":
253
+ image_message = await get_image_message(item.image_url.url, engine)
254
+ content.append(image_message)
255
+ else:
256
+ content = [{"text": msg.content}]
257
+ tool_calls = msg.tool_calls
258
+
259
+ if tool_calls:
260
+ tool_call = tool_calls[0]
261
+ function_arguments = {
262
+ "functionCall": {
263
+ "name": tool_call.function.name,
264
+ "args": json.loads(tool_call.function.arguments)
265
+ }
266
+ }
267
+ messages.append(
268
+ {
269
+ "role": "model",
270
+ "parts": [function_arguments]
271
+ }
272
+ )
273
+ elif msg.role == "tool":
274
+ function_call_name = function_arguments["functionCall"]["name"]
275
+ messages.append(
276
+ {
277
+ "role": "function",
278
+ "parts": [{
279
+ "functionResponse": {
280
+ "name": function_call_name,
281
+ "response": {
282
+ "name": function_call_name,
283
+ "content": {
284
+ "result": msg.content,
285
+ }
286
+ }
287
+ }
288
+ }]
289
+ }
290
+ )
291
+ elif msg.role != "system":
292
+ messages.append({"role": msg.role, "parts": content})
293
+ elif msg.role == "system":
294
+ systemInstruction = {"parts": content}
295
+
296
+
297
+ payload = {
298
+ "contents": messages,
299
+ # "safetySettings": [
300
+ # {
301
+ # "category": "HARM_CATEGORY_HARASSMENT",
302
+ # "threshold": "BLOCK_NONE"
303
+ # },
304
+ # {
305
+ # "category": "HARM_CATEGORY_HATE_SPEECH",
306
+ # "threshold": "BLOCK_NONE"
307
+ # },
308
+ # {
309
+ # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
310
+ # "threshold": "BLOCK_NONE"
311
+ # },
312
+ # {
313
+ # "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
314
+ # "threshold": "BLOCK_NONE"
315
+ # }
316
+ # ]
317
+ "generationConfig": {
318
+ "temperature": 0.5,
319
+ "max_output_tokens": 8192,
320
+ "top_k": 40,
321
+ "top_p": 0.95
322
+ },
323
+ }
324
+ if systemInstruction:
325
+ payload["system_instruction"] = systemInstruction
326
+
327
+ miss_fields = [
328
+ 'model',
329
+ 'messages',
330
+ 'stream',
331
+ 'tool_choice',
332
+ 'temperature',
333
+ 'top_p',
334
+ 'max_tokens',
335
+ 'presence_penalty',
336
+ 'frequency_penalty',
337
+ 'n',
338
+ 'user',
339
+ 'include_usage',
340
+ 'logprobs',
341
+ 'top_logprobs'
342
+ ]
343
+
344
+ for field, value in request.model_dump(exclude_unset=True).items():
345
+ if field not in miss_fields and value is not None:
346
+ if field == "tools":
347
+ payload.update({
348
+ "tools": [{
349
+ "function_declarations": [tool["function"] for tool in value]
350
+ }],
351
+ "tool_config": {
352
+ "function_calling_config": {
353
+ "mode": "AUTO"
354
+ }
355
+ }
356
+ })
357
+ else:
358
+ payload[field] = value
359
+
360
+ return url, headers, payload
361
+
362
+ async def get_gpt_payload(request, engine, provider):
363
+ headers = {
364
+ 'Content-Type': 'application/json',
365
+ }
366
  if provider.get("api"):
367
  headers['Authorization'] = f"Bearer {provider['api']}"
368
  url = provider['base_url']
 
620
  async def get_payload(request: RequestModel, engine, provider):
621
  if engine == "gemini":
622
  return await get_gemini_payload(request, engine, provider)
623
+ elif engine == "vertex":
624
+ return await get_vertex_payload(request, engine, provider)
625
  elif engine == "claude":
626
  return await get_claude_payload(request, engine, provider)
627
  elif engine == "gpt":
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- httpx
2
  pyyaml
3
  pytest
4
  uvicorn
 
1
+ httpx[http2]
2
  pyyaml
3
  pytest
4
  uvicorn
response.py CHANGED
@@ -202,7 +202,7 @@ async def fetch_response(client, url, headers, payload):
202
 
203
  async def fetch_response_stream(client, url, headers, payload, engine, model):
204
  try:
205
- if engine == "gemini":
206
  async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
207
  yield chunk
208
  elif engine == "claude":
 
202
 
203
  async def fetch_response_stream(client, url, headers, payload, engine, model):
204
  try:
205
+ if engine == "gemini" or engine == "vertex":
206
  async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
207
  yield chunk
208
  elif engine == "claude":
test/provider_test.py CHANGED
@@ -74,6 +74,7 @@ def test_request_model(test_client, api_key, get_model):
74
  }
75
 
76
  headers = {
 
77
  "Authorization": f"Bearer {api_key}"
78
  }
79
 
 
74
  }
75
 
76
  headers = {
77
+ 'Content-Type': 'application/json',
78
  "Authorization": f"Bearer {api_key}"
79
  }
80
 
test/test_httpx.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import asyncio
3
+ import ssl
4
+ import logging
5
+
6
+ # 设置日志
7
+ logging.basicConfig(level=logging.DEBUG)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ async def make_request():
11
+ # SSL 上下文设置
12
+ # ssl_context = ssl.create_default_context()
13
+ # ssl_context.set_alpn_protocols(["h2", "http/1.1"])
14
+
15
+ # 创建自定义传输
16
+ transport = httpx.AsyncHTTPTransport(
17
+ http2=True,
18
+ # verify=ssl_context,
19
+ verify=False,
20
+ retries=1
21
+ )
22
+
23
+ # 设置头部
24
+ headers = {
25
+ "User-Agent": "curl/8.7.1",
26
+ "Accept": "*/*",
27
+ "Content-Type": "application/json",
28
+ "Authorization": "Bearer sk-xxxxxxx"
29
+ }
30
+
31
+ # 请求数据
32
+ data = {
33
+ "model": "gpt-4o",
34
+ "messages": [
35
+ {
36
+ "role": "user",
37
+ "content": "say test"
38
+ }
39
+ ],
40
+ "stream": True
41
+ }
42
+
43
+ async with httpx.AsyncClient(transport=transport) as client:
44
+ try:
45
+ response = await client.post(
46
+ "https://api.xxxxxxxxxx.me/v1/chat/completions",
47
+ headers=headers,
48
+ json=data,
49
+ timeout=30.0
50
+ )
51
+
52
+ logger.info(f"Status Code: {response.status_code}")
53
+ logger.info(f"Headers: {response.headers}")
54
+
55
+ # 处理流式响应
56
+ async for line in response.aiter_lines():
57
+ if line:
58
+ print(line)
59
+
60
+ except httpx.RequestError as e:
61
+ logger.error(f"An error occurred while requesting {e.request.url!r}.")
62
+
63
+ # 运行异步函数
64
+ asyncio.run(make_request())
utils.py CHANGED
@@ -13,6 +13,8 @@ def update_config(config_data):
13
  if type(model) == dict:
14
  model_dict.update({new: old for old, new in model.items()})
15
  provider['model'] = model_dict
 
 
16
  config_data['providers'][index] = provider
17
  api_keys_db = config_data['api_keys']
18
  api_list = [item["api"] for item in api_keys_db]
 
13
  if type(model) == dict:
14
  model_dict.update({new: old for old, new in model.items()})
15
  provider['model'] = model_dict
16
+ if provider.get('project_id'):
17
+ provider['base_url'] = 'https://aiplatform.googleapis.com/'
18
  config_data['providers'][index] = provider
19
  api_keys_db = config_data['api_keys']
20
  api_list = [item["api"] for item in api_keys_db]