Kororinpa commited on
Commit
8f5c615
·
verified ·
1 Parent(s): 10b2f97

Create handler

Browse files
Files changed (1) hide show
  1. handler +78 -0
handler ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ import torch
3
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
4
+ from .modeling import BinaryClassifier # 你的模型类
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ # 加载配置
9
+ self.config = AutoConfig.from_pretrained(path)
10
+
11
+ # 初始化模型
12
+ self.model = BinaryClassifier.from_pretrained(path)
13
+ self.model.eval()
14
+
15
+ # 初始化tokenizer
16
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
17
+
18
+ # 设置最大长度,可以根据你的需求调整
19
+ self.max_length = 512
20
+
21
+ def __call__(self, data: List[Dict[str, str]]) -> List[Dict[str, float]]:
22
+ """
23
+ 处理文本推理请求
24
+ Args:
25
+ data: 输入数据列表,每个元素是一个字典
26
+ 例如:[{"inputs": "这是一段测试文本"}]
27
+ Returns:
28
+ 预测结果列表
29
+ """
30
+ # 获取所有输入文本
31
+ texts = [item["inputs"] for item in data]
32
+
33
+ # tokenization
34
+ encoded_inputs = self.tokenizer(
35
+ texts,
36
+ padding=True,
37
+ truncation=True,
38
+ max_length=self.max_length,
39
+ return_tensors="pt"
40
+ )
41
+
42
+ # 进行预测
43
+ with torch.no_grad():
44
+ outputs = self.model(**encoded_inputs)
45
+ predictions = (outputs >= 0.5).float()
46
+
47
+ # 格式化输出
48
+ results = []
49
+ for pred, score in zip(predictions, outputs):
50
+ results.append({
51
+ "label": str(int(pred.item())), # 0 或 1
52
+ "score": float(score.item()) # 预测概率
53
+ })
54
+
55
+ return results
56
+
57
+ def preprocess(self, text: str) -> Dict[str, torch.Tensor]:
58
+ """
59
+ 可选的预处理方法,如果需要更复杂的预处理可以使用
60
+ """
61
+ encoded = self.tokenizer(
62
+ text,
63
+ padding=True,
64
+ truncation=True,
65
+ max_length=self.max_length,
66
+ return_tensors="pt"
67
+ )
68
+ return encoded
69
+
70
+ def postprocess(self, model_outputs: torch.Tensor) -> Dict:
71
+ """
72
+ 可选的后处理方法,如果需要更复杂的后处理可以使用
73
+ """
74
+ predictions = (model_outputs >= 0.5).float()
75
+ return {
76
+ "label": str(int(predictions[0].item())),
77
+ "score": float(model_outputs[0].item())
78
+ }