yym68686 commited on
Commit
eb02b52
·
1 Parent(s): 252357c

🤖 Models: Add support for the cohere series models

Browse files
Files changed (6) hide show
  1. README.md +3 -3
  2. README_CN.md +3 -3
  3. main.py +5 -1
  4. request.py +75 -5
  5. response.py +26 -0
  6. test/test_matplotlib.py +72 -1
README.md CHANGED
@@ -13,13 +13,13 @@
13
 
14
  ## Introduction
15
 
16
- If used personally, one/new-api is too complex and has many commercial functions that individuals do not need. If you do not want a complicated front-end interface and want to support more models, you can try uni-api. This is a project for unified management of large model APIs, allowing you to call multiple backend services through a unified API interface, converting them uniformly to OpenAI format and supporting load balancing. Currently supported backend services include: OpenAI, Anthropic, Gemini, Vertex, Cloudflare, DeepBricks, OpenRouter, etc.
17
 
18
  ## Features
19
 
20
  - No frontend, pure configuration file setup for API channels. You can run your own API site by just writing one file, with detailed configuration guides in the documentation, beginner-friendly.
21
  - Unified management of multiple backend services, supporting providers like OpenAI, Deepseek, DeepBricks, OpenRouter, and other APIs in the OpenAI format. Supports OpenAI Dalle-3 image generation.
22
- - Supports Anthropic, Gemini, Vertex API, and Cloudflare. Vertex supports both Claude and Gemini API.
23
  - Supports OpenAI, Anthropic, Gemini, Vertex native tool use function calls.
24
  - Supports OpenAI, Anthropic, Gemini, Vertex native image recognition API.
25
  - Supports four types of load balancing.
@@ -125,7 +125,7 @@ api_keys:
125
  ## Environment Variables
126
 
127
  - CONFIG_URL: The download address of the configuration file, it can be a local file or a remote file, optional
128
- - TIMEOUT: Request timeout, default is 40 seconds. The timeout can control the time needed to switch to the next channel when a channel does not respond. Optional.
129
 
130
  ## Docker Local Deployment
131
 
 
13
 
14
  ## Introduction
15
 
16
+ If used for personal purposes, one/new-api is too complex and has many commercial features that individuals do not need. If you do not want a complex front-end interface and want to support more models, you can try uni-api. This is a project that manages large model APIs uniformly and allows you to call multiple backend services through a unified API interface, converting them uniformly to the OpenAI format and supporting load balancing. The currently supported backend services include: OpenAI, Anthropic, Gemini, Vertex, Cohere, Cloudflare, DeepBricks, OpenRouter, etc.
17
 
18
  ## Features
19
 
20
  - No frontend, pure configuration file setup for API channels. You can run your own API site by just writing one file, with detailed configuration guides in the documentation, beginner-friendly.
21
  - Unified management of multiple backend services, supporting providers like OpenAI, Deepseek, DeepBricks, OpenRouter, and other APIs in the OpenAI format. Supports OpenAI Dalle-3 image generation.
22
+ - Supports Anthropic, Gemini, Vertex AI, Cohere, Cloudflare. Vertex supports both Claude and Gemini API.
23
  - Supports OpenAI, Anthropic, Gemini, Vertex native tool use function calls.
24
  - Supports OpenAI, Anthropic, Gemini, Vertex native image recognition API.
25
  - Supports four types of load balancing.
 
125
  ## Environment Variables
126
 
127
  - CONFIG_URL: The download address of the configuration file, it can be a local file or a remote file, optional
128
+ - TIMEOUT: Request timeout, default is 100 seconds, the timeout can control the time needed to switch to the next channel when a channel does not respond. Optional
129
 
130
  ## Docker Local Deployment
131
 
README_CN.md CHANGED
@@ -13,13 +13,13 @@
13
 
14
  ## Introduction
15
 
16
- 如果个人使用的话,one/new-api 过于复杂,有很多个人不需要使用的商用功能,如果你不想要复杂的前端界面,有想要支持的模型多一点,可以试试 uni-api。这是一个统一管理大模型API的项目,可以通过一个统一的API接口调用多个后端服务,统一转换为 OpenAI 格式,支持负载均衡。目前支持的后端服务有:OpenAI、Anthropic、Gemini、Vertex、cloudflare、DeepBricks、OpenRouter 等。
17
 
18
  ## Features
19
 
20
  - 无前端,纯配置文件配置 API 渠道。只要写一个文件就能运行起一个属于自己的 API 站,文档有详细的配置指南,小白友好。
21
  - 统一管理多个后端服务,支持 OpenAI、Deepseek、DeepBricks、OpenRouter 等其他 API 是 OpenAI 格式的提供商。支持 OpenAI Dalle-3 图像生成。
22
- - 同时支持 Anthropic、Gemini、Vertex APIcloudflare。Vertex 同时支持 Claude 和 Gemini API。
23
  - 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
24
  - 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
25
  - 支持四种负载均衡。
@@ -125,7 +125,7 @@ api_keys:
125
  ## 环境变量
126
 
127
  - CONFIG_URL: 配置文件的下载地址,可以是本地文件,也可以是远程文件,选填
128
- - TIMEOUT: 请求超时时间,默认为 40 秒,超时时间可以控制当一个渠道没有响应时,切换下一个渠道需要的时间。选填
129
 
130
  ## Docker Local Deployment
131
 
 
13
 
14
  ## Introduction
15
 
16
+ 如果个人使用的话,one/new-api 过于复杂,有很多个人不需要使用的商用功能,如果你不想要复杂的前端界面,有想要支持的模型多一点,可以试试 uni-api。这是一个统一管理大模型API的项目,可以通过一个统一的API接口调用多个后端服务,统一转换为 OpenAI 格式,支持负载均衡。目前支持的后端服务有:OpenAI、Anthropic、Gemini、Vertex、CohereCloudflare、DeepBricks、OpenRouter 等。
17
 
18
  ## Features
19
 
20
  - 无前端,纯配置文件配置 API 渠道。只要写一个文件就能运行起一个属于自己的 API 站,文档有详细的配置指南,小白友好。
21
  - 统一管理多个后端服务,支持 OpenAI、Deepseek、DeepBricks、OpenRouter 等其他 API 是 OpenAI 格式的提供商。支持 OpenAI Dalle-3 图像生成。
22
+ - 同时支持 Anthropic、Gemini、Vertex AICohere、Cloudflare。Vertex 同时支持 Claude 和 Gemini API。
23
  - 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
24
  - 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
25
  - 支持四种负载均衡。
 
125
  ## 环境变量
126
 
127
  - CONFIG_URL: 配置文件的下载地址,可以是本地文件,也可以是远程文件,选填
128
+ - TIMEOUT: 请求超时时间,默认为 100 秒,超时时间可以控制当一个渠道没有响应时,切换下一个渠道需要的时间。选填
129
 
130
  ## Docker Local Deployment
131
 
main.py CHANGED
@@ -229,13 +229,17 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
229
  engine = "claude"
230
  elif parsed_url.netloc == 'openrouter.ai':
231
  engine = "openrouter"
 
 
 
232
  else:
233
  engine = "gpt"
234
 
235
  if "claude" not in provider['model'][request.model] \
236
  and "gpt" not in provider['model'][request.model] \
237
  and "gemini" not in provider['model'][request.model] \
238
- and parsed_url.netloc != 'api.cloudflare.com':
 
239
  engine = "openrouter"
240
 
241
  if "claude" in provider['model'][request.model] and engine == "vertex":
 
229
  engine = "claude"
230
  elif parsed_url.netloc == 'openrouter.ai':
231
  engine = "openrouter"
232
+ elif parsed_url.netloc == 'api.cohere.com':
233
+ engine = "cohere"
234
+ request.stream = True
235
  else:
236
  engine = "gpt"
237
 
238
  if "claude" not in provider['model'][request.model] \
239
  and "gpt" not in provider['model'][request.model] \
240
  and "gemini" not in provider['model'][request.model] \
241
+ and parsed_url.netloc != 'api.cloudflare.com' \
242
+ and parsed_url.netloc != 'api.cohere.com':
243
  engine = "openrouter"
244
 
245
  if "claude" in provider['model'][request.model] and engine == "vertex":
request.py CHANGED
@@ -1,12 +1,13 @@
1
  import os
2
  import re
3
  import json
4
- from models import RequestModel
5
- from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
6
-
7
  import base64
8
  import urllib.parse
9
 
 
 
 
10
  def encode_image(image_path):
11
  with open(image_path, "rb") as image_file:
12
  return base64.b64encode(image_file.read()).decode('utf-8')
@@ -82,6 +83,8 @@ async def get_text_message(role, message, engine = None):
82
  return {"text": message}
83
  if engine == "cloudflare":
84
  return message
 
 
85
  raise ValueError("Unknown engine")
86
 
87
  async def get_gemini_payload(request, engine, provider):
@@ -215,8 +218,6 @@ async def get_gemini_payload(request, engine, provider):
215
  return url, headers, payload
216
 
217
  import time
218
- import httpx
219
- import base64
220
  from cryptography.hazmat.primitives import hashes
221
  from cryptography.hazmat.primitives.asymmetric import padding
222
  from cryptography.hazmat.primitives.serialization import load_pem_private_key
@@ -690,6 +691,73 @@ async def get_openrouter_payload(request, engine, provider):
690
 
691
  return url, headers, payload
692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  async def get_cloudflare_payload(request, engine, provider):
694
  headers = {
695
  'Content-Type': 'application/json'
@@ -989,6 +1057,8 @@ async def get_payload(request: RequestModel, engine, provider):
989
  return await get_cloudflare_payload(request, engine, provider)
990
  elif engine == "o1":
991
  return await get_o1_payload(request, engine, provider)
 
 
992
  elif engine == "dalle":
993
  return await get_dalle_payload(request, engine, provider)
994
  else:
 
1
  import os
2
  import re
3
  import json
4
+ import httpx
 
 
5
  import base64
6
  import urllib.parse
7
 
8
+ from models import RequestModel
9
+ from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
10
+
11
  def encode_image(image_path):
12
  with open(image_path, "rb") as image_file:
13
  return base64.b64encode(image_file.read()).decode('utf-8')
 
83
  return {"text": message}
84
  if engine == "cloudflare":
85
  return message
86
+ if engine == "cohere":
87
+ return message
88
  raise ValueError("Unknown engine")
89
 
90
  async def get_gemini_payload(request, engine, provider):
 
218
  return url, headers, payload
219
 
220
  import time
 
 
221
  from cryptography.hazmat.primitives import hashes
222
  from cryptography.hazmat.primitives.asymmetric import padding
223
  from cryptography.hazmat.primitives.serialization import load_pem_private_key
 
691
 
692
  return url, headers, payload
693
 
694
+ async def get_cohere_payload(request, engine, provider):
695
+ headers = {
696
+ 'Content-Type': 'application/json'
697
+ }
698
+ if provider.get("api"):
699
+ headers['Authorization'] = f"Bearer {provider['api'].next()}"
700
+
701
+ url = provider['base_url']
702
+
703
+ role_map = {
704
+ "user": "USER",
705
+ "assistant" : "CHATBOT",
706
+ "system": "SYSTEM"
707
+ }
708
+
709
+ messages = []
710
+ for msg in request.messages:
711
+ if isinstance(msg.content, list):
712
+ content = []
713
+ for item in msg.content:
714
+ if item.type == "text":
715
+ text_message = await get_text_message(msg.role, item.text, engine)
716
+ content.append(text_message)
717
+ else:
718
+ content = msg.content
719
+
720
+ if isinstance(content, list):
721
+ for item in content:
722
+ if item["type"] == "text":
723
+ messages.append({"role": role_map[msg.role], "message": item["text"]})
724
+ else:
725
+ messages.append({"role": role_map[msg.role], "message": content})
726
+
727
+ model = provider['model'][request.model]
728
+ chat_history = messages[:-1]
729
+ query = messages[-1].get("message")
730
+ payload = {
731
+ "model": model,
732
+ "message": query,
733
+ }
734
+
735
+ if chat_history:
736
+ payload["chat_history"] = chat_history
737
+
738
+ miss_fields = [
739
+ 'model',
740
+ 'messages',
741
+ 'tools',
742
+ 'tool_choice',
743
+ 'temperature',
744
+ 'top_p',
745
+ 'max_tokens',
746
+ 'presence_penalty',
747
+ 'frequency_penalty',
748
+ 'n',
749
+ 'user',
750
+ 'include_usage',
751
+ 'logprobs',
752
+ 'top_logprobs'
753
+ ]
754
+
755
+ for field, value in request.model_dump(exclude_unset=True).items():
756
+ if field not in miss_fields and value is not None:
757
+ payload[field] = value
758
+
759
+ return url, headers, payload
760
+
761
  async def get_cloudflare_payload(request, engine, provider):
762
  headers = {
763
  'Content-Type': 'application/json'
 
1057
  return await get_cloudflare_payload(request, engine, provider)
1058
  elif engine == "o1":
1059
  return await get_o1_payload(request, engine, provider)
1060
+ elif engine == "cohere":
1061
+ return await get_cohere_payload(request, engine, provider)
1062
  elif engine == "dalle":
1063
  return await get_dalle_payload(request, engine, provider)
1064
  else:
response.py CHANGED
@@ -184,6 +184,29 @@ async def fetch_cloudflare_response_stream(client, url, headers, payload, model)
184
  sse_string = await generate_sse_response(timestamp, model, content=message)
185
  yield sse_string
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  async def fetch_claude_response_stream(client, url, headers, payload, model):
188
  timestamp = int(datetime.timestamp(datetime.now()))
189
  async with client.stream('POST', url, headers=headers, json=payload) as response:
@@ -270,6 +293,9 @@ async def fetch_response_stream(client, url, headers, payload, engine, model):
270
  elif engine == "cloudflare":
271
  async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model):
272
  yield chunk
 
 
 
273
  else:
274
  raise ValueError("Unknown response")
275
  except httpx.ConnectError as e:
 
184
  sse_string = await generate_sse_response(timestamp, model, content=message)
185
  yield sse_string
186
 
187
+ async def fetch_cohere_response_stream(client, url, headers, payload, model):
188
+ timestamp = int(datetime.timestamp(datetime.now()))
189
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
190
+ error_message = await check_response(response, "fetch_gpt_response_stream")
191
+ if error_message:
192
+ yield error_message
193
+ return
194
+
195
+ buffer = ""
196
+ async for chunk in response.aiter_text():
197
+ buffer += chunk
198
+ while "\n" in buffer:
199
+ line, buffer = buffer.split("\n", 1)
200
+ # logger.info("line: %s", repr(line))
201
+ resp: dict = json.loads(line)
202
+ if resp.get("is_finished") == True:
203
+ yield "data: [DONE]\n\r\n"
204
+ return
205
+ if resp.get("event_type") == "text-generation":
206
+ message = resp.get("text")
207
+ sse_string = await generate_sse_response(timestamp, model, content=message)
208
+ yield sse_string
209
+
210
  async def fetch_claude_response_stream(client, url, headers, payload, model):
211
  timestamp = int(datetime.timestamp(datetime.now()))
212
  async with client.stream('POST', url, headers=headers, json=payload) as response:
 
293
  elif engine == "cloudflare":
294
  async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model):
295
  yield chunk
296
+ elif engine == "cohere":
297
+ async for chunk in fetch_cohere_response_stream(client, url, headers, payload, model):
298
+ yield chunk
299
  else:
300
  raise ValueError("Unknown response")
301
  except httpx.ConnectError as e:
test/test_matplotlib.py CHANGED
@@ -2,6 +2,7 @@ import json
2
  import matplotlib.pyplot as plt
3
  from datetime import datetime, timedelta
4
  from collections import defaultdict
 
5
 
6
  import matplotlib.font_manager as fm
7
  font_path = '/System/Library/Fonts/PingFang.ttc'
@@ -45,5 +46,75 @@ def create_pic(request_arrivals, key):
45
  # 保存图片
46
  plt.savefig(f'{key.replace("/", "")}.png')
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  if __name__ == '__main__':
49
- create_pic(request_arrivals, 'POST /v1/chat/completions')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import matplotlib.pyplot as plt
3
  from datetime import datetime, timedelta
4
  from collections import defaultdict
5
+ import numpy as np
6
 
7
  import matplotlib.font_manager as fm
8
  font_path = '/System/Library/Fonts/PingFang.ttc'
 
46
  # 保存图片
47
  plt.savefig(f'{key.replace("/", "")}.png')
48
 
49
+ def create_pie_chart(model_counts):
50
+ models = list(model_counts.keys())
51
+ counts = list(model_counts.values())
52
+
53
+ # 设置颜色和排列顺序
54
+ colors = plt.cm.Set3(np.linspace(0, 1, len(models)))
55
+ sorted_data = sorted(zip(counts, models, colors), reverse=True)
56
+ counts, models, colors = zip(*sorted_data)
57
+
58
+ # 创建饼图
59
+ fig, ax = plt.subplots(figsize=(16, 10))
60
+ wedges, _ = ax.pie(counts, colors=colors, startangle=90, wedgeprops=dict(width=0.5))
61
+
62
+ # 添加圆环效果
63
+ centre_circle = plt.Circle((0, 0), 0.35, fc='white')
64
+ fig.gca().add_artist(centre_circle)
65
+
66
+ # 计算总数
67
+ total = sum(counts)
68
+
69
+ # 准备标注
70
+ bbox_props = dict(boxstyle="round,pad=0.3", fc="w", ec="k", lw=0.72)
71
+ kw = dict(xycoords='data', textcoords='data', arrowprops=dict(arrowstyle="-"), bbox=bbox_props, zorder=0)
72
+
73
+ left_labels = []
74
+ right_labels = []
75
+
76
+ for i, p in enumerate(wedges):
77
+ ang = (p.theta2 - p.theta1) / 2. + p.theta1
78
+ y = np.sin(np.deg2rad(ang))
79
+ x = np.cos(np.deg2rad(ang))
80
+
81
+ percentage = counts[i] / total * 100
82
+ label = f"{models[i]}: {percentage:.1f}%"
83
+
84
+ if x > 0:
85
+ right_labels.append((x, y, label))
86
+ else:
87
+ left_labels.append((x, y, label))
88
+
89
+ # 绘制左侧标注
90
+ for i, (x, y, label) in enumerate(left_labels):
91
+ ax.annotate(label, xy=(x, y), xytext=(-1.2, 0.9 - i * 0.15), **kw)
92
+
93
+ # 绘制右侧标注
94
+ for i, (x, y, label) in enumerate(right_labels):
95
+ ax.annotate(label, xy=(x, y), xytext=(1.2, 0.9 - i * 0.15), **kw)
96
+
97
+ plt.title("各模型使用次数对比", size=16)
98
+ ax.set_xlim(-1.5, 1.5)
99
+ ax.set_ylim(-1.2, 1.2)
100
+ ax.axis('off')
101
+ plt.tight_layout()
102
+ plt.savefig('model_usage_pie_chart.png', bbox_inches='tight', pad_inches=0.5)
103
+
104
  if __name__ == '__main__':
105
+ model_counts = {
106
+ "model_counts": {
107
+ "claude-3-5-sonnet": 94,
108
+ "o1-preview": 71,
109
+ "gpt-4o": 512,
110
+ "gpt-4o-mini": 5,
111
+ "gemini-1.5-pro": 5,
112
+ "deepseek-chat": 7,
113
+ "grok-2-mini": 1,
114
+ "grok-2": 9,
115
+ "o1-mini": 8
116
+ }
117
+ }
118
+ # create_pic(request_arrivals, 'POST /v1/chat/completions')
119
+
120
+ create_pie_chart(model_counts["model_counts"])