taylorj94 commited on
Commit
33d2c2b
·
verified ·
1 Parent(s): 9bf0388

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -47
handler.py CHANGED
@@ -1,52 +1,45 @@
1
  import torch
2
- from transformers import (
3
- AutoTokenizer,
4
- AutoModelForCausalLM,
5
- pipeline,
6
- LogitsProcessor,
7
- LogitsProcessorList
8
- )
9
  from typing import Any, List, Dict
10
 
11
 
12
- class FixedVocabLogitsProcessor(LogitsProcessor):
13
  """
14
- A custom LogitsProcessor that restricts the vocabulary
15
- to a fixed set of token IDs, masking out everything else.
16
  """
17
 
18
  def __init__(self, allowed_ids: set[int], fill_value=float('-inf')):
19
- """
20
- Args:
21
- allowed_ids (set[int]): Token IDs allowed for generation.
22
- fill_value (float): Value used to mask disallowed tokens, default -inf.
23
- """
24
  self.allowed_ids = allowed_ids
25
  self.fill_value = fill_value
26
 
27
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
28
  """
29
- Args:
30
- input_ids: shape (batch_size, sequence_length)
31
- scores: shape (batch_size, vocab_size) - pre-softmax logits for the next token
32
- Returns:
33
- scores: shape (batch_size, vocab_size) with masked logits
34
  """
35
- batch_size, vocab_size = scores.size()
36
- for b in range(batch_size):
37
- for token_id in range(vocab_size):
38
- if token_id not in self.allowed_ids:
39
- scores[b, token_id] = self.fill_value
40
- return scores
41
 
42
 
43
  class EndpointHandler:
44
  def __init__(self, path=""):
45
- # Load tokenizer and model
46
- self.tokenizer = AutoTokenizer.from_pretrained(path)
47
- self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=torch.float16)
 
 
 
 
48
 
49
  def __call__(self, data: Any) -> List[Dict[str, str]]:
 
 
 
 
 
 
 
50
  # Extract inputs and parameters
51
  inputs = data.pop("inputs", data)
52
  parameters = data.pop("parameters", {})
@@ -58,29 +51,20 @@ class EndpointHandler:
58
  # Define allowed tokens dynamically
59
  allowed_ids = set()
60
  for word in vocab_list:
61
- for tid in self.tokenizer.encode(word, add_special_tokens=False):
62
- allowed_ids.add(tid)
63
- for tid in self.tokenizer.encode(" " + word, add_special_tokens=False):
64
  allowed_ids.add(tid)
65
 
66
- # Create custom logits processor
67
- logits_processors = LogitsProcessorList([FixedVocabLogitsProcessor(allowed_ids=allowed_ids)])
68
-
69
- # Prepare input IDs
70
- input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.model.device)
71
 
72
- # Generate output
73
  output_ids = self.model.generate(
74
- input_ids=input_ids,
75
- logits_processor=logits_processors,
76
- max_length=parameters.get("max_length", 30),
77
- num_beams=parameters.get("num_beams", 1),
78
- do_sample=parameters.get("do_sample", False),
79
- pad_token_id=self.tokenizer.eos_token_id,
80
- no_repeat_ngram_size=parameters.get("no_repeat_ngram_size", 3)
81
  )
82
 
83
  # Decode the output
84
- generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
85
 
86
  return [{"generated_text": generated_text}]
 
1
  import torch
2
+ from llama_cpp import Llama # Library for GGUF model handling
 
 
 
 
 
 
3
  from typing import Any, List, Dict
4
 
5
 
6
+ class FixedVocabLogitsProcessor:
7
  """
8
+ A custom logits processor for GGUF-compatible models.
 
9
  """
10
 
11
  def __init__(self, allowed_ids: set[int], fill_value=float('-inf')):
 
 
 
 
 
12
  self.allowed_ids = allowed_ids
13
  self.fill_value = fill_value
14
 
15
+ def apply(self, logits: torch.FloatTensor):
16
  """
17
+ Modify logits to restrict to allowed token IDs.
 
 
 
 
18
  """
19
+ for token_id in range(len(logits)):
20
+ if token_id not in self.allowed_ids:
21
+ logits[token_id] = self.fill_value
22
+ return logits
 
 
23
 
24
 
25
  class EndpointHandler:
26
  def __init__(self, path=""):
27
+ """
28
+ Initialize the GGUF model handler.
29
+ Args:
30
+ path (str): Path to the GGUF file.
31
+ """
32
+ self.model = Llama(model_path=path)
33
+ self.tokenizer = self.model.tokenizer # GGUF-specific tokenizer, if available
34
 
35
  def __call__(self, data: Any) -> List[Dict[str, str]]:
36
+ """
37
+ Handle the request, performing inference with a restricted vocabulary.
38
+ Args:
39
+ data (Any): Input data.
40
+ Returns:
41
+ List[Dict[str, str]]: Generated output.
42
+ """
43
  # Extract inputs and parameters
44
  inputs = data.pop("inputs", data)
45
  parameters = data.pop("parameters", {})
 
51
  # Define allowed tokens dynamically
52
  allowed_ids = set()
53
  for word in vocab_list:
54
+ for tid in self.model.tokenize(word):
 
 
55
  allowed_ids.add(tid)
56
 
57
+ # Tokenize input
58
+ input_ids = self.model.tokenize(inputs)
 
 
 
59
 
60
+ # Perform inference
61
  output_ids = self.model.generate(
62
+ input_ids,
63
+ max_tokens=parameters.get("max_length", 30),
64
+ logits_processor=lambda logits: FixedVocabLogitsProcessor(allowed_ids).apply(logits)
 
 
 
 
65
  )
66
 
67
  # Decode the output
68
+ generated_text = self.model.detokenize(output_ids)
69
 
70
  return [{"generated_text": generated_text}]