|
import json |
|
import httpx |
|
from datetime import datetime |
|
|
|
from log_config import logger |
|
|
|
from utils import safe_get |
|
|
|
|
|
|
|
|
|
end_of_line = "\n\n" |
|
|
|
|
|
|
|
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0): |
|
sample_data = { |
|
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc", |
|
"object": "chat.completion.chunk", |
|
"created": timestamp, |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"delta": {"content": content}, |
|
"logprobs": None, |
|
"finish_reason": None |
|
} |
|
], |
|
"usage": None, |
|
"system_fingerprint": "fp_d576307f90", |
|
} |
|
if function_call_content: |
|
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]} |
|
if tools_id and function_call_name: |
|
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id": tools_id,"type":"function","function":{"name": function_call_name, "arguments":""}}]} |
|
|
|
if role: |
|
sample_data["choices"][0]["delta"] = {"role": role, "content": ""} |
|
if total_tokens: |
|
total_tokens = prompt_tokens + completion_tokens |
|
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens} |
|
sample_data["choices"] = [] |
|
json_data = json.dumps(sample_data, ensure_ascii=False) |
|
|
|
|
|
sse_response = f"data: {json_data}" + end_of_line |
|
|
|
return sse_response |
|
|
|
async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0): |
|
sample_data = { |
|
"id": "chatcmpl-ALGS9hpJBb8xVAe62DRriY2SpoT4L", |
|
"object": "chat.completion", |
|
"created": timestamp, |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": role, |
|
"content": content, |
|
"refusal": None |
|
}, |
|
"logprobs": None, |
|
"finish_reason": "stop" |
|
} |
|
], |
|
"usage": None, |
|
"system_fingerprint": "fp_a7d06e42a7" |
|
} |
|
if total_tokens: |
|
total_tokens = prompt_tokens + completion_tokens |
|
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens} |
|
json_data = json.dumps(sample_data, ensure_ascii=False) |
|
|
|
return json_data |
|
|
|
async def check_response(response, error_log): |
|
if response and response.status_code != 200: |
|
error_message = await response.aread() |
|
error_str = error_message.decode('utf-8', errors='replace') |
|
try: |
|
error_json = json.loads(error_str) |
|
except json.JSONDecodeError: |
|
error_json = error_str |
|
return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json} |
|
return None |
|
|
|
async def fetch_gemini_response_stream(client, url, headers, payload, model): |
|
timestamp = int(datetime.timestamp(datetime.now())) |
|
async with client.stream('POST', url, headers=headers, json=payload) as response: |
|
error_message = await check_response(response, "fetch_gemini_response_stream") |
|
if error_message: |
|
yield error_message |
|
return |
|
buffer = "" |
|
revicing_function_call = False |
|
function_full_response = "{" |
|
need_function_call = False |
|
async for chunk in response.aiter_text(): |
|
buffer += chunk |
|
while "\n" in buffer: |
|
line, buffer = buffer.split("\n", 1) |
|
|
|
if line and '\"text\": \"' in line: |
|
try: |
|
json_data = json.loads( "{" + line + "}") |
|
content = json_data.get('text', '') |
|
content = "\n".join(content.split("\\n")) |
|
sse_string = await generate_sse_response(timestamp, model, content=content) |
|
yield sse_string |
|
except json.JSONDecodeError: |
|
logger.error(f"无法解析JSON: {line}") |
|
|
|
if line and ('\"functionCall\": {' in line or revicing_function_call): |
|
revicing_function_call = True |
|
need_function_call = True |
|
if ']' in line: |
|
revicing_function_call = False |
|
continue |
|
|
|
function_full_response += line |
|
|
|
if need_function_call: |
|
function_call = json.loads(function_full_response) |
|
function_call_name = function_call["functionCall"]["name"] |
|
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=function_call_name) |
|
yield sse_string |
|
function_full_response = json.dumps(function_call["functionCall"]["args"]) |
|
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response) |
|
yield sse_string |
|
yield "data: [DONE]" + end_of_line |
|
|
|
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model): |
|
timestamp = int(datetime.timestamp(datetime.now())) |
|
async with client.stream('POST', url, headers=headers, json=payload) as response: |
|
error_message = await check_response(response, "fetch_vertex_claude_response_stream") |
|
if error_message: |
|
yield error_message |
|
return |
|
|
|
buffer = "" |
|
revicing_function_call = False |
|
function_full_response = "{" |
|
need_function_call = False |
|
async for chunk in response.aiter_text(): |
|
buffer += chunk |
|
while "\n" in buffer: |
|
line, buffer = buffer.split("\n", 1) |
|
|
|
if line and '\"text\": \"' in line: |
|
try: |
|
json_data = json.loads( "{" + line + "}") |
|
content = json_data.get('text', '') |
|
content = "\n".join(content.split("\\n")) |
|
sse_string = await generate_sse_response(timestamp, model, content=content) |
|
yield sse_string |
|
except json.JSONDecodeError: |
|
logger.error(f"无法解析JSON: {line}") |
|
|
|
if line and ('\"type\": \"tool_use\"' in line or revicing_function_call): |
|
revicing_function_call = True |
|
need_function_call = True |
|
if ']' in line: |
|
revicing_function_call = False |
|
continue |
|
|
|
function_full_response += line |
|
|
|
if need_function_call: |
|
function_call = json.loads(function_full_response) |
|
function_call_name = function_call["name"] |
|
function_call_id = function_call["id"] |
|
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name) |
|
yield sse_string |
|
function_full_response = json.dumps(function_call["input"]) |
|
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response) |
|
yield sse_string |
|
yield "data: [DONE]" + end_of_line |
|
|
|
async def fetch_gpt_response_stream(client, url, headers, payload): |
|
async with client.stream('POST', url, headers=headers, json=payload) as response: |
|
error_message = await check_response(response, "fetch_gpt_response_stream") |
|
if error_message: |
|
yield error_message |
|
return |
|
|
|
buffer = "" |
|
async for chunk in response.aiter_text(): |
|
buffer += chunk |
|
while "\n" in buffer: |
|
line, buffer = buffer.split("\n", 1) |
|
|
|
if line and line != "data: " and line != "data:" and not line.startswith(": "): |
|
yield line.strip() + end_of_line |
|
|
|
async def fetch_cloudflare_response_stream(client, url, headers, payload, model): |
|
timestamp = int(datetime.timestamp(datetime.now())) |
|
async with client.stream('POST', url, headers=headers, json=payload) as response: |
|
error_message = await check_response(response, "fetch_gpt_response_stream") |
|
if error_message: |
|
yield error_message |
|
return |
|
|
|
buffer = "" |
|
async for chunk in response.aiter_text(): |
|
buffer += chunk |
|
while "\n" in buffer: |
|
line, buffer = buffer.split("\n", 1) |
|
|
|
if line.startswith("data:"): |
|
line = line.lstrip("data: ") |
|
if line == "[DONE]": |
|
yield "data: [DONE]" + end_of_line |
|
return |
|
resp: dict = json.loads(line) |
|
message = resp.get("response") |
|
if message: |
|
sse_string = await generate_sse_response(timestamp, model, content=message) |
|
yield sse_string |
|
|
|
async def fetch_cohere_response_stream(client, url, headers, payload, model): |
|
timestamp = int(datetime.timestamp(datetime.now())) |
|
async with client.stream('POST', url, headers=headers, json=payload) as response: |
|
error_message = await check_response(response, "fetch_gpt_response_stream") |
|
if error_message: |
|
yield error_message |
|
return |
|
|
|
buffer = "" |
|
async for chunk in response.aiter_text(): |
|
buffer += chunk |
|
while "\n" in buffer: |
|
line, buffer = buffer.split("\n", 1) |
|
|
|
resp: dict = json.loads(line) |
|
if resp.get("is_finished") == True: |
|
yield "data: [DONE]" + end_of_line |
|
return |
|
if resp.get("event_type") == "text-generation": |
|
message = resp.get("text") |
|
sse_string = await generate_sse_response(timestamp, model, content=message) |
|
yield sse_string |
|
|
|
async def fetch_claude_response_stream(client, url, headers, payload, model): |
|
timestamp = int(datetime.timestamp(datetime.now())) |
|
async with client.stream('POST', url, headers=headers, json=payload) as response: |
|
error_message = await check_response(response, "fetch_claude_response_stream") |
|
if error_message: |
|
yield error_message |
|
return |
|
buffer = "" |
|
input_tokens = 0 |
|
async for chunk in response.aiter_text(): |
|
|
|
buffer += chunk |
|
while "\n" in buffer: |
|
line, buffer = buffer.split("\n", 1) |
|
|
|
|
|
if line.startswith("data:"): |
|
line = line.lstrip("data: ") |
|
resp: dict = json.loads(line) |
|
message = resp.get("message") |
|
if message: |
|
role = message.get("role") |
|
if role: |
|
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role) |
|
yield sse_string |
|
tokens_use = message.get("usage") |
|
if tokens_use: |
|
input_tokens = tokens_use.get("input_tokens", 0) |
|
usage = resp.get("usage") |
|
if usage: |
|
output_tokens = usage.get("output_tokens", 0) |
|
total_tokens = input_tokens + output_tokens |
|
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens) |
|
yield sse_string |
|
|
|
|
|
tool_use = resp.get("content_block") |
|
tools_id = None |
|
function_call_name = None |
|
if tool_use and "tool_use" == tool_use['type']: |
|
|
|
tools_id = tool_use["id"] |
|
if "name" in tool_use: |
|
function_call_name = tool_use["name"] |
|
sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None) |
|
yield sse_string |
|
delta = resp.get("delta") |
|
|
|
if not delta: |
|
continue |
|
if "text" in delta: |
|
content = delta["text"] |
|
sse_string = await generate_sse_response(timestamp, model, content, None, None) |
|
yield sse_string |
|
if "partial_json" in delta: |
|
|
|
function_call_content = delta["partial_json"] |
|
sse_string = await generate_sse_response(timestamp, model, None, None, None, function_call_content) |
|
yield sse_string |
|
yield "data: [DONE]" + end_of_line |
|
|
|
async def fetch_response(client, url, headers, payload, engine, model): |
|
response = None |
|
if payload.get("file"): |
|
file = payload.pop("file") |
|
response = await client.post(url, headers=headers, data=payload, files={"file": file}) |
|
else: |
|
response = await client.post(url, headers=headers, json=payload) |
|
error_message = await check_response(response, "fetch_response") |
|
if error_message: |
|
yield error_message |
|
return |
|
response_json = response.json() |
|
if engine == "gemini" or engine == "vertex-gemini": |
|
|
|
if isinstance(response_json, str): |
|
import ast |
|
parsed_data = ast.literal_eval(str(response_json)) |
|
elif isinstance(response_json, list): |
|
parsed_data = response_json |
|
else: |
|
logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}") |
|
parsed_data = response_json |
|
|
|
content = "" |
|
for item in parsed_data: |
|
chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text") |
|
|
|
if chunk: |
|
content += chunk |
|
|
|
usage_metadata = safe_get(parsed_data, -1, "usageMetadata") |
|
prompt_tokens = usage_metadata.get("promptTokenCount", 0) |
|
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0) |
|
total_tokens = usage_metadata.get("totalTokenCount", 0) |
|
|
|
role = safe_get(parsed_data, -1, "candidates", 0, "content", "role") |
|
if role == "model": |
|
role = "assistant" |
|
else: |
|
logger.error(f"Unknown role: {role}") |
|
role = "assistant" |
|
|
|
timestamp = int(datetime.timestamp(datetime.now())) |
|
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=None, function_call_content=None, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens) |
|
|
|
else: |
|
yield response_json |
|
|
|
async def fetch_response_stream(client, url, headers, payload, engine, model): |
|
try: |
|
if engine == "gemini" or engine == "vertex-gemini": |
|
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model): |
|
yield chunk |
|
elif engine == "claude" or engine == "vertex-claude": |
|
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model): |
|
yield chunk |
|
elif engine == "gpt": |
|
async for chunk in fetch_gpt_response_stream(client, url, headers, payload): |
|
yield chunk |
|
elif engine == "openrouter": |
|
async for chunk in fetch_gpt_response_stream(client, url, headers, payload): |
|
yield chunk |
|
elif engine == "cloudflare": |
|
async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model): |
|
yield chunk |
|
elif engine == "cohere": |
|
async for chunk in fetch_cohere_response_stream(client, url, headers, payload, model): |
|
yield chunk |
|
else: |
|
raise ValueError("Unknown response") |
|
except httpx.ConnectError as e: |
|
yield {"error": f"500", "details": "fetch_response_stream Connect Error"} |
|
except httpx.ReadTimeout as e: |
|
yield {"error": f"500", "details": "fetch_response_stream Read Response Timeout"} |