🤖 Models: Add support for o1-mini o1-preview model
Browse files- main.py +4 -0
- request.py +58 -0
- utils.py +9 -0
main.py
CHANGED
@@ -201,6 +201,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
201 |
if "gemini" in provider['model'][request.model] and engine == "vertex":
|
202 |
engine = "vertex-gemini"
|
203 |
|
|
|
|
|
|
|
|
|
204 |
if endpoint == "/v1/images/generations":
|
205 |
engine = "dalle"
|
206 |
request.stream = False
|
|
|
201 |
if "gemini" in provider['model'][request.model] and engine == "vertex":
|
202 |
engine = "vertex-gemini"
|
203 |
|
204 |
+
if "o1-preview" in provider['model'][request.model] or "o1-mini" in provider['model'][request.model]:
|
205 |
+
engine = "o1"
|
206 |
+
request.stream = False
|
207 |
+
|
208 |
if endpoint == "/v1/images/generations":
|
209 |
engine = "dalle"
|
210 |
request.stream = False
|
request.py
CHANGED
@@ -737,6 +737,62 @@ async def get_cloudflare_payload(request, engine, provider):
|
|
737 |
|
738 |
return url, headers, payload
|
739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
740 |
async def gpt2claude_tools_json(json_dict):
|
741 |
import copy
|
742 |
json_dict = copy.deepcopy(json_dict)
|
@@ -929,6 +985,8 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
929 |
return await get_openrouter_payload(request, engine, provider)
|
930 |
elif engine == "cloudflare":
|
931 |
return await get_cloudflare_payload(request, engine, provider)
|
|
|
|
|
932 |
elif engine == "dalle":
|
933 |
return await get_dalle_payload(request, engine, provider)
|
934 |
else:
|
|
|
737 |
|
738 |
return url, headers, payload
|
739 |
|
740 |
+
async def get_o1_payload(request, engine, provider):
|
741 |
+
headers = {
|
742 |
+
'Content-Type': 'application/json'
|
743 |
+
}
|
744 |
+
if provider.get("api"):
|
745 |
+
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
746 |
+
|
747 |
+
url = provider['base_url']
|
748 |
+
|
749 |
+
messages = []
|
750 |
+
for msg in request.messages:
|
751 |
+
if isinstance(msg.content, list):
|
752 |
+
content = []
|
753 |
+
for item in msg.content:
|
754 |
+
if item.type == "text":
|
755 |
+
text_message = await get_text_message(msg.role, item.text, engine)
|
756 |
+
content.append(text_message)
|
757 |
+
else:
|
758 |
+
content = msg.content
|
759 |
+
|
760 |
+
if isinstance(content, list):
|
761 |
+
for item in content:
|
762 |
+
if item["type"] == "text":
|
763 |
+
messages.append({"role": msg.role, "content": item["text"]})
|
764 |
+
else:
|
765 |
+
messages.append({"role": msg.role, "content": content})
|
766 |
+
|
767 |
+
model = provider['model'][request.model]
|
768 |
+
payload = {
|
769 |
+
"model": model,
|
770 |
+
"messages": messages,
|
771 |
+
}
|
772 |
+
|
773 |
+
miss_fields = [
|
774 |
+
'model',
|
775 |
+
'messages',
|
776 |
+
'tools',
|
777 |
+
'tool_choice',
|
778 |
+
'temperature',
|
779 |
+
'top_p',
|
780 |
+
'max_tokens',
|
781 |
+
'presence_penalty',
|
782 |
+
'frequency_penalty',
|
783 |
+
'n',
|
784 |
+
'user',
|
785 |
+
'include_usage',
|
786 |
+
'logprobs',
|
787 |
+
'top_logprobs'
|
788 |
+
]
|
789 |
+
|
790 |
+
for field, value in request.model_dump(exclude_unset=True).items():
|
791 |
+
if field not in miss_fields and value is not None:
|
792 |
+
payload[field] = value
|
793 |
+
|
794 |
+
return url, headers, payload
|
795 |
+
|
796 |
async def gpt2claude_tools_json(json_dict):
|
797 |
import copy
|
798 |
json_dict = copy.deepcopy(json_dict)
|
|
|
985 |
return await get_openrouter_payload(request, engine, provider)
|
986 |
elif engine == "cloudflare":
|
987 |
return await get_cloudflare_payload(request, engine, provider)
|
988 |
+
elif engine == "o1":
|
989 |
+
return await get_o1_payload(request, engine, provider)
|
990 |
elif engine == "dalle":
|
991 |
return await get_dalle_payload(request, engine, provider)
|
992 |
else:
|
utils.py
CHANGED
@@ -53,6 +53,15 @@ def update_config(config_data):
|
|
53 |
async def load_config(app=None):
|
54 |
import yaml
|
55 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
with open('./api.yaml', 'r') as f:
|
57 |
# 判断是否为空文件
|
58 |
conf = yaml.safe_load(f)
|
|
|
53 |
async def load_config(app=None):
|
54 |
import yaml
|
55 |
try:
|
56 |
+
# with open('./api.yaml', 'r') as f:
|
57 |
+
# tokens = yaml.scan(f)
|
58 |
+
# for token in tokens:
|
59 |
+
# if isinstance(token, yaml.ScalarToken):
|
60 |
+
# value = token.value
|
61 |
+
# # 如果plain为False,表示字符串被引号包裹
|
62 |
+
# is_quoted = not token.plain
|
63 |
+
# print(f"值: {value}, 是否被引号包裹: {is_quoted}")
|
64 |
+
|
65 |
with open('./api.yaml', 'r') as f:
|
66 |
# 判断是否为空文件
|
67 |
conf = yaml.safe_load(f)
|