import json import httpx from datetime import datetime from log_config import logger from utils import safe_get # end_of_line = "\n\r\n" # end_of_line = "\r\n" # end_of_line = "\n\r" end_of_line = "\n\n" # end_of_line = "\r" # end_of_line = "\n" async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0): sample_data = { "id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc", "object": "chat.completion.chunk", "created": timestamp, "model": model, "choices": [ { "index": 0, "delta": {"content": content}, "logprobs": None, "finish_reason": None } ], "usage": None, "system_fingerprint": "fp_d576307f90", } if function_call_content: sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]} if tools_id and function_call_name: sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id": tools_id,"type":"function","function":{"name": function_call_name, "arguments":""}}]} # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]} if role: sample_data["choices"][0]["delta"] = {"role": role, "content": ""} if total_tokens: total_tokens = prompt_tokens + completion_tokens sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens} sample_data["choices"] = [] json_data = json.dumps(sample_data, ensure_ascii=False) # 构建SSE响应 sse_response = f"data: {json_data}" + end_of_line return sse_response async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0): sample_data = { "id": "chatcmpl-ALGS9hpJBb8xVAe62DRriY2SpoT4L", "object": "chat.completion", "created": timestamp, "model": model, "choices": [ { "index": 0, "message": { "role": role, "content": content, "refusal": None }, "logprobs": None, "finish_reason": "stop" } ], "usage": None, "system_fingerprint": "fp_a7d06e42a7" } if total_tokens: total_tokens = prompt_tokens + completion_tokens sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens} json_data = json.dumps(sample_data, ensure_ascii=False) return json_data async def check_response(response, error_log): if response and response.status_code != 200: error_message = await response.aread() error_str = error_message.decode('utf-8', errors='replace') try: error_json = json.loads(error_str) except json.JSONDecodeError: error_json = error_str return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json} return None async def fetch_gemini_response_stream(client, url, headers, payload, model): timestamp = int(datetime.timestamp(datetime.now())) async with client.stream('POST', url, headers=headers, json=payload) as response: error_message = await check_response(response, "fetch_gemini_response_stream") if error_message: yield error_message return buffer = "" revicing_function_call = False function_full_response = "{" need_function_call = False async for chunk in response.aiter_text(): buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) # print(line) if line and '\"text\": \"' in line: try: json_data = json.loads( "{" + line + "}") content = json_data.get('text', '') content = "\n".join(content.split("\\n")) sse_string = await generate_sse_response(timestamp, model, content=content) yield sse_string except json.JSONDecodeError: logger.error(f"无法解析JSON: {line}") if line and ('\"functionCall\": {' in line or revicing_function_call): revicing_function_call = True need_function_call = True if ']' in line: revicing_function_call = False continue function_full_response += line if need_function_call: function_call = json.loads(function_full_response) function_call_name = function_call["functionCall"]["name"] sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=function_call_name) yield sse_string function_full_response = json.dumps(function_call["functionCall"]["args"]) sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response) yield sse_string yield "data: [DONE]" + end_of_line async def fetch_vertex_claude_response_stream(client, url, headers, payload, model): timestamp = int(datetime.timestamp(datetime.now())) async with client.stream('POST', url, headers=headers, json=payload) as response: error_message = await check_response(response, "fetch_vertex_claude_response_stream") if error_message: yield error_message return buffer = "" revicing_function_call = False function_full_response = "{" need_function_call = False async for chunk in response.aiter_text(): buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) # logger.info(f"{line}") if line and '\"text\": \"' in line: try: json_data = json.loads( "{" + line + "}") content = json_data.get('text', '') content = "\n".join(content.split("\\n")) sse_string = await generate_sse_response(timestamp, model, content=content) yield sse_string except json.JSONDecodeError: logger.error(f"无法解析JSON: {line}") if line and ('\"type\": \"tool_use\"' in line or revicing_function_call): revicing_function_call = True need_function_call = True if ']' in line: revicing_function_call = False continue function_full_response += line if need_function_call: function_call = json.loads(function_full_response) function_call_name = function_call["name"] function_call_id = function_call["id"] sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name) yield sse_string function_full_response = json.dumps(function_call["input"]) sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response) yield sse_string yield "data: [DONE]" + end_of_line async def fetch_gpt_response_stream(client, url, headers, payload): async with client.stream('POST', url, headers=headers, json=payload) as response: error_message = await check_response(response, "fetch_gpt_response_stream") if error_message: yield error_message return buffer = "" async for chunk in response.aiter_text(): buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) # logger.info("line: %s", repr(line)) if line and line != "data: " and line != "data:" and not line.startswith(": "): yield line.strip() + end_of_line async def fetch_cloudflare_response_stream(client, url, headers, payload, model): timestamp = int(datetime.timestamp(datetime.now())) async with client.stream('POST', url, headers=headers, json=payload) as response: error_message = await check_response(response, "fetch_gpt_response_stream") if error_message: yield error_message return buffer = "" async for chunk in response.aiter_text(): buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) # logger.info("line: %s", repr(line)) if line.startswith("data:"): line = line.lstrip("data: ") if line == "[DONE]": yield "data: [DONE]" + end_of_line return resp: dict = json.loads(line) message = resp.get("response") if message: sse_string = await generate_sse_response(timestamp, model, content=message) yield sse_string async def fetch_cohere_response_stream(client, url, headers, payload, model): timestamp = int(datetime.timestamp(datetime.now())) async with client.stream('POST', url, headers=headers, json=payload) as response: error_message = await check_response(response, "fetch_gpt_response_stream") if error_message: yield error_message return buffer = "" async for chunk in response.aiter_text(): buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) # logger.info("line: %s", repr(line)) resp: dict = json.loads(line) if resp.get("is_finished") == True: yield "data: [DONE]" + end_of_line return if resp.get("event_type") == "text-generation": message = resp.get("text") sse_string = await generate_sse_response(timestamp, model, content=message) yield sse_string async def fetch_claude_response_stream(client, url, headers, payload, model): timestamp = int(datetime.timestamp(datetime.now())) async with client.stream('POST', url, headers=headers, json=payload) as response: error_message = await check_response(response, "fetch_claude_response_stream") if error_message: yield error_message return buffer = "" input_tokens = 0 async for chunk in response.aiter_text(): # logger.info(f"chunk: {repr(chunk)}") buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) # logger.info(line) if line.startswith("data:"): line = line.lstrip("data: ") resp: dict = json.loads(line) message = resp.get("message") if message: role = message.get("role") if role: sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role) yield sse_string tokens_use = message.get("usage") if tokens_use: input_tokens = tokens_use.get("input_tokens", 0) usage = resp.get("usage") if usage: output_tokens = usage.get("output_tokens", 0) total_tokens = input_tokens + output_tokens sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens) yield sse_string # print("\n\rtotal_tokens", total_tokens) tool_use = resp.get("content_block") tools_id = None function_call_name = None if tool_use and "tool_use" == tool_use['type']: # print("tool_use", tool_use) tools_id = tool_use["id"] if "name" in tool_use: function_call_name = tool_use["name"] sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None) yield sse_string delta = resp.get("delta") # print("delta", delta) if not delta: continue if "text" in delta: content = delta["text"] sse_string = await generate_sse_response(timestamp, model, content, None, None) yield sse_string if "partial_json" in delta: # {"type":"input_json_delta","partial_json":""} function_call_content = delta["partial_json"] sse_string = await generate_sse_response(timestamp, model, None, None, None, function_call_content) yield sse_string yield "data: [DONE]" + end_of_line async def fetch_response(client, url, headers, payload, engine, model): response = None if payload.get("file"): file = payload.pop("file") response = await client.post(url, headers=headers, data=payload, files={"file": file}) else: response = await client.post(url, headers=headers, json=payload) error_message = await check_response(response, "fetch_response") if error_message: yield error_message return response_json = response.json() if engine == "gemini" or engine == "vertex-gemini": if isinstance(response_json, str): import ast parsed_data = ast.literal_eval(str(response_json)) elif isinstance(response_json, list): parsed_data = response_json else: logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}") parsed_data = response_json content = "" for item in parsed_data: chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text") # logger.info(f"chunk: {repr(chunk)}") if chunk: content += chunk usage_metadata = safe_get(parsed_data, -1, "usageMetadata") prompt_tokens = usage_metadata.get("promptTokenCount", 0) candidates_tokens = usage_metadata.get("candidatesTokenCount", 0) total_tokens = usage_metadata.get("totalTokenCount", 0) role = safe_get(parsed_data, -1, "candidates", 0, "content", "role") if role == "model": role = "assistant" else: logger.error(f"Unknown role: {role}") role = "assistant" timestamp = int(datetime.timestamp(datetime.now())) yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=None, function_call_content=None, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens) else: yield response_json async def fetch_response_stream(client, url, headers, payload, engine, model): try: if engine == "gemini" or engine == "vertex-gemini": async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model): yield chunk elif engine == "claude" or engine == "vertex-claude": async for chunk in fetch_claude_response_stream(client, url, headers, payload, model): yield chunk elif engine == "gpt": async for chunk in fetch_gpt_response_stream(client, url, headers, payload): yield chunk elif engine == "openrouter": async for chunk in fetch_gpt_response_stream(client, url, headers, payload): yield chunk elif engine == "cloudflare": async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model): yield chunk elif engine == "cohere": async for chunk in fetch_cohere_response_stream(client, url, headers, payload, model): yield chunk else: raise ValueError("Unknown response") except httpx.ConnectError as e: yield {"error": f"500", "details": "fetch_response_stream Connect Error"} except httpx.ReadTimeout as e: yield {"error": f"500", "details": "fetch_response_stream Read Response Timeout"}