from typing import List, Dict import torch from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification class EndpointHandler: def __init__(self, path=""): # 加载配置 self.config = AutoConfig.from_pretrained(path) # 使用 AutoModel 加载模型 self.model = AutoModelForSequenceClassification.from_pretrained(path) self.model.eval() # 初始化tokenizer self.tokenizer = AutoTokenizer.from_pretrained(path) # 设置最大长度 self.max_length = 512 def __call__(self, data: List[Dict[str, str]]) -> List[Dict[str, float]]: """ 处理文本推理请求 """ try: # 获取所有输入文本 texts = [] for item in data: # 确保我们正确处理输入数据 if isinstance(item, dict) and "inputs" in item: texts.append(item["inputs"]) elif isinstance(item, str): texts.append(item) else: raise ValueError(f"Unexpected input format: {item}") # tokenization encoded_inputs = self.tokenizer( texts, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt" ) # 进行预测 with torch.no_grad(): outputs = self.model(**encoded_inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) # 格式化输出 results = [] for probs in probabilities: label_id = int(torch.argmax(probs).item()) confidence = float(probs[label_id].item()) results.append({ "label": str(label_id), # 转换为字符串 "score": confidence # 预测概率 }) return results except Exception as e: # 添加错误处理和日志记录 print(f"Error in prediction: {str(e)}") return [{"error": str(e)}] def preprocess(self, text: str) -> Dict[str, torch.Tensor]: """ 预处理方法 """ try: encoded = self.tokenizer( text, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt" ) return encoded except Exception as e: print(f"Error in preprocessing: {str(e)}") raise e def postprocess(self, model_outputs) -> Dict: """ 后处理方法 """ try: logits = model_outputs.logits probabilities = torch.softmax(logits, dim=-1) label_id = int(torch.argmax(probabilities[0]).item()) confidence = float(probabilities[0][label_id].item()) return { "label": str(label_id), "score": confidence } except Exception as e: print(f"Error in postprocessing: {str(e)}") raise e