ssaroya commited on
Commit
a405859
·
1 Parent(s): e418f96

Update llama_inference_class.py

Browse files
Files changed (1) hide show
  1. llama_inference_class.py +46 -46
llama_inference_class.py CHANGED
@@ -34,52 +34,52 @@ class ModelInference:
34
  model.seqlen = 2048
35
  return model
36
 
37
- def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
38
- from transformers import LlamaConfig, LlamaForCausalLM
39
- config = LlamaConfig.from_pretrained(model)
40
-
41
- def noop(*args, **kwargs):
42
- pass
43
-
44
- torch.nn.init.kaiming_uniform_ = noop
45
- torch.nn.init.uniform_ = noop
46
- torch.nn.init.normal_ = noop
47
-
48
- torch.set_default_dtype(torch.half)
49
- transformers.modeling_utils._init_weights = False
50
- torch.set_default_dtype(torch.half)
51
- model = LlamaForCausalLM(config)
52
- torch.set_default_dtype(torch.float)
53
- if eval:
54
- model = model.eval()
55
- layers = find_layers(model)
56
- for name in ['lm_head']:
57
- if name in layers:
58
- del layers[name]
59
- quant.make_quant_linear(model, layers, wbits, groupsize)
60
-
61
- del layers
62
-
63
- print('Loading model ...')
64
- if checkpoint.endswith('.safetensors'):
65
- from safetensors.torch import load_file as safe_load
66
- model.load_state_dict(safe_load(checkpoint), strict=False)
67
- else:
68
- model.load_state_dict(torch.load(checkpoint), strict=False)
69
-
70
- if eval:
71
- quant.make_quant_attn(model)
72
- quant.make_quant_norm(model)
73
- if fused_mlp:
74
- quant.make_fused_mlp(model)
75
- if warmup_autotune:
76
- quant.autotune_warmup_linear(model, transpose=not (eval))
77
- if eval and fused_mlp:
78
- quant.autotune_warmup_fused(model)
79
- model.seqlen = 2048
80
- print('Done.')
81
-
82
- return model
83
 
84
  def generate_text(self, text, min_length=10, max_length=50, top_p=0.95, temperature=0.8):
85
  input_ids = self.tokenizer.encode(text, return_tensors="pt").to(DEV)
 
34
  model.seqlen = 2048
35
  return model
36
 
37
+ def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
38
+ from transformers import LlamaConfig, LlamaForCausalLM
39
+ config = LlamaConfig.from_pretrained(model)
40
+
41
+ def noop(*args, **kwargs):
42
+ pass
43
+
44
+ torch.nn.init.kaiming_uniform_ = noop
45
+ torch.nn.init.uniform_ = noop
46
+ torch.nn.init.normal_ = noop
47
+
48
+ torch.set_default_dtype(torch.half)
49
+ transformers.modeling_utils._init_weights = False
50
+ torch.set_default_dtype(torch.half)
51
+ model = LlamaForCausalLM(config)
52
+ torch.set_default_dtype(torch.float)
53
+ if eval:
54
+ model = model.eval()
55
+ layers = find_layers(model)
56
+ for name in ['lm_head']:
57
+ if name in layers:
58
+ del layers[name]
59
+ quant.make_quant_linear(model, layers, wbits, groupsize)
60
+
61
+ del layers
62
+
63
+ print('Loading model ...')
64
+ if checkpoint.endswith('.safetensors'):
65
+ from safetensors.torch import load_file as safe_load
66
+ model.load_state_dict(safe_load(checkpoint), strict=False)
67
+ else:
68
+ model.load_state_dict(torch.load(checkpoint), strict=False)
69
+
70
+ if eval:
71
+ quant.make_quant_attn(model)
72
+ quant.make_quant_norm(model)
73
+ if fused_mlp:
74
+ quant.make_fused_mlp(model)
75
+ if warmup_autotune:
76
+ quant.autotune_warmup_linear(model, transpose=not (eval))
77
+ if eval and fused_mlp:
78
+ quant.autotune_warmup_fused(model)
79
+ model.seqlen = 2048
80
+ print('Done.')
81
+
82
+ return model
83
 
84
  def generate_text(self, text, min_length=10, max_length=50, top_p=0.95, temperature=0.8):
85
  input_ids = self.tokenizer.encode(text, return_tensors="pt").to(DEV)