Upload 10 files
Browse files- LICENSE.txt +201 -0
- README.md +145 -3
- convert_llama_weights_to_hf.py +34 -0
- gptq.py +236 -0
- llama.py +515 -0
- llama_inference.py +123 -0
- llama_inference_offload.py +279 -0
- neox.py +430 -0
- opt.py +446 -0
- requirements.txt +11 -0
LICENSE.txt
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,145 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GPTQ-for-LLaMA
|
2 |
+
<img src = https://user-images.githubusercontent.com/64115820/235287009-2d07bba8-9b85-4973-9e06-2a3c28777f06.png width="50%" height="50%">
|
3 |
+
|
4 |
+
4 bits quantization of [LLaMA](https://arxiv.org/abs/2302.13971) using [GPTQ](https://arxiv.org/abs/2210.17323)
|
5 |
+
|
6 |
+
GPTQ is SOTA one-shot weight quantization method
|
7 |
+
|
8 |
+
**It can be used universally, but it is not the [fastest](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/old-cuda) and only supports linux.**
|
9 |
+
|
10 |
+
**Triton only supports Linux, so if you are a Windows user, please use [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install).**
|
11 |
+
|
12 |
+
## News or Update
|
13 |
+
**AutoGPTQ-triton, a packaged version of GPTQ with triton, has been integrated into [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ).**
|
14 |
+
## Result
|
15 |
+
<details>
|
16 |
+
<summary>LLaMA-7B(click me)</summary>
|
17 |
+
|
18 |
+
| [LLaMA-7B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
|
19 |
+
| -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
|
20 |
+
| FP16 | 16 | - | 13940 | 5.68 | 12.5 |
|
21 |
+
| RTN | 4 | - | - | 6.29 | - |
|
22 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 4740 | 6.09 | 3.5 |
|
23 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 4891 | 5.85 | 3.6 |
|
24 |
+
| RTN | 3 | - | - | 25.54 | - |
|
25 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 3852 | 8.07 | 2.7 |
|
26 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 4116 | 6.61 | 3.0 |
|
27 |
+
|
28 |
+
</details>
|
29 |
+
|
30 |
+
<details>
|
31 |
+
<summary>LLaMA-13B</summary>
|
32 |
+
|
33 |
+
| [LLaMA-13B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
|
34 |
+
| -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
|
35 |
+
| FP16 | 16 | - | OOM | 5.09 | 24.2 |
|
36 |
+
| RTN | 4 | - | - | 5.53 | - |
|
37 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 8410 | 5.36 | 6.5 |
|
38 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 8747 | 5.20 | 6.7 |
|
39 |
+
| RTN | 3 | - | - | 11.40 | - |
|
40 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 6870 | 6.63 | 5.1 |
|
41 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 7277 | 5.62 | 5.4 |
|
42 |
+
|
43 |
+
</details>
|
44 |
+
|
45 |
+
<details>
|
46 |
+
<summary>LLaMA-33B</summary>
|
47 |
+
|
48 |
+
| [LLaMA-33B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
|
49 |
+
| -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
|
50 |
+
| FP16 | 16 | - | OOM | 4.10 | 60.5 |
|
51 |
+
| RTN | 4 | - | - | 4.54 | - |
|
52 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 19493 | 4.45 | 15.7 |
|
53 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 20570 | 4.23 | 16.3 |
|
54 |
+
| RTN | 3 | - | - | 14.89 | - |
|
55 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 15493 | 5.69 | 12.0 |
|
56 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 16566 | 4.80 | 13.0 |
|
57 |
+
|
58 |
+
</details>
|
59 |
+
|
60 |
+
<details>
|
61 |
+
<summary>LLaMA-65B</summary>
|
62 |
+
|
63 |
+
| [LLaMA-65B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
|
64 |
+
| -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
|
65 |
+
| FP16 | 16 | - | OOM | 3.53 | 121.0 |
|
66 |
+
| RTN | 4 | - | - | 3.92 | - |
|
67 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | OOM | 3.84 | 31.1 |
|
68 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | OOM | 3.65 | 32.3 |
|
69 |
+
| RTN | 3 | - | - | 10.59 | - |
|
70 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | OOM | 5.04 | 23.6 |
|
71 |
+
| [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | OOM | 4.17 | 25.6 |
|
72 |
+
</details>
|
73 |
+
|
74 |
+
Quantization requires a large amount of CPU memory. However, the memory required can be reduced by using swap memory.
|
75 |
+
|
76 |
+
Depending on the GPUs/drivers, there may be a difference in performance, which decreases as the model size increases.(https://github.com/IST-DASLab/gptq/issues/1)
|
77 |
+
|
78 |
+
According to [GPTQ paper](https://arxiv.org/abs/2210.17323), As the size of the model increases, the difference in performance between FP16 and GPTQ decreases.
|
79 |
+
|
80 |
+
## Installation
|
81 |
+
If you don't have [conda](https://docs.conda.io/en/latest/miniconda.html), install it first.
|
82 |
+
```
|
83 |
+
conda create --name gptq python=3.9 -y
|
84 |
+
conda activate gptq
|
85 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
|
86 |
+
# Or, if you're having trouble with conda, use pip with python3.9:
|
87 |
+
# pip3 install torch torchvision torchaudio
|
88 |
+
|
89 |
+
git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa
|
90 |
+
cd GPTQ-for-LLaMa
|
91 |
+
pip install -r requirements.txt
|
92 |
+
```
|
93 |
+
## Dependencies
|
94 |
+
|
95 |
+
* `torch`: tested on v2.0.0+cu117
|
96 |
+
* `transformers`: tested on v4.28.0.dev0
|
97 |
+
* `datasets`: tested on v2.10.1
|
98 |
+
* `safetensors`: tested on v0.3.0
|
99 |
+
|
100 |
+
All experiments were run on a single NVIDIA RTX3090.
|
101 |
+
|
102 |
+
# Language Generation
|
103 |
+
## LLaMA
|
104 |
+
|
105 |
+
```
|
106 |
+
#convert LLaMA to hf
|
107 |
+
python convert_llama_weights_to_hf.py --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir ./llama-hf
|
108 |
+
|
109 |
+
# Benchmark language generation with 4-bit LLaMA-7B:
|
110 |
+
|
111 |
+
# Save compressed model
|
112 |
+
CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save llama7b-4bit-128g.pt
|
113 |
+
|
114 |
+
# Or save compressed `.safetensors` model
|
115 |
+
CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save_safetensors llama7b-4bit-128g.safetensors
|
116 |
+
|
117 |
+
# Benchmark generating a 2048 token sequence with the saved model
|
118 |
+
CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --benchmark 2048 --check
|
119 |
+
|
120 |
+
# Benchmark FP16 baseline, note that the model will be split across all listed GPUs
|
121 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4 python llama.py ${MODEL_DIR} c4 --benchmark 2048 --check
|
122 |
+
|
123 |
+
# model inference with the saved model
|
124 |
+
CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama"
|
125 |
+
|
126 |
+
# model inference with the saved model using safetensors loaded direct to gpu
|
127 |
+
CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.safetensors --text "this is llama" --device=0
|
128 |
+
|
129 |
+
# model inference with the saved model with offload(This is very slow).
|
130 |
+
CUDA_VISIBLE_DEVICES=0 python llama_inference_offload.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama" --pre_layer 16
|
131 |
+
It takes about 180 seconds to generate 45 tokens(5->50 tokens) on single RTX3090 based on LLaMa-65B. pre_layer is set to 50.
|
132 |
+
```
|
133 |
+
Basically, 4-bit quantization and 128 groupsize are recommended.
|
134 |
+
|
135 |
+
You can also export quantization parameters with toml+numpy format.
|
136 |
+
```
|
137 |
+
CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --quant-directory ${TOML_DIR}
|
138 |
+
```
|
139 |
+
|
140 |
+
# Acknowledgements
|
141 |
+
This code is based on [GPTQ](https://github.com/IST-DASLab/gptq)
|
142 |
+
|
143 |
+
Thanks to Meta AI for releasing [LLaMA](https://arxiv.org/abs/2302.13971), a powerful LLM.
|
144 |
+
|
145 |
+
Triton GPTQ kernel code is based on [GPTQ-triton](https://github.com/fpgaminer/GPTQ-triton)
|
convert_llama_weights_to_hf.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from transformers.models.llama.convert_llama_weights_to_hf import write_model, write_tokenizer
|
4 |
+
|
5 |
+
|
6 |
+
def main():
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument(
|
9 |
+
"--input_dir",
|
10 |
+
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
11 |
+
)
|
12 |
+
parser.add_argument(
|
13 |
+
"--model_size",
|
14 |
+
choices=["7B", "13B", "30B", "65B", "tokenizer_only"],
|
15 |
+
)
|
16 |
+
parser.add_argument(
|
17 |
+
"--output_dir",
|
18 |
+
help="Location to write HF model and tokenizer",
|
19 |
+
)
|
20 |
+
args = parser.parse_args()
|
21 |
+
if args.model_size != "tokenizer_only":
|
22 |
+
write_model(
|
23 |
+
model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
|
24 |
+
input_base_path=os.path.join(args.input_dir, args.model_size),
|
25 |
+
model_size=args.model_size,
|
26 |
+
)
|
27 |
+
write_tokenizer(
|
28 |
+
tokenizer_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
|
29 |
+
input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == "__main__":
|
34 |
+
main()
|
gptq.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import transformers
|
7 |
+
import quant
|
8 |
+
from texttable import Texttable
|
9 |
+
from utils import torch_snr_error
|
10 |
+
|
11 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
12 |
+
torch.backends.cudnn.allow_tf32 = False
|
13 |
+
|
14 |
+
|
15 |
+
class Observer:
|
16 |
+
|
17 |
+
def __init__(self, topk=32):
|
18 |
+
self.loss_list = []
|
19 |
+
self.topk = topk
|
20 |
+
|
21 |
+
def submit(self, name: str, layerid: int, gptq, error: float):
|
22 |
+
|
23 |
+
item = (name, layerid, {'gptq': gptq, 'error': error})
|
24 |
+
|
25 |
+
if len(self.loss_list) < self.topk:
|
26 |
+
self.loss_list.append(item)
|
27 |
+
return
|
28 |
+
|
29 |
+
min_error = error
|
30 |
+
min_idx = -1
|
31 |
+
for idx, data in enumerate(self.loss_list):
|
32 |
+
if min_error > data[2]['error']:
|
33 |
+
min_idx = idx
|
34 |
+
min_error = data[2]['error']
|
35 |
+
|
36 |
+
if min_idx >= 0:
|
37 |
+
self.loss_list[min_idx] = item
|
38 |
+
|
39 |
+
def print(self):
|
40 |
+
self.loss_list = sorted(self.loss_list, key=lambda s: s[2]['error'], reverse=True)
|
41 |
+
|
42 |
+
table = Texttable()
|
43 |
+
|
44 |
+
table.header(['name', 'error'])
|
45 |
+
table.set_cols_dtype(['t', 'f'])
|
46 |
+
|
47 |
+
for item in self.loss_list:
|
48 |
+
table.add_row([f"{item[0]}.{item[1]}", item[2]['error']])
|
49 |
+
print(table.draw())
|
50 |
+
print('\n')
|
51 |
+
|
52 |
+
def items(self):
|
53 |
+
return self.loss_list
|
54 |
+
|
55 |
+
|
56 |
+
class GPTQ:
|
57 |
+
|
58 |
+
def __init__(self, layer, observe=False):
|
59 |
+
self.layer = layer
|
60 |
+
self.dev = self.layer.weight.device
|
61 |
+
W = layer.weight.data.clone()
|
62 |
+
if isinstance(self.layer, nn.Conv2d):
|
63 |
+
W = W.flatten(1)
|
64 |
+
if isinstance(self.layer, transformers.Conv1D):
|
65 |
+
W = W.t()
|
66 |
+
self.rows = W.shape[0]
|
67 |
+
self.columns = W.shape[1]
|
68 |
+
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
69 |
+
self.nsamples = 0
|
70 |
+
self.quantizer = quant.Quantizer()
|
71 |
+
self.observe = observe
|
72 |
+
|
73 |
+
def add_batch(self, inp, out):
|
74 |
+
# Hessian H = 2 X XT + λ I
|
75 |
+
if self.observe:
|
76 |
+
self.inp1 = inp
|
77 |
+
self.out1 = out
|
78 |
+
else:
|
79 |
+
self.inp1 = None
|
80 |
+
self.out1 = None
|
81 |
+
|
82 |
+
if len(inp.shape) == 2:
|
83 |
+
inp = inp.unsqueeze(0)
|
84 |
+
tmp = inp.shape[0]
|
85 |
+
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
|
86 |
+
if len(inp.shape) == 3:
|
87 |
+
inp = inp.reshape((-1, inp.shape[-1]))
|
88 |
+
inp = inp.t()
|
89 |
+
if isinstance(self.layer, nn.Conv2d):
|
90 |
+
unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride)
|
91 |
+
inp = unfold(inp)
|
92 |
+
inp = inp.permute([1, 0, 2])
|
93 |
+
inp = inp.flatten(1)
|
94 |
+
self.H *= self.nsamples / (self.nsamples + tmp)
|
95 |
+
self.nsamples += tmp
|
96 |
+
# inp = inp.float()
|
97 |
+
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
98 |
+
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
|
99 |
+
self.H += inp.matmul(inp.t())
|
100 |
+
|
101 |
+
def print_loss(self, name, q_weight, weight_error, timecost):
|
102 |
+
table = Texttable()
|
103 |
+
name += ' ' * (16 - len(name))
|
104 |
+
|
105 |
+
table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time'])
|
106 |
+
|
107 |
+
# assign weight
|
108 |
+
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
|
109 |
+
|
110 |
+
if self.inp1 is not None:
|
111 |
+
# quantize input to int8
|
112 |
+
quantizer = quant.Quantizer()
|
113 |
+
quantizer.configure(8, perchannel=False, sym=True, mse=False)
|
114 |
+
quantizer.find_params(self.inp1)
|
115 |
+
q_in = quantizer.quantize(self.inp1).type(torch.float16)
|
116 |
+
q_out = self.layer(q_in)
|
117 |
+
|
118 |
+
# get kinds of SNR
|
119 |
+
q_SNR = torch_snr_error(q_out, self.out1).item()
|
120 |
+
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
|
121 |
+
else:
|
122 |
+
q_SNR = '-'
|
123 |
+
fp_SNR = '-'
|
124 |
+
|
125 |
+
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
|
126 |
+
print(table.draw().split('\n')[-2])
|
127 |
+
|
128 |
+
def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''):
|
129 |
+
self.layer.to(self.dev)
|
130 |
+
|
131 |
+
W = self.layer.weight.data.clone()
|
132 |
+
if isinstance(self.layer, nn.Conv2d):
|
133 |
+
W = W.flatten(1)
|
134 |
+
if isinstance(self.layer, transformers.Conv1D):
|
135 |
+
W = W.t()
|
136 |
+
W = W.float()
|
137 |
+
|
138 |
+
tick = time.time()
|
139 |
+
|
140 |
+
if not self.quantizer.ready():
|
141 |
+
self.quantizer.find_params(W, weight=True)
|
142 |
+
|
143 |
+
H = self.H
|
144 |
+
if not self.observe:
|
145 |
+
del self.H
|
146 |
+
dead = torch.diag(H) == 0
|
147 |
+
H[dead, dead] = 1
|
148 |
+
W[:, dead] = 0
|
149 |
+
|
150 |
+
if actorder:
|
151 |
+
perm = torch.argsort(torch.diag(H), descending=True)
|
152 |
+
W = W[:, perm]
|
153 |
+
H = H[perm][:, perm]
|
154 |
+
|
155 |
+
Losses = torch.zeros_like(W)
|
156 |
+
Q = torch.zeros_like(W)
|
157 |
+
|
158 |
+
damp = percdamp * torch.mean(torch.diag(H))
|
159 |
+
diag = torch.arange(self.columns, device=self.dev)
|
160 |
+
H[diag, diag] += damp
|
161 |
+
H = torch.linalg.cholesky(H)
|
162 |
+
H = torch.cholesky_inverse(H)
|
163 |
+
H = torch.linalg.cholesky(H, upper=True)
|
164 |
+
Hinv = H
|
165 |
+
|
166 |
+
g_idx = []
|
167 |
+
scale = []
|
168 |
+
zero = []
|
169 |
+
now_idx = 1
|
170 |
+
|
171 |
+
for i1 in range(0, self.columns, blocksize):
|
172 |
+
i2 = min(i1 + blocksize, self.columns)
|
173 |
+
count = i2 - i1
|
174 |
+
|
175 |
+
W1 = W[:, i1:i2].clone()
|
176 |
+
Q1 = torch.zeros_like(W1)
|
177 |
+
Err1 = torch.zeros_like(W1)
|
178 |
+
Losses1 = torch.zeros_like(W1)
|
179 |
+
Hinv1 = Hinv[i1:i2, i1:i2]
|
180 |
+
|
181 |
+
for i in range(count):
|
182 |
+
w = W1[:, i]
|
183 |
+
d = Hinv1[i, i]
|
184 |
+
|
185 |
+
if groupsize != -1:
|
186 |
+
if (i1 + i) % groupsize == 0:
|
187 |
+
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
|
188 |
+
|
189 |
+
if ((i1 + i) // groupsize) - now_idx == -1:
|
190 |
+
scale.append(self.quantizer.scale)
|
191 |
+
zero.append(self.quantizer.zero)
|
192 |
+
now_idx += 1
|
193 |
+
|
194 |
+
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
195 |
+
Q1[:, i] = q
|
196 |
+
Losses1[:, i] = (w - q)**2 / d**2
|
197 |
+
|
198 |
+
err1 = (w - q) / d
|
199 |
+
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
200 |
+
Err1[:, i] = err1
|
201 |
+
|
202 |
+
Q[:, i1:i2] = Q1
|
203 |
+
Losses[:, i1:i2] = Losses1 / 2
|
204 |
+
|
205 |
+
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
206 |
+
|
207 |
+
torch.cuda.synchronize()
|
208 |
+
error = torch.sum(Losses).item()
|
209 |
+
|
210 |
+
groupsize = groupsize if groupsize != -1 else self.columns
|
211 |
+
g_idx = [i // groupsize for i in range(self.columns)]
|
212 |
+
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
213 |
+
if actorder:
|
214 |
+
invperm = torch.argsort(perm)
|
215 |
+
Q = Q[:, invperm]
|
216 |
+
g_idx = g_idx[invperm]
|
217 |
+
|
218 |
+
if isinstance(self.layer, transformers.Conv1D):
|
219 |
+
Q = Q.t()
|
220 |
+
|
221 |
+
self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick))
|
222 |
+
|
223 |
+
if scale == []:
|
224 |
+
scale.append(self.quantizer.scale)
|
225 |
+
zero.append(self.quantizer.zero)
|
226 |
+
scale = torch.cat(scale, dim=1)
|
227 |
+
zero = torch.cat(zero, dim=1)
|
228 |
+
return scale, zero, g_idx, error
|
229 |
+
|
230 |
+
def free(self):
|
231 |
+
self.inp1 = None
|
232 |
+
self.out1 = None
|
233 |
+
self.H = None
|
234 |
+
self.Losses = None
|
235 |
+
self.Trace = None
|
236 |
+
torch.cuda.empty_cache()
|
llama.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import quant
|
7 |
+
|
8 |
+
from gptq import GPTQ, Observer
|
9 |
+
from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions
|
10 |
+
from texttable import Texttable
|
11 |
+
|
12 |
+
|
13 |
+
def get_llama(model):
|
14 |
+
|
15 |
+
def skip(*args, **kwargs):
|
16 |
+
pass
|
17 |
+
|
18 |
+
torch.nn.init.kaiming_uniform_ = skip
|
19 |
+
torch.nn.init.uniform_ = skip
|
20 |
+
torch.nn.init.normal_ = skip
|
21 |
+
from transformers import LlamaForCausalLM
|
22 |
+
model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
|
23 |
+
model.seqlen = 2048
|
24 |
+
return model
|
25 |
+
|
26 |
+
|
27 |
+
@torch.no_grad()
|
28 |
+
def llama_sequential(model, dataloader, dev):
|
29 |
+
print('Starting ...')
|
30 |
+
|
31 |
+
use_cache = model.config.use_cache
|
32 |
+
model.config.use_cache = False
|
33 |
+
layers = model.model.layers
|
34 |
+
|
35 |
+
model.model.embed_tokens = model.model.embed_tokens.to(dev)
|
36 |
+
model.model.norm = model.model.norm.to(dev)
|
37 |
+
layers[0] = layers[0].to(dev)
|
38 |
+
|
39 |
+
dtype = next(iter(model.parameters())).dtype
|
40 |
+
inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
41 |
+
cache = {'i': 0, 'attention_mask': None}
|
42 |
+
|
43 |
+
class Catcher(nn.Module):
|
44 |
+
|
45 |
+
def __init__(self, module):
|
46 |
+
super().__init__()
|
47 |
+
self.module = module
|
48 |
+
|
49 |
+
def forward(self, inp, **kwargs):
|
50 |
+
inps[cache['i']] = inp
|
51 |
+
cache['i'] += 1
|
52 |
+
cache['attention_mask'] = kwargs['attention_mask']
|
53 |
+
cache['position_ids'] = kwargs['position_ids']
|
54 |
+
raise ValueError
|
55 |
+
|
56 |
+
layers[0] = Catcher(layers[0])
|
57 |
+
for batch in dataloader:
|
58 |
+
try:
|
59 |
+
model(batch[0].to(dev))
|
60 |
+
except ValueError:
|
61 |
+
pass
|
62 |
+
layers[0] = layers[0].module
|
63 |
+
|
64 |
+
layers[0] = layers[0].cpu()
|
65 |
+
model.model.embed_tokens = model.model.embed_tokens.cpu()
|
66 |
+
model.model.norm = model.model.norm.cpu()
|
67 |
+
torch.cuda.empty_cache()
|
68 |
+
|
69 |
+
outs = torch.zeros_like(inps)
|
70 |
+
attention_mask = cache['attention_mask']
|
71 |
+
position_ids = cache['position_ids']
|
72 |
+
|
73 |
+
print('Ready.')
|
74 |
+
|
75 |
+
quantizers = {}
|
76 |
+
observer = Observer()
|
77 |
+
for i in range(len(layers)):
|
78 |
+
|
79 |
+
print(f'Quantizing layer {i+1}/{len(layers)}..')
|
80 |
+
print('+------------------+--------------+------------+-----------+-------+')
|
81 |
+
print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
|
82 |
+
print('+==================+==============+============+===========+=======+')
|
83 |
+
|
84 |
+
layer = layers[i].to(dev)
|
85 |
+
full = find_layers(layer)
|
86 |
+
if args.true_sequential:
|
87 |
+
sequential = [['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], ['self_attn.o_proj'], ['mlp.up_proj', 'mlp.gate_proj'], ['mlp.down_proj']]
|
88 |
+
else:
|
89 |
+
sequential = [list(full.keys())]
|
90 |
+
|
91 |
+
for names in sequential:
|
92 |
+
subset = {n: full[n] for n in names}
|
93 |
+
gptq = {}
|
94 |
+
for name in subset:
|
95 |
+
gptq[name] = GPTQ(subset[name], observe=args.observe)
|
96 |
+
gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
|
97 |
+
|
98 |
+
def add_batch(name):
|
99 |
+
|
100 |
+
def tmp(_, inp, out):
|
101 |
+
gptq[name].add_batch(inp[0].data, out.data)
|
102 |
+
|
103 |
+
return tmp
|
104 |
+
|
105 |
+
handles = []
|
106 |
+
for name in subset:
|
107 |
+
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
108 |
+
for j in range(args.nsamples):
|
109 |
+
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
110 |
+
for h in handles:
|
111 |
+
h.remove()
|
112 |
+
|
113 |
+
for name in subset:
|
114 |
+
scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name)
|
115 |
+
quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize)
|
116 |
+
|
117 |
+
if args.observe:
|
118 |
+
observer.submit(name=name, layerid=i, gptq=gptq[name], error=error)
|
119 |
+
else:
|
120 |
+
gptq[name].free()
|
121 |
+
|
122 |
+
for j in range(args.nsamples):
|
123 |
+
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
124 |
+
|
125 |
+
layers[i] = layer.cpu()
|
126 |
+
del layer
|
127 |
+
del gptq
|
128 |
+
torch.cuda.empty_cache()
|
129 |
+
|
130 |
+
inps, outs = outs, inps
|
131 |
+
print('+------------------+--------------+------------+-----------+-------+')
|
132 |
+
print('\n')
|
133 |
+
|
134 |
+
if args.observe:
|
135 |
+
observer.print()
|
136 |
+
conditions = gen_conditions(args.wbits, args.groupsize)
|
137 |
+
for item in observer.items():
|
138 |
+
name = item[0]
|
139 |
+
layerid = item[1]
|
140 |
+
gptq = item[2]['gptq']
|
141 |
+
error = item[2]['error']
|
142 |
+
target = error / 2
|
143 |
+
|
144 |
+
table = Texttable()
|
145 |
+
table.header(['wbits', 'groupsize', 'error'])
|
146 |
+
table.set_cols_dtype(['i', 'i', 'f'])
|
147 |
+
table.add_row([args.wbits, args.groupsize, error])
|
148 |
+
|
149 |
+
print('Optimizing {} {} ..'.format(name, layerid))
|
150 |
+
for wbits, groupsize in conditions:
|
151 |
+
|
152 |
+
if error < target:
|
153 |
+
# if error dropped 50%, skip
|
154 |
+
break
|
155 |
+
|
156 |
+
gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False)
|
157 |
+
|
158 |
+
scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name)
|
159 |
+
|
160 |
+
table.add_row([wbits, groupsize, error])
|
161 |
+
quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize)
|
162 |
+
|
163 |
+
print(table.draw())
|
164 |
+
print('\n')
|
165 |
+
gptq.layer.to('cpu')
|
166 |
+
gptq.free()
|
167 |
+
|
168 |
+
model.config.use_cache = use_cache
|
169 |
+
|
170 |
+
return quantizers
|
171 |
+
|
172 |
+
|
173 |
+
@torch.no_grad()
|
174 |
+
def llama_eval(model, testenc, dev):
|
175 |
+
print('Evaluating ...')
|
176 |
+
|
177 |
+
testenc = testenc.input_ids
|
178 |
+
nsamples = testenc.numel() // model.seqlen
|
179 |
+
|
180 |
+
use_cache = model.config.use_cache
|
181 |
+
model.config.use_cache = False
|
182 |
+
layers = model.model.layers
|
183 |
+
|
184 |
+
model.model.embed_tokens = model.model.embed_tokens.to(dev)
|
185 |
+
layers[0] = layers[0].to(dev)
|
186 |
+
|
187 |
+
dtype = next(iter(model.parameters())).dtype
|
188 |
+
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
189 |
+
cache = {'i': 0, 'attention_mask': None}
|
190 |
+
|
191 |
+
class Catcher(nn.Module):
|
192 |
+
|
193 |
+
def __init__(self, module):
|
194 |
+
super().__init__()
|
195 |
+
self.module = module
|
196 |
+
|
197 |
+
def forward(self, inp, **kwargs):
|
198 |
+
inps[cache['i']] = inp
|
199 |
+
cache['i'] += 1
|
200 |
+
cache['attention_mask'] = kwargs['attention_mask']
|
201 |
+
cache['position_ids'] = kwargs['position_ids']
|
202 |
+
raise ValueError
|
203 |
+
|
204 |
+
layers[0] = Catcher(layers[0])
|
205 |
+
for i in range(nsamples):
|
206 |
+
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
|
207 |
+
try:
|
208 |
+
model(batch)
|
209 |
+
except ValueError:
|
210 |
+
pass
|
211 |
+
layers[0] = layers[0].module
|
212 |
+
|
213 |
+
layers[0] = layers[0].cpu()
|
214 |
+
model.model.embed_tokens = model.model.embed_tokens.cpu()
|
215 |
+
torch.cuda.empty_cache()
|
216 |
+
|
217 |
+
outs = torch.zeros_like(inps)
|
218 |
+
attention_mask = cache['attention_mask']
|
219 |
+
position_ids = cache['position_ids']
|
220 |
+
|
221 |
+
for i in range(len(layers)):
|
222 |
+
print(i)
|
223 |
+
layer = layers[i].to(dev)
|
224 |
+
|
225 |
+
if args.nearest:
|
226 |
+
subset = find_layers(layer)
|
227 |
+
for name in subset:
|
228 |
+
quantizer = quant.Quantizer()
|
229 |
+
quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
|
230 |
+
W = subset[name].weight.data
|
231 |
+
quantizer.find_params(W, weight=True)
|
232 |
+
subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
|
233 |
+
|
234 |
+
for j in range(nsamples):
|
235 |
+
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
236 |
+
layers[i] = layer.cpu()
|
237 |
+
del layer
|
238 |
+
torch.cuda.empty_cache()
|
239 |
+
inps, outs = outs, inps
|
240 |
+
|
241 |
+
if model.model.norm is not None:
|
242 |
+
model.model.norm = model.model.norm.to(dev)
|
243 |
+
model.lm_head = model.lm_head.to(dev)
|
244 |
+
|
245 |
+
testenc = testenc.to(dev)
|
246 |
+
nlls = []
|
247 |
+
for i in range(nsamples):
|
248 |
+
hidden_states = inps[i].unsqueeze(0)
|
249 |
+
if model.model.norm is not None:
|
250 |
+
hidden_states = model.model.norm(hidden_states)
|
251 |
+
lm_logits = model.lm_head(hidden_states)
|
252 |
+
shift_logits = lm_logits[:, :-1, :].contiguous()
|
253 |
+
shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
|
254 |
+
loss_fct = nn.CrossEntropyLoss()
|
255 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
256 |
+
neg_log_likelihood = loss.float() * model.seqlen
|
257 |
+
nlls.append(neg_log_likelihood)
|
258 |
+
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
|
259 |
+
print(ppl.item())
|
260 |
+
|
261 |
+
model.config.use_cache = use_cache
|
262 |
+
|
263 |
+
|
264 |
+
# TODO: perform packing on GPU
|
265 |
+
def llama_pack(model, quantizers, wbits, groupsize):
|
266 |
+
layers = find_layers(model)
|
267 |
+
layers = {n: layers[n] for n in quantizers}
|
268 |
+
quant.make_quant_linear(model, quantizers, wbits, groupsize)
|
269 |
+
qlayers = find_layers(model, [quant.QuantLinear])
|
270 |
+
print('Packing ...')
|
271 |
+
for name in qlayers:
|
272 |
+
print(name)
|
273 |
+
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
|
274 |
+
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
275 |
+
print('Done.')
|
276 |
+
return model
|
277 |
+
|
278 |
+
|
279 |
+
def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
|
280 |
+
from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
|
281 |
+
config = LlamaConfig.from_pretrained(model)
|
282 |
+
|
283 |
+
def noop(*args, **kwargs):
|
284 |
+
pass
|
285 |
+
|
286 |
+
torch.nn.init.kaiming_uniform_ = noop
|
287 |
+
torch.nn.init.uniform_ = noop
|
288 |
+
torch.nn.init.normal_ = noop
|
289 |
+
|
290 |
+
torch.set_default_dtype(torch.half)
|
291 |
+
modeling_utils._init_weights = False
|
292 |
+
torch.set_default_dtype(torch.half)
|
293 |
+
model = LlamaForCausalLM(config)
|
294 |
+
torch.set_default_dtype(torch.float)
|
295 |
+
if eval:
|
296 |
+
model = model.eval()
|
297 |
+
layers = find_layers(model)
|
298 |
+
for name in ['lm_head']:
|
299 |
+
if name in layers:
|
300 |
+
del layers[name]
|
301 |
+
quant.make_quant_linear(model, layers, wbits, groupsize)
|
302 |
+
|
303 |
+
del layers
|
304 |
+
|
305 |
+
print('Loading model ...')
|
306 |
+
if checkpoint.endswith('.safetensors'):
|
307 |
+
from safetensors.torch import load_file as safe_load
|
308 |
+
model.load_state_dict(safe_load(checkpoint))
|
309 |
+
else:
|
310 |
+
model.load_state_dict(torch.load(checkpoint))
|
311 |
+
|
312 |
+
if eval:
|
313 |
+
quant.make_quant_attn(model)
|
314 |
+
quant.make_quant_norm(model)
|
315 |
+
if fused_mlp:
|
316 |
+
quant.make_fused_mlp(model)
|
317 |
+
|
318 |
+
if warmup_autotune:
|
319 |
+
quant.autotune_warmup_linear(model, transpose=not (eval))
|
320 |
+
if eval and fused_mlp:
|
321 |
+
quant.autotune_warmup_fused(model)
|
322 |
+
model.seqlen = 2048
|
323 |
+
print('Done.')
|
324 |
+
|
325 |
+
return model
|
326 |
+
|
327 |
+
|
328 |
+
def llama_multigpu(model, gpus, gpu_dist):
|
329 |
+
model.model.embed_tokens = model.model.embed_tokens.to(gpus[0])
|
330 |
+
if hasattr(model.model, 'norm') and model.model.norm:
|
331 |
+
model.model.norm = model.model.norm.to(gpus[-1])
|
332 |
+
import copy
|
333 |
+
model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1])
|
334 |
+
|
335 |
+
cache = {'mask': None}
|
336 |
+
|
337 |
+
class MoveModule(nn.Module):
|
338 |
+
|
339 |
+
def __init__(self, module):
|
340 |
+
super().__init__()
|
341 |
+
self.module = module
|
342 |
+
self.dev = next(iter(self.module.parameters())).device
|
343 |
+
|
344 |
+
def forward(self, *inp, **kwargs):
|
345 |
+
inp = list(inp)
|
346 |
+
if inp[0].device != self.dev:
|
347 |
+
inp[0] = inp[0].to(self.dev)
|
348 |
+
if cache['mask'] is None or cache['mask'].device != self.dev:
|
349 |
+
cache['mask'] = kwargs['attention_mask'].to(self.dev)
|
350 |
+
kwargs['attention_mask'] = cache['mask']
|
351 |
+
tmp = self.module(*inp, **kwargs)
|
352 |
+
return tmp
|
353 |
+
|
354 |
+
layers = model.model.layers
|
355 |
+
from math import ceil
|
356 |
+
if not gpu_dist:
|
357 |
+
pergpu = ceil(len(layers) / len(gpus))
|
358 |
+
for i in range(len(layers)):
|
359 |
+
layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
|
360 |
+
else:
|
361 |
+
assigned_gpus = []
|
362 |
+
for i in range(len(gpu_dist)):
|
363 |
+
assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
|
364 |
+
|
365 |
+
remaining_assignments = len(layers)-len(assigned_gpus)
|
366 |
+
if remaining_assignments > 0:
|
367 |
+
assigned_gpus = assigned_gpus + [-1] * remaining_assignments
|
368 |
+
|
369 |
+
for i in range(len(layers)):
|
370 |
+
layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]))
|
371 |
+
|
372 |
+
model.gpus = gpus
|
373 |
+
|
374 |
+
|
375 |
+
def benchmark(model, input_ids, check=False):
|
376 |
+
input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
|
377 |
+
torch.cuda.synchronize()
|
378 |
+
|
379 |
+
cache = {'past': None}
|
380 |
+
|
381 |
+
def clear_past(i):
|
382 |
+
|
383 |
+
def tmp(layer, inp, out):
|
384 |
+
if cache['past']:
|
385 |
+
cache['past'][i] = None
|
386 |
+
|
387 |
+
return tmp
|
388 |
+
|
389 |
+
for i, layer in enumerate(model.model.layers):
|
390 |
+
layer.register_forward_hook(clear_past(i))
|
391 |
+
|
392 |
+
print('Benchmarking ...')
|
393 |
+
|
394 |
+
if check:
|
395 |
+
loss = nn.CrossEntropyLoss()
|
396 |
+
tot = 0.
|
397 |
+
|
398 |
+
def sync():
|
399 |
+
if hasattr(model, 'gpus'):
|
400 |
+
for gpu in model.gpus:
|
401 |
+
torch.cuda.synchronize(gpu)
|
402 |
+
else:
|
403 |
+
torch.cuda.synchronize()
|
404 |
+
|
405 |
+
max_memory = 0
|
406 |
+
with torch.no_grad():
|
407 |
+
attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
|
408 |
+
times = []
|
409 |
+
for i in range(input_ids.numel()):
|
410 |
+
tick = time.time()
|
411 |
+
out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
|
412 |
+
sync()
|
413 |
+
times.append(time.time() - tick)
|
414 |
+
print(i, times[-1])
|
415 |
+
if hasattr(model, 'gpus'):
|
416 |
+
mem_allocated = sum(torch.cuda.memory_allocated(gpu) for gpu in model.gpus) / 1024 / 1024
|
417 |
+
else:
|
418 |
+
mem_allocated = torch.cuda.memory_allocated() / 1024 / 1024
|
419 |
+
max_memory = max(max_memory, mem_allocated)
|
420 |
+
if check and i != input_ids.numel() - 1:
|
421 |
+
tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
|
422 |
+
cache['past'] = list(out.past_key_values)
|
423 |
+
del out
|
424 |
+
sync()
|
425 |
+
print('Median:', np.median(times))
|
426 |
+
if check:
|
427 |
+
print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
|
428 |
+
print('max memory(MiB):', max_memory)
|
429 |
+
|
430 |
+
|
431 |
+
if __name__ == '__main__':
|
432 |
+
|
433 |
+
parser = argparse.ArgumentParser()
|
434 |
+
|
435 |
+
parser.add_argument('model', type=str, help='llama model to load')
|
436 |
+
parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
|
437 |
+
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
|
438 |
+
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
|
439 |
+
parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
|
440 |
+
parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
|
441 |
+
parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
|
442 |
+
parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
|
443 |
+
parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
|
444 |
+
parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
|
445 |
+
parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
|
446 |
+
parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
|
447 |
+
parser.add_argument('--load', type=str, default='', help='Load quantized model.')
|
448 |
+
parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
|
449 |
+
parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
|
450 |
+
parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
|
451 |
+
parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
|
452 |
+
parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.')
|
453 |
+
parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
|
454 |
+
parser.add_argument('--layers-dist', type=str, default='', help='Distribution of layers across GPUs. e.g. 2:1:1 for 2 layers on GPU 0, 1 layer on GPU 1, and 1 layer on GPU 2. Any remaining layers will be assigned to your last GPU.')
|
455 |
+
parser.add_argument('--observe',
|
456 |
+
action='store_true',
|
457 |
+
help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64. \
|
458 |
+
When this feature enabled, `--save` or `--save_safetensors` would be disable.')
|
459 |
+
parser.add_argument('--quant-directory', type=str, default=None, help='Specify the directory for export quantization parameters to toml format. `None` means no export by default.')
|
460 |
+
|
461 |
+
args = parser.parse_args()
|
462 |
+
|
463 |
+
if args.layers_dist:
|
464 |
+
gpu_dist = [int(x) for x in args.layers_dist.split(':')]
|
465 |
+
else:
|
466 |
+
gpu_dist = []
|
467 |
+
|
468 |
+
if type(args.load) is not str:
|
469 |
+
args.load = args.load.as_posix()
|
470 |
+
|
471 |
+
if args.load:
|
472 |
+
model = load_quant(args.model, args.load, args.wbits, args.groupsize)
|
473 |
+
else:
|
474 |
+
model = get_llama(args.model)
|
475 |
+
model.eval()
|
476 |
+
|
477 |
+
dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
478 |
+
|
479 |
+
if not args.load and args.wbits < 16 and not args.nearest:
|
480 |
+
tick = time.time()
|
481 |
+
quantizers = llama_sequential(model, dataloader, DEV)
|
482 |
+
print(time.time() - tick)
|
483 |
+
|
484 |
+
if args.benchmark:
|
485 |
+
gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
|
486 |
+
if len(gpus) > 1:
|
487 |
+
llama_multigpu(model, gpus, gpu_dist)
|
488 |
+
else:
|
489 |
+
model = model.to(DEV)
|
490 |
+
if args.benchmark:
|
491 |
+
input_ids = next(iter(dataloader))[0][:, :args.benchmark]
|
492 |
+
benchmark(model, input_ids, check=args.check)
|
493 |
+
|
494 |
+
if args.eval:
|
495 |
+
datasets = ['wikitext2', 'ptb', 'c4']
|
496 |
+
if args.new_eval:
|
497 |
+
datasets = ['wikitext2', 'ptb-new', 'c4-new']
|
498 |
+
for dataset in datasets:
|
499 |
+
dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
500 |
+
print(dataset)
|
501 |
+
llama_eval(model, testloader, DEV)
|
502 |
+
|
503 |
+
if args.quant_directory is not None:
|
504 |
+
export_quant_table(quantizers, args.quant_directory)
|
505 |
+
|
506 |
+
if not args.observe and args.save:
|
507 |
+
llama_pack(model, quantizers, args.wbits, args.groupsize)
|
508 |
+
torch.save(model.state_dict(), args.save)
|
509 |
+
|
510 |
+
if not args.observe and args.save_safetensors:
|
511 |
+
llama_pack(model, quantizers, args.wbits, args.groupsize)
|
512 |
+
from safetensors.torch import save_file as safe_save
|
513 |
+
state_dict = model.state_dict()
|
514 |
+
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
515 |
+
safe_save(state_dict, args.save_safetensors)
|
llama_inference.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import quant
|
6 |
+
|
7 |
+
from gptq import GPTQ
|
8 |
+
from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
|
9 |
+
import transformers
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
|
12 |
+
|
13 |
+
def get_llama(model):
|
14 |
+
|
15 |
+
def skip(*args, **kwargs):
|
16 |
+
pass
|
17 |
+
|
18 |
+
torch.nn.init.kaiming_uniform_ = skip
|
19 |
+
torch.nn.init.uniform_ = skip
|
20 |
+
torch.nn.init.normal_ = skip
|
21 |
+
from transformers import LlamaForCausalLM
|
22 |
+
model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto')
|
23 |
+
model.seqlen = 2048
|
24 |
+
return model
|
25 |
+
|
26 |
+
|
27 |
+
def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
|
28 |
+
from transformers import LlamaConfig, LlamaForCausalLM
|
29 |
+
config = LlamaConfig.from_pretrained(model)
|
30 |
+
|
31 |
+
def noop(*args, **kwargs):
|
32 |
+
pass
|
33 |
+
|
34 |
+
torch.nn.init.kaiming_uniform_ = noop
|
35 |
+
torch.nn.init.uniform_ = noop
|
36 |
+
torch.nn.init.normal_ = noop
|
37 |
+
|
38 |
+
torch.set_default_dtype(torch.half)
|
39 |
+
transformers.modeling_utils._init_weights = False
|
40 |
+
torch.set_default_dtype(torch.half)
|
41 |
+
model = LlamaForCausalLM(config)
|
42 |
+
torch.set_default_dtype(torch.float)
|
43 |
+
if eval:
|
44 |
+
model = model.eval()
|
45 |
+
layers = find_layers(model)
|
46 |
+
for name in ['lm_head']:
|
47 |
+
if name in layers:
|
48 |
+
del layers[name]
|
49 |
+
quant.make_quant_linear(model, layers, wbits, groupsize)
|
50 |
+
|
51 |
+
del layers
|
52 |
+
|
53 |
+
print('Loading model ...')
|
54 |
+
if checkpoint.endswith('.safetensors'):
|
55 |
+
from safetensors.torch import load_file as safe_load
|
56 |
+
model.load_state_dict(safe_load(checkpoint), strict=False)
|
57 |
+
else:
|
58 |
+
model.load_state_dict(torch.load(checkpoint), strict=False)
|
59 |
+
|
60 |
+
if eval:
|
61 |
+
quant.make_quant_attn(model)
|
62 |
+
quant.make_quant_norm(model)
|
63 |
+
if fused_mlp:
|
64 |
+
quant.make_fused_mlp(model)
|
65 |
+
if warmup_autotune:
|
66 |
+
quant.autotune_warmup_linear(model, transpose=not (eval))
|
67 |
+
if eval and fused_mlp:
|
68 |
+
quant.autotune_warmup_fused(model)
|
69 |
+
model.seqlen = 2048
|
70 |
+
print('Done.')
|
71 |
+
|
72 |
+
return model
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
|
77 |
+
parser = argparse.ArgumentParser()
|
78 |
+
|
79 |
+
parser.add_argument('model', type=str, help='llama model to load')
|
80 |
+
parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
|
81 |
+
parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
|
82 |
+
parser.add_argument('--load', type=str, default='', help='Load quantized model.')
|
83 |
+
|
84 |
+
parser.add_argument('--text', type=str, help='input text')
|
85 |
+
|
86 |
+
parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.')
|
87 |
+
|
88 |
+
parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.')
|
89 |
+
|
90 |
+
parser.add_argument('--top_p',
|
91 |
+
type=float,
|
92 |
+
default=0.95,
|
93 |
+
help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.')
|
94 |
+
|
95 |
+
parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.')
|
96 |
+
|
97 |
+
parser.add_argument('--device', type=int, default=-1, help='The device used to load the model when using safetensors. Default device is "cpu" or specify, 0,1,2,3,... for GPU device.')
|
98 |
+
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
if type(args.load) is not str:
|
102 |
+
args.load = args.load.as_posix()
|
103 |
+
|
104 |
+
if args.load:
|
105 |
+
model = load_quant(args.model, args.load, args.wbits, args.groupsize)
|
106 |
+
else:
|
107 |
+
model = get_llama(args.model)
|
108 |
+
model.eval()
|
109 |
+
|
110 |
+
model.to(DEV)
|
111 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
|
112 |
+
input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV)
|
113 |
+
|
114 |
+
with torch.no_grad():
|
115 |
+
generated_ids = model.generate(
|
116 |
+
input_ids,
|
117 |
+
do_sample=True,
|
118 |
+
min_length=args.min_length,
|
119 |
+
max_length=args.max_length,
|
120 |
+
top_p=args.top_p,
|
121 |
+
temperature=args.temperature,
|
122 |
+
)
|
123 |
+
print(tokenizer.decode([el.item() for el in generated_ids[0]]))
|
llama_inference_offload.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from gptq import GPTQ
|
5 |
+
import argparse
|
6 |
+
from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
|
7 |
+
import quant
|
8 |
+
|
9 |
+
import transformers
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig
|
12 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
+
from accelerate import cpu_offload_with_hook, load_checkpoint_in_model
|
15 |
+
|
16 |
+
|
17 |
+
class Offload_LlamaModel(LlamaModel):
|
18 |
+
|
19 |
+
def __init__(self, config: LlamaConfig):
|
20 |
+
super().__init__(config)
|
21 |
+
|
22 |
+
def cpu_offload(self, preload):
|
23 |
+
hook = None
|
24 |
+
for cpu_offloaded_model in self.layers[preload:]:
|
25 |
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, DEV, prev_module_hook=hook)
|
26 |
+
|
27 |
+
def forward(
|
28 |
+
self,
|
29 |
+
input_ids: torch.LongTensor = None,
|
30 |
+
attention_mask: Optional[torch.Tensor] = None,
|
31 |
+
position_ids: Optional[torch.LongTensor] = None,
|
32 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
33 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
34 |
+
use_cache: Optional[bool] = None,
|
35 |
+
output_attentions: Optional[bool] = None,
|
36 |
+
output_hidden_states: Optional[bool] = None,
|
37 |
+
return_dict: Optional[bool] = None,
|
38 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
39 |
+
r"""
|
40 |
+
Args:
|
41 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
42 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
43 |
+
provide it.
|
44 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
45 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
46 |
+
[What are input IDs?](../glossary#input-ids)
|
47 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
48 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
49 |
+
- 1 for tokens that are **not masked**,
|
50 |
+
- 0 for tokens that are **masked**.
|
51 |
+
[What are attention masks?](../glossary#attention-mask)
|
52 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
53 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
54 |
+
`[0, config.n_positions - 1]`.
|
55 |
+
[What are position IDs?](../glossary#position-ids)
|
56 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
57 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
58 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
59 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
60 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
61 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
62 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
63 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
64 |
+
use_cache (`bool`, *optional*):
|
65 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
66 |
+
(see `past_key_values`).
|
67 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
68 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
69 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
70 |
+
than the model's internal embedding lookup matrix.
|
71 |
+
output_attentions (`bool`, *optional*):
|
72 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
73 |
+
returned tensors for more detail.
|
74 |
+
output_hidden_states (`bool`, *optional*):
|
75 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
76 |
+
for more detail.
|
77 |
+
return_dict (`bool`, *optional*):
|
78 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
79 |
+
"""
|
80 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
81 |
+
output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
82 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
83 |
+
|
84 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
85 |
+
|
86 |
+
# retrieve input_ids and inputs_embeds
|
87 |
+
if input_ids is not None and inputs_embeds is not None:
|
88 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
89 |
+
elif input_ids is not None:
|
90 |
+
batch_size, seq_length = input_ids.shape
|
91 |
+
elif inputs_embeds is not None:
|
92 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
93 |
+
else:
|
94 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
95 |
+
seq_length_with_past = seq_length
|
96 |
+
past_key_values_length = 0
|
97 |
+
if past_key_values is not None:
|
98 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
99 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
100 |
+
|
101 |
+
if position_ids is None:
|
102 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
103 |
+
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
|
104 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
105 |
+
else:
|
106 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
107 |
+
|
108 |
+
if inputs_embeds is None:
|
109 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
110 |
+
|
111 |
+
# embed positions
|
112 |
+
if attention_mask is None:
|
113 |
+
attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
|
114 |
+
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length)
|
115 |
+
|
116 |
+
hidden_states = inputs_embeds
|
117 |
+
|
118 |
+
if self.gradient_checkpointing and self.training:
|
119 |
+
if use_cache:
|
120 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
121 |
+
use_cache = False
|
122 |
+
|
123 |
+
# decoder layers
|
124 |
+
all_hidden_states = () if output_hidden_states else None
|
125 |
+
all_self_attns = () if output_attentions else None
|
126 |
+
next_decoder_cache = () if use_cache else None
|
127 |
+
|
128 |
+
for idx in range(len(self.layers)):
|
129 |
+
decoder_layer = self.layers[idx]
|
130 |
+
|
131 |
+
if output_hidden_states:
|
132 |
+
all_hidden_states += (hidden_states, )
|
133 |
+
|
134 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
135 |
+
|
136 |
+
if self.gradient_checkpointing and self.training:
|
137 |
+
|
138 |
+
def create_custom_forward(module):
|
139 |
+
|
140 |
+
def custom_forward(*inputs):
|
141 |
+
# None for past_key_value
|
142 |
+
return module(*inputs, output_attentions, None)
|
143 |
+
|
144 |
+
return custom_forward
|
145 |
+
|
146 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
147 |
+
create_custom_forward(decoder_layer),
|
148 |
+
hidden_states,
|
149 |
+
attention_mask,
|
150 |
+
position_ids,
|
151 |
+
None,
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
layer_outputs = decoder_layer(
|
155 |
+
hidden_states,
|
156 |
+
attention_mask=attention_mask,
|
157 |
+
position_ids=position_ids,
|
158 |
+
past_key_value=past_key_value,
|
159 |
+
output_attentions=output_attentions,
|
160 |
+
use_cache=use_cache,
|
161 |
+
)
|
162 |
+
|
163 |
+
hidden_states = layer_outputs[0]
|
164 |
+
|
165 |
+
if use_cache:
|
166 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1], )
|
167 |
+
|
168 |
+
if output_attentions:
|
169 |
+
all_self_attns += (layer_outputs[1], )
|
170 |
+
|
171 |
+
hidden_states = self.norm(hidden_states)
|
172 |
+
|
173 |
+
# add hidden states from the last decoder layer
|
174 |
+
if output_hidden_states:
|
175 |
+
all_hidden_states += (hidden_states, )
|
176 |
+
|
177 |
+
next_cache = next_decoder_cache if use_cache else None
|
178 |
+
if not return_dict:
|
179 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
180 |
+
return BaseModelOutputWithPast(
|
181 |
+
last_hidden_state=hidden_states,
|
182 |
+
past_key_values=next_cache,
|
183 |
+
hidden_states=all_hidden_states,
|
184 |
+
attentions=all_self_attns,
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
def load_quant(model, checkpoint, wbits, groupsize, pre_layer, fused_mlp=True, warmup_autotune=True):
|
189 |
+
transformers.models.llama.modeling_llama.LlamaModel = Offload_LlamaModel
|
190 |
+
from transformers import LlamaConfig, LlamaForCausalLM
|
191 |
+
config = LlamaConfig.from_pretrained(model)
|
192 |
+
|
193 |
+
def noop(*args, **kwargs):
|
194 |
+
pass
|
195 |
+
|
196 |
+
torch.nn.init.kaiming_uniform_ = noop
|
197 |
+
torch.nn.init.uniform_ = noop
|
198 |
+
torch.nn.init.normal_ = noop
|
199 |
+
|
200 |
+
torch.set_default_dtype(torch.half)
|
201 |
+
transformers.modeling_utils._init_weights = False
|
202 |
+
torch.set_default_dtype(torch.half)
|
203 |
+
model = LlamaForCausalLM(config)
|
204 |
+
torch.set_default_dtype(torch.float)
|
205 |
+
model = model.eval()
|
206 |
+
layers = find_layers(model)
|
207 |
+
for name in ['lm_head']:
|
208 |
+
if name in layers:
|
209 |
+
del layers[name]
|
210 |
+
quant.make_quant_linear(model, layers, wbits, groupsize)
|
211 |
+
|
212 |
+
print('Loading model ...')
|
213 |
+
load_checkpoint_in_model(model, checkpoint, dtype='float16')
|
214 |
+
model.seqlen = 2048
|
215 |
+
|
216 |
+
if eval:
|
217 |
+
quant.make_quant_attn(model)
|
218 |
+
quant.make_quant_norm(model)
|
219 |
+
if fused_mlp:
|
220 |
+
quant.make_fused_mlp(model)
|
221 |
+
|
222 |
+
|
223 |
+
if warmup_autotune:
|
224 |
+
quant.autotune_warmup_linear(model)
|
225 |
+
if fused_mlp:
|
226 |
+
quant.autotune_warmup_fused(model)
|
227 |
+
|
228 |
+
for i in range(pre_layer):
|
229 |
+
model.model.layers[i].to(DEV)
|
230 |
+
model.model.embed_tokens.to(DEV)
|
231 |
+
model.model.norm.to(DEV)
|
232 |
+
model.lm_head.to(DEV)
|
233 |
+
model.model.cpu_offload(pre_layer)
|
234 |
+
print('Done.')
|
235 |
+
return model
|
236 |
+
|
237 |
+
|
238 |
+
if __name__ == '__main__':
|
239 |
+
parser = argparse.ArgumentParser()
|
240 |
+
|
241 |
+
parser.add_argument('model', type=str, help='llama model to load')
|
242 |
+
parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8], help='#bits to use for quantization')
|
243 |
+
parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
|
244 |
+
parser.add_argument('--load', type=str, default='', help='Load quantized model.')
|
245 |
+
parser.add_argument('--text', type=str, help='input text')
|
246 |
+
|
247 |
+
parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.')
|
248 |
+
|
249 |
+
parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.')
|
250 |
+
|
251 |
+
parser.add_argument('--top_p',
|
252 |
+
type=float,
|
253 |
+
default=0.95,
|
254 |
+
help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.')
|
255 |
+
|
256 |
+
parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.')
|
257 |
+
|
258 |
+
parser.add_argument('--pre_layer', type=int, default=50, help='The number of layers to preload')
|
259 |
+
|
260 |
+
args = parser.parse_args()
|
261 |
+
|
262 |
+
if type(args.load) is not str:
|
263 |
+
args.load = args.load.as_posix()
|
264 |
+
|
265 |
+
model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.pre_layer)
|
266 |
+
|
267 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
|
268 |
+
input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV)
|
269 |
+
|
270 |
+
with torch.no_grad():
|
271 |
+
generated_ids = model.generate(
|
272 |
+
input_ids,
|
273 |
+
do_sample=True,
|
274 |
+
min_length=args.min_length,
|
275 |
+
max_length=args.max_length,
|
276 |
+
top_p=args.top_p,
|
277 |
+
temperature=args.temperature,
|
278 |
+
)
|
279 |
+
print(tokenizer.decode([el.item() for el in generated_ids[0]]))
|
neox.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import quant
|
7 |
+
|
8 |
+
from gptq import GPTQ, Observer
|
9 |
+
from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions
|
10 |
+
from texttable import Texttable
|
11 |
+
|
12 |
+
|
13 |
+
def get_neox(model, seqlen=-1):
|
14 |
+
|
15 |
+
def skip(*args, **kwargs):
|
16 |
+
pass
|
17 |
+
|
18 |
+
torch.nn.init.kaiming_uniform_ = skip
|
19 |
+
torch.nn.init.uniform_ = skip
|
20 |
+
torch.nn.init.normal_ = skip
|
21 |
+
from transformers import GPTNeoXForCausalLM
|
22 |
+
model = GPTNeoXForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
|
23 |
+
model.seqlen = seqlen if seqlen != -1 else model.config.max_position_embeddings
|
24 |
+
return model
|
25 |
+
|
26 |
+
|
27 |
+
@torch.no_grad()
|
28 |
+
def neox_sequential(model, dataloader, dev):
|
29 |
+
print('Starting ...')
|
30 |
+
|
31 |
+
use_cache = model.config.use_cache
|
32 |
+
model.config.use_cache = False
|
33 |
+
layers = model.gpt_neox.layers
|
34 |
+
|
35 |
+
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev)
|
36 |
+
layers[0] = layers[0].to(dev)
|
37 |
+
|
38 |
+
dtype = next(iter(model.parameters())).dtype
|
39 |
+
inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
40 |
+
cache = {'i': 0, 'attention_mask': None}
|
41 |
+
|
42 |
+
class Catcher(nn.Module):
|
43 |
+
|
44 |
+
def __init__(self, module):
|
45 |
+
super().__init__()
|
46 |
+
self.module = module
|
47 |
+
|
48 |
+
def forward(self, inp, **kwargs):
|
49 |
+
inps[cache['i']] = inp
|
50 |
+
cache['i'] += 1
|
51 |
+
cache['attention_mask'] = kwargs['attention_mask']
|
52 |
+
cache['position_ids'] = kwargs['position_ids']
|
53 |
+
raise ValueError
|
54 |
+
|
55 |
+
layers[0] = Catcher(layers[0])
|
56 |
+
for batch in dataloader:
|
57 |
+
try:
|
58 |
+
model(batch[0].to(dev))
|
59 |
+
except ValueError:
|
60 |
+
pass
|
61 |
+
layers[0] = layers[0].module
|
62 |
+
|
63 |
+
layers[0] = layers[0].cpu()
|
64 |
+
model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu()
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
|
67 |
+
outs = torch.zeros_like(inps)
|
68 |
+
attention_mask = cache['attention_mask']
|
69 |
+
position_ids = cache['position_ids']
|
70 |
+
|
71 |
+
print('Ready.')
|
72 |
+
|
73 |
+
quantizers = {}
|
74 |
+
observer = Observer()
|
75 |
+
for i in range(len(layers)):
|
76 |
+
|
77 |
+
print(f'Quantizing layer {i+1}/{len(layers)}..')
|
78 |
+
print('+------------------+--------------+------------+-----------+-------+')
|
79 |
+
print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
|
80 |
+
print('+==================+==============+============+===========+=======+')
|
81 |
+
|
82 |
+
layer = layers[i].to(dev)
|
83 |
+
full = find_layers(layer)
|
84 |
+
sequential = [list(full.keys())]
|
85 |
+
|
86 |
+
for names in sequential:
|
87 |
+
subset = {n: full[n] for n in names}
|
88 |
+
gptq = {}
|
89 |
+
for name in subset:
|
90 |
+
gptq[name] = GPTQ(subset[name], observe=False)
|
91 |
+
gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
|
92 |
+
|
93 |
+
def add_batch(name):
|
94 |
+
|
95 |
+
def tmp(_, inp, out):
|
96 |
+
gptq[name].add_batch(inp[0].data, out.data)
|
97 |
+
|
98 |
+
return tmp
|
99 |
+
|
100 |
+
handles = []
|
101 |
+
for name in subset:
|
102 |
+
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
103 |
+
for j in range(args.nsamples):
|
104 |
+
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
105 |
+
for h in handles:
|
106 |
+
h.remove()
|
107 |
+
|
108 |
+
for name in subset:
|
109 |
+
scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name)
|
110 |
+
quantizers['gpt_neox.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize)
|
111 |
+
gptq[name].free()
|
112 |
+
|
113 |
+
for j in range(args.nsamples):
|
114 |
+
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
115 |
+
|
116 |
+
layers[i] = layer.cpu()
|
117 |
+
del layer
|
118 |
+
del gptq
|
119 |
+
torch.cuda.empty_cache()
|
120 |
+
|
121 |
+
inps, outs = outs, inps
|
122 |
+
print('+------------------+--------------+------------+-----------+-------+')
|
123 |
+
print('\n')
|
124 |
+
|
125 |
+
model.config.use_cache = use_cache
|
126 |
+
|
127 |
+
return quantizers
|
128 |
+
|
129 |
+
|
130 |
+
@torch.no_grad()
|
131 |
+
def neox_eval(model, testenc, dev):
|
132 |
+
print('Evaluating ...')
|
133 |
+
|
134 |
+
testenc = testenc.input_ids
|
135 |
+
nsamples = testenc.numel() // model.seqlen
|
136 |
+
|
137 |
+
use_cache = model.config.use_cache
|
138 |
+
model.config.use_cache = False
|
139 |
+
layers = model.gpt_neox.layers
|
140 |
+
|
141 |
+
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev)
|
142 |
+
layers[0] = layers[0].to(dev)
|
143 |
+
|
144 |
+
dtype = next(iter(model.parameters())).dtype
|
145 |
+
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
146 |
+
cache = {'i': 0, 'attention_mask': None}
|
147 |
+
|
148 |
+
class Catcher(nn.Module):
|
149 |
+
|
150 |
+
def __init__(self, module):
|
151 |
+
super().__init__()
|
152 |
+
self.module = module
|
153 |
+
|
154 |
+
def forward(self, inp, **kwargs):
|
155 |
+
inps[cache['i']] = inp
|
156 |
+
cache['i'] += 1
|
157 |
+
cache['attention_mask'] = kwargs['attention_mask']
|
158 |
+
cache['position_ids'] = kwargs['position_ids']
|
159 |
+
raise ValueError
|
160 |
+
|
161 |
+
layers[0] = Catcher(layers[0])
|
162 |
+
for i in range(nsamples):
|
163 |
+
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
|
164 |
+
try:
|
165 |
+
model(batch)
|
166 |
+
except ValueError:
|
167 |
+
pass
|
168 |
+
layers[0] = layers[0].module
|
169 |
+
|
170 |
+
layers[0] = layers[0].cpu()
|
171 |
+
model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu()
|
172 |
+
torch.cuda.empty_cache()
|
173 |
+
|
174 |
+
outs = torch.zeros_like(inps)
|
175 |
+
attention_mask = cache['attention_mask']
|
176 |
+
position_ids = cache['position_ids']
|
177 |
+
|
178 |
+
for i in range(len(layers)):
|
179 |
+
print(i)
|
180 |
+
layer = layers[i].to(dev)
|
181 |
+
|
182 |
+
if args.nearest:
|
183 |
+
subset = find_layers(layer)
|
184 |
+
for name in subset:
|
185 |
+
quantizer = quant.Quantizer()
|
186 |
+
quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
|
187 |
+
W = subset[name].weight.data
|
188 |
+
quantizer.find_params(W, weight=True)
|
189 |
+
subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
|
190 |
+
|
191 |
+
for j in range(nsamples):
|
192 |
+
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
193 |
+
layers[i] = layer.cpu()
|
194 |
+
del layer
|
195 |
+
torch.cuda.empty_cache()
|
196 |
+
inps, outs = outs, inps
|
197 |
+
|
198 |
+
model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(dev)
|
199 |
+
model.embed_out = model.embed_out.to(dev)
|
200 |
+
|
201 |
+
testenc = testenc.to(dev)
|
202 |
+
nlls = []
|
203 |
+
for i in range(nsamples):
|
204 |
+
hidden_states = inps[i].unsqueeze(0)
|
205 |
+
hidden_states = model.gpt_neox.final_layer_norm(hidden_states)
|
206 |
+
lm_logits = model.embed_out(hidden_states)
|
207 |
+
shift_logits = lm_logits[:, :-1, :].contiguous()
|
208 |
+
shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
|
209 |
+
loss_fct = nn.CrossEntropyLoss()
|
210 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
211 |
+
neg_log_likelihood = loss.float() * model.seqlen
|
212 |
+
nlls.append(neg_log_likelihood)
|
213 |
+
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
|
214 |
+
print(ppl.item())
|
215 |
+
|
216 |
+
model.config.use_cache = use_cache
|
217 |
+
|
218 |
+
|
219 |
+
# TODO: perform packing on GPU
|
220 |
+
def neox_pack(model, quantizers, wbits, groupsize):
|
221 |
+
layers = find_layers(model)
|
222 |
+
layers = {n: layers[n] for n in quantizers}
|
223 |
+
quant.make_quant_linear(model, quantizers, wbits, groupsize)
|
224 |
+
qlayers = find_layers(model, [quant.QuantLinear])
|
225 |
+
print('Packing ...')
|
226 |
+
for name in qlayers:
|
227 |
+
print(name)
|
228 |
+
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
|
229 |
+
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
230 |
+
print('Done.')
|
231 |
+
return model
|
232 |
+
|
233 |
+
|
234 |
+
def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True):
|
235 |
+
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, modeling_utils
|
236 |
+
config = GPTNeoXConfig.from_pretrained(model)
|
237 |
+
|
238 |
+
def noop(*args, **kwargs):
|
239 |
+
pass
|
240 |
+
|
241 |
+
torch.nn.init.kaiming_uniform_ = noop
|
242 |
+
torch.nn.init.uniform_ = noop
|
243 |
+
torch.nn.init.normal_ = noop
|
244 |
+
|
245 |
+
torch.set_default_dtype(torch.half)
|
246 |
+
modeling_utils._init_weights = False
|
247 |
+
torch.set_default_dtype(torch.half)
|
248 |
+
model = GPTNeoXForCausalLM(config)
|
249 |
+
torch.set_default_dtype(torch.float)
|
250 |
+
if eval:
|
251 |
+
model = model.eval()
|
252 |
+
layers = find_layers(model)
|
253 |
+
for name in ['embed_in','embed_out']:
|
254 |
+
if name in layers:
|
255 |
+
del layers[name]
|
256 |
+
quant.make_quant_linear(model, layers, wbits, groupsize)
|
257 |
+
|
258 |
+
del layers
|
259 |
+
|
260 |
+
print('Loading model ...')
|
261 |
+
if checkpoint.endswith('.safetensors'):
|
262 |
+
from safetensors.torch import load_file as safe_load
|
263 |
+
model.load_state_dict(safe_load(checkpoint))
|
264 |
+
else:
|
265 |
+
model.load_state_dict(torch.load(checkpoint))
|
266 |
+
|
267 |
+
if warmup_autotune:
|
268 |
+
quant.autotune_warmup_linear(model, transpose=not (eval))
|
269 |
+
|
270 |
+
model.seqlen = model.config.max_position_embeddings
|
271 |
+
print('Done.')
|
272 |
+
|
273 |
+
return model
|
274 |
+
|
275 |
+
|
276 |
+
def neox_multigpu(model, gpus):
|
277 |
+
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(gpus[0])
|
278 |
+
model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(gpus[-1])
|
279 |
+
import copy
|
280 |
+
model.embed_out = copy.deepcopy(model.embed_out).to(gpus[-1])
|
281 |
+
|
282 |
+
cache = {'mask': None}
|
283 |
+
|
284 |
+
class MoveModule(nn.Module):
|
285 |
+
|
286 |
+
def __init__(self, module):
|
287 |
+
super().__init__()
|
288 |
+
self.module = module
|
289 |
+
self.dev = next(iter(self.module.parameters())).device
|
290 |
+
|
291 |
+
def forward(self, *inp, **kwargs):
|
292 |
+
inp = list(inp)
|
293 |
+
if inp[0].device != self.dev:
|
294 |
+
inp[0] = inp[0].to(self.dev)
|
295 |
+
if cache['mask'] is None or cache['mask'].device != self.dev:
|
296 |
+
cache['mask'] = kwargs['attention_mask'].to(self.dev)
|
297 |
+
kwargs['attention_mask'] = cache['mask']
|
298 |
+
tmp = self.module(*inp, **kwargs)
|
299 |
+
return tmp
|
300 |
+
|
301 |
+
layers = model.gpt_neox.layers
|
302 |
+
pergpu = math.ceil(len(layers) / len(gpus))
|
303 |
+
for i in range(len(layers)):
|
304 |
+
layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
|
305 |
+
|
306 |
+
model.gpus = gpus
|
307 |
+
|
308 |
+
|
309 |
+
def benchmark(model, input_ids, check=False):
|
310 |
+
input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
|
311 |
+
torch.cuda.synchronize()
|
312 |
+
|
313 |
+
cache = {'past': None}
|
314 |
+
|
315 |
+
def clear_past(i):
|
316 |
+
|
317 |
+
def tmp(layer, inp, out):
|
318 |
+
if cache['past']:
|
319 |
+
cache['past'][i] = None
|
320 |
+
|
321 |
+
return tmp
|
322 |
+
|
323 |
+
for i, layer in enumerate(model.gpt_neox.layers):
|
324 |
+
layer.register_forward_hook(clear_past(i))
|
325 |
+
|
326 |
+
print('Benchmarking ...')
|
327 |
+
|
328 |
+
if check:
|
329 |
+
loss = nn.CrossEntropyLoss()
|
330 |
+
tot = 0.
|
331 |
+
|
332 |
+
def sync():
|
333 |
+
if hasattr(model, 'gpus'):
|
334 |
+
for gpu in model.gpus:
|
335 |
+
torch.cuda.synchronize(gpu)
|
336 |
+
else:
|
337 |
+
torch.cuda.synchronize()
|
338 |
+
|
339 |
+
max_memory = 0
|
340 |
+
with torch.no_grad():
|
341 |
+
attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
|
342 |
+
times = []
|
343 |
+
for i in range(input_ids.numel()):
|
344 |
+
tick = time.time()
|
345 |
+
out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
|
346 |
+
sync()
|
347 |
+
times.append(time.time() - tick)
|
348 |
+
print(i, times[-1])
|
349 |
+
max_memory = max(max_memory, torch.cuda.memory_allocated() / 1024 / 1024)
|
350 |
+
if check and i != input_ids.numel() - 1:
|
351 |
+
tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
|
352 |
+
cache['past'] = list(out.past_key_values)
|
353 |
+
del out
|
354 |
+
sync()
|
355 |
+
print('Median:', np.median(times))
|
356 |
+
if check:
|
357 |
+
print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
|
358 |
+
print('max memory(MiB):', max_memory)
|
359 |
+
|
360 |
+
|
361 |
+
if __name__ == '__main__':
|
362 |
+
|
363 |
+
parser = argparse.ArgumentParser()
|
364 |
+
|
365 |
+
parser.add_argument('model', type=str, help='llama model to load')
|
366 |
+
parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
|
367 |
+
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
|
368 |
+
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
|
369 |
+
parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
|
370 |
+
parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
|
371 |
+
parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='bits to use for quantization; use 16 for evaluating base model.')
|
372 |
+
parser.add_argument('--seqlen', type=int, default=-1, help='seqlen to use for quantization; default uses full seqlen')
|
373 |
+
parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
|
374 |
+
parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
|
375 |
+
parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
|
376 |
+
parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
|
377 |
+
parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
|
378 |
+
parser.add_argument('--load', type=str, default='', help='Load quantized model.')
|
379 |
+
parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
|
380 |
+
parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
|
381 |
+
parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
|
382 |
+
parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
|
383 |
+
parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
|
384 |
+
args = parser.parse_args()
|
385 |
+
|
386 |
+
if type(args.load) is not str:
|
387 |
+
args.load = args.load.as_posix()
|
388 |
+
|
389 |
+
if args.load:
|
390 |
+
model = load_quant(args.model, args.load, args.wbits, args.groupsize)
|
391 |
+
else:
|
392 |
+
model = get_neox(args.model)
|
393 |
+
model.eval()
|
394 |
+
|
395 |
+
dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
396 |
+
|
397 |
+
if not args.load and args.wbits < 16 and not args.nearest:
|
398 |
+
tick = time.time()
|
399 |
+
quantizers = neox_sequential(model, dataloader, DEV)
|
400 |
+
print(time.time() - tick)
|
401 |
+
|
402 |
+
if args.benchmark:
|
403 |
+
gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
|
404 |
+
if len(gpus) > 1:
|
405 |
+
neox_multigpu(model, gpus)
|
406 |
+
else:
|
407 |
+
model = model.to(DEV)
|
408 |
+
if args.benchmark:
|
409 |
+
input_ids = next(iter(dataloader))[0][:, :args.benchmark]
|
410 |
+
benchmark(model, input_ids, check=args.check)
|
411 |
+
|
412 |
+
if args.eval:
|
413 |
+
datasets = ['wikitext2', 'ptb', 'c4']
|
414 |
+
if args.new_eval:
|
415 |
+
datasets = ['wikitext2', 'ptb-new', 'c4-new']
|
416 |
+
for dataset in datasets:
|
417 |
+
dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
418 |
+
print(dataset)
|
419 |
+
neox_eval(model, testloader, DEV)
|
420 |
+
|
421 |
+
if args.save:
|
422 |
+
neox_pack(model, quantizers, args.wbits, args.groupsize)
|
423 |
+
torch.save(model.state_dict(), args.save)
|
424 |
+
|
425 |
+
if args.save_safetensors:
|
426 |
+
neox_pack(model, quantizers, args.wbits, args.groupsize)
|
427 |
+
from safetensors.torch import save_file as safe_save
|
428 |
+
state_dict = model.state_dict()
|
429 |
+
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
430 |
+
safe_save(state_dict, args.save_safetensors)
|
opt.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import transformers
|
8 |
+
from gptq import GPTQ
|
9 |
+
from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
|
10 |
+
import quant
|
11 |
+
|
12 |
+
|
13 |
+
def get_opt(model):
|
14 |
+
import torch
|
15 |
+
|
16 |
+
def skip(*args, **kwargs):
|
17 |
+
pass
|
18 |
+
|
19 |
+
torch.nn.init.kaiming_uniform_ = skip
|
20 |
+
torch.nn.init.uniform_ = skip
|
21 |
+
torch.nn.init.normal_ = skip
|
22 |
+
from transformers import OPTForCausalLM
|
23 |
+
model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
|
24 |
+
model.seqlen = model.config.max_position_embeddings
|
25 |
+
return model
|
26 |
+
|
27 |
+
|
28 |
+
@torch.no_grad()
|
29 |
+
def opt_sequential(model, dataloader, dev):
|
30 |
+
print('Starting ...')
|
31 |
+
|
32 |
+
use_cache = model.config.use_cache
|
33 |
+
model.config.use_cache = False
|
34 |
+
layers = model.model.decoder.layers
|
35 |
+
|
36 |
+
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
|
37 |
+
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
|
38 |
+
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
|
39 |
+
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
|
40 |
+
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
|
41 |
+
model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
|
42 |
+
layers[0] = layers[0].to(dev)
|
43 |
+
|
44 |
+
dtype = next(iter(model.parameters())).dtype
|
45 |
+
inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
46 |
+
cache = {'i': 0, 'attention_mask': None}
|
47 |
+
|
48 |
+
class Catcher(nn.Module):
|
49 |
+
|
50 |
+
def __init__(self, module):
|
51 |
+
super().__init__()
|
52 |
+
self.module = module
|
53 |
+
|
54 |
+
def forward(self, inp, **kwargs):
|
55 |
+
inps[cache['i']] = inp
|
56 |
+
cache['i'] += 1
|
57 |
+
cache['attention_mask'] = kwargs['attention_mask']
|
58 |
+
raise ValueError
|
59 |
+
|
60 |
+
layers[0] = Catcher(layers[0])
|
61 |
+
for batch in dataloader:
|
62 |
+
try:
|
63 |
+
model(batch[0].to(dev))
|
64 |
+
except ValueError:
|
65 |
+
pass
|
66 |
+
layers[0] = layers[0].module
|
67 |
+
|
68 |
+
layers[0] = layers[0].cpu()
|
69 |
+
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
|
70 |
+
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
|
71 |
+
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
|
72 |
+
model.model.decoder.project_out = model.model.decoder.project_out.cpu()
|
73 |
+
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
|
74 |
+
model.model.decoder.project_in = model.model.decoder.project_in.cpu()
|
75 |
+
torch.cuda.empty_cache()
|
76 |
+
|
77 |
+
outs = torch.zeros_like(inps)
|
78 |
+
attention_mask = cache['attention_mask']
|
79 |
+
|
80 |
+
print('Ready.')
|
81 |
+
|
82 |
+
quantizers = {}
|
83 |
+
for i in range(len(layers)):
|
84 |
+
layer = layers[i].to(dev)
|
85 |
+
|
86 |
+
subset = find_layers(layer)
|
87 |
+
gptq = {}
|
88 |
+
for name in subset:
|
89 |
+
gptq[name] = GPTQ(subset[name])
|
90 |
+
gptq[name].quantizer = quant.Quantizer()
|
91 |
+
gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits)
|
92 |
+
|
93 |
+
def add_batch(name):
|
94 |
+
|
95 |
+
def tmp(_, inp, out):
|
96 |
+
gptq[name].add_batch(inp[0].data, out.data)
|
97 |
+
|
98 |
+
return tmp
|
99 |
+
|
100 |
+
handles = []
|
101 |
+
for name in subset:
|
102 |
+
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
103 |
+
|
104 |
+
for j in range(args.nsamples):
|
105 |
+
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
|
106 |
+
|
107 |
+
for h in handles:
|
108 |
+
h.remove()
|
109 |
+
|
110 |
+
for name in subset:
|
111 |
+
print(f'Quantizing {name} in layer {i+1}/{len(layers)}...')
|
112 |
+
scale, zero, g_idx, _ = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
|
113 |
+
quantizers['model.decoder.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
|
114 |
+
gptq[name].free()
|
115 |
+
|
116 |
+
for j in range(args.nsamples):
|
117 |
+
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
|
118 |
+
|
119 |
+
layers[i] = layer.cpu()
|
120 |
+
del layer
|
121 |
+
del gptq
|
122 |
+
torch.cuda.empty_cache()
|
123 |
+
|
124 |
+
inps, outs = outs, inps
|
125 |
+
|
126 |
+
model.config.use_cache = use_cache
|
127 |
+
|
128 |
+
return quantizers
|
129 |
+
|
130 |
+
|
131 |
+
@torch.no_grad()
|
132 |
+
def opt_eval(model, testenc, dev):
|
133 |
+
print('Evaluating ...')
|
134 |
+
|
135 |
+
testenc = testenc.input_ids
|
136 |
+
nsamples = testenc.numel() // model.seqlen
|
137 |
+
|
138 |
+
use_cache = model.config.use_cache
|
139 |
+
model.config.use_cache = False
|
140 |
+
layers = model.model.decoder.layers
|
141 |
+
|
142 |
+
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
|
143 |
+
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
|
144 |
+
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
|
145 |
+
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
|
146 |
+
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
|
147 |
+
model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
|
148 |
+
layers[0] = layers[0].to(dev)
|
149 |
+
|
150 |
+
dtype = next(iter(model.parameters())).dtype
|
151 |
+
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
152 |
+
cache = {'i': 0, 'attention_mask': None}
|
153 |
+
|
154 |
+
class Catcher(nn.Module):
|
155 |
+
|
156 |
+
def __init__(self, module):
|
157 |
+
super().__init__()
|
158 |
+
self.module = module
|
159 |
+
|
160 |
+
def forward(self, inp, **kwargs):
|
161 |
+
inps[cache['i']] = inp
|
162 |
+
cache['i'] += 1
|
163 |
+
cache['attention_mask'] = kwargs['attention_mask']
|
164 |
+
raise ValueError
|
165 |
+
|
166 |
+
layers[0] = Catcher(layers[0])
|
167 |
+
for i in range(nsamples):
|
168 |
+
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
|
169 |
+
try:
|
170 |
+
model(batch)
|
171 |
+
except ValueError:
|
172 |
+
pass
|
173 |
+
layers[0] = layers[0].module
|
174 |
+
|
175 |
+
layers[0] = layers[0].cpu()
|
176 |
+
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
|
177 |
+
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
|
178 |
+
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
|
179 |
+
model.model.decoder.project_out = model.model.decoder.project_out.cpu()
|
180 |
+
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
|
181 |
+
model.model.decoder.project_in = model.model.decoder.project_in.cpu()
|
182 |
+
torch.cuda.empty_cache()
|
183 |
+
|
184 |
+
outs = torch.zeros_like(inps)
|
185 |
+
attention_mask = cache['attention_mask']
|
186 |
+
|
187 |
+
for i in range(len(layers)):
|
188 |
+
print(i)
|
189 |
+
layer = layers[i].to(dev)
|
190 |
+
|
191 |
+
if args.nearest:
|
192 |
+
subset = find_layers(layer)
|
193 |
+
for name in subset:
|
194 |
+
quantizer = quant.Quantizer()
|
195 |
+
quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
|
196 |
+
W = subset[name].weight.data
|
197 |
+
quantizer.find_params(W, weight=True)
|
198 |
+
subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
|
199 |
+
|
200 |
+
for j in range(nsamples):
|
201 |
+
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
|
202 |
+
layers[i] = layer.cpu()
|
203 |
+
del layer
|
204 |
+
torch.cuda.empty_cache()
|
205 |
+
inps, outs = outs, inps
|
206 |
+
|
207 |
+
if model.model.decoder.final_layer_norm is not None:
|
208 |
+
model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
|
209 |
+
if model.model.decoder.project_out is not None:
|
210 |
+
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
|
211 |
+
model.lm_head = model.lm_head.to(dev)
|
212 |
+
|
213 |
+
testenc = testenc.to(dev)
|
214 |
+
nlls = []
|
215 |
+
for i in range(nsamples):
|
216 |
+
hidden_states = inps[i].unsqueeze(0)
|
217 |
+
if model.model.decoder.final_layer_norm is not None:
|
218 |
+
hidden_states = model.model.decoder.final_layer_norm(hidden_states)
|
219 |
+
if model.model.decoder.project_out is not None:
|
220 |
+
hidden_states = model.model.decoder.project_out(hidden_states)
|
221 |
+
lm_logits = model.lm_head(hidden_states)
|
222 |
+
shift_logits = lm_logits[:, :-1, :].contiguous()
|
223 |
+
shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
|
224 |
+
loss_fct = nn.CrossEntropyLoss()
|
225 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
226 |
+
neg_log_likelihood = loss.float() * model.seqlen
|
227 |
+
nlls.append(neg_log_likelihood)
|
228 |
+
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
|
229 |
+
print(ppl.item())
|
230 |
+
|
231 |
+
model.config.use_cache = use_cache
|
232 |
+
|
233 |
+
|
234 |
+
# TODO: perform packing on GPU
|
235 |
+
def opt_pack(model, quantizers, wbits, groupsize):
|
236 |
+
layers = find_layers(model)
|
237 |
+
layers = {n: layers[n] for n in quantizers}
|
238 |
+
quant.make_quant_linear(model, quantizers, wbits, groupsize)
|
239 |
+
qlayers = find_layers(model, [quant.QuantLinear])
|
240 |
+
print('Packing ...')
|
241 |
+
for name in qlayers:
|
242 |
+
print(name)
|
243 |
+
quantizers[name], scale, zero, g_idx = quantizers[name]
|
244 |
+
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
245 |
+
print('Done.')
|
246 |
+
return model
|
247 |
+
|
248 |
+
|
249 |
+
def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True):
|
250 |
+
from transformers import OPTConfig, OPTForCausalLM
|
251 |
+
config = OPTConfig.from_pretrained(model)
|
252 |
+
|
253 |
+
def noop(*args, **kwargs):
|
254 |
+
pass
|
255 |
+
|
256 |
+
torch.nn.init.kaiming_uniform_ = noop
|
257 |
+
torch.nn.init.uniform_ = noop
|
258 |
+
torch.nn.init.normal_ = noop
|
259 |
+
|
260 |
+
torch.set_default_dtype(torch.half)
|
261 |
+
transformers.modeling_utils._init_weights = False
|
262 |
+
torch.set_default_dtype(torch.half)
|
263 |
+
model = OPTForCausalLM(config)
|
264 |
+
torch.set_default_dtype(torch.float)
|
265 |
+
model = model.eval()
|
266 |
+
layers = find_layers(model)
|
267 |
+
for name in ['model.decoder.project_out', 'model.decoder.project_in', 'lm_head']:
|
268 |
+
if name in layers:
|
269 |
+
del layers[name]
|
270 |
+
quant.make_quant_linear(model, layers, wbits, groupsize)
|
271 |
+
|
272 |
+
del layers
|
273 |
+
|
274 |
+
print('Loading model ...')
|
275 |
+
if checkpoint.endswith('.safetensors'):
|
276 |
+
from safetensors.torch import load_file as safe_load
|
277 |
+
model.load_state_dict(safe_load(checkpoint))
|
278 |
+
else:
|
279 |
+
model.load_state_dict(torch.load(checkpoint))
|
280 |
+
|
281 |
+
if warmup_autotune:
|
282 |
+
quant.autotune_warmup_linear(model, transpose=not (eval))
|
283 |
+
model.seqlen = model.config.max_position_embeddings
|
284 |
+
print('Done.')
|
285 |
+
return model
|
286 |
+
|
287 |
+
|
288 |
+
def opt_multigpu(model, gpus):
|
289 |
+
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0])
|
290 |
+
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0])
|
291 |
+
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
|
292 |
+
model.model.decoder.project_in = model.model.decoder.project_in.to(gpus[0])
|
293 |
+
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
|
294 |
+
model.model.decoder.project_out = model.model.decoder.project_out.to(gpus[-1])
|
295 |
+
if hasattr(model.model.decoder, 'final_layer_norm') and model.model.decoder.final_layer_norm:
|
296 |
+
model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(gpus[-1])
|
297 |
+
import copy
|
298 |
+
model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1])
|
299 |
+
|
300 |
+
cache = {'mask': None}
|
301 |
+
|
302 |
+
class MoveModule(nn.Module):
|
303 |
+
|
304 |
+
def __init__(self, module):
|
305 |
+
super().__init__()
|
306 |
+
self.module = module
|
307 |
+
self.dev = next(iter(self.module.parameters())).device
|
308 |
+
|
309 |
+
def forward(self, *inp, **kwargs):
|
310 |
+
inp = list(inp)
|
311 |
+
if inp[0].device != self.dev:
|
312 |
+
inp[0] = inp[0].to(self.dev)
|
313 |
+
if cache['mask'] is None or cache['mask'].device != self.dev:
|
314 |
+
cache['mask'] = kwargs['attention_mask'].to(self.dev)
|
315 |
+
kwargs['attention_mask'] = cache['mask']
|
316 |
+
tmp = self.module(*inp, **kwargs)
|
317 |
+
return tmp
|
318 |
+
|
319 |
+
layers = model.model.decoder.layers
|
320 |
+
pergpu = math.ceil(len(layers) / len(gpus))
|
321 |
+
for i in range(len(layers)):
|
322 |
+
layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
|
323 |
+
|
324 |
+
model.gpus = gpus
|
325 |
+
|
326 |
+
|
327 |
+
def benchmark(model, input_ids, check=False):
|
328 |
+
input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
|
329 |
+
torch.cuda.synchronize()
|
330 |
+
|
331 |
+
cache = {'past': None}
|
332 |
+
|
333 |
+
def clear_past(i):
|
334 |
+
|
335 |
+
def tmp(layer, inp, out):
|
336 |
+
if cache['past']:
|
337 |
+
cache['past'][i] = None
|
338 |
+
|
339 |
+
return tmp
|
340 |
+
|
341 |
+
for i, layer in enumerate(model.model.decoder.layers):
|
342 |
+
layer.register_forward_hook(clear_past(i))
|
343 |
+
|
344 |
+
print('Benchmarking ...')
|
345 |
+
|
346 |
+
if check:
|
347 |
+
loss = nn.CrossEntropyLoss()
|
348 |
+
tot = 0.
|
349 |
+
|
350 |
+
def sync():
|
351 |
+
if hasattr(model, 'gpus'):
|
352 |
+
for gpu in model.gpus:
|
353 |
+
torch.cuda.synchronize(gpu)
|
354 |
+
else:
|
355 |
+
torch.cuda.synchronize()
|
356 |
+
|
357 |
+
with torch.no_grad():
|
358 |
+
attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
|
359 |
+
times = []
|
360 |
+
for i in range(input_ids.numel()):
|
361 |
+
tick = time.time()
|
362 |
+
out = model(input_ids[:, i].reshape(-1), past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
|
363 |
+
sync()
|
364 |
+
times.append(time.time() - tick)
|
365 |
+
print(i, times[-1])
|
366 |
+
if check and i != input_ids.numel() - 1:
|
367 |
+
tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
|
368 |
+
cache['past'] = list(out.past_key_values)
|
369 |
+
del out
|
370 |
+
sync()
|
371 |
+
import numpy as np
|
372 |
+
print('Median:', np.median(times))
|
373 |
+
if check:
|
374 |
+
print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
|
375 |
+
|
376 |
+
|
377 |
+
if __name__ == '__main__':
|
378 |
+
|
379 |
+
parser = argparse.ArgumentParser()
|
380 |
+
|
381 |
+
parser.add_argument('model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
|
382 |
+
parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
|
383 |
+
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
|
384 |
+
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
|
385 |
+
parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
|
386 |
+
parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
|
387 |
+
parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
|
388 |
+
parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
|
389 |
+
parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
|
390 |
+
parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
|
391 |
+
parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
|
392 |
+
parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
|
393 |
+
parser.add_argument('--load', type=str, default='', help='Load quantized model.')
|
394 |
+
parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
|
395 |
+
parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
|
396 |
+
parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
|
397 |
+
parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
|
398 |
+
parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
|
399 |
+
|
400 |
+
args = parser.parse_args()
|
401 |
+
|
402 |
+
if type(args.load) is not str:
|
403 |
+
args.load = args.load.as_posix()
|
404 |
+
|
405 |
+
if args.load:
|
406 |
+
model = load_quant(args.model, args.load, args.wbits, args.groupsize)
|
407 |
+
else:
|
408 |
+
model = get_opt(args.model)
|
409 |
+
model.eval()
|
410 |
+
|
411 |
+
dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
412 |
+
|
413 |
+
if not args.load and args.wbits < 16 and not args.nearest:
|
414 |
+
tick = time.time()
|
415 |
+
quantizers = opt_sequential(model, dataloader, DEV)
|
416 |
+
print(time.time() - tick)
|
417 |
+
|
418 |
+
if args.benchmark:
|
419 |
+
gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
|
420 |
+
if len(gpus) > 1:
|
421 |
+
opt_multigpu(model, gpus)
|
422 |
+
else:
|
423 |
+
model = model.to(DEV)
|
424 |
+
if args.benchmark:
|
425 |
+
input_ids = next(iter(dataloader))[0][:, :args.benchmark]
|
426 |
+
benchmark(model, input_ids, check=args.check)
|
427 |
+
|
428 |
+
if args.eval:
|
429 |
+
datasets = ['wikitext2', 'ptb', 'c4']
|
430 |
+
if args.new_eval:
|
431 |
+
datasets = ['wikitext2', 'ptb-new', 'c4-new']
|
432 |
+
for dataset in datasets:
|
433 |
+
dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
434 |
+
print(dataset)
|
435 |
+
opt_eval(model, testloader, DEV)
|
436 |
+
|
437 |
+
if args.save:
|
438 |
+
opt_pack(model, quantizers, args.wbits, args.groupsize)
|
439 |
+
torch.save(model.state_dict(), args.save)
|
440 |
+
|
441 |
+
if args.save_safetensors:
|
442 |
+
opt_pack(model, quantizers, args.wbits, args.groupsize)
|
443 |
+
from safetensors.torch import save_file as safe_save
|
444 |
+
state_dict = model.state_dict()
|
445 |
+
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
446 |
+
safe_save(state_dict, args.save_safetensors)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
safetensors==0.3.0
|
2 |
+
datasets==2.10.1
|
3 |
+
sentencepiece
|
4 |
+
git+https://github.com/huggingface/transformers
|
5 |
+
accelerate==0.17.1
|
6 |
+
triton==2.0.0
|
7 |
+
texttable
|
8 |
+
toml
|
9 |
+
numpy
|
10 |
+
protobuf==3.20.2
|
11 |
+
|