File size: 3,349 Bytes
8f5c615 ac4dae6 8f5c615 ac4dae6 8f5c615 2ddb15e 8f5c615 2ddb15e 8f5c615 ac4dae6 8f5c615 2ddb15e 8f5c615 ac4dae6 8f5c615 ac4dae6 8f5c615 2ddb15e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from typing import List, Dict
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
class EndpointHandler:
def __init__(self, path=""):
# 加载配置
self.config = AutoConfig.from_pretrained(path)
# 使用 AutoModel 加载模型
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.model.eval()
# 初始化tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path)
# 设置最大长度
self.max_length = 512
def __call__(self, data: List[Dict[str, str]]) -> List[Dict[str, float]]:
"""
处理文本推理请求
"""
try:
# 获取所有输入文本
texts = []
for item in data:
# 确保我们正确处理输入数据
if isinstance(item, dict) and "inputs" in item:
texts.append(item["inputs"])
elif isinstance(item, str):
texts.append(item)
else:
raise ValueError(f"Unexpected input format: {item}")
# tokenization
encoded_inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
# 进行预测
with torch.no_grad():
outputs = self.model(**encoded_inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
# 格式化输出
results = []
for probs in probabilities:
label_id = int(torch.argmax(probs).item())
confidence = float(probs[label_id].item())
results.append({
"label": str(label_id), # 转换为字符串
"score": confidence # 预测概率
})
return results
except Exception as e:
# 添加错误处理和日志记录
print(f"Error in prediction: {str(e)}")
return [{"error": str(e)}]
def preprocess(self, text: str) -> Dict[str, torch.Tensor]:
"""
预处理方法
"""
try:
encoded = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
return encoded
except Exception as e:
print(f"Error in preprocessing: {str(e)}")
raise e
def postprocess(self, model_outputs) -> Dict:
"""
后处理方法
"""
try:
logits = model_outputs.logits
probabilities = torch.softmax(logits, dim=-1)
label_id = int(torch.argmax(probabilities[0]).item())
confidence = float(probabilities[0][label_id].item())
return {
"label": str(label_id),
"score": confidence
}
except Exception as e:
print(f"Error in postprocessing: {str(e)}")
raise e |