yym68686 commited on
Commit
1126d73
·
1 Parent(s): f4d6dda

🤖 Models: Add support for o1-mini o1-preview model

Browse files
Files changed (3) hide show
  1. main.py +4 -0
  2. request.py +58 -0
  3. utils.py +9 -0
main.py CHANGED
@@ -201,6 +201,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
201
  if "gemini" in provider['model'][request.model] and engine == "vertex":
202
  engine = "vertex-gemini"
203
 
 
 
 
 
204
  if endpoint == "/v1/images/generations":
205
  engine = "dalle"
206
  request.stream = False
 
201
  if "gemini" in provider['model'][request.model] and engine == "vertex":
202
  engine = "vertex-gemini"
203
 
204
+ if "o1-preview" in provider['model'][request.model] or "o1-mini" in provider['model'][request.model]:
205
+ engine = "o1"
206
+ request.stream = False
207
+
208
  if endpoint == "/v1/images/generations":
209
  engine = "dalle"
210
  request.stream = False
request.py CHANGED
@@ -737,6 +737,62 @@ async def get_cloudflare_payload(request, engine, provider):
737
 
738
  return url, headers, payload
739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
  async def gpt2claude_tools_json(json_dict):
741
  import copy
742
  json_dict = copy.deepcopy(json_dict)
@@ -929,6 +985,8 @@ async def get_payload(request: RequestModel, engine, provider):
929
  return await get_openrouter_payload(request, engine, provider)
930
  elif engine == "cloudflare":
931
  return await get_cloudflare_payload(request, engine, provider)
 
 
932
  elif engine == "dalle":
933
  return await get_dalle_payload(request, engine, provider)
934
  else:
 
737
 
738
  return url, headers, payload
739
 
740
+ async def get_o1_payload(request, engine, provider):
741
+ headers = {
742
+ 'Content-Type': 'application/json'
743
+ }
744
+ if provider.get("api"):
745
+ headers['Authorization'] = f"Bearer {provider['api'].next()}"
746
+
747
+ url = provider['base_url']
748
+
749
+ messages = []
750
+ for msg in request.messages:
751
+ if isinstance(msg.content, list):
752
+ content = []
753
+ for item in msg.content:
754
+ if item.type == "text":
755
+ text_message = await get_text_message(msg.role, item.text, engine)
756
+ content.append(text_message)
757
+ else:
758
+ content = msg.content
759
+
760
+ if isinstance(content, list):
761
+ for item in content:
762
+ if item["type"] == "text":
763
+ messages.append({"role": msg.role, "content": item["text"]})
764
+ else:
765
+ messages.append({"role": msg.role, "content": content})
766
+
767
+ model = provider['model'][request.model]
768
+ payload = {
769
+ "model": model,
770
+ "messages": messages,
771
+ }
772
+
773
+ miss_fields = [
774
+ 'model',
775
+ 'messages',
776
+ 'tools',
777
+ 'tool_choice',
778
+ 'temperature',
779
+ 'top_p',
780
+ 'max_tokens',
781
+ 'presence_penalty',
782
+ 'frequency_penalty',
783
+ 'n',
784
+ 'user',
785
+ 'include_usage',
786
+ 'logprobs',
787
+ 'top_logprobs'
788
+ ]
789
+
790
+ for field, value in request.model_dump(exclude_unset=True).items():
791
+ if field not in miss_fields and value is not None:
792
+ payload[field] = value
793
+
794
+ return url, headers, payload
795
+
796
  async def gpt2claude_tools_json(json_dict):
797
  import copy
798
  json_dict = copy.deepcopy(json_dict)
 
985
  return await get_openrouter_payload(request, engine, provider)
986
  elif engine == "cloudflare":
987
  return await get_cloudflare_payload(request, engine, provider)
988
+ elif engine == "o1":
989
+ return await get_o1_payload(request, engine, provider)
990
  elif engine == "dalle":
991
  return await get_dalle_payload(request, engine, provider)
992
  else:
utils.py CHANGED
@@ -53,6 +53,15 @@ def update_config(config_data):
53
  async def load_config(app=None):
54
  import yaml
55
  try:
 
 
 
 
 
 
 
 
 
56
  with open('./api.yaml', 'r') as f:
57
  # 判断是否为空文件
58
  conf = yaml.safe_load(f)
 
53
  async def load_config(app=None):
54
  import yaml
55
  try:
56
+ # with open('./api.yaml', 'r') as f:
57
+ # tokens = yaml.scan(f)
58
+ # for token in tokens:
59
+ # if isinstance(token, yaml.ScalarToken):
60
+ # value = token.value
61
+ # # 如果plain为False,表示字符串被引号包裹
62
+ # is_quoted = not token.plain
63
+ # print(f"值: {value}, 是否被引号包裹: {is_quoted}")
64
+
65
  with open('./api.yaml', 'r') as f:
66
  # 判断是否为空文件
67
  conf = yaml.safe_load(f)