Andrewwwwww commited on
Commit
3ee4bf3
·
verified ·
1 Parent(s): b28f462

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -30
handler.py CHANGED
@@ -1,22 +1,12 @@
1
- # Code to inference Hermes with HF Transformers
2
- # Requires pytorch, transformers, bitsandbytes, sentencepiece, protobuf, and flash-attn packages
3
 
4
- import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
- from transformers import LlamaTokenizer, MixtralForCausalLM
7
- #import bitsandbytes, flash_attn
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
- self.tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True)
12
- self.model = MixtralForCausalLM.from_pretrained(
13
- path,
14
- torch_dtype=torch.float16,
15
- device_map="auto",
16
- load_in_8bit=False,
17
- load_in_4bit=True,
18
- #use_flash_attention_2=True
19
- )
20
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
21
  sys_prompt=data["prompt"]
22
  list=data["inputs"]
@@ -32,20 +22,8 @@ class EndpointHandler:
32
 
33
  #for chat in prompts:
34
  #print(chat)
 
 
 
35
 
36
- encodeds = self.tokenizer.encode(prompt, return_tensors="pt")
37
- model_inputs = encodeds.to(device)
38
- self.model.to(device)
39
- generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
40
- decoded = self.tokenizer.decode(generated_ids[0])
41
- return decoded
42
-
43
- """
44
- encodeds = self.tokenizer.encode(prompt, return_tensors="pt")
45
- model_inputs = encodeds.to(device)
46
- self.model.to(device)
47
- generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
48
- decoded = self.tokenizer.decode(generated_ids[0])
49
- return decoded
50
- """
51
 
 
 
 
1
 
2
+ from modelscope import AutoModelForCausalLM, AutoTokenizer
3
+
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ self.tokenizer =AutoTokenizer.from_pretrained(path)
8
+ self.model = AutoModelForCausalLM.from_pretrained(path, device_map='auto')
9
+
 
 
 
 
 
 
10
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
11
  sys_prompt=data["prompt"]
12
  list=data["inputs"]
 
22
 
23
  #for chat in prompts:
24
  #print(chat)
25
+ inputs = self.tokenizer(prompt, return_tensors="pt")
26
+ outputs = self.model.generate(**inputs, max_new_tokens=20)
27
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29