yujiepan commited on
Commit
038c7fd
1 Parent(s): 927ea0e

Upload run_autoawq.py

Browse files
Files changed (1) hide show
  1. run_autoawq.py +111 -0
run_autoawq.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Tested on: transformers==4.38.1, autoawq=0.2.3
3
+ Run on 1 card (mem>=18G)
4
+ '''
5
+
6
+ import torch
7
+
8
+ from awq.quantize.quantizer import AwqQuantizer
9
+ from awq.quantize.quantizer import *
10
+ from awq import AutoAWQForCausalLM
11
+ from transformers import AutoTokenizer
12
+ from unittest.mock import patch
13
+
14
+
15
+ class FalconAwqQuantizer(AwqQuantizer):
16
+ def quantize(self):
17
+ print('Patched!')
18
+ for i in tqdm(range(len(self.modules)), desc="AWQ"):
19
+ # Move module and inputs to correct device
20
+ common_device = next(self.modules[i].parameters()).device
21
+ if common_device is None or str(common_device) == "cpu":
22
+ if torch.cuda.is_available():
23
+ best_device = "cuda:" + str(i % torch.cuda.device_count())
24
+ else:
25
+ best_device = get_best_device()
26
+
27
+ self.modules[i] = self.modules[i].to(best_device)
28
+ common_device = next(self.modules[i].parameters()).device
29
+
30
+ if self.module_kwargs.get("position_ids") is not None:
31
+ self.module_kwargs["position_ids"] = self.module_kwargs[
32
+ "position_ids"
33
+ ].to(common_device)
34
+
35
+ if self.module_kwargs.get("attention_mask") is not None:
36
+ self.module_kwargs["attention_mask"] = self.module_kwargs[
37
+ "attention_mask"
38
+ ].to(common_device)
39
+
40
+ # include alibi
41
+ if self.module_kwargs.get("alibi") is not None:
42
+ self.module_kwargs["alibi"] = self.module_kwargs[
43
+ "alibi"
44
+ ].to(common_device)
45
+ else:
46
+ self.module_kwargs['alibi'] = None
47
+ print(f'alibi=None in layer {i}, this is expected if use_alibi=False.')
48
+
49
+ self.inps = self.inps.to(common_device)
50
+
51
+ # [STEP 1]: Get layer, extract linear modules, extract input features
52
+ named_linears = get_named_linears(self.modules[i])
53
+
54
+ # Filter out the linear layers we don't want to exclude
55
+ named_linears = exclude_layers_to_not_quantize(
56
+ named_linears, self.modules_to_not_convert
57
+ )
58
+
59
+ input_feat = self._get_input_feat(self.modules[i], named_linears)
60
+ clear_memory()
61
+
62
+ # [STEP 2]: Compute and apply scale list
63
+ module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
64
+ self.modules[i], input_feat, self.module_kwargs
65
+ )
66
+ scales_list = [
67
+ self._search_best_scale(self.modules[i], **layer)
68
+ for layer in module_config
69
+ ]
70
+ apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
71
+ scales_list = append_str_prefix(
72
+ scales_list, get_op_name(self.model, self.modules[i]) + "."
73
+ )
74
+
75
+ # [STEP 3]: Compute and apply clipping list
76
+ clip_list = self._search_best_clip(
77
+ self.modules[i], named_linears, input_feat
78
+ )
79
+ apply_clip(self.modules[i], clip_list)
80
+ clip_list = append_str_prefix(
81
+ clip_list, get_op_name(self.model, self.modules[i]) + "."
82
+ )
83
+
84
+ # [STEP 4]: Quantize weights
85
+ if not self.export_compatible:
86
+ self._apply_quant(self.modules[i], named_linears)
87
+
88
+ clear_memory()
89
+
90
+
91
+ model_path = 'tiiuae/falcon-40b'
92
+ # model_path = 'yujiepan/falcon-new-tiny-random'
93
+ quant_path = 'falcon-40b-autoawq-w4g128'
94
+ quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}
95
+
96
+
97
+ # Load model
98
+ model = AutoAWQForCausalLM.from_pretrained(
99
+ model_path, device_map='cpu', trust_remote_code=False, **{"low_cpu_mem_usage": True, "use_cache": False}
100
+ )
101
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
102
+
103
+ # Quantize
104
+ with patch('awq.models.base.AwqQuantizer', FalconAwqQuantizer):
105
+ model.quantize(tokenizer, quant_config=quant_config)
106
+
107
+ # Save quantized model
108
+ model.save_quantized(quant_path)
109
+ tokenizer.save_pretrained(quant_path)
110
+
111
+ print(f'Model is quantized and saved at "{quant_path}"')