🐛 Bug: Fix the bug of tool use request body format error
Browse files- .gitignore +3 -1
- models.py +9 -4
- request.py +30 -26
- test/test_nostream.py +2 -3
.gitignore
CHANGED
@@ -5,4 +5,6 @@ __pycache__
|
|
5 |
.vscode
|
6 |
node_modules
|
7 |
.wrangler
|
8 |
-
.pytest_cache
|
|
|
|
|
|
5 |
.vscode
|
6 |
node_modules
|
7 |
.wrangler
|
8 |
+
.pytest_cache
|
9 |
+
*.jpg
|
10 |
+
*.json
|
models.py
CHANGED
@@ -10,16 +10,14 @@ class ImageGenerationRequest(BaseModel):
|
|
10 |
|
11 |
class FunctionParameter(BaseModel):
|
12 |
type: str
|
13 |
-
properties: Dict[str, Dict[str, str]]
|
14 |
required: List[str]
|
15 |
|
16 |
-
# 定义 Function 模型
|
17 |
class Function(BaseModel):
|
18 |
name: str
|
19 |
description: str
|
20 |
parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)
|
21 |
|
22 |
-
# 定义 Tool 模型
|
23 |
class Tool(BaseModel):
|
24 |
type: str
|
25 |
function: Function
|
@@ -58,6 +56,13 @@ class Message(BaseModel):
|
|
58 |
class Config:
|
59 |
extra = "allow" # 允许额外的字段
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
class RequestModel(BaseModel):
|
62 |
model: str
|
63 |
messages: List[Message]
|
@@ -72,5 +77,5 @@ class RequestModel(BaseModel):
|
|
72 |
frequency_penalty: Optional[float] = 0.0
|
73 |
n: Optional[int] = 1
|
74 |
user: Optional[str] = None
|
75 |
-
tool_choice: Optional[str] = None
|
76 |
tools: Optional[List[Tool]] = None
|
|
|
10 |
|
11 |
class FunctionParameter(BaseModel):
|
12 |
type: str
|
13 |
+
properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
|
14 |
required: List[str]
|
15 |
|
|
|
16 |
class Function(BaseModel):
|
17 |
name: str
|
18 |
description: str
|
19 |
parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)
|
20 |
|
|
|
21 |
class Tool(BaseModel):
|
22 |
type: str
|
23 |
function: Function
|
|
|
56 |
class Config:
|
57 |
extra = "allow" # 允许额外的字段
|
58 |
|
59 |
+
class FunctionChoice(BaseModel):
|
60 |
+
name: str
|
61 |
+
|
62 |
+
class ToolChoice(BaseModel):
|
63 |
+
type: str
|
64 |
+
function: Optional[FunctionChoice] = None
|
65 |
+
|
66 |
class RequestModel(BaseModel):
|
67 |
model: str
|
68 |
messages: List[Message]
|
|
|
77 |
frequency_penalty: Optional[float] = 0.0
|
78 |
n: Optional[int] = 1
|
79 |
user: Optional[str] = None
|
80 |
+
tool_choice: Optional[Union[str, ToolChoice]] = None
|
81 |
tools: Optional[List[Tool]] = None
|
request.py
CHANGED
@@ -474,19 +474,21 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
474 |
tools.append(json_tool)
|
475 |
payload["tools"] = tools
|
476 |
if "tool_choice" in payload:
|
477 |
-
if payload["tool_choice"]
|
478 |
-
payload["tool_choice"]
|
479 |
-
"
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
|
|
|
|
490 |
|
491 |
if provider.get("tools") == False:
|
492 |
payload.pop("tools", None)
|
@@ -746,19 +748,21 @@ async def get_claude_payload(request, engine, provider):
|
|
746 |
tools.append(json_tool)
|
747 |
payload["tools"] = tools
|
748 |
if "tool_choice" in payload:
|
749 |
-
if payload["tool_choice"]
|
750 |
-
payload["tool_choice"]
|
751 |
-
"
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
|
|
|
|
762 |
|
763 |
if provider.get("tools") == False:
|
764 |
payload.pop("tools", None)
|
|
|
474 |
tools.append(json_tool)
|
475 |
payload["tools"] = tools
|
476 |
if "tool_choice" in payload:
|
477 |
+
if isinstance(payload["tool_choice"], dict):
|
478 |
+
if payload["tool_choice"]["type"] == "function":
|
479 |
+
payload["tool_choice"] = {
|
480 |
+
"type": "tool",
|
481 |
+
"name": payload["tool_choice"]["function"]["name"]
|
482 |
+
}
|
483 |
+
if isinstance(payload["tool_choice"], str):
|
484 |
+
if payload["tool_choice"] == "auto":
|
485 |
+
payload["tool_choice"] = {
|
486 |
+
"type": "auto"
|
487 |
+
}
|
488 |
+
if payload["tool_choice"] == "none":
|
489 |
+
payload["tool_choice"] = {
|
490 |
+
"type": "any"
|
491 |
+
}
|
492 |
|
493 |
if provider.get("tools") == False:
|
494 |
payload.pop("tools", None)
|
|
|
748 |
tools.append(json_tool)
|
749 |
payload["tools"] = tools
|
750 |
if "tool_choice" in payload:
|
751 |
+
if isinstance(payload["tool_choice"], dict):
|
752 |
+
if payload["tool_choice"]["type"] == "function":
|
753 |
+
payload["tool_choice"] = {
|
754 |
+
"type": "tool",
|
755 |
+
"name": payload["tool_choice"]["function"]["name"]
|
756 |
+
}
|
757 |
+
if isinstance(payload["tool_choice"], str):
|
758 |
+
if payload["tool_choice"] == "auto":
|
759 |
+
payload["tool_choice"] = {
|
760 |
+
"type": "auto"
|
761 |
+
}
|
762 |
+
if payload["tool_choice"] == "none":
|
763 |
+
payload["tool_choice"] = {
|
764 |
+
"type": "any"
|
765 |
+
}
|
766 |
|
767 |
if provider.get("tools") == False:
|
768 |
payload.pop("tools", None)
|
test/test_nostream.py
CHANGED
@@ -45,7 +45,6 @@ def get_model_response(image_base64):
|
|
45 |
]
|
46 |
|
47 |
payload = {
|
48 |
-
|
49 |
"model": "claude-3-5-sonnet",
|
50 |
"messages": [
|
51 |
{
|
@@ -64,7 +63,7 @@ def get_model_response(image_base64):
|
|
64 |
]
|
65 |
}
|
66 |
],
|
67 |
-
"stream": True,
|
68 |
"tools": tools,
|
69 |
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
|
70 |
"max_tokens": 300
|
@@ -117,5 +116,5 @@ def main(image_path):
|
|
117 |
print("\n無法解析回應。")
|
118 |
|
119 |
if __name__ == "__main__":
|
120 |
-
image_path = "
|
121 |
main(image_path)
|
|
|
45 |
]
|
46 |
|
47 |
payload = {
|
|
|
48 |
"model": "claude-3-5-sonnet",
|
49 |
"messages": [
|
50 |
{
|
|
|
63 |
]
|
64 |
}
|
65 |
],
|
66 |
+
# "stream": True,
|
67 |
"tools": tools,
|
68 |
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
|
69 |
"max_tokens": 300
|
|
|
116 |
print("\n無法解析回應。")
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
+
image_path = "1.jpg" # 替換為您的圖像路徑
|
120 |
main(image_path)
|