File size: 5,051 Bytes
17409c4
923b378
599e110
09c584b
888a669
1af48fa
 
972208e
1af48fa
 
 
 
 
6b0ab57
1af48fa
 
 
 
 
0140d23
 
 
 
 
 
 
 
 
1af48fa
 
 
 
 
 
 
 
 
 
 
819dd2f
0140d23
 
 
 
 
 
 
 
 
 
 
 
1af48fa
44caf41
 
 
 
 
 
 
09c584b
 
 
923b378
 
 
 
 
 
 
 
599e110
923b378
599e110
 
 
 
09c584b
1af48fa
 
 
 
 
 
 
 
 
 
 
 
 
44caf41
f156f8a
599e110
f156f8a
 
 
 
 
 
 
 
 
 
c34a2a5
 
09c584b
c34a2a5
 
 
 
 
 
c50b8cc
 
 
 
 
 
09c584b
c34a2a5
 
 
 
 
 
 
 
 
 
 
09c584b
c34a2a5
 
 
 
 
c50b8cc
c34a2a5
 
 
 
 
 
 
09c584b
c34a2a5
 
09c584b
c34a2a5
 
09c584b
c34a2a5
 
09c584b
c50b8cc
 
 
c34a2a5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from io import IOBase
from pydantic import BaseModel, Field, model_validator, ConfigDict
from typing import List, Dict, Optional, Union, Tuple, Literal, Any
from log_config import logger

class FunctionParameter(BaseModel):
    type: str
    properties: Dict[str, Dict[str, Any]]
    required: List[str]

class Function(BaseModel):
    name: str
    description: str
    parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)

class Tool(BaseModel):
    type: str
    function: Function

class FunctionCall(BaseModel):
    name: str
    arguments: str

class ToolCall(BaseModel):
    id: str
    type: str
    function: FunctionCall

class ImageUrl(BaseModel):
    url: str

class ContentItem(BaseModel):
    type: str
    text: Optional[str] = None
    image_url: Optional[ImageUrl] = None

class Message(BaseModel):
    role: str
    name: Optional[str] = None
    arguments: Optional[str] = None
    content: Optional[Union[str, List[ContentItem]]] = None
    tool_calls: Optional[List[ToolCall]] = None

class Message(BaseModel):
    role: str
    name: Optional[str] = None
    content: Optional[Union[str, List[ContentItem]]] = None
    tool_calls: Optional[List[ToolCall]] = None
    tool_call_id: Optional[str] = None

    class Config:
        extra = "allow"  # 允许额外的字段

class FunctionChoice(BaseModel):
    name: str

class ToolChoice(BaseModel):
    type: str
    function: Optional[FunctionChoice] = None

class BaseRequest(BaseModel):
    request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True)

def create_json_schema_class():
    class JsonSchema(BaseModel):
        name: str

        model_config = ConfigDict(protected_namespaces=())

    JsonSchema.__annotations__['schema'] = Dict[str, Any]
    return JsonSchema

JsonSchema = create_json_schema_class()
class ResponseFormat(BaseModel):
    type: Literal["text", "json_object", "json_schema"]
    json_schema: Optional[JsonSchema] = None

class RequestModel(BaseRequest):
    model: str
    messages: List[Message]
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    stream: Optional[bool] = None
    include_usage: Optional[bool] = None
    temperature: Optional[float] = 0.5
    top_p: Optional[float] = 1.0
    max_tokens: Optional[int] = None
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0
    n: Optional[int] = 1
    user: Optional[str] = None
    tool_choice: Optional[Union[str, ToolChoice]] = None
    tools: Optional[List[Tool]] = None
    response_format: Optional[ResponseFormat] = None  # 新增字段

    def get_last_text_message(self) -> Optional[str]:
        for message in reversed(self.messages):
            if message.content:
                if isinstance(message.content, str):
                    return message.content
                elif isinstance(message.content, list):
                    for item in reversed(message.content):
                        if item.type == "text" and item.text:
                            return item.text
        return ""

class ImageGenerationRequest(BaseRequest):
    prompt: str
    model: Optional[str] = "dall-e-3"
    n:  Optional[int] = 1
    size: Optional[str] = "1024x1024"
    stream: bool = False

class EmbeddingRequest(BaseRequest):
    input: str
    model: str
    encoding_format: Optional[str] = "float"
    stream: bool = False

class AudioTranscriptionRequest(BaseRequest):
    file: Tuple[str, IOBase, str]
    model: str
    language: Optional[str] = None
    prompt: Optional[str] = None
    response_format: Optional[str] = None
    temperature: Optional[float] = None
    stream: bool = False

    class Config:
        arbitrary_types_allowed = True

class ModerationRequest(BaseRequest):
    input: str
    model: Optional[str] = "text-moderation-latest"
    stream: bool = False

class UnifiedRequest(BaseModel):
    data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest]

    @model_validator(mode='before')
    @classmethod
    def set_request_type(cls, values):
        if isinstance(values, dict):
            if "messages" in values:
                values["data"] = RequestModel(**values)
                values["data"].request_type = "chat"
            elif "prompt" in values:
                values["data"] = ImageGenerationRequest(**values)
                values["data"].request_type = "image"
            elif "file" in values:
                values["data"] = AudioTranscriptionRequest(**values)
                values["data"].request_type = "audio"
            elif "input" in values:
                values["data"] = ModerationRequest(**values)
                values["data"].request_type = "moderation"
            elif "input" in values:
                values["data"] = EmbeddingRequest(**values)
                values["data"].request_type = "embedding"
            else:
                raise ValueError("无法确定请求类型")
        return values