yym68686 commited on
Commit
44caf41
·
1 Parent(s): 0ce2715

🐛 Bug: Fix the bug of tool use request body format error

Browse files
Files changed (4) hide show
  1. .gitignore +3 -1
  2. models.py +9 -4
  3. request.py +30 -26
  4. test/test_nostream.py +2 -3
.gitignore CHANGED
@@ -5,4 +5,6 @@ __pycache__
5
  .vscode
6
  node_modules
7
  .wrangler
8
- .pytest_cache
 
 
 
5
  .vscode
6
  node_modules
7
  .wrangler
8
+ .pytest_cache
9
+ *.jpg
10
+ *.json
models.py CHANGED
@@ -10,16 +10,14 @@ class ImageGenerationRequest(BaseModel):
10
 
11
  class FunctionParameter(BaseModel):
12
  type: str
13
- properties: Dict[str, Dict[str, str]]
14
  required: List[str]
15
 
16
- # 定义 Function 模型
17
  class Function(BaseModel):
18
  name: str
19
  description: str
20
  parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)
21
 
22
- # 定义 Tool 模型
23
  class Tool(BaseModel):
24
  type: str
25
  function: Function
@@ -58,6 +56,13 @@ class Message(BaseModel):
58
  class Config:
59
  extra = "allow" # 允许额外的字段
60
 
 
 
 
 
 
 
 
61
  class RequestModel(BaseModel):
62
  model: str
63
  messages: List[Message]
@@ -72,5 +77,5 @@ class RequestModel(BaseModel):
72
  frequency_penalty: Optional[float] = 0.0
73
  n: Optional[int] = 1
74
  user: Optional[str] = None
75
- tool_choice: Optional[str] = None
76
  tools: Optional[List[Tool]] = None
 
10
 
11
  class FunctionParameter(BaseModel):
12
  type: str
13
+ properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
14
  required: List[str]
15
 
 
16
  class Function(BaseModel):
17
  name: str
18
  description: str
19
  parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)
20
 
 
21
  class Tool(BaseModel):
22
  type: str
23
  function: Function
 
56
  class Config:
57
  extra = "allow" # 允许额外的字段
58
 
59
+ class FunctionChoice(BaseModel):
60
+ name: str
61
+
62
+ class ToolChoice(BaseModel):
63
+ type: str
64
+ function: Optional[FunctionChoice] = None
65
+
66
  class RequestModel(BaseModel):
67
  model: str
68
  messages: List[Message]
 
77
  frequency_penalty: Optional[float] = 0.0
78
  n: Optional[int] = 1
79
  user: Optional[str] = None
80
+ tool_choice: Optional[Union[str, ToolChoice]] = None
81
  tools: Optional[List[Tool]] = None
request.py CHANGED
@@ -474,19 +474,21 @@ async def get_vertex_claude_payload(request, engine, provider):
474
  tools.append(json_tool)
475
  payload["tools"] = tools
476
  if "tool_choice" in payload:
477
- if payload["tool_choice"]["type"] == "auto":
478
- payload["tool_choice"] = {
479
- "type": "auto"
480
- }
481
- if payload["tool_choice"]["type"] == "any":
482
- payload["tool_choice"] = {
483
- "type": "any"
484
- }
485
- if payload["tool_choice"]["type"] == "function":
486
- payload["tool_choice"] = {
487
- "type": "tool",
488
- "name": payload["tool_choice"]["function"]["name"]
489
- }
 
 
490
 
491
  if provider.get("tools") == False:
492
  payload.pop("tools", None)
@@ -746,19 +748,21 @@ async def get_claude_payload(request, engine, provider):
746
  tools.append(json_tool)
747
  payload["tools"] = tools
748
  if "tool_choice" in payload:
749
- if payload["tool_choice"]["type"] == "auto":
750
- payload["tool_choice"] = {
751
- "type": "auto"
752
- }
753
- if payload["tool_choice"]["type"] == "any":
754
- payload["tool_choice"] = {
755
- "type": "any"
756
- }
757
- if payload["tool_choice"]["type"] == "function":
758
- payload["tool_choice"] = {
759
- "type": "tool",
760
- "name": payload["tool_choice"]["function"]["name"]
761
- }
 
 
762
 
763
  if provider.get("tools") == False:
764
  payload.pop("tools", None)
 
474
  tools.append(json_tool)
475
  payload["tools"] = tools
476
  if "tool_choice" in payload:
477
+ if isinstance(payload["tool_choice"], dict):
478
+ if payload["tool_choice"]["type"] == "function":
479
+ payload["tool_choice"] = {
480
+ "type": "tool",
481
+ "name": payload["tool_choice"]["function"]["name"]
482
+ }
483
+ if isinstance(payload["tool_choice"], str):
484
+ if payload["tool_choice"] == "auto":
485
+ payload["tool_choice"] = {
486
+ "type": "auto"
487
+ }
488
+ if payload["tool_choice"] == "none":
489
+ payload["tool_choice"] = {
490
+ "type": "any"
491
+ }
492
 
493
  if provider.get("tools") == False:
494
  payload.pop("tools", None)
 
748
  tools.append(json_tool)
749
  payload["tools"] = tools
750
  if "tool_choice" in payload:
751
+ if isinstance(payload["tool_choice"], dict):
752
+ if payload["tool_choice"]["type"] == "function":
753
+ payload["tool_choice"] = {
754
+ "type": "tool",
755
+ "name": payload["tool_choice"]["function"]["name"]
756
+ }
757
+ if isinstance(payload["tool_choice"], str):
758
+ if payload["tool_choice"] == "auto":
759
+ payload["tool_choice"] = {
760
+ "type": "auto"
761
+ }
762
+ if payload["tool_choice"] == "none":
763
+ payload["tool_choice"] = {
764
+ "type": "any"
765
+ }
766
 
767
  if provider.get("tools") == False:
768
  payload.pop("tools", None)
test/test_nostream.py CHANGED
@@ -45,7 +45,6 @@ def get_model_response(image_base64):
45
  ]
46
 
47
  payload = {
48
-
49
  "model": "claude-3-5-sonnet",
50
  "messages": [
51
  {
@@ -64,7 +63,7 @@ def get_model_response(image_base64):
64
  ]
65
  }
66
  ],
67
- "stream": True,
68
  "tools": tools,
69
  "tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
70
  "max_tokens": 300
@@ -117,5 +116,5 @@ def main(image_path):
117
  print("\n無法解析回應。")
118
 
119
  if __name__ == "__main__":
120
- image_path = "00001 (8).jpg" # 替換為您的圖像路徑
121
  main(image_path)
 
45
  ]
46
 
47
  payload = {
 
48
  "model": "claude-3-5-sonnet",
49
  "messages": [
50
  {
 
63
  ]
64
  }
65
  ],
66
+ # "stream": True,
67
  "tools": tools,
68
  "tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
69
  "max_tokens": 300
 
116
  print("\n無法解析回應。")
117
 
118
  if __name__ == "__main__":
119
+ image_path = "1.jpg" # 替換為您的圖像路徑
120
  main(image_path)