ssaroya commited on
Commit
e418f96
·
1 Parent(s): 3eb845e

Create llama_inference_class.py

Browse files
Files changed (1) hide show
  1. llama_inference_class.py +96 -0
llama_inference_class.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import quant
4
+
5
+ from gptq import GPTQ
6
+ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
7
+ import transformers
8
+ from transformers import AutoTokenizer
9
+
10
+ class ModelInference:
11
+ def __init__(self, model_name, load=None, wbits=16, groupsize=-1):
12
+ self.model_name = model_name
13
+ self.load = load
14
+ self.wbits = wbits
15
+ self.groupsize = groupsize
16
+ if self.load:
17
+ self.model = self.load_quant(self.model_name, self.load, self.wbits, self.groupsize)
18
+ else:
19
+ self.model = self.get_llama(self.model_name)
20
+ self.model.eval()
21
+ self.model.to(DEV)
22
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
23
+
24
+ def get_llama(model):
25
+
26
+ def skip(*args, **kwargs):
27
+ pass
28
+
29
+ torch.nn.init.kaiming_uniform_ = skip
30
+ torch.nn.init.uniform_ = skip
31
+ torch.nn.init.normal_ = skip
32
+ from transformers import LlamaForCausalLM
33
+ model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto')
34
+ model.seqlen = 2048
35
+ return model
36
+
37
+ def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
38
+ from transformers import LlamaConfig, LlamaForCausalLM
39
+ config = LlamaConfig.from_pretrained(model)
40
+
41
+ def noop(*args, **kwargs):
42
+ pass
43
+
44
+ torch.nn.init.kaiming_uniform_ = noop
45
+ torch.nn.init.uniform_ = noop
46
+ torch.nn.init.normal_ = noop
47
+
48
+ torch.set_default_dtype(torch.half)
49
+ transformers.modeling_utils._init_weights = False
50
+ torch.set_default_dtype(torch.half)
51
+ model = LlamaForCausalLM(config)
52
+ torch.set_default_dtype(torch.float)
53
+ if eval:
54
+ model = model.eval()
55
+ layers = find_layers(model)
56
+ for name in ['lm_head']:
57
+ if name in layers:
58
+ del layers[name]
59
+ quant.make_quant_linear(model, layers, wbits, groupsize)
60
+
61
+ del layers
62
+
63
+ print('Loading model ...')
64
+ if checkpoint.endswith('.safetensors'):
65
+ from safetensors.torch import load_file as safe_load
66
+ model.load_state_dict(safe_load(checkpoint), strict=False)
67
+ else:
68
+ model.load_state_dict(torch.load(checkpoint), strict=False)
69
+
70
+ if eval:
71
+ quant.make_quant_attn(model)
72
+ quant.make_quant_norm(model)
73
+ if fused_mlp:
74
+ quant.make_fused_mlp(model)
75
+ if warmup_autotune:
76
+ quant.autotune_warmup_linear(model, transpose=not (eval))
77
+ if eval and fused_mlp:
78
+ quant.autotune_warmup_fused(model)
79
+ model.seqlen = 2048
80
+ print('Done.')
81
+
82
+ return model
83
+
84
+ def generate_text(self, text, min_length=10, max_length=50, top_p=0.95, temperature=0.8):
85
+ input_ids = self.tokenizer.encode(text, return_tensors="pt").to(DEV)
86
+
87
+ with torch.no_grad():
88
+ generated_ids = self.model.generate(
89
+ input_ids,
90
+ do_sample=True,
91
+ min_length=min_length,
92
+ max_length=max_length,
93
+ top_p=top_p,
94
+ temperature=temperature,
95
+ )
96
+ return self.tokenizer.decode([el.item() for el in generated_ids[0]])