ssaroya commited on
Commit
e5bdf53
·
1 Parent(s): 51260d5

Upload 10 files

Browse files
Files changed (10) hide show
  1. LICENSE.txt +201 -0
  2. README.md +145 -3
  3. convert_llama_weights_to_hf.py +34 -0
  4. gptq.py +236 -0
  5. llama.py +515 -0
  6. llama_inference.py +123 -0
  7. llama_inference_offload.py +279 -0
  8. neox.py +430 -0
  9. opt.py +446 -0
  10. requirements.txt +11 -0
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,145 @@
1
- ---
2
- license: other
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPTQ-for-LLaMA
2
+ <img src = https://user-images.githubusercontent.com/64115820/235287009-2d07bba8-9b85-4973-9e06-2a3c28777f06.png width="50%" height="50%">
3
+
4
+ 4 bits quantization of [LLaMA](https://arxiv.org/abs/2302.13971) using [GPTQ](https://arxiv.org/abs/2210.17323)
5
+
6
+ GPTQ is SOTA one-shot weight quantization method
7
+
8
+ **It can be used universally, but it is not the [fastest](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/old-cuda) and only supports linux.**
9
+
10
+ **Triton only supports Linux, so if you are a Windows user, please use [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install).**
11
+
12
+ ## News or Update
13
+ **AutoGPTQ-triton, a packaged version of GPTQ with triton, has been integrated into [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ).**
14
+ ## Result
15
+ <details>
16
+ <summary>LLaMA-7B(click me)</summary>
17
+
18
+ | [LLaMA-7B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
19
+ | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
20
+ | FP16 | 16 | - | 13940 | 5.68 | 12.5 |
21
+ | RTN | 4 | - | - | 6.29 | - |
22
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 4740 | 6.09 | 3.5 |
23
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 4891 | 5.85 | 3.6 |
24
+ | RTN | 3 | - | - | 25.54 | - |
25
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 3852 | 8.07 | 2.7 |
26
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 4116 | 6.61 | 3.0 |
27
+
28
+ </details>
29
+
30
+ <details>
31
+ <summary>LLaMA-13B</summary>
32
+
33
+ | [LLaMA-13B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
34
+ | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
35
+ | FP16 | 16 | - | OOM | 5.09 | 24.2 |
36
+ | RTN | 4 | - | - | 5.53 | - |
37
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 8410 | 5.36 | 6.5 |
38
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 8747 | 5.20 | 6.7 |
39
+ | RTN | 3 | - | - | 11.40 | - |
40
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 6870 | 6.63 | 5.1 |
41
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 7277 | 5.62 | 5.4 |
42
+
43
+ </details>
44
+
45
+ <details>
46
+ <summary>LLaMA-33B</summary>
47
+
48
+ | [LLaMA-33B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
49
+ | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
50
+ | FP16 | 16 | - | OOM | 4.10 | 60.5 |
51
+ | RTN | 4 | - | - | 4.54 | - |
52
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 19493 | 4.45 | 15.7 |
53
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 20570 | 4.23 | 16.3 |
54
+ | RTN | 3 | - | - | 14.89 | - |
55
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 15493 | 5.69 | 12.0 |
56
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 16566 | 4.80 | 13.0 |
57
+
58
+ </details>
59
+
60
+ <details>
61
+ <summary>LLaMA-65B</summary>
62
+
63
+ | [LLaMA-65B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
64
+ | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
65
+ | FP16 | 16 | - | OOM | 3.53 | 121.0 |
66
+ | RTN | 4 | - | - | 3.92 | - |
67
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | OOM | 3.84 | 31.1 |
68
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | OOM | 3.65 | 32.3 |
69
+ | RTN | 3 | - | - | 10.59 | - |
70
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | OOM | 5.04 | 23.6 |
71
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | OOM | 4.17 | 25.6 |
72
+ </details>
73
+
74
+ Quantization requires a large amount of CPU memory. However, the memory required can be reduced by using swap memory.
75
+
76
+ Depending on the GPUs/drivers, there may be a difference in performance, which decreases as the model size increases.(https://github.com/IST-DASLab/gptq/issues/1)
77
+
78
+ According to [GPTQ paper](https://arxiv.org/abs/2210.17323), As the size of the model increases, the difference in performance between FP16 and GPTQ decreases.
79
+
80
+ ## Installation
81
+ If you don't have [conda](https://docs.conda.io/en/latest/miniconda.html), install it first.
82
+ ```
83
+ conda create --name gptq python=3.9 -y
84
+ conda activate gptq
85
+ conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
86
+ # Or, if you're having trouble with conda, use pip with python3.9:
87
+ # pip3 install torch torchvision torchaudio
88
+
89
+ git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa
90
+ cd GPTQ-for-LLaMa
91
+ pip install -r requirements.txt
92
+ ```
93
+ ## Dependencies
94
+
95
+ * `torch`: tested on v2.0.0+cu117
96
+ * `transformers`: tested on v4.28.0.dev0
97
+ * `datasets`: tested on v2.10.1
98
+ * `safetensors`: tested on v0.3.0
99
+
100
+ All experiments were run on a single NVIDIA RTX3090.
101
+
102
+ # Language Generation
103
+ ## LLaMA
104
+
105
+ ```
106
+ #convert LLaMA to hf
107
+ python convert_llama_weights_to_hf.py --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir ./llama-hf
108
+
109
+ # Benchmark language generation with 4-bit LLaMA-7B:
110
+
111
+ # Save compressed model
112
+ CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save llama7b-4bit-128g.pt
113
+
114
+ # Or save compressed `.safetensors` model
115
+ CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save_safetensors llama7b-4bit-128g.safetensors
116
+
117
+ # Benchmark generating a 2048 token sequence with the saved model
118
+ CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --benchmark 2048 --check
119
+
120
+ # Benchmark FP16 baseline, note that the model will be split across all listed GPUs
121
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4 python llama.py ${MODEL_DIR} c4 --benchmark 2048 --check
122
+
123
+ # model inference with the saved model
124
+ CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama"
125
+
126
+ # model inference with the saved model using safetensors loaded direct to gpu
127
+ CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.safetensors --text "this is llama" --device=0
128
+
129
+ # model inference with the saved model with offload(This is very slow).
130
+ CUDA_VISIBLE_DEVICES=0 python llama_inference_offload.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama" --pre_layer 16
131
+ It takes about 180 seconds to generate 45 tokens(5->50 tokens) on single RTX3090 based on LLaMa-65B. pre_layer is set to 50.
132
+ ```
133
+ Basically, 4-bit quantization and 128 groupsize are recommended.
134
+
135
+ You can also export quantization parameters with toml+numpy format.
136
+ ```
137
+ CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --quant-directory ${TOML_DIR}
138
+ ```
139
+
140
+ # Acknowledgements
141
+ This code is based on [GPTQ](https://github.com/IST-DASLab/gptq)
142
+
143
+ Thanks to Meta AI for releasing [LLaMA](https://arxiv.org/abs/2302.13971), a powerful LLM.
144
+
145
+ Triton GPTQ kernel code is based on [GPTQ-triton](https://github.com/fpgaminer/GPTQ-triton)
convert_llama_weights_to_hf.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from transformers.models.llama.convert_llama_weights_to_hf import write_model, write_tokenizer
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument(
9
+ "--input_dir",
10
+ help="Location of LLaMA weights, which contains tokenizer.model and model folders",
11
+ )
12
+ parser.add_argument(
13
+ "--model_size",
14
+ choices=["7B", "13B", "30B", "65B", "tokenizer_only"],
15
+ )
16
+ parser.add_argument(
17
+ "--output_dir",
18
+ help="Location to write HF model and tokenizer",
19
+ )
20
+ args = parser.parse_args()
21
+ if args.model_size != "tokenizer_only":
22
+ write_model(
23
+ model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
24
+ input_base_path=os.path.join(args.input_dir, args.model_size),
25
+ model_size=args.model_size,
26
+ )
27
+ write_tokenizer(
28
+ tokenizer_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
29
+ input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
30
+ )
31
+
32
+
33
+ if __name__ == "__main__":
34
+ main()
gptq.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import transformers
7
+ import quant
8
+ from texttable import Texttable
9
+ from utils import torch_snr_error
10
+
11
+ torch.backends.cuda.matmul.allow_tf32 = False
12
+ torch.backends.cudnn.allow_tf32 = False
13
+
14
+
15
+ class Observer:
16
+
17
+ def __init__(self, topk=32):
18
+ self.loss_list = []
19
+ self.topk = topk
20
+
21
+ def submit(self, name: str, layerid: int, gptq, error: float):
22
+
23
+ item = (name, layerid, {'gptq': gptq, 'error': error})
24
+
25
+ if len(self.loss_list) < self.topk:
26
+ self.loss_list.append(item)
27
+ return
28
+
29
+ min_error = error
30
+ min_idx = -1
31
+ for idx, data in enumerate(self.loss_list):
32
+ if min_error > data[2]['error']:
33
+ min_idx = idx
34
+ min_error = data[2]['error']
35
+
36
+ if min_idx >= 0:
37
+ self.loss_list[min_idx] = item
38
+
39
+ def print(self):
40
+ self.loss_list = sorted(self.loss_list, key=lambda s: s[2]['error'], reverse=True)
41
+
42
+ table = Texttable()
43
+
44
+ table.header(['name', 'error'])
45
+ table.set_cols_dtype(['t', 'f'])
46
+
47
+ for item in self.loss_list:
48
+ table.add_row([f"{item[0]}.{item[1]}", item[2]['error']])
49
+ print(table.draw())
50
+ print('\n')
51
+
52
+ def items(self):
53
+ return self.loss_list
54
+
55
+
56
+ class GPTQ:
57
+
58
+ def __init__(self, layer, observe=False):
59
+ self.layer = layer
60
+ self.dev = self.layer.weight.device
61
+ W = layer.weight.data.clone()
62
+ if isinstance(self.layer, nn.Conv2d):
63
+ W = W.flatten(1)
64
+ if isinstance(self.layer, transformers.Conv1D):
65
+ W = W.t()
66
+ self.rows = W.shape[0]
67
+ self.columns = W.shape[1]
68
+ self.H = torch.zeros((self.columns, self.columns), device=self.dev)
69
+ self.nsamples = 0
70
+ self.quantizer = quant.Quantizer()
71
+ self.observe = observe
72
+
73
+ def add_batch(self, inp, out):
74
+ # Hessian H = 2 X XT + λ I
75
+ if self.observe:
76
+ self.inp1 = inp
77
+ self.out1 = out
78
+ else:
79
+ self.inp1 = None
80
+ self.out1 = None
81
+
82
+ if len(inp.shape) == 2:
83
+ inp = inp.unsqueeze(0)
84
+ tmp = inp.shape[0]
85
+ if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
86
+ if len(inp.shape) == 3:
87
+ inp = inp.reshape((-1, inp.shape[-1]))
88
+ inp = inp.t()
89
+ if isinstance(self.layer, nn.Conv2d):
90
+ unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride)
91
+ inp = unfold(inp)
92
+ inp = inp.permute([1, 0, 2])
93
+ inp = inp.flatten(1)
94
+ self.H *= self.nsamples / (self.nsamples + tmp)
95
+ self.nsamples += tmp
96
+ # inp = inp.float()
97
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
98
+ # self.H += 2 / self.nsamples * inp.matmul(inp.t())
99
+ self.H += inp.matmul(inp.t())
100
+
101
+ def print_loss(self, name, q_weight, weight_error, timecost):
102
+ table = Texttable()
103
+ name += ' ' * (16 - len(name))
104
+
105
+ table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time'])
106
+
107
+ # assign weight
108
+ self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
109
+
110
+ if self.inp1 is not None:
111
+ # quantize input to int8
112
+ quantizer = quant.Quantizer()
113
+ quantizer.configure(8, perchannel=False, sym=True, mse=False)
114
+ quantizer.find_params(self.inp1)
115
+ q_in = quantizer.quantize(self.inp1).type(torch.float16)
116
+ q_out = self.layer(q_in)
117
+
118
+ # get kinds of SNR
119
+ q_SNR = torch_snr_error(q_out, self.out1).item()
120
+ fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
121
+ else:
122
+ q_SNR = '-'
123
+ fp_SNR = '-'
124
+
125
+ table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
126
+ print(table.draw().split('\n')[-2])
127
+
128
+ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''):
129
+ self.layer.to(self.dev)
130
+
131
+ W = self.layer.weight.data.clone()
132
+ if isinstance(self.layer, nn.Conv2d):
133
+ W = W.flatten(1)
134
+ if isinstance(self.layer, transformers.Conv1D):
135
+ W = W.t()
136
+ W = W.float()
137
+
138
+ tick = time.time()
139
+
140
+ if not self.quantizer.ready():
141
+ self.quantizer.find_params(W, weight=True)
142
+
143
+ H = self.H
144
+ if not self.observe:
145
+ del self.H
146
+ dead = torch.diag(H) == 0
147
+ H[dead, dead] = 1
148
+ W[:, dead] = 0
149
+
150
+ if actorder:
151
+ perm = torch.argsort(torch.diag(H), descending=True)
152
+ W = W[:, perm]
153
+ H = H[perm][:, perm]
154
+
155
+ Losses = torch.zeros_like(W)
156
+ Q = torch.zeros_like(W)
157
+
158
+ damp = percdamp * torch.mean(torch.diag(H))
159
+ diag = torch.arange(self.columns, device=self.dev)
160
+ H[diag, diag] += damp
161
+ H = torch.linalg.cholesky(H)
162
+ H = torch.cholesky_inverse(H)
163
+ H = torch.linalg.cholesky(H, upper=True)
164
+ Hinv = H
165
+
166
+ g_idx = []
167
+ scale = []
168
+ zero = []
169
+ now_idx = 1
170
+
171
+ for i1 in range(0, self.columns, blocksize):
172
+ i2 = min(i1 + blocksize, self.columns)
173
+ count = i2 - i1
174
+
175
+ W1 = W[:, i1:i2].clone()
176
+ Q1 = torch.zeros_like(W1)
177
+ Err1 = torch.zeros_like(W1)
178
+ Losses1 = torch.zeros_like(W1)
179
+ Hinv1 = Hinv[i1:i2, i1:i2]
180
+
181
+ for i in range(count):
182
+ w = W1[:, i]
183
+ d = Hinv1[i, i]
184
+
185
+ if groupsize != -1:
186
+ if (i1 + i) % groupsize == 0:
187
+ self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
188
+
189
+ if ((i1 + i) // groupsize) - now_idx == -1:
190
+ scale.append(self.quantizer.scale)
191
+ zero.append(self.quantizer.zero)
192
+ now_idx += 1
193
+
194
+ q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
195
+ Q1[:, i] = q
196
+ Losses1[:, i] = (w - q)**2 / d**2
197
+
198
+ err1 = (w - q) / d
199
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
200
+ Err1[:, i] = err1
201
+
202
+ Q[:, i1:i2] = Q1
203
+ Losses[:, i1:i2] = Losses1 / 2
204
+
205
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
206
+
207
+ torch.cuda.synchronize()
208
+ error = torch.sum(Losses).item()
209
+
210
+ groupsize = groupsize if groupsize != -1 else self.columns
211
+ g_idx = [i // groupsize for i in range(self.columns)]
212
+ g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
213
+ if actorder:
214
+ invperm = torch.argsort(perm)
215
+ Q = Q[:, invperm]
216
+ g_idx = g_idx[invperm]
217
+
218
+ if isinstance(self.layer, transformers.Conv1D):
219
+ Q = Q.t()
220
+
221
+ self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick))
222
+
223
+ if scale == []:
224
+ scale.append(self.quantizer.scale)
225
+ zero.append(self.quantizer.zero)
226
+ scale = torch.cat(scale, dim=1)
227
+ zero = torch.cat(zero, dim=1)
228
+ return scale, zero, g_idx, error
229
+
230
+ def free(self):
231
+ self.inp1 = None
232
+ self.out1 = None
233
+ self.H = None
234
+ self.Losses = None
235
+ self.Trace = None
236
+ torch.cuda.empty_cache()
llama.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import quant
7
+
8
+ from gptq import GPTQ, Observer
9
+ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions
10
+ from texttable import Texttable
11
+
12
+
13
+ def get_llama(model):
14
+
15
+ def skip(*args, **kwargs):
16
+ pass
17
+
18
+ torch.nn.init.kaiming_uniform_ = skip
19
+ torch.nn.init.uniform_ = skip
20
+ torch.nn.init.normal_ = skip
21
+ from transformers import LlamaForCausalLM
22
+ model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
23
+ model.seqlen = 2048
24
+ return model
25
+
26
+
27
+ @torch.no_grad()
28
+ def llama_sequential(model, dataloader, dev):
29
+ print('Starting ...')
30
+
31
+ use_cache = model.config.use_cache
32
+ model.config.use_cache = False
33
+ layers = model.model.layers
34
+
35
+ model.model.embed_tokens = model.model.embed_tokens.to(dev)
36
+ model.model.norm = model.model.norm.to(dev)
37
+ layers[0] = layers[0].to(dev)
38
+
39
+ dtype = next(iter(model.parameters())).dtype
40
+ inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
41
+ cache = {'i': 0, 'attention_mask': None}
42
+
43
+ class Catcher(nn.Module):
44
+
45
+ def __init__(self, module):
46
+ super().__init__()
47
+ self.module = module
48
+
49
+ def forward(self, inp, **kwargs):
50
+ inps[cache['i']] = inp
51
+ cache['i'] += 1
52
+ cache['attention_mask'] = kwargs['attention_mask']
53
+ cache['position_ids'] = kwargs['position_ids']
54
+ raise ValueError
55
+
56
+ layers[0] = Catcher(layers[0])
57
+ for batch in dataloader:
58
+ try:
59
+ model(batch[0].to(dev))
60
+ except ValueError:
61
+ pass
62
+ layers[0] = layers[0].module
63
+
64
+ layers[0] = layers[0].cpu()
65
+ model.model.embed_tokens = model.model.embed_tokens.cpu()
66
+ model.model.norm = model.model.norm.cpu()
67
+ torch.cuda.empty_cache()
68
+
69
+ outs = torch.zeros_like(inps)
70
+ attention_mask = cache['attention_mask']
71
+ position_ids = cache['position_ids']
72
+
73
+ print('Ready.')
74
+
75
+ quantizers = {}
76
+ observer = Observer()
77
+ for i in range(len(layers)):
78
+
79
+ print(f'Quantizing layer {i+1}/{len(layers)}..')
80
+ print('+------------------+--------------+------------+-----------+-------+')
81
+ print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
82
+ print('+==================+==============+============+===========+=======+')
83
+
84
+ layer = layers[i].to(dev)
85
+ full = find_layers(layer)
86
+ if args.true_sequential:
87
+ sequential = [['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], ['self_attn.o_proj'], ['mlp.up_proj', 'mlp.gate_proj'], ['mlp.down_proj']]
88
+ else:
89
+ sequential = [list(full.keys())]
90
+
91
+ for names in sequential:
92
+ subset = {n: full[n] for n in names}
93
+ gptq = {}
94
+ for name in subset:
95
+ gptq[name] = GPTQ(subset[name], observe=args.observe)
96
+ gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
97
+
98
+ def add_batch(name):
99
+
100
+ def tmp(_, inp, out):
101
+ gptq[name].add_batch(inp[0].data, out.data)
102
+
103
+ return tmp
104
+
105
+ handles = []
106
+ for name in subset:
107
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
108
+ for j in range(args.nsamples):
109
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
110
+ for h in handles:
111
+ h.remove()
112
+
113
+ for name in subset:
114
+ scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name)
115
+ quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize)
116
+
117
+ if args.observe:
118
+ observer.submit(name=name, layerid=i, gptq=gptq[name], error=error)
119
+ else:
120
+ gptq[name].free()
121
+
122
+ for j in range(args.nsamples):
123
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
124
+
125
+ layers[i] = layer.cpu()
126
+ del layer
127
+ del gptq
128
+ torch.cuda.empty_cache()
129
+
130
+ inps, outs = outs, inps
131
+ print('+------------------+--------------+------------+-----------+-------+')
132
+ print('\n')
133
+
134
+ if args.observe:
135
+ observer.print()
136
+ conditions = gen_conditions(args.wbits, args.groupsize)
137
+ for item in observer.items():
138
+ name = item[0]
139
+ layerid = item[1]
140
+ gptq = item[2]['gptq']
141
+ error = item[2]['error']
142
+ target = error / 2
143
+
144
+ table = Texttable()
145
+ table.header(['wbits', 'groupsize', 'error'])
146
+ table.set_cols_dtype(['i', 'i', 'f'])
147
+ table.add_row([args.wbits, args.groupsize, error])
148
+
149
+ print('Optimizing {} {} ..'.format(name, layerid))
150
+ for wbits, groupsize in conditions:
151
+
152
+ if error < target:
153
+ # if error dropped 50%, skip
154
+ break
155
+
156
+ gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False)
157
+
158
+ scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name)
159
+
160
+ table.add_row([wbits, groupsize, error])
161
+ quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize)
162
+
163
+ print(table.draw())
164
+ print('\n')
165
+ gptq.layer.to('cpu')
166
+ gptq.free()
167
+
168
+ model.config.use_cache = use_cache
169
+
170
+ return quantizers
171
+
172
+
173
+ @torch.no_grad()
174
+ def llama_eval(model, testenc, dev):
175
+ print('Evaluating ...')
176
+
177
+ testenc = testenc.input_ids
178
+ nsamples = testenc.numel() // model.seqlen
179
+
180
+ use_cache = model.config.use_cache
181
+ model.config.use_cache = False
182
+ layers = model.model.layers
183
+
184
+ model.model.embed_tokens = model.model.embed_tokens.to(dev)
185
+ layers[0] = layers[0].to(dev)
186
+
187
+ dtype = next(iter(model.parameters())).dtype
188
+ inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
189
+ cache = {'i': 0, 'attention_mask': None}
190
+
191
+ class Catcher(nn.Module):
192
+
193
+ def __init__(self, module):
194
+ super().__init__()
195
+ self.module = module
196
+
197
+ def forward(self, inp, **kwargs):
198
+ inps[cache['i']] = inp
199
+ cache['i'] += 1
200
+ cache['attention_mask'] = kwargs['attention_mask']
201
+ cache['position_ids'] = kwargs['position_ids']
202
+ raise ValueError
203
+
204
+ layers[0] = Catcher(layers[0])
205
+ for i in range(nsamples):
206
+ batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
207
+ try:
208
+ model(batch)
209
+ except ValueError:
210
+ pass
211
+ layers[0] = layers[0].module
212
+
213
+ layers[0] = layers[0].cpu()
214
+ model.model.embed_tokens = model.model.embed_tokens.cpu()
215
+ torch.cuda.empty_cache()
216
+
217
+ outs = torch.zeros_like(inps)
218
+ attention_mask = cache['attention_mask']
219
+ position_ids = cache['position_ids']
220
+
221
+ for i in range(len(layers)):
222
+ print(i)
223
+ layer = layers[i].to(dev)
224
+
225
+ if args.nearest:
226
+ subset = find_layers(layer)
227
+ for name in subset:
228
+ quantizer = quant.Quantizer()
229
+ quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
230
+ W = subset[name].weight.data
231
+ quantizer.find_params(W, weight=True)
232
+ subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
233
+
234
+ for j in range(nsamples):
235
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
236
+ layers[i] = layer.cpu()
237
+ del layer
238
+ torch.cuda.empty_cache()
239
+ inps, outs = outs, inps
240
+
241
+ if model.model.norm is not None:
242
+ model.model.norm = model.model.norm.to(dev)
243
+ model.lm_head = model.lm_head.to(dev)
244
+
245
+ testenc = testenc.to(dev)
246
+ nlls = []
247
+ for i in range(nsamples):
248
+ hidden_states = inps[i].unsqueeze(0)
249
+ if model.model.norm is not None:
250
+ hidden_states = model.model.norm(hidden_states)
251
+ lm_logits = model.lm_head(hidden_states)
252
+ shift_logits = lm_logits[:, :-1, :].contiguous()
253
+ shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
254
+ loss_fct = nn.CrossEntropyLoss()
255
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
256
+ neg_log_likelihood = loss.float() * model.seqlen
257
+ nlls.append(neg_log_likelihood)
258
+ ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
259
+ print(ppl.item())
260
+
261
+ model.config.use_cache = use_cache
262
+
263
+
264
+ # TODO: perform packing on GPU
265
+ def llama_pack(model, quantizers, wbits, groupsize):
266
+ layers = find_layers(model)
267
+ layers = {n: layers[n] for n in quantizers}
268
+ quant.make_quant_linear(model, quantizers, wbits, groupsize)
269
+ qlayers = find_layers(model, [quant.QuantLinear])
270
+ print('Packing ...')
271
+ for name in qlayers:
272
+ print(name)
273
+ quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
274
+ qlayers[name].pack(layers[name], scale, zero, g_idx)
275
+ print('Done.')
276
+ return model
277
+
278
+
279
+ def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
280
+ from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
281
+ config = LlamaConfig.from_pretrained(model)
282
+
283
+ def noop(*args, **kwargs):
284
+ pass
285
+
286
+ torch.nn.init.kaiming_uniform_ = noop
287
+ torch.nn.init.uniform_ = noop
288
+ torch.nn.init.normal_ = noop
289
+
290
+ torch.set_default_dtype(torch.half)
291
+ modeling_utils._init_weights = False
292
+ torch.set_default_dtype(torch.half)
293
+ model = LlamaForCausalLM(config)
294
+ torch.set_default_dtype(torch.float)
295
+ if eval:
296
+ model = model.eval()
297
+ layers = find_layers(model)
298
+ for name in ['lm_head']:
299
+ if name in layers:
300
+ del layers[name]
301
+ quant.make_quant_linear(model, layers, wbits, groupsize)
302
+
303
+ del layers
304
+
305
+ print('Loading model ...')
306
+ if checkpoint.endswith('.safetensors'):
307
+ from safetensors.torch import load_file as safe_load
308
+ model.load_state_dict(safe_load(checkpoint))
309
+ else:
310
+ model.load_state_dict(torch.load(checkpoint))
311
+
312
+ if eval:
313
+ quant.make_quant_attn(model)
314
+ quant.make_quant_norm(model)
315
+ if fused_mlp:
316
+ quant.make_fused_mlp(model)
317
+
318
+ if warmup_autotune:
319
+ quant.autotune_warmup_linear(model, transpose=not (eval))
320
+ if eval and fused_mlp:
321
+ quant.autotune_warmup_fused(model)
322
+ model.seqlen = 2048
323
+ print('Done.')
324
+
325
+ return model
326
+
327
+
328
+ def llama_multigpu(model, gpus, gpu_dist):
329
+ model.model.embed_tokens = model.model.embed_tokens.to(gpus[0])
330
+ if hasattr(model.model, 'norm') and model.model.norm:
331
+ model.model.norm = model.model.norm.to(gpus[-1])
332
+ import copy
333
+ model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1])
334
+
335
+ cache = {'mask': None}
336
+
337
+ class MoveModule(nn.Module):
338
+
339
+ def __init__(self, module):
340
+ super().__init__()
341
+ self.module = module
342
+ self.dev = next(iter(self.module.parameters())).device
343
+
344
+ def forward(self, *inp, **kwargs):
345
+ inp = list(inp)
346
+ if inp[0].device != self.dev:
347
+ inp[0] = inp[0].to(self.dev)
348
+ if cache['mask'] is None or cache['mask'].device != self.dev:
349
+ cache['mask'] = kwargs['attention_mask'].to(self.dev)
350
+ kwargs['attention_mask'] = cache['mask']
351
+ tmp = self.module(*inp, **kwargs)
352
+ return tmp
353
+
354
+ layers = model.model.layers
355
+ from math import ceil
356
+ if not gpu_dist:
357
+ pergpu = ceil(len(layers) / len(gpus))
358
+ for i in range(len(layers)):
359
+ layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
360
+ else:
361
+ assigned_gpus = []
362
+ for i in range(len(gpu_dist)):
363
+ assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
364
+
365
+ remaining_assignments = len(layers)-len(assigned_gpus)
366
+ if remaining_assignments > 0:
367
+ assigned_gpus = assigned_gpus + [-1] * remaining_assignments
368
+
369
+ for i in range(len(layers)):
370
+ layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]))
371
+
372
+ model.gpus = gpus
373
+
374
+
375
+ def benchmark(model, input_ids, check=False):
376
+ input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
377
+ torch.cuda.synchronize()
378
+
379
+ cache = {'past': None}
380
+
381
+ def clear_past(i):
382
+
383
+ def tmp(layer, inp, out):
384
+ if cache['past']:
385
+ cache['past'][i] = None
386
+
387
+ return tmp
388
+
389
+ for i, layer in enumerate(model.model.layers):
390
+ layer.register_forward_hook(clear_past(i))
391
+
392
+ print('Benchmarking ...')
393
+
394
+ if check:
395
+ loss = nn.CrossEntropyLoss()
396
+ tot = 0.
397
+
398
+ def sync():
399
+ if hasattr(model, 'gpus'):
400
+ for gpu in model.gpus:
401
+ torch.cuda.synchronize(gpu)
402
+ else:
403
+ torch.cuda.synchronize()
404
+
405
+ max_memory = 0
406
+ with torch.no_grad():
407
+ attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
408
+ times = []
409
+ for i in range(input_ids.numel()):
410
+ tick = time.time()
411
+ out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
412
+ sync()
413
+ times.append(time.time() - tick)
414
+ print(i, times[-1])
415
+ if hasattr(model, 'gpus'):
416
+ mem_allocated = sum(torch.cuda.memory_allocated(gpu) for gpu in model.gpus) / 1024 / 1024
417
+ else:
418
+ mem_allocated = torch.cuda.memory_allocated() / 1024 / 1024
419
+ max_memory = max(max_memory, mem_allocated)
420
+ if check and i != input_ids.numel() - 1:
421
+ tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
422
+ cache['past'] = list(out.past_key_values)
423
+ del out
424
+ sync()
425
+ print('Median:', np.median(times))
426
+ if check:
427
+ print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
428
+ print('max memory(MiB):', max_memory)
429
+
430
+
431
+ if __name__ == '__main__':
432
+
433
+ parser = argparse.ArgumentParser()
434
+
435
+ parser.add_argument('model', type=str, help='llama model to load')
436
+ parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
437
+ parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
438
+ parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
439
+ parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
440
+ parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
441
+ parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
442
+ parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
443
+ parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
444
+ parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
445
+ parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
446
+ parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
447
+ parser.add_argument('--load', type=str, default='', help='Load quantized model.')
448
+ parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
449
+ parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
450
+ parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
451
+ parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
452
+ parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.')
453
+ parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
454
+ parser.add_argument('--layers-dist', type=str, default='', help='Distribution of layers across GPUs. e.g. 2:1:1 for 2 layers on GPU 0, 1 layer on GPU 1, and 1 layer on GPU 2. Any remaining layers will be assigned to your last GPU.')
455
+ parser.add_argument('--observe',
456
+ action='store_true',
457
+ help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64. \
458
+ When this feature enabled, `--save` or `--save_safetensors` would be disable.')
459
+ parser.add_argument('--quant-directory', type=str, default=None, help='Specify the directory for export quantization parameters to toml format. `None` means no export by default.')
460
+
461
+ args = parser.parse_args()
462
+
463
+ if args.layers_dist:
464
+ gpu_dist = [int(x) for x in args.layers_dist.split(':')]
465
+ else:
466
+ gpu_dist = []
467
+
468
+ if type(args.load) is not str:
469
+ args.load = args.load.as_posix()
470
+
471
+ if args.load:
472
+ model = load_quant(args.model, args.load, args.wbits, args.groupsize)
473
+ else:
474
+ model = get_llama(args.model)
475
+ model.eval()
476
+
477
+ dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
478
+
479
+ if not args.load and args.wbits < 16 and not args.nearest:
480
+ tick = time.time()
481
+ quantizers = llama_sequential(model, dataloader, DEV)
482
+ print(time.time() - tick)
483
+
484
+ if args.benchmark:
485
+ gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
486
+ if len(gpus) > 1:
487
+ llama_multigpu(model, gpus, gpu_dist)
488
+ else:
489
+ model = model.to(DEV)
490
+ if args.benchmark:
491
+ input_ids = next(iter(dataloader))[0][:, :args.benchmark]
492
+ benchmark(model, input_ids, check=args.check)
493
+
494
+ if args.eval:
495
+ datasets = ['wikitext2', 'ptb', 'c4']
496
+ if args.new_eval:
497
+ datasets = ['wikitext2', 'ptb-new', 'c4-new']
498
+ for dataset in datasets:
499
+ dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
500
+ print(dataset)
501
+ llama_eval(model, testloader, DEV)
502
+
503
+ if args.quant_directory is not None:
504
+ export_quant_table(quantizers, args.quant_directory)
505
+
506
+ if not args.observe and args.save:
507
+ llama_pack(model, quantizers, args.wbits, args.groupsize)
508
+ torch.save(model.state_dict(), args.save)
509
+
510
+ if not args.observe and args.save_safetensors:
511
+ llama_pack(model, quantizers, args.wbits, args.groupsize)
512
+ from safetensors.torch import save_file as safe_save
513
+ state_dict = model.state_dict()
514
+ state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
515
+ safe_save(state_dict, args.save_safetensors)
llama_inference.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import quant
6
+
7
+ from gptq import GPTQ
8
+ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
9
+ import transformers
10
+ from transformers import AutoTokenizer
11
+
12
+
13
+ def get_llama(model):
14
+
15
+ def skip(*args, **kwargs):
16
+ pass
17
+
18
+ torch.nn.init.kaiming_uniform_ = skip
19
+ torch.nn.init.uniform_ = skip
20
+ torch.nn.init.normal_ = skip
21
+ from transformers import LlamaForCausalLM
22
+ model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto')
23
+ model.seqlen = 2048
24
+ return model
25
+
26
+
27
+ def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
28
+ from transformers import LlamaConfig, LlamaForCausalLM
29
+ config = LlamaConfig.from_pretrained(model)
30
+
31
+ def noop(*args, **kwargs):
32
+ pass
33
+
34
+ torch.nn.init.kaiming_uniform_ = noop
35
+ torch.nn.init.uniform_ = noop
36
+ torch.nn.init.normal_ = noop
37
+
38
+ torch.set_default_dtype(torch.half)
39
+ transformers.modeling_utils._init_weights = False
40
+ torch.set_default_dtype(torch.half)
41
+ model = LlamaForCausalLM(config)
42
+ torch.set_default_dtype(torch.float)
43
+ if eval:
44
+ model = model.eval()
45
+ layers = find_layers(model)
46
+ for name in ['lm_head']:
47
+ if name in layers:
48
+ del layers[name]
49
+ quant.make_quant_linear(model, layers, wbits, groupsize)
50
+
51
+ del layers
52
+
53
+ print('Loading model ...')
54
+ if checkpoint.endswith('.safetensors'):
55
+ from safetensors.torch import load_file as safe_load
56
+ model.load_state_dict(safe_load(checkpoint), strict=False)
57
+ else:
58
+ model.load_state_dict(torch.load(checkpoint), strict=False)
59
+
60
+ if eval:
61
+ quant.make_quant_attn(model)
62
+ quant.make_quant_norm(model)
63
+ if fused_mlp:
64
+ quant.make_fused_mlp(model)
65
+ if warmup_autotune:
66
+ quant.autotune_warmup_linear(model, transpose=not (eval))
67
+ if eval and fused_mlp:
68
+ quant.autotune_warmup_fused(model)
69
+ model.seqlen = 2048
70
+ print('Done.')
71
+
72
+ return model
73
+
74
+
75
+ if __name__ == '__main__':
76
+
77
+ parser = argparse.ArgumentParser()
78
+
79
+ parser.add_argument('model', type=str, help='llama model to load')
80
+ parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
81
+ parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
82
+ parser.add_argument('--load', type=str, default='', help='Load quantized model.')
83
+
84
+ parser.add_argument('--text', type=str, help='input text')
85
+
86
+ parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.')
87
+
88
+ parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.')
89
+
90
+ parser.add_argument('--top_p',
91
+ type=float,
92
+ default=0.95,
93
+ help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.')
94
+
95
+ parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.')
96
+
97
+ parser.add_argument('--device', type=int, default=-1, help='The device used to load the model when using safetensors. Default device is "cpu" or specify, 0,1,2,3,... for GPU device.')
98
+
99
+ args = parser.parse_args()
100
+
101
+ if type(args.load) is not str:
102
+ args.load = args.load.as_posix()
103
+
104
+ if args.load:
105
+ model = load_quant(args.model, args.load, args.wbits, args.groupsize)
106
+ else:
107
+ model = get_llama(args.model)
108
+ model.eval()
109
+
110
+ model.to(DEV)
111
+ tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
112
+ input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV)
113
+
114
+ with torch.no_grad():
115
+ generated_ids = model.generate(
116
+ input_ids,
117
+ do_sample=True,
118
+ min_length=args.min_length,
119
+ max_length=args.max_length,
120
+ top_p=args.top_p,
121
+ temperature=args.temperature,
122
+ )
123
+ print(tokenizer.decode([el.item() for el in generated_ids[0]]))
llama_inference_offload.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from gptq import GPTQ
5
+ import argparse
6
+ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
7
+ import quant
8
+
9
+ import transformers
10
+ from transformers import AutoTokenizer
11
+ from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast
13
+ from typing import List, Optional, Tuple, Union
14
+ from accelerate import cpu_offload_with_hook, load_checkpoint_in_model
15
+
16
+
17
+ class Offload_LlamaModel(LlamaModel):
18
+
19
+ def __init__(self, config: LlamaConfig):
20
+ super().__init__(config)
21
+
22
+ def cpu_offload(self, preload):
23
+ hook = None
24
+ for cpu_offloaded_model in self.layers[preload:]:
25
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, DEV, prev_module_hook=hook)
26
+
27
+ def forward(
28
+ self,
29
+ input_ids: torch.LongTensor = None,
30
+ attention_mask: Optional[torch.Tensor] = None,
31
+ position_ids: Optional[torch.LongTensor] = None,
32
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
33
+ inputs_embeds: Optional[torch.FloatTensor] = None,
34
+ use_cache: Optional[bool] = None,
35
+ output_attentions: Optional[bool] = None,
36
+ output_hidden_states: Optional[bool] = None,
37
+ return_dict: Optional[bool] = None,
38
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
39
+ r"""
40
+ Args:
41
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
42
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
43
+ provide it.
44
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
45
+ [`PreTrainedTokenizer.__call__`] for details.
46
+ [What are input IDs?](../glossary#input-ids)
47
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
48
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
49
+ - 1 for tokens that are **not masked**,
50
+ - 0 for tokens that are **masked**.
51
+ [What are attention masks?](../glossary#attention-mask)
52
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
53
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
54
+ `[0, config.n_positions - 1]`.
55
+ [What are position IDs?](../glossary#position-ids)
56
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
57
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
58
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
59
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
60
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
61
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
62
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
63
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
64
+ use_cache (`bool`, *optional*):
65
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
66
+ (see `past_key_values`).
67
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
68
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
69
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
70
+ than the model's internal embedding lookup matrix.
71
+ output_attentions (`bool`, *optional*):
72
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
73
+ returned tensors for more detail.
74
+ output_hidden_states (`bool`, *optional*):
75
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
76
+ for more detail.
77
+ return_dict (`bool`, *optional*):
78
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
79
+ """
80
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
81
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
82
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
83
+
84
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
85
+
86
+ # retrieve input_ids and inputs_embeds
87
+ if input_ids is not None and inputs_embeds is not None:
88
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
89
+ elif input_ids is not None:
90
+ batch_size, seq_length = input_ids.shape
91
+ elif inputs_embeds is not None:
92
+ batch_size, seq_length, _ = inputs_embeds.shape
93
+ else:
94
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
95
+ seq_length_with_past = seq_length
96
+ past_key_values_length = 0
97
+ if past_key_values is not None:
98
+ past_key_values_length = past_key_values[0][0].shape[2]
99
+ seq_length_with_past = seq_length_with_past + past_key_values_length
100
+
101
+ if position_ids is None:
102
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
103
+ position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
104
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
105
+ else:
106
+ position_ids = position_ids.view(-1, seq_length).long()
107
+
108
+ if inputs_embeds is None:
109
+ inputs_embeds = self.embed_tokens(input_ids)
110
+
111
+ # embed positions
112
+ if attention_mask is None:
113
+ attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
114
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length)
115
+
116
+ hidden_states = inputs_embeds
117
+
118
+ if self.gradient_checkpointing and self.training:
119
+ if use_cache:
120
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
121
+ use_cache = False
122
+
123
+ # decoder layers
124
+ all_hidden_states = () if output_hidden_states else None
125
+ all_self_attns = () if output_attentions else None
126
+ next_decoder_cache = () if use_cache else None
127
+
128
+ for idx in range(len(self.layers)):
129
+ decoder_layer = self.layers[idx]
130
+
131
+ if output_hidden_states:
132
+ all_hidden_states += (hidden_states, )
133
+
134
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
135
+
136
+ if self.gradient_checkpointing and self.training:
137
+
138
+ def create_custom_forward(module):
139
+
140
+ def custom_forward(*inputs):
141
+ # None for past_key_value
142
+ return module(*inputs, output_attentions, None)
143
+
144
+ return custom_forward
145
+
146
+ layer_outputs = torch.utils.checkpoint.checkpoint(
147
+ create_custom_forward(decoder_layer),
148
+ hidden_states,
149
+ attention_mask,
150
+ position_ids,
151
+ None,
152
+ )
153
+ else:
154
+ layer_outputs = decoder_layer(
155
+ hidden_states,
156
+ attention_mask=attention_mask,
157
+ position_ids=position_ids,
158
+ past_key_value=past_key_value,
159
+ output_attentions=output_attentions,
160
+ use_cache=use_cache,
161
+ )
162
+
163
+ hidden_states = layer_outputs[0]
164
+
165
+ if use_cache:
166
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1], )
167
+
168
+ if output_attentions:
169
+ all_self_attns += (layer_outputs[1], )
170
+
171
+ hidden_states = self.norm(hidden_states)
172
+
173
+ # add hidden states from the last decoder layer
174
+ if output_hidden_states:
175
+ all_hidden_states += (hidden_states, )
176
+
177
+ next_cache = next_decoder_cache if use_cache else None
178
+ if not return_dict:
179
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
180
+ return BaseModelOutputWithPast(
181
+ last_hidden_state=hidden_states,
182
+ past_key_values=next_cache,
183
+ hidden_states=all_hidden_states,
184
+ attentions=all_self_attns,
185
+ )
186
+
187
+
188
+ def load_quant(model, checkpoint, wbits, groupsize, pre_layer, fused_mlp=True, warmup_autotune=True):
189
+ transformers.models.llama.modeling_llama.LlamaModel = Offload_LlamaModel
190
+ from transformers import LlamaConfig, LlamaForCausalLM
191
+ config = LlamaConfig.from_pretrained(model)
192
+
193
+ def noop(*args, **kwargs):
194
+ pass
195
+
196
+ torch.nn.init.kaiming_uniform_ = noop
197
+ torch.nn.init.uniform_ = noop
198
+ torch.nn.init.normal_ = noop
199
+
200
+ torch.set_default_dtype(torch.half)
201
+ transformers.modeling_utils._init_weights = False
202
+ torch.set_default_dtype(torch.half)
203
+ model = LlamaForCausalLM(config)
204
+ torch.set_default_dtype(torch.float)
205
+ model = model.eval()
206
+ layers = find_layers(model)
207
+ for name in ['lm_head']:
208
+ if name in layers:
209
+ del layers[name]
210
+ quant.make_quant_linear(model, layers, wbits, groupsize)
211
+
212
+ print('Loading model ...')
213
+ load_checkpoint_in_model(model, checkpoint, dtype='float16')
214
+ model.seqlen = 2048
215
+
216
+ if eval:
217
+ quant.make_quant_attn(model)
218
+ quant.make_quant_norm(model)
219
+ if fused_mlp:
220
+ quant.make_fused_mlp(model)
221
+
222
+
223
+ if warmup_autotune:
224
+ quant.autotune_warmup_linear(model)
225
+ if fused_mlp:
226
+ quant.autotune_warmup_fused(model)
227
+
228
+ for i in range(pre_layer):
229
+ model.model.layers[i].to(DEV)
230
+ model.model.embed_tokens.to(DEV)
231
+ model.model.norm.to(DEV)
232
+ model.lm_head.to(DEV)
233
+ model.model.cpu_offload(pre_layer)
234
+ print('Done.')
235
+ return model
236
+
237
+
238
+ if __name__ == '__main__':
239
+ parser = argparse.ArgumentParser()
240
+
241
+ parser.add_argument('model', type=str, help='llama model to load')
242
+ parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8], help='#bits to use for quantization')
243
+ parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
244
+ parser.add_argument('--load', type=str, default='', help='Load quantized model.')
245
+ parser.add_argument('--text', type=str, help='input text')
246
+
247
+ parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.')
248
+
249
+ parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.')
250
+
251
+ parser.add_argument('--top_p',
252
+ type=float,
253
+ default=0.95,
254
+ help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.')
255
+
256
+ parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.')
257
+
258
+ parser.add_argument('--pre_layer', type=int, default=50, help='The number of layers to preload')
259
+
260
+ args = parser.parse_args()
261
+
262
+ if type(args.load) is not str:
263
+ args.load = args.load.as_posix()
264
+
265
+ model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.pre_layer)
266
+
267
+ tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
268
+ input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV)
269
+
270
+ with torch.no_grad():
271
+ generated_ids = model.generate(
272
+ input_ids,
273
+ do_sample=True,
274
+ min_length=args.min_length,
275
+ max_length=args.max_length,
276
+ top_p=args.top_p,
277
+ temperature=args.temperature,
278
+ )
279
+ print(tokenizer.decode([el.item() for el in generated_ids[0]]))
neox.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import quant
7
+
8
+ from gptq import GPTQ, Observer
9
+ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions
10
+ from texttable import Texttable
11
+
12
+
13
+ def get_neox(model, seqlen=-1):
14
+
15
+ def skip(*args, **kwargs):
16
+ pass
17
+
18
+ torch.nn.init.kaiming_uniform_ = skip
19
+ torch.nn.init.uniform_ = skip
20
+ torch.nn.init.normal_ = skip
21
+ from transformers import GPTNeoXForCausalLM
22
+ model = GPTNeoXForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
23
+ model.seqlen = seqlen if seqlen != -1 else model.config.max_position_embeddings
24
+ return model
25
+
26
+
27
+ @torch.no_grad()
28
+ def neox_sequential(model, dataloader, dev):
29
+ print('Starting ...')
30
+
31
+ use_cache = model.config.use_cache
32
+ model.config.use_cache = False
33
+ layers = model.gpt_neox.layers
34
+
35
+ model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev)
36
+ layers[0] = layers[0].to(dev)
37
+
38
+ dtype = next(iter(model.parameters())).dtype
39
+ inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
40
+ cache = {'i': 0, 'attention_mask': None}
41
+
42
+ class Catcher(nn.Module):
43
+
44
+ def __init__(self, module):
45
+ super().__init__()
46
+ self.module = module
47
+
48
+ def forward(self, inp, **kwargs):
49
+ inps[cache['i']] = inp
50
+ cache['i'] += 1
51
+ cache['attention_mask'] = kwargs['attention_mask']
52
+ cache['position_ids'] = kwargs['position_ids']
53
+ raise ValueError
54
+
55
+ layers[0] = Catcher(layers[0])
56
+ for batch in dataloader:
57
+ try:
58
+ model(batch[0].to(dev))
59
+ except ValueError:
60
+ pass
61
+ layers[0] = layers[0].module
62
+
63
+ layers[0] = layers[0].cpu()
64
+ model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu()
65
+ torch.cuda.empty_cache()
66
+
67
+ outs = torch.zeros_like(inps)
68
+ attention_mask = cache['attention_mask']
69
+ position_ids = cache['position_ids']
70
+
71
+ print('Ready.')
72
+
73
+ quantizers = {}
74
+ observer = Observer()
75
+ for i in range(len(layers)):
76
+
77
+ print(f'Quantizing layer {i+1}/{len(layers)}..')
78
+ print('+------------------+--------------+------------+-----------+-------+')
79
+ print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
80
+ print('+==================+==============+============+===========+=======+')
81
+
82
+ layer = layers[i].to(dev)
83
+ full = find_layers(layer)
84
+ sequential = [list(full.keys())]
85
+
86
+ for names in sequential:
87
+ subset = {n: full[n] for n in names}
88
+ gptq = {}
89
+ for name in subset:
90
+ gptq[name] = GPTQ(subset[name], observe=False)
91
+ gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
92
+
93
+ def add_batch(name):
94
+
95
+ def tmp(_, inp, out):
96
+ gptq[name].add_batch(inp[0].data, out.data)
97
+
98
+ return tmp
99
+
100
+ handles = []
101
+ for name in subset:
102
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
103
+ for j in range(args.nsamples):
104
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
105
+ for h in handles:
106
+ h.remove()
107
+
108
+ for name in subset:
109
+ scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name)
110
+ quantizers['gpt_neox.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize)
111
+ gptq[name].free()
112
+
113
+ for j in range(args.nsamples):
114
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
115
+
116
+ layers[i] = layer.cpu()
117
+ del layer
118
+ del gptq
119
+ torch.cuda.empty_cache()
120
+
121
+ inps, outs = outs, inps
122
+ print('+------------------+--------------+------------+-----------+-------+')
123
+ print('\n')
124
+
125
+ model.config.use_cache = use_cache
126
+
127
+ return quantizers
128
+
129
+
130
+ @torch.no_grad()
131
+ def neox_eval(model, testenc, dev):
132
+ print('Evaluating ...')
133
+
134
+ testenc = testenc.input_ids
135
+ nsamples = testenc.numel() // model.seqlen
136
+
137
+ use_cache = model.config.use_cache
138
+ model.config.use_cache = False
139
+ layers = model.gpt_neox.layers
140
+
141
+ model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev)
142
+ layers[0] = layers[0].to(dev)
143
+
144
+ dtype = next(iter(model.parameters())).dtype
145
+ inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
146
+ cache = {'i': 0, 'attention_mask': None}
147
+
148
+ class Catcher(nn.Module):
149
+
150
+ def __init__(self, module):
151
+ super().__init__()
152
+ self.module = module
153
+
154
+ def forward(self, inp, **kwargs):
155
+ inps[cache['i']] = inp
156
+ cache['i'] += 1
157
+ cache['attention_mask'] = kwargs['attention_mask']
158
+ cache['position_ids'] = kwargs['position_ids']
159
+ raise ValueError
160
+
161
+ layers[0] = Catcher(layers[0])
162
+ for i in range(nsamples):
163
+ batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
164
+ try:
165
+ model(batch)
166
+ except ValueError:
167
+ pass
168
+ layers[0] = layers[0].module
169
+
170
+ layers[0] = layers[0].cpu()
171
+ model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu()
172
+ torch.cuda.empty_cache()
173
+
174
+ outs = torch.zeros_like(inps)
175
+ attention_mask = cache['attention_mask']
176
+ position_ids = cache['position_ids']
177
+
178
+ for i in range(len(layers)):
179
+ print(i)
180
+ layer = layers[i].to(dev)
181
+
182
+ if args.nearest:
183
+ subset = find_layers(layer)
184
+ for name in subset:
185
+ quantizer = quant.Quantizer()
186
+ quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
187
+ W = subset[name].weight.data
188
+ quantizer.find_params(W, weight=True)
189
+ subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
190
+
191
+ for j in range(nsamples):
192
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
193
+ layers[i] = layer.cpu()
194
+ del layer
195
+ torch.cuda.empty_cache()
196
+ inps, outs = outs, inps
197
+
198
+ model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(dev)
199
+ model.embed_out = model.embed_out.to(dev)
200
+
201
+ testenc = testenc.to(dev)
202
+ nlls = []
203
+ for i in range(nsamples):
204
+ hidden_states = inps[i].unsqueeze(0)
205
+ hidden_states = model.gpt_neox.final_layer_norm(hidden_states)
206
+ lm_logits = model.embed_out(hidden_states)
207
+ shift_logits = lm_logits[:, :-1, :].contiguous()
208
+ shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
209
+ loss_fct = nn.CrossEntropyLoss()
210
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
211
+ neg_log_likelihood = loss.float() * model.seqlen
212
+ nlls.append(neg_log_likelihood)
213
+ ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
214
+ print(ppl.item())
215
+
216
+ model.config.use_cache = use_cache
217
+
218
+
219
+ # TODO: perform packing on GPU
220
+ def neox_pack(model, quantizers, wbits, groupsize):
221
+ layers = find_layers(model)
222
+ layers = {n: layers[n] for n in quantizers}
223
+ quant.make_quant_linear(model, quantizers, wbits, groupsize)
224
+ qlayers = find_layers(model, [quant.QuantLinear])
225
+ print('Packing ...')
226
+ for name in qlayers:
227
+ print(name)
228
+ quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
229
+ qlayers[name].pack(layers[name], scale, zero, g_idx)
230
+ print('Done.')
231
+ return model
232
+
233
+
234
+ def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True):
235
+ from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, modeling_utils
236
+ config = GPTNeoXConfig.from_pretrained(model)
237
+
238
+ def noop(*args, **kwargs):
239
+ pass
240
+
241
+ torch.nn.init.kaiming_uniform_ = noop
242
+ torch.nn.init.uniform_ = noop
243
+ torch.nn.init.normal_ = noop
244
+
245
+ torch.set_default_dtype(torch.half)
246
+ modeling_utils._init_weights = False
247
+ torch.set_default_dtype(torch.half)
248
+ model = GPTNeoXForCausalLM(config)
249
+ torch.set_default_dtype(torch.float)
250
+ if eval:
251
+ model = model.eval()
252
+ layers = find_layers(model)
253
+ for name in ['embed_in','embed_out']:
254
+ if name in layers:
255
+ del layers[name]
256
+ quant.make_quant_linear(model, layers, wbits, groupsize)
257
+
258
+ del layers
259
+
260
+ print('Loading model ...')
261
+ if checkpoint.endswith('.safetensors'):
262
+ from safetensors.torch import load_file as safe_load
263
+ model.load_state_dict(safe_load(checkpoint))
264
+ else:
265
+ model.load_state_dict(torch.load(checkpoint))
266
+
267
+ if warmup_autotune:
268
+ quant.autotune_warmup_linear(model, transpose=not (eval))
269
+
270
+ model.seqlen = model.config.max_position_embeddings
271
+ print('Done.')
272
+
273
+ return model
274
+
275
+
276
+ def neox_multigpu(model, gpus):
277
+ model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(gpus[0])
278
+ model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(gpus[-1])
279
+ import copy
280
+ model.embed_out = copy.deepcopy(model.embed_out).to(gpus[-1])
281
+
282
+ cache = {'mask': None}
283
+
284
+ class MoveModule(nn.Module):
285
+
286
+ def __init__(self, module):
287
+ super().__init__()
288
+ self.module = module
289
+ self.dev = next(iter(self.module.parameters())).device
290
+
291
+ def forward(self, *inp, **kwargs):
292
+ inp = list(inp)
293
+ if inp[0].device != self.dev:
294
+ inp[0] = inp[0].to(self.dev)
295
+ if cache['mask'] is None or cache['mask'].device != self.dev:
296
+ cache['mask'] = kwargs['attention_mask'].to(self.dev)
297
+ kwargs['attention_mask'] = cache['mask']
298
+ tmp = self.module(*inp, **kwargs)
299
+ return tmp
300
+
301
+ layers = model.gpt_neox.layers
302
+ pergpu = math.ceil(len(layers) / len(gpus))
303
+ for i in range(len(layers)):
304
+ layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
305
+
306
+ model.gpus = gpus
307
+
308
+
309
+ def benchmark(model, input_ids, check=False):
310
+ input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
311
+ torch.cuda.synchronize()
312
+
313
+ cache = {'past': None}
314
+
315
+ def clear_past(i):
316
+
317
+ def tmp(layer, inp, out):
318
+ if cache['past']:
319
+ cache['past'][i] = None
320
+
321
+ return tmp
322
+
323
+ for i, layer in enumerate(model.gpt_neox.layers):
324
+ layer.register_forward_hook(clear_past(i))
325
+
326
+ print('Benchmarking ...')
327
+
328
+ if check:
329
+ loss = nn.CrossEntropyLoss()
330
+ tot = 0.
331
+
332
+ def sync():
333
+ if hasattr(model, 'gpus'):
334
+ for gpu in model.gpus:
335
+ torch.cuda.synchronize(gpu)
336
+ else:
337
+ torch.cuda.synchronize()
338
+
339
+ max_memory = 0
340
+ with torch.no_grad():
341
+ attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
342
+ times = []
343
+ for i in range(input_ids.numel()):
344
+ tick = time.time()
345
+ out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
346
+ sync()
347
+ times.append(time.time() - tick)
348
+ print(i, times[-1])
349
+ max_memory = max(max_memory, torch.cuda.memory_allocated() / 1024 / 1024)
350
+ if check and i != input_ids.numel() - 1:
351
+ tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
352
+ cache['past'] = list(out.past_key_values)
353
+ del out
354
+ sync()
355
+ print('Median:', np.median(times))
356
+ if check:
357
+ print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
358
+ print('max memory(MiB):', max_memory)
359
+
360
+
361
+ if __name__ == '__main__':
362
+
363
+ parser = argparse.ArgumentParser()
364
+
365
+ parser.add_argument('model', type=str, help='llama model to load')
366
+ parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
367
+ parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
368
+ parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
369
+ parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
370
+ parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
371
+ parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='bits to use for quantization; use 16 for evaluating base model.')
372
+ parser.add_argument('--seqlen', type=int, default=-1, help='seqlen to use for quantization; default uses full seqlen')
373
+ parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
374
+ parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
375
+ parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
376
+ parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
377
+ parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
378
+ parser.add_argument('--load', type=str, default='', help='Load quantized model.')
379
+ parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
380
+ parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
381
+ parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
382
+ parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
383
+ parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
384
+ args = parser.parse_args()
385
+
386
+ if type(args.load) is not str:
387
+ args.load = args.load.as_posix()
388
+
389
+ if args.load:
390
+ model = load_quant(args.model, args.load, args.wbits, args.groupsize)
391
+ else:
392
+ model = get_neox(args.model)
393
+ model.eval()
394
+
395
+ dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
396
+
397
+ if not args.load and args.wbits < 16 and not args.nearest:
398
+ tick = time.time()
399
+ quantizers = neox_sequential(model, dataloader, DEV)
400
+ print(time.time() - tick)
401
+
402
+ if args.benchmark:
403
+ gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
404
+ if len(gpus) > 1:
405
+ neox_multigpu(model, gpus)
406
+ else:
407
+ model = model.to(DEV)
408
+ if args.benchmark:
409
+ input_ids = next(iter(dataloader))[0][:, :args.benchmark]
410
+ benchmark(model, input_ids, check=args.check)
411
+
412
+ if args.eval:
413
+ datasets = ['wikitext2', 'ptb', 'c4']
414
+ if args.new_eval:
415
+ datasets = ['wikitext2', 'ptb-new', 'c4-new']
416
+ for dataset in datasets:
417
+ dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
418
+ print(dataset)
419
+ neox_eval(model, testloader, DEV)
420
+
421
+ if args.save:
422
+ neox_pack(model, quantizers, args.wbits, args.groupsize)
423
+ torch.save(model.state_dict(), args.save)
424
+
425
+ if args.save_safetensors:
426
+ neox_pack(model, quantizers, args.wbits, args.groupsize)
427
+ from safetensors.torch import save_file as safe_save
428
+ state_dict = model.state_dict()
429
+ state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
430
+ safe_save(state_dict, args.save_safetensors)
opt.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import argparse
6
+
7
+ import transformers
8
+ from gptq import GPTQ
9
+ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
10
+ import quant
11
+
12
+
13
+ def get_opt(model):
14
+ import torch
15
+
16
+ def skip(*args, **kwargs):
17
+ pass
18
+
19
+ torch.nn.init.kaiming_uniform_ = skip
20
+ torch.nn.init.uniform_ = skip
21
+ torch.nn.init.normal_ = skip
22
+ from transformers import OPTForCausalLM
23
+ model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
24
+ model.seqlen = model.config.max_position_embeddings
25
+ return model
26
+
27
+
28
+ @torch.no_grad()
29
+ def opt_sequential(model, dataloader, dev):
30
+ print('Starting ...')
31
+
32
+ use_cache = model.config.use_cache
33
+ model.config.use_cache = False
34
+ layers = model.model.decoder.layers
35
+
36
+ model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
37
+ model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
38
+ if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
39
+ model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
40
+ if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
41
+ model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
42
+ layers[0] = layers[0].to(dev)
43
+
44
+ dtype = next(iter(model.parameters())).dtype
45
+ inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
46
+ cache = {'i': 0, 'attention_mask': None}
47
+
48
+ class Catcher(nn.Module):
49
+
50
+ def __init__(self, module):
51
+ super().__init__()
52
+ self.module = module
53
+
54
+ def forward(self, inp, **kwargs):
55
+ inps[cache['i']] = inp
56
+ cache['i'] += 1
57
+ cache['attention_mask'] = kwargs['attention_mask']
58
+ raise ValueError
59
+
60
+ layers[0] = Catcher(layers[0])
61
+ for batch in dataloader:
62
+ try:
63
+ model(batch[0].to(dev))
64
+ except ValueError:
65
+ pass
66
+ layers[0] = layers[0].module
67
+
68
+ layers[0] = layers[0].cpu()
69
+ model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
70
+ model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
71
+ if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
72
+ model.model.decoder.project_out = model.model.decoder.project_out.cpu()
73
+ if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
74
+ model.model.decoder.project_in = model.model.decoder.project_in.cpu()
75
+ torch.cuda.empty_cache()
76
+
77
+ outs = torch.zeros_like(inps)
78
+ attention_mask = cache['attention_mask']
79
+
80
+ print('Ready.')
81
+
82
+ quantizers = {}
83
+ for i in range(len(layers)):
84
+ layer = layers[i].to(dev)
85
+
86
+ subset = find_layers(layer)
87
+ gptq = {}
88
+ for name in subset:
89
+ gptq[name] = GPTQ(subset[name])
90
+ gptq[name].quantizer = quant.Quantizer()
91
+ gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits)
92
+
93
+ def add_batch(name):
94
+
95
+ def tmp(_, inp, out):
96
+ gptq[name].add_batch(inp[0].data, out.data)
97
+
98
+ return tmp
99
+
100
+ handles = []
101
+ for name in subset:
102
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
103
+
104
+ for j in range(args.nsamples):
105
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
106
+
107
+ for h in handles:
108
+ h.remove()
109
+
110
+ for name in subset:
111
+ print(f'Quantizing {name} in layer {i+1}/{len(layers)}...')
112
+ scale, zero, g_idx, _ = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
113
+ quantizers['model.decoder.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
114
+ gptq[name].free()
115
+
116
+ for j in range(args.nsamples):
117
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
118
+
119
+ layers[i] = layer.cpu()
120
+ del layer
121
+ del gptq
122
+ torch.cuda.empty_cache()
123
+
124
+ inps, outs = outs, inps
125
+
126
+ model.config.use_cache = use_cache
127
+
128
+ return quantizers
129
+
130
+
131
+ @torch.no_grad()
132
+ def opt_eval(model, testenc, dev):
133
+ print('Evaluating ...')
134
+
135
+ testenc = testenc.input_ids
136
+ nsamples = testenc.numel() // model.seqlen
137
+
138
+ use_cache = model.config.use_cache
139
+ model.config.use_cache = False
140
+ layers = model.model.decoder.layers
141
+
142
+ model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
143
+ model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
144
+ if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
145
+ model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
146
+ if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
147
+ model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
148
+ layers[0] = layers[0].to(dev)
149
+
150
+ dtype = next(iter(model.parameters())).dtype
151
+ inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
152
+ cache = {'i': 0, 'attention_mask': None}
153
+
154
+ class Catcher(nn.Module):
155
+
156
+ def __init__(self, module):
157
+ super().__init__()
158
+ self.module = module
159
+
160
+ def forward(self, inp, **kwargs):
161
+ inps[cache['i']] = inp
162
+ cache['i'] += 1
163
+ cache['attention_mask'] = kwargs['attention_mask']
164
+ raise ValueError
165
+
166
+ layers[0] = Catcher(layers[0])
167
+ for i in range(nsamples):
168
+ batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
169
+ try:
170
+ model(batch)
171
+ except ValueError:
172
+ pass
173
+ layers[0] = layers[0].module
174
+
175
+ layers[0] = layers[0].cpu()
176
+ model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
177
+ model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
178
+ if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
179
+ model.model.decoder.project_out = model.model.decoder.project_out.cpu()
180
+ if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
181
+ model.model.decoder.project_in = model.model.decoder.project_in.cpu()
182
+ torch.cuda.empty_cache()
183
+
184
+ outs = torch.zeros_like(inps)
185
+ attention_mask = cache['attention_mask']
186
+
187
+ for i in range(len(layers)):
188
+ print(i)
189
+ layer = layers[i].to(dev)
190
+
191
+ if args.nearest:
192
+ subset = find_layers(layer)
193
+ for name in subset:
194
+ quantizer = quant.Quantizer()
195
+ quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
196
+ W = subset[name].weight.data
197
+ quantizer.find_params(W, weight=True)
198
+ subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
199
+
200
+ for j in range(nsamples):
201
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
202
+ layers[i] = layer.cpu()
203
+ del layer
204
+ torch.cuda.empty_cache()
205
+ inps, outs = outs, inps
206
+
207
+ if model.model.decoder.final_layer_norm is not None:
208
+ model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
209
+ if model.model.decoder.project_out is not None:
210
+ model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
211
+ model.lm_head = model.lm_head.to(dev)
212
+
213
+ testenc = testenc.to(dev)
214
+ nlls = []
215
+ for i in range(nsamples):
216
+ hidden_states = inps[i].unsqueeze(0)
217
+ if model.model.decoder.final_layer_norm is not None:
218
+ hidden_states = model.model.decoder.final_layer_norm(hidden_states)
219
+ if model.model.decoder.project_out is not None:
220
+ hidden_states = model.model.decoder.project_out(hidden_states)
221
+ lm_logits = model.lm_head(hidden_states)
222
+ shift_logits = lm_logits[:, :-1, :].contiguous()
223
+ shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
224
+ loss_fct = nn.CrossEntropyLoss()
225
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
226
+ neg_log_likelihood = loss.float() * model.seqlen
227
+ nlls.append(neg_log_likelihood)
228
+ ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
229
+ print(ppl.item())
230
+
231
+ model.config.use_cache = use_cache
232
+
233
+
234
+ # TODO: perform packing on GPU
235
+ def opt_pack(model, quantizers, wbits, groupsize):
236
+ layers = find_layers(model)
237
+ layers = {n: layers[n] for n in quantizers}
238
+ quant.make_quant_linear(model, quantizers, wbits, groupsize)
239
+ qlayers = find_layers(model, [quant.QuantLinear])
240
+ print('Packing ...')
241
+ for name in qlayers:
242
+ print(name)
243
+ quantizers[name], scale, zero, g_idx = quantizers[name]
244
+ qlayers[name].pack(layers[name], scale, zero, g_idx)
245
+ print('Done.')
246
+ return model
247
+
248
+
249
+ def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True):
250
+ from transformers import OPTConfig, OPTForCausalLM
251
+ config = OPTConfig.from_pretrained(model)
252
+
253
+ def noop(*args, **kwargs):
254
+ pass
255
+
256
+ torch.nn.init.kaiming_uniform_ = noop
257
+ torch.nn.init.uniform_ = noop
258
+ torch.nn.init.normal_ = noop
259
+
260
+ torch.set_default_dtype(torch.half)
261
+ transformers.modeling_utils._init_weights = False
262
+ torch.set_default_dtype(torch.half)
263
+ model = OPTForCausalLM(config)
264
+ torch.set_default_dtype(torch.float)
265
+ model = model.eval()
266
+ layers = find_layers(model)
267
+ for name in ['model.decoder.project_out', 'model.decoder.project_in', 'lm_head']:
268
+ if name in layers:
269
+ del layers[name]
270
+ quant.make_quant_linear(model, layers, wbits, groupsize)
271
+
272
+ del layers
273
+
274
+ print('Loading model ...')
275
+ if checkpoint.endswith('.safetensors'):
276
+ from safetensors.torch import load_file as safe_load
277
+ model.load_state_dict(safe_load(checkpoint))
278
+ else:
279
+ model.load_state_dict(torch.load(checkpoint))
280
+
281
+ if warmup_autotune:
282
+ quant.autotune_warmup_linear(model, transpose=not (eval))
283
+ model.seqlen = model.config.max_position_embeddings
284
+ print('Done.')
285
+ return model
286
+
287
+
288
+ def opt_multigpu(model, gpus):
289
+ model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0])
290
+ model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0])
291
+ if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
292
+ model.model.decoder.project_in = model.model.decoder.project_in.to(gpus[0])
293
+ if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
294
+ model.model.decoder.project_out = model.model.decoder.project_out.to(gpus[-1])
295
+ if hasattr(model.model.decoder, 'final_layer_norm') and model.model.decoder.final_layer_norm:
296
+ model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(gpus[-1])
297
+ import copy
298
+ model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1])
299
+
300
+ cache = {'mask': None}
301
+
302
+ class MoveModule(nn.Module):
303
+
304
+ def __init__(self, module):
305
+ super().__init__()
306
+ self.module = module
307
+ self.dev = next(iter(self.module.parameters())).device
308
+
309
+ def forward(self, *inp, **kwargs):
310
+ inp = list(inp)
311
+ if inp[0].device != self.dev:
312
+ inp[0] = inp[0].to(self.dev)
313
+ if cache['mask'] is None or cache['mask'].device != self.dev:
314
+ cache['mask'] = kwargs['attention_mask'].to(self.dev)
315
+ kwargs['attention_mask'] = cache['mask']
316
+ tmp = self.module(*inp, **kwargs)
317
+ return tmp
318
+
319
+ layers = model.model.decoder.layers
320
+ pergpu = math.ceil(len(layers) / len(gpus))
321
+ for i in range(len(layers)):
322
+ layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
323
+
324
+ model.gpus = gpus
325
+
326
+
327
+ def benchmark(model, input_ids, check=False):
328
+ input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
329
+ torch.cuda.synchronize()
330
+
331
+ cache = {'past': None}
332
+
333
+ def clear_past(i):
334
+
335
+ def tmp(layer, inp, out):
336
+ if cache['past']:
337
+ cache['past'][i] = None
338
+
339
+ return tmp
340
+
341
+ for i, layer in enumerate(model.model.decoder.layers):
342
+ layer.register_forward_hook(clear_past(i))
343
+
344
+ print('Benchmarking ...')
345
+
346
+ if check:
347
+ loss = nn.CrossEntropyLoss()
348
+ tot = 0.
349
+
350
+ def sync():
351
+ if hasattr(model, 'gpus'):
352
+ for gpu in model.gpus:
353
+ torch.cuda.synchronize(gpu)
354
+ else:
355
+ torch.cuda.synchronize()
356
+
357
+ with torch.no_grad():
358
+ attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
359
+ times = []
360
+ for i in range(input_ids.numel()):
361
+ tick = time.time()
362
+ out = model(input_ids[:, i].reshape(-1), past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
363
+ sync()
364
+ times.append(time.time() - tick)
365
+ print(i, times[-1])
366
+ if check and i != input_ids.numel() - 1:
367
+ tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
368
+ cache['past'] = list(out.past_key_values)
369
+ del out
370
+ sync()
371
+ import numpy as np
372
+ print('Median:', np.median(times))
373
+ if check:
374
+ print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
375
+
376
+
377
+ if __name__ == '__main__':
378
+
379
+ parser = argparse.ArgumentParser()
380
+
381
+ parser.add_argument('model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
382
+ parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
383
+ parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
384
+ parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
385
+ parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
386
+ parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
387
+ parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
388
+ parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
389
+ parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
390
+ parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
391
+ parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
392
+ parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
393
+ parser.add_argument('--load', type=str, default='', help='Load quantized model.')
394
+ parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
395
+ parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
396
+ parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
397
+ parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
398
+ parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
399
+
400
+ args = parser.parse_args()
401
+
402
+ if type(args.load) is not str:
403
+ args.load = args.load.as_posix()
404
+
405
+ if args.load:
406
+ model = load_quant(args.model, args.load, args.wbits, args.groupsize)
407
+ else:
408
+ model = get_opt(args.model)
409
+ model.eval()
410
+
411
+ dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
412
+
413
+ if not args.load and args.wbits < 16 and not args.nearest:
414
+ tick = time.time()
415
+ quantizers = opt_sequential(model, dataloader, DEV)
416
+ print(time.time() - tick)
417
+
418
+ if args.benchmark:
419
+ gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
420
+ if len(gpus) > 1:
421
+ opt_multigpu(model, gpus)
422
+ else:
423
+ model = model.to(DEV)
424
+ if args.benchmark:
425
+ input_ids = next(iter(dataloader))[0][:, :args.benchmark]
426
+ benchmark(model, input_ids, check=args.check)
427
+
428
+ if args.eval:
429
+ datasets = ['wikitext2', 'ptb', 'c4']
430
+ if args.new_eval:
431
+ datasets = ['wikitext2', 'ptb-new', 'c4-new']
432
+ for dataset in datasets:
433
+ dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
434
+ print(dataset)
435
+ opt_eval(model, testloader, DEV)
436
+
437
+ if args.save:
438
+ opt_pack(model, quantizers, args.wbits, args.groupsize)
439
+ torch.save(model.state_dict(), args.save)
440
+
441
+ if args.save_safetensors:
442
+ opt_pack(model, quantizers, args.wbits, args.groupsize)
443
+ from safetensors.torch import save_file as safe_save
444
+ state_dict = model.state_dict()
445
+ state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
446
+ safe_save(state_dict, args.save_safetensors)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ safetensors==0.3.0
2
+ datasets==2.10.1
3
+ sentencepiece
4
+ git+https://github.com/huggingface/transformers
5
+ accelerate==0.17.1
6
+ triton==2.0.0
7
+ texttable
8
+ toml
9
+ numpy
10
+ protobuf==3.20.2
11
+