File size: 5,696 Bytes
17409c4 923b378 599e110 09c584b 888a669 1af48fa 972208e 1af48fa 6b0ab57 1af48fa 0140d23 1af48fa 819dd2f 0140d23 1af48fa 44caf41 09c584b 923b378 599e110 923b378 599e110 09c584b 1af48fa 44caf41 f156f8a 599e110 f156f8a c34a2a5 09c584b c34a2a5 c50b8cc ebc2add c50b8cc d2272e8 c50b8cc 09c584b c34a2a5 09c584b d2272e8 c34a2a5 483c524 c34a2a5 483c524 c34a2a5 09c584b c34a2a5 09c584b c34a2a5 09c584b 483c524 d2272e8 c34a2a5 09c584b c34a2a5 2ec384d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from io import IOBase
from pydantic import BaseModel, Field, model_validator, ConfigDict
from typing import List, Dict, Optional, Union, Tuple, Literal, Any
from log_config import logger
class FunctionParameter(BaseModel):
type: str
properties: Dict[str, Dict[str, Any]]
required: List[str]
class Function(BaseModel):
name: str
description: str
parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)
class Tool(BaseModel):
type: str
function: Function
class FunctionCall(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
type: str
function: FunctionCall
class ImageUrl(BaseModel):
url: str
class ContentItem(BaseModel):
type: str
text: Optional[str] = None
image_url: Optional[ImageUrl] = None
class Message(BaseModel):
role: str
name: Optional[str] = None
arguments: Optional[str] = None
content: Optional[Union[str, List[ContentItem]]] = None
tool_calls: Optional[List[ToolCall]] = None
class Message(BaseModel):
role: str
name: Optional[str] = None
content: Optional[Union[str, List[ContentItem]]] = None
tool_calls: Optional[List[ToolCall]] = None
tool_call_id: Optional[str] = None
class Config:
extra = "allow" # 允许额外的字段
class FunctionChoice(BaseModel):
name: str
class ToolChoice(BaseModel):
type: str
function: Optional[FunctionChoice] = None
class BaseRequest(BaseModel):
request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True)
def create_json_schema_class():
class JsonSchema(BaseModel):
name: str
model_config = ConfigDict(protected_namespaces=())
JsonSchema.__annotations__['schema'] = Dict[str, Any]
return JsonSchema
JsonSchema = create_json_schema_class()
class ResponseFormat(BaseModel):
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchema] = None
class RequestModel(BaseRequest):
model: str
messages: List[Message]
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
stream: Optional[bool] = None
include_usage: Optional[bool] = None
temperature: Optional[float] = 0.5
top_p: Optional[float] = 1.0
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
n: Optional[int] = 1
user: Optional[str] = None
tool_choice: Optional[Union[str, ToolChoice]] = None
tools: Optional[List[Tool]] = None
response_format: Optional[ResponseFormat] = None # 新增字段
def get_last_text_message(self) -> Optional[str]:
for message in reversed(self.messages):
if message.content:
if isinstance(message.content, str):
return message.content
elif isinstance(message.content, list):
for item in reversed(message.content):
if item.type == "text" and item.text:
return item.text
return ""
class ImageGenerationRequest(BaseRequest):
prompt: str
model: Optional[str] = "dall-e-3"
n: Optional[int] = 1
size: Optional[str] = "1024x1024"
stream: bool = False
class EmbeddingRequest(BaseRequest):
input: Union[str, List[Union[str, int, List[int]]]] # 支持字符串或数组
model: str
encoding_format: Optional[str] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
stream: bool = False
class AudioTranscriptionRequest(BaseRequest):
file: Tuple[str, IOBase, str]
model: str
language: Optional[str] = None
prompt: Optional[str] = None
response_format: Optional[str] = None
temperature: Optional[float] = None
stream: bool = False
class Config:
arbitrary_types_allowed = True
class ModerationRequest(BaseRequest):
input: Union[str, List[str]]
model: Optional[str] = "text-moderation-latest"
stream: bool = False
class TextToSpeechRequest(BaseRequest):
model: str
input: str
voice: str
response_format: Optional[str] = "mp3"
speed: Optional[float] = 1.0
stream: Optional[bool] = False # Add this line
class UnifiedRequest(BaseModel):
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest, TextToSpeechRequest]
@model_validator(mode='before')
@classmethod
def set_request_type(cls, values):
if isinstance(values, dict):
if "messages" in values:
values["data"] = RequestModel(**values)
values["data"].request_type = "chat"
elif "prompt" in values:
values["data"] = ImageGenerationRequest(**values)
values["data"].request_type = "image"
elif "file" in values:
values["data"] = AudioTranscriptionRequest(**values)
values["data"].request_type = "audio"
elif "tts" in values.get("model", ""):
logger.info(f"TextToSpeechRequest: {values}")
values["data"] = TextToSpeechRequest(**values)
values["data"].request_type = "tts"
elif "text-embedding" in values.get("model", ""):
values["data"] = EmbeddingRequest(**values)
values["data"].request_type = "embedding"
elif "input" in values:
values["data"] = ModerationRequest(**values)
values["data"].request_type = "moderation"
else:
raise ValueError("无法确定请求类型")
return values
|