|
from io import IOBase |
|
from pydantic import BaseModel, Field, model_validator, ConfigDict |
|
from typing import List, Dict, Optional, Union, Tuple, Literal, Any |
|
from log_config import logger |
|
|
|
class FunctionParameter(BaseModel): |
|
type: str |
|
properties: Dict[str, Dict[str, Any]] |
|
required: List[str] |
|
|
|
class Function(BaseModel): |
|
name: str |
|
description: str |
|
parameters: Optional[FunctionParameter] = Field(default=None, exclude=None) |
|
|
|
class Tool(BaseModel): |
|
type: str |
|
function: Function |
|
|
|
class FunctionCall(BaseModel): |
|
name: str |
|
arguments: str |
|
|
|
class ToolCall(BaseModel): |
|
id: str |
|
type: str |
|
function: FunctionCall |
|
|
|
class ImageUrl(BaseModel): |
|
url: str |
|
|
|
class ContentItem(BaseModel): |
|
type: str |
|
text: Optional[str] = None |
|
image_url: Optional[ImageUrl] = None |
|
|
|
class Message(BaseModel): |
|
role: str |
|
name: Optional[str] = None |
|
arguments: Optional[str] = None |
|
content: Optional[Union[str, List[ContentItem]]] = None |
|
tool_calls: Optional[List[ToolCall]] = None |
|
|
|
class Message(BaseModel): |
|
role: str |
|
name: Optional[str] = None |
|
content: Optional[Union[str, List[ContentItem]]] = None |
|
tool_calls: Optional[List[ToolCall]] = None |
|
tool_call_id: Optional[str] = None |
|
|
|
class Config: |
|
extra = "allow" |
|
|
|
class FunctionChoice(BaseModel): |
|
name: str |
|
|
|
class ToolChoice(BaseModel): |
|
type: str |
|
function: Optional[FunctionChoice] = None |
|
|
|
class BaseRequest(BaseModel): |
|
request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True) |
|
|
|
def create_json_schema_class(): |
|
class JsonSchema(BaseModel): |
|
name: str |
|
|
|
model_config = ConfigDict(protected_namespaces=()) |
|
|
|
JsonSchema.__annotations__['schema'] = Dict[str, Any] |
|
return JsonSchema |
|
|
|
JsonSchema = create_json_schema_class() |
|
class ResponseFormat(BaseModel): |
|
type: Literal["text", "json_object", "json_schema"] |
|
json_schema: Optional[JsonSchema] = None |
|
|
|
class RequestModel(BaseRequest): |
|
model: str |
|
messages: List[Message] |
|
logprobs: Optional[bool] = None |
|
top_logprobs: Optional[int] = None |
|
stream: Optional[bool] = None |
|
include_usage: Optional[bool] = None |
|
temperature: Optional[float] = 0.5 |
|
top_p: Optional[float] = 1.0 |
|
max_tokens: Optional[int] = None |
|
presence_penalty: Optional[float] = 0.0 |
|
frequency_penalty: Optional[float] = 0.0 |
|
n: Optional[int] = 1 |
|
user: Optional[str] = None |
|
tool_choice: Optional[Union[str, ToolChoice]] = None |
|
tools: Optional[List[Tool]] = None |
|
response_format: Optional[ResponseFormat] = None |
|
|
|
def get_last_text_message(self) -> Optional[str]: |
|
for message in reversed(self.messages): |
|
if message.content: |
|
if isinstance(message.content, str): |
|
return message.content |
|
elif isinstance(message.content, list): |
|
for item in reversed(message.content): |
|
if item.type == "text" and item.text: |
|
return item.text |
|
return "" |
|
|
|
class ImageGenerationRequest(BaseRequest): |
|
prompt: str |
|
model: Optional[str] = "dall-e-3" |
|
n: Optional[int] = 1 |
|
size: Optional[str] = "1024x1024" |
|
stream: bool = False |
|
|
|
class EmbeddingRequest(BaseRequest): |
|
input: Union[str, List[Union[str, int, List[int]]]] |
|
model: str |
|
encoding_format: Optional[str] = "float" |
|
dimensions: Optional[int] = None |
|
user: Optional[str] = None |
|
stream: bool = False |
|
|
|
class AudioTranscriptionRequest(BaseRequest): |
|
file: Tuple[str, IOBase, str] |
|
model: str |
|
language: Optional[str] = None |
|
prompt: Optional[str] = None |
|
response_format: Optional[str] = None |
|
temperature: Optional[float] = None |
|
stream: bool = False |
|
|
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
class ModerationRequest(BaseRequest): |
|
input: Union[str, List[str]] |
|
model: Optional[str] = "text-moderation-latest" |
|
stream: bool = False |
|
|
|
class TextToSpeechRequest(BaseRequest): |
|
model: str |
|
input: str |
|
voice: str |
|
response_format: Optional[str] = "mp3" |
|
speed: Optional[float] = 1.0 |
|
stream: Optional[bool] = False |
|
|
|
class UnifiedRequest(BaseModel): |
|
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest, TextToSpeechRequest] |
|
|
|
@model_validator(mode='before') |
|
@classmethod |
|
def set_request_type(cls, values): |
|
if isinstance(values, dict): |
|
if "messages" in values: |
|
values["data"] = RequestModel(**values) |
|
values["data"].request_type = "chat" |
|
elif "prompt" in values: |
|
values["data"] = ImageGenerationRequest(**values) |
|
values["data"].request_type = "image" |
|
elif "file" in values: |
|
values["data"] = AudioTranscriptionRequest(**values) |
|
values["data"].request_type = "audio" |
|
elif "tts" in values.get("model", ""): |
|
logger.info(f"TextToSpeechRequest: {values}") |
|
values["data"] = TextToSpeechRequest(**values) |
|
values["data"].request_type = "tts" |
|
elif "text-embedding" in values.get("model", ""): |
|
values["data"] = EmbeddingRequest(**values) |
|
values["data"].request_type = "embedding" |
|
elif "input" in values: |
|
values["data"] = ModerationRequest(**values) |
|
values["data"].request_type = "moderation" |
|
else: |
|
raise ValueError("无法确定请求类型") |
|
return values |
|
|