meyandrei commited on
Commit
538c9ec
·
verified ·
1 Parent(s): 4bc31f1

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -71
handler.py DELETED
@@ -1,71 +0,0 @@
1
- from typing import Dict, List, Any
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import re
5
-
6
- class EndpointHandler():
7
- def __init__(self, path="meyandrei/bankchat"):
8
- # Load the model and tokenizer
9
- self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left', use_safetensors=True)
10
- self.model = AutoModelForCausalLM.from_pretrained(path, use_safetensors=True)
11
- self.context_token = self.tokenizer.encode('<|context|>', return_tensors='pt')
12
- self.endofcontext_token = self.tokenizer.encode(' <|endofcontext|>', return_tensors='pt')
13
-
14
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
- """
16
- data args:
17
- inputs (:obj: `str`)
18
- context (:obj: `list` of `str`)
19
- Return:
20
- A :obj:`list` | `dict`: will be serialized and returned
21
- """
22
- user_input = data.get('inputs', '')
23
- history = data.get('context', [])
24
-
25
- if history == []:
26
- context_tokenized = torch.LongTensor(history)
27
- else:
28
- history_str = tokenizer.decode(history[0])
29
- turns = re.split('<\|system\|>|<\|user\|>', history_str)[1:]
30
-
31
- for i in range(0, len(turns)-1, 2):
32
- turns[i] = '<|user|>' + turns[i]
33
- turns[i+1] = '<|system|>' + turns[i+1]
34
-
35
- context_tokenized = self.tokenizer.encode(''.join(turns), return_tensors='pt')
36
-
37
- user_input_tokenized = self.tokenizer.encode(' ' + user_input, return_tensors='pt')
38
- model_input = torch.cat([self.context_token, context_tokenized, user_input_tokenized, self.endofcontext_token], dim=-1)
39
- attention_mask = torch.ones_like(model_input)
40
-
41
- out_tokenized = self.model.generate(model_input, max_length=1024, eos_token_id=50258, pad_token_id=50260, attention_mask=attention_mask).tolist()[0]
42
- out_str = self.tokenizer.decode(out_tokenized)
43
- out_str = out_str.split('\n')[0]
44
-
45
- generated_substring = out_str.split('')[1] # belief, actions, system_response
46
-
47
- beliefs_start_index = generated_substring.find('') + len('')
48
- beliefs_end_index = generated_substring.find('', beliefs_start_index)
49
-
50
- actions_start_index = generated_substring.find('') + len('')
51
- actions_end_index = generated_substring.find('', actions_start_index)
52
-
53
- response_start_index = generated_substring.find('') + len('')
54
- response_end_index = generated_substring.find('', response_start_index)
55
-
56
- beliefs_str = generated_substring[beliefs_start_index:beliefs_end_index]
57
- actions_str = generated_substring[actions_start_index:actions_end_index]
58
- system_response_str = generated_substring[response_start_index:response_end_index]
59
-
60
- system_resp_tokenized = self.tokenizer.encode(' ' + system_response_str, return_tensors='pt')
61
- history = torch.cat([torch.LongTensor(history), user_input_tokenized, system_resp_tokenized], dim=-1).tolist()
62
-
63
- # Prepare the output
64
- model_outputs = {
65
- 'response': system_response_str,
66
- 'context': history,
67
- 'beliefs': beliefs_str,
68
- 'actions': actions_str
69
- }
70
-
71
- return model_outputs