Htzhang commited on
Commit
18ca2f4
·
1 Parent(s): 5917a9e

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +52 -0
handler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Tuple
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ import torch
4
+ from subprocess import run
5
+
6
+ # set device
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ # self.pipeline = pipeline("text-classification", model=path)
13
+ # self.holidays = holidays.US()
14
+ self.query_model = AutoModelForMaskedLM.from_pretrained(path).to(device)
15
+ self.query_tokenizer = AutoTokenizer.from_pretrained(path)
16
+
17
+
18
+ def __call__(self, data: Dict[str, Any]) -> Tuple[List[List[int]], List[List[float]]]:
19
+ """
20
+ data args:
21
+ inputs (:obj: `str`)
22
+ date (:obj: `str`)
23
+ Return:
24
+ A :obj:`list` | `dict`: will be serialized and returned
25
+ """
26
+ # get inputs
27
+ texts = data.pop("inputs", data)
28
+
29
+ tokens = self.query_tokenizer(
30
+ texts, truncation=True, padding=True, return_tensors="pt"
31
+ )
32
+
33
+ tokens = self.query_tokenizer(
34
+ texts, truncation=True, padding=True, return_tensors="pt"
35
+ )
36
+ if torch.cuda.is_available():
37
+ tokens = tokens.to("cuda")
38
+
39
+ output = self.query_model(**tokens)
40
+ logits, attention_mask = output.logits, tokens.attention_mask
41
+ relu_log = torch.log(1 + torch.relu(logits))
42
+ weighted_log = relu_log * attention_mask.unsqueeze(-1)
43
+ tvecs, _ = torch.max(weighted_log, dim=1)
44
+
45
+ # extract the vectors that are non-zero and their indices
46
+ indices = []
47
+ vecs = []
48
+ for batch in tvecs:
49
+ indices.append(batch.nonzero(as_tuple=True)[0].tolist())
50
+ vecs.append(batch[indices[-1]].tolist())
51
+
52
+ return [indices, vecs]