Upload 7 files
Browse files- quant/__init__.py +5 -0
- quant/custom_autotune.py +193 -0
- quant/fused_attn.py +203 -0
- quant/fused_mlp.py +288 -0
- quant/quant_linear.py +423 -0
- quant/quantizer.py +127 -0
- quant/triton_norm.py +91 -0
quant/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .quantizer import Quantizer
|
2 |
+
from .fused_attn import QuantLlamaAttention, make_quant_attn
|
3 |
+
from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused
|
4 |
+
from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear
|
5 |
+
from .triton_norm import TritonLlamaRMSNorm, make_quant_norm
|
quant/custom_autotune.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#https://github.com/fpgaminer/GPTQ-triton
|
2 |
+
"""
|
3 |
+
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import builtins
|
7 |
+
import math
|
8 |
+
import time
|
9 |
+
from typing import Dict
|
10 |
+
|
11 |
+
import triton
|
12 |
+
|
13 |
+
|
14 |
+
class Autotuner(triton.KernelInterface):
|
15 |
+
|
16 |
+
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
|
17 |
+
'''
|
18 |
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
19 |
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
20 |
+
'top_k': number of configs to bench
|
21 |
+
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
22 |
+
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
|
23 |
+
'''
|
24 |
+
if not configs:
|
25 |
+
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
|
26 |
+
else:
|
27 |
+
self.configs = configs
|
28 |
+
self.key_idx = [arg_names.index(k) for k in key]
|
29 |
+
self.nearest_power_of_two = nearest_power_of_two
|
30 |
+
self.cache = {}
|
31 |
+
# hook to reset all required tensor to zeros before relaunching a kernel
|
32 |
+
self.hook = lambda args: 0
|
33 |
+
if reset_to_zero is not None:
|
34 |
+
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
35 |
+
|
36 |
+
def _hook(args):
|
37 |
+
for i in self.reset_idx:
|
38 |
+
args[i].zero_()
|
39 |
+
|
40 |
+
self.hook = _hook
|
41 |
+
self.arg_names = arg_names
|
42 |
+
# prune configs
|
43 |
+
if prune_configs_by:
|
44 |
+
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
45 |
+
if 'early_config_prune' in prune_configs_by:
|
46 |
+
early_config_prune = prune_configs_by['early_config_prune']
|
47 |
+
else:
|
48 |
+
perf_model, top_k, early_config_prune = None, None, None
|
49 |
+
self.perf_model, self.configs_top_k = perf_model, top_k
|
50 |
+
self.early_config_prune = early_config_prune
|
51 |
+
self.fn = fn
|
52 |
+
|
53 |
+
def _bench(self, *args, config, **meta):
|
54 |
+
# check for conflicts, i.e. meta-parameters both provided
|
55 |
+
# as kwargs and by the autotuner
|
56 |
+
conflicts = meta.keys() & config.kwargs.keys()
|
57 |
+
if conflicts:
|
58 |
+
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
59 |
+
" Make sure that you don't re-define auto-tuned symbols.")
|
60 |
+
# augment meta-parameters with tunable ones
|
61 |
+
current = dict(meta, **config.kwargs)
|
62 |
+
|
63 |
+
def kernel_call():
|
64 |
+
if config.pre_hook:
|
65 |
+
config.pre_hook(self.nargs)
|
66 |
+
self.hook(args)
|
67 |
+
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
68 |
+
|
69 |
+
try:
|
70 |
+
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
|
71 |
+
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
|
72 |
+
return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
|
73 |
+
except triton.compiler.OutOfResources:
|
74 |
+
return (float('inf'), float('inf'), float('inf'))
|
75 |
+
|
76 |
+
def run(self, *args, **kwargs):
|
77 |
+
self.nargs = dict(zip(self.arg_names, args))
|
78 |
+
if len(self.configs) > 1:
|
79 |
+
key = tuple(args[i] for i in self.key_idx)
|
80 |
+
|
81 |
+
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
|
82 |
+
# In my testing this gives decent results, and greatly reduces the amount of tuning required
|
83 |
+
if self.nearest_power_of_two:
|
84 |
+
key = tuple([2**int(math.log2(x) + 0.5) for x in key])
|
85 |
+
|
86 |
+
if key not in self.cache:
|
87 |
+
# prune configs
|
88 |
+
pruned_configs = self.prune_configs(kwargs)
|
89 |
+
bench_start = time.time()
|
90 |
+
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
91 |
+
bench_end = time.time()
|
92 |
+
self.bench_time = bench_end - bench_start
|
93 |
+
self.cache[key] = builtins.min(timings, key=timings.get)
|
94 |
+
self.hook(args)
|
95 |
+
self.configs_timings = timings
|
96 |
+
config = self.cache[key]
|
97 |
+
else:
|
98 |
+
config = self.configs[0]
|
99 |
+
self.best_config = config
|
100 |
+
if config.pre_hook is not None:
|
101 |
+
config.pre_hook(self.nargs)
|
102 |
+
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
103 |
+
|
104 |
+
def prune_configs(self, kwargs):
|
105 |
+
pruned_configs = self.configs
|
106 |
+
if self.early_config_prune:
|
107 |
+
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
108 |
+
if self.perf_model:
|
109 |
+
top_k = self.configs_top_k
|
110 |
+
if isinstance(top_k, float) and top_k <= 1.0:
|
111 |
+
top_k = int(len(self.configs) * top_k)
|
112 |
+
if len(pruned_configs) > top_k:
|
113 |
+
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
114 |
+
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
115 |
+
return pruned_configs
|
116 |
+
|
117 |
+
def warmup(self, *args, **kwargs):
|
118 |
+
self.nargs = dict(zip(self.arg_names, args))
|
119 |
+
for config in self.prune_configs(kwargs):
|
120 |
+
self.fn.warmup(
|
121 |
+
*args,
|
122 |
+
num_warps=config.num_warps,
|
123 |
+
num_stages=config.num_stages,
|
124 |
+
**kwargs,
|
125 |
+
**config.kwargs,
|
126 |
+
)
|
127 |
+
self.nargs = None
|
128 |
+
|
129 |
+
|
130 |
+
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
|
131 |
+
"""
|
132 |
+
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
133 |
+
.. highlight:: python
|
134 |
+
.. code-block:: python
|
135 |
+
@triton.autotune(configs=[
|
136 |
+
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
137 |
+
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
138 |
+
],
|
139 |
+
key=['x_size'] # the two above configs will be evaluated anytime
|
140 |
+
# the value of x_size changes
|
141 |
+
)
|
142 |
+
@triton.jit
|
143 |
+
def kernel(x_ptr, x_size, **META):
|
144 |
+
BLOCK_SIZE = META['BLOCK_SIZE']
|
145 |
+
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
146 |
+
This means that whatever value the kernel updates will be updated multiple times.
|
147 |
+
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
148 |
+
reset the value of the provided tensor to `zero` before running any configuration.
|
149 |
+
:param configs: a list of :code:`triton.Config` objects
|
150 |
+
:type configs: list[triton.Config]
|
151 |
+
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
152 |
+
:type key: list[str]
|
153 |
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
154 |
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
155 |
+
'top_k': number of configs to bench
|
156 |
+
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
157 |
+
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
158 |
+
:type reset_to_zero: list[str]
|
159 |
+
"""
|
160 |
+
|
161 |
+
def decorator(fn):
|
162 |
+
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
|
163 |
+
|
164 |
+
return decorator
|
165 |
+
|
166 |
+
|
167 |
+
def matmul248_kernel_config_pruner(configs, nargs):
|
168 |
+
"""
|
169 |
+
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
|
170 |
+
"""
|
171 |
+
m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
|
172 |
+
n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
|
173 |
+
k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
|
174 |
+
|
175 |
+
used = set()
|
176 |
+
for config in configs:
|
177 |
+
block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
|
178 |
+
block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
|
179 |
+
block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
|
180 |
+
group_size_m = config.kwargs['GROUP_SIZE_M']
|
181 |
+
|
182 |
+
if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used:
|
183 |
+
continue
|
184 |
+
|
185 |
+
used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps))
|
186 |
+
yield triton.Config({
|
187 |
+
'BLOCK_SIZE_M': block_size_m,
|
188 |
+
'BLOCK_SIZE_N': block_size_n,
|
189 |
+
'BLOCK_SIZE_K': block_size_k,
|
190 |
+
'GROUP_SIZE_M': group_size_m
|
191 |
+
},
|
192 |
+
num_stages=config.num_stages,
|
193 |
+
num_warps=config.num_warps)
|
quant/fused_attn.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import functional as F
|
2 |
+
from transformers.models.llama.modeling_llama import LlamaAttention
|
3 |
+
from .quant_linear import *
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
|
8 |
+
@triton.jit
|
9 |
+
def rotate_half_kernel(
|
10 |
+
qk_seq_ptr,
|
11 |
+
position_ids_ptr,
|
12 |
+
qk_seq_stride,
|
13 |
+
position_ids_batch_stride,
|
14 |
+
seq_len,
|
15 |
+
HEAD_DIM: tl.constexpr,
|
16 |
+
BLOCK_HEIGHT: tl.constexpr,
|
17 |
+
BLOCK_WIDTH: tl.constexpr,
|
18 |
+
INV_BASE: tl.constexpr
|
19 |
+
):
|
20 |
+
# qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension.
|
21 |
+
# position ids: (bsz, seq_len) -- must be contiguous in the last dimension.
|
22 |
+
|
23 |
+
HALF_HEAD: tl.constexpr = HEAD_DIM // 2
|
24 |
+
STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH
|
25 |
+
|
26 |
+
batch_seq = tl.program_id(axis=0)
|
27 |
+
row_blk_x_col_blk = tl.program_id(axis=1)
|
28 |
+
|
29 |
+
row_blk = row_blk_x_col_blk // STEPS_PER_ROW
|
30 |
+
row = row_blk * BLOCK_HEIGHT
|
31 |
+
if BLOCK_WIDTH < HALF_HEAD:
|
32 |
+
col_blk = row_blk_x_col_blk % STEPS_PER_ROW
|
33 |
+
col = col_blk * BLOCK_WIDTH
|
34 |
+
else:
|
35 |
+
col: tl.constexpr = 0
|
36 |
+
|
37 |
+
# A block will never cross a sequence boundary, which simplifies things a lot.
|
38 |
+
batch = batch_seq // seq_len
|
39 |
+
seq = batch_seq % seq_len
|
40 |
+
position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq)
|
41 |
+
# As sometimes happens, just calculating this on the fly is faster than loading it from memory.
|
42 |
+
# Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate.
|
43 |
+
freq = tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE) * position_id
|
44 |
+
cos = tl.cos(freq).to(tl.float32)
|
45 |
+
sin = tl.sin(freq).to(tl.float32)
|
46 |
+
|
47 |
+
col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH)
|
48 |
+
embed_offsets = (row * HEAD_DIM + col) + col_offsets
|
49 |
+
x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets
|
50 |
+
|
51 |
+
for k in range(0, BLOCK_HEIGHT):
|
52 |
+
x = tl.load(x_ptrs).to(tl.float32)
|
53 |
+
y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32)
|
54 |
+
out_x = x * cos - y * sin
|
55 |
+
tl.store(x_ptrs, out_x)
|
56 |
+
out_y = x * sin + y * cos
|
57 |
+
tl.store(x_ptrs + HALF_HEAD, out_y)
|
58 |
+
x_ptrs += HEAD_DIM
|
59 |
+
|
60 |
+
|
61 |
+
def triton_rotate_half_(qk, position_ids, config=None):
|
62 |
+
batch_size, seq_len, qandk, num_heads, head_dim = qk.shape
|
63 |
+
|
64 |
+
# This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective.
|
65 |
+
config = config or {'BLOCK_HEIGHT': 1, 'BLOCK_WIDTH': min(128, head_dim // 2), 'num_warps': 1}
|
66 |
+
config['BLOCK_HEIGHT'] = min(config['BLOCK_HEIGHT'], 2 * num_heads)
|
67 |
+
|
68 |
+
assert qk.stride(3) == head_dim
|
69 |
+
assert qk.stride(4) == 1
|
70 |
+
assert position_ids.shape == (batch_size, seq_len)
|
71 |
+
assert position_ids.stride(1) == 1, 'position_ids must be contiguous in the last dimension'
|
72 |
+
assert (2 * num_heads) % config['BLOCK_HEIGHT'] == 0, f'number of rows not evenly divisible by {config["BLOCK_HEIGHT"]}'
|
73 |
+
assert (head_dim // 2) % config['BLOCK_WIDTH'] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config["BLOCK_WIDTH"]}'
|
74 |
+
|
75 |
+
qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim)
|
76 |
+
grid = (qk_by_seq.shape[0], (2 * num_heads // config['BLOCK_HEIGHT']) * (head_dim // 2 // config['BLOCK_WIDTH']))
|
77 |
+
|
78 |
+
# Must be the same as the theta of the frequencies used to train the model.
|
79 |
+
BASE = 10000.0
|
80 |
+
|
81 |
+
rotate_half_kernel[grid](
|
82 |
+
qk_by_seq,
|
83 |
+
position_ids,
|
84 |
+
qk_by_seq.stride(0),
|
85 |
+
position_ids.stride(0),
|
86 |
+
seq_len,
|
87 |
+
HEAD_DIM=head_dim,
|
88 |
+
BLOCK_HEIGHT=config['BLOCK_HEIGHT'],
|
89 |
+
BLOCK_WIDTH=config['BLOCK_WIDTH'],
|
90 |
+
INV_BASE=-2.0 * math.log(BASE) / head_dim,
|
91 |
+
num_warps=config['num_warps']
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
class QuantLlamaAttention(nn.Module):
|
96 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
hidden_size,
|
101 |
+
num_heads,
|
102 |
+
qkv_proj,
|
103 |
+
o_proj
|
104 |
+
):
|
105 |
+
super().__init__()
|
106 |
+
self.hidden_size = hidden_size
|
107 |
+
self.num_heads = num_heads
|
108 |
+
self.head_dim = hidden_size // num_heads
|
109 |
+
|
110 |
+
if (self.head_dim * num_heads) != self.hidden_size:
|
111 |
+
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
112 |
+
f" and `num_heads`: {num_heads}).")
|
113 |
+
self.qkv_proj = qkv_proj
|
114 |
+
self.o_proj = o_proj
|
115 |
+
|
116 |
+
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
|
117 |
+
"""Input shape: Batch x Time x Channel"""
|
118 |
+
|
119 |
+
bsz, q_len, _ = hidden_states.size()
|
120 |
+
|
121 |
+
qkv_states = self.qkv_proj(hidden_states)
|
122 |
+
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
|
123 |
+
|
124 |
+
# This updates the query and key states in-place, saving VRAM.
|
125 |
+
triton_rotate_half_(qkv_states[:, :, :2], position_ids)
|
126 |
+
|
127 |
+
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
|
128 |
+
del qkv_states
|
129 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
130 |
+
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
131 |
+
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
132 |
+
|
133 |
+
is_causal = past_key_value is None
|
134 |
+
|
135 |
+
kv_seq_len = q_len
|
136 |
+
if past_key_value is not None:
|
137 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
138 |
+
|
139 |
+
if past_key_value is not None:
|
140 |
+
# reuse k, v, self_attention
|
141 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
142 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
143 |
+
|
144 |
+
if use_cache:
|
145 |
+
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
|
146 |
+
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
|
147 |
+
key_states = key_states.contiguous()
|
148 |
+
value_states = value_states.contiguous()
|
149 |
+
query_states = query_states.contiguous()
|
150 |
+
|
151 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
152 |
+
|
153 |
+
with torch.backends.cuda.sdp_kernel(enable_math=False):
|
154 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
|
155 |
+
del query_states, key_states, value_states
|
156 |
+
|
157 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
158 |
+
attn_output = self.o_proj(attn_output)
|
159 |
+
|
160 |
+
return attn_output, None, past_key_value
|
161 |
+
|
162 |
+
|
163 |
+
def make_quant_attn(model):
|
164 |
+
"""
|
165 |
+
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
166 |
+
"""
|
167 |
+
|
168 |
+
for name, m in model.named_modules():
|
169 |
+
if not isinstance(m, LlamaAttention):
|
170 |
+
continue
|
171 |
+
|
172 |
+
q_proj = m.q_proj
|
173 |
+
k_proj = m.k_proj
|
174 |
+
v_proj = m.v_proj
|
175 |
+
|
176 |
+
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
177 |
+
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
178 |
+
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
179 |
+
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
180 |
+
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
181 |
+
|
182 |
+
qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False)
|
183 |
+
qkv_layer.qweight = qweights
|
184 |
+
qkv_layer.qzeros = qzeros
|
185 |
+
qkv_layer.scales = scales
|
186 |
+
qkv_layer.g_idx = g_idx
|
187 |
+
qkv_layer.bias = bias
|
188 |
+
# We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch.
|
189 |
+
|
190 |
+
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj)
|
191 |
+
|
192 |
+
if '.' in name:
|
193 |
+
parent_name = name.rsplit('.', 1)[0]
|
194 |
+
child_name = name[len(parent_name) + 1:]
|
195 |
+
parent = model.get_submodule(parent_name)
|
196 |
+
else:
|
197 |
+
parent_name = ''
|
198 |
+
parent = model
|
199 |
+
child_name = name
|
200 |
+
|
201 |
+
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
|
202 |
+
|
203 |
+
setattr(parent, child_name, attn)
|
quant/fused_mlp.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
5 |
+
from transformers.models.llama.modeling_llama import LlamaMLP
|
6 |
+
|
7 |
+
try:
|
8 |
+
import triton
|
9 |
+
import triton.language as tl
|
10 |
+
from . import custom_autotune
|
11 |
+
|
12 |
+
# code based https://github.com/fpgaminer/GPTQ-triton
|
13 |
+
@custom_autotune.autotune(
|
14 |
+
configs=[
|
15 |
+
triton.Config({
|
16 |
+
'BLOCK_SIZE_M': 256,
|
17 |
+
'BLOCK_SIZE_N': 64,
|
18 |
+
'BLOCK_SIZE_K': 32,
|
19 |
+
'GROUP_SIZE_M': 8
|
20 |
+
}, num_stages=4, num_warps=4),
|
21 |
+
triton.Config({
|
22 |
+
'BLOCK_SIZE_M': 64,
|
23 |
+
'BLOCK_SIZE_N': 256,
|
24 |
+
'BLOCK_SIZE_K': 32,
|
25 |
+
'GROUP_SIZE_M': 8
|
26 |
+
}, num_stages=4, num_warps=4),
|
27 |
+
triton.Config({
|
28 |
+
'BLOCK_SIZE_M': 128,
|
29 |
+
'BLOCK_SIZE_N': 128,
|
30 |
+
'BLOCK_SIZE_K': 32,
|
31 |
+
'GROUP_SIZE_M': 8
|
32 |
+
}, num_stages=4, num_warps=4),
|
33 |
+
triton.Config({
|
34 |
+
'BLOCK_SIZE_M': 128,
|
35 |
+
'BLOCK_SIZE_N': 64,
|
36 |
+
'BLOCK_SIZE_K': 32,
|
37 |
+
'GROUP_SIZE_M': 8
|
38 |
+
}, num_stages=4, num_warps=4),
|
39 |
+
triton.Config({
|
40 |
+
'BLOCK_SIZE_M': 64,
|
41 |
+
'BLOCK_SIZE_N': 128,
|
42 |
+
'BLOCK_SIZE_K': 32,
|
43 |
+
'GROUP_SIZE_M': 8
|
44 |
+
}, num_stages=4, num_warps=4),
|
45 |
+
triton.Config({
|
46 |
+
'BLOCK_SIZE_M': 128,
|
47 |
+
'BLOCK_SIZE_N': 32,
|
48 |
+
'BLOCK_SIZE_K': 32,
|
49 |
+
'GROUP_SIZE_M': 8
|
50 |
+
}, num_stages=4, num_warps=4), # 3090
|
51 |
+
triton.Config({
|
52 |
+
'BLOCK_SIZE_M': 128,
|
53 |
+
'BLOCK_SIZE_N': 16,
|
54 |
+
'BLOCK_SIZE_K': 32,
|
55 |
+
'GROUP_SIZE_M': 8
|
56 |
+
}, num_stages=4, num_warps=4), # 3090
|
57 |
+
triton.Config({
|
58 |
+
'BLOCK_SIZE_M': 32,
|
59 |
+
'BLOCK_SIZE_N': 32,
|
60 |
+
'BLOCK_SIZE_K': 128,
|
61 |
+
'GROUP_SIZE_M': 8
|
62 |
+
}, num_stages=2, num_warps=4), # 3090
|
63 |
+
triton.Config({
|
64 |
+
'BLOCK_SIZE_M': 64,
|
65 |
+
'BLOCK_SIZE_N': 16,
|
66 |
+
'BLOCK_SIZE_K': 64,
|
67 |
+
'GROUP_SIZE_M': 8
|
68 |
+
}, num_stages=4, num_warps=4), # 3090
|
69 |
+
triton.Config({
|
70 |
+
'BLOCK_SIZE_M': 64,
|
71 |
+
'BLOCK_SIZE_N': 32,
|
72 |
+
'BLOCK_SIZE_K': 64,
|
73 |
+
'GROUP_SIZE_M': 8
|
74 |
+
}, num_stages=4, num_warps=4), # 3090
|
75 |
+
],
|
76 |
+
key=['M', 'N', 'K'],
|
77 |
+
nearest_power_of_two=True,
|
78 |
+
prune_configs_by={
|
79 |
+
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
80 |
+
'perf_model': None,
|
81 |
+
'top_k': None,
|
82 |
+
},
|
83 |
+
)
|
84 |
+
@triton.jit
|
85 |
+
def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn,
|
86 |
+
stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
87 |
+
"""
|
88 |
+
Computes: C = silu(A * B1) * (A * B2)
|
89 |
+
A is of shape (M, K) float16
|
90 |
+
B is of shape (K//8, N) int32
|
91 |
+
C is of shape (M, N) float16
|
92 |
+
scales is of shape (1, N) float16
|
93 |
+
zeros is of shape (1, N//8) int32
|
94 |
+
"""
|
95 |
+
infearure_per_bits = 32 // bits
|
96 |
+
|
97 |
+
pid = tl.program_id(axis=0)
|
98 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
99 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
100 |
+
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
101 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
102 |
+
group_id = pid // num_pid_in_group
|
103 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
104 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
105 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
106 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
107 |
+
|
108 |
+
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
109 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
110 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
111 |
+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
112 |
+
a_mask = (offs_am[:, None] < M)
|
113 |
+
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
114 |
+
b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
|
115 |
+
b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
|
116 |
+
g1_ptrs = g1_ptr + offs_k
|
117 |
+
g2_ptrs = g2_ptr + offs_k
|
118 |
+
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
119 |
+
scales1_ptrs = scales1_ptr + offs_bn[None, :]
|
120 |
+
scales2_ptrs = scales2_ptr + offs_bn[None, :]
|
121 |
+
zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)
|
122 |
+
zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)
|
123 |
+
|
124 |
+
shifter = (offs_k % infearure_per_bits) * bits
|
125 |
+
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
126 |
+
accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
127 |
+
accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
128 |
+
for k in range(0, num_pid_k):
|
129 |
+
g1_idx = tl.load(g1_ptrs)
|
130 |
+
g2_idx = tl.load(g2_ptrs)
|
131 |
+
|
132 |
+
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
133 |
+
scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
134 |
+
scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)
|
135 |
+
|
136 |
+
zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
137 |
+
zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq
|
138 |
+
zeros1 = (zeros1 + 1)
|
139 |
+
|
140 |
+
zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
141 |
+
zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq
|
142 |
+
zeros2 = (zeros2 + 1)
|
143 |
+
|
144 |
+
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
145 |
+
b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
146 |
+
b2 = tl.load(b2_ptrs)
|
147 |
+
|
148 |
+
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
149 |
+
b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values
|
150 |
+
b1 = (b1 - zeros1) * scales1 # Scale and shift
|
151 |
+
accumulator1 += tl.dot(a, b1)
|
152 |
+
|
153 |
+
b2 = (b2 >> shifter[:, None]) & maxq
|
154 |
+
b2 = (b2 - zeros2) * scales2
|
155 |
+
accumulator2 += tl.dot(a, b2)
|
156 |
+
|
157 |
+
a_ptrs += BLOCK_SIZE_K
|
158 |
+
b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
159 |
+
b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
160 |
+
g1_ptrs += BLOCK_SIZE_K
|
161 |
+
g2_ptrs += BLOCK_SIZE_K
|
162 |
+
|
163 |
+
accumulator1 = silu(accumulator1)
|
164 |
+
c = accumulator1 * accumulator2
|
165 |
+
c = c.to(tl.float16)
|
166 |
+
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
167 |
+
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
168 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
169 |
+
|
170 |
+
@triton.jit
|
171 |
+
def silu(x):
|
172 |
+
return x * tl.sigmoid(x)
|
173 |
+
except:
|
174 |
+
print('triton not installed.')
|
175 |
+
|
176 |
+
|
177 |
+
class QuantLlamaMLP(nn.Module):
|
178 |
+
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
gate_proj,
|
182 |
+
down_proj,
|
183 |
+
up_proj,
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
|
187 |
+
self.register_buffer('gate_proj_scales', gate_proj.scales)
|
188 |
+
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
|
189 |
+
self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)
|
190 |
+
self.register_buffer('up_proj_qweight', up_proj.qweight)
|
191 |
+
self.register_buffer('up_proj_scales', up_proj.scales)
|
192 |
+
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
|
193 |
+
self.register_buffer('up_proj_g_idx', up_proj.g_idx)
|
194 |
+
|
195 |
+
self.infeatures = gate_proj.infeatures
|
196 |
+
self.intermediate_size = gate_proj.outfeatures
|
197 |
+
self.outfeatures = down_proj.outfeatures
|
198 |
+
self.bits = gate_proj.bits
|
199 |
+
self.maxq = gate_proj.maxq
|
200 |
+
|
201 |
+
self.down_proj = down_proj
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
return self.down_proj(self.triton_llama_mlp(x))
|
205 |
+
|
206 |
+
def triton_llama_mlp(self, x):
|
207 |
+
with torch.cuda.device(x.device):
|
208 |
+
out_shape = x.shape[:-1] + (self.intermediate_size, )
|
209 |
+
x = x.reshape(-1, x.shape[-1])
|
210 |
+
M, K = x.shape
|
211 |
+
N = self.intermediate_size
|
212 |
+
c = torch.empty((M, N), device=x.device, dtype=torch.float16)
|
213 |
+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
214 |
+
fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales,
|
215 |
+
self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0),
|
216 |
+
self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0))
|
217 |
+
c = c.reshape(out_shape)
|
218 |
+
return c
|
219 |
+
|
220 |
+
def fused2cuda(self):
|
221 |
+
self.gate_proj_qweight = self.gate_proj_qweight.cuda()
|
222 |
+
self.gate_proj_scales = self.gate_proj_scales.cuda()
|
223 |
+
self.gate_proj_qzeros = self.gate_proj_qzeros.cuda()
|
224 |
+
self.gate_proj_g_idx = self.gate_proj_g_idx.cuda()
|
225 |
+
self.up_proj_qweight = self.up_proj_qweight.cuda()
|
226 |
+
self.up_proj_scales = self.up_proj_scales.cuda()
|
227 |
+
self.up_proj_qzeros = self.up_proj_qzeros.cuda()
|
228 |
+
self.up_proj_g_idx = self.up_proj_g_idx.cuda()
|
229 |
+
|
230 |
+
def fused2cpu(self):
|
231 |
+
self.gate_proj_qweight = self.gate_proj_qweight.cpu()
|
232 |
+
self.gate_proj_scales = self.gate_proj_scales.cpu()
|
233 |
+
self.gate_proj_qzeros = self.gate_proj_qzeros.cpu()
|
234 |
+
self.gate_proj_g_idx = self.gate_proj_g_idx.cpu()
|
235 |
+
self.up_proj_qweight = self.up_proj_qweight.cpu()
|
236 |
+
self.up_proj_scales = self.up_proj_scales.cpu()
|
237 |
+
self.up_proj_qzeros = self.up_proj_qzeros.cpu()
|
238 |
+
self.up_proj_g_idx = self.up_proj_g_idx.cpu()
|
239 |
+
|
240 |
+
|
241 |
+
def make_fused_mlp(m, parent_name=''):
|
242 |
+
"""
|
243 |
+
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
|
244 |
+
"""
|
245 |
+
if isinstance(m, LlamaMLP):
|
246 |
+
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
|
247 |
+
|
248 |
+
for name, child in m.named_children():
|
249 |
+
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
|
250 |
+
|
251 |
+
if isinstance(child, QuantLlamaMLP):
|
252 |
+
setattr(m, name, child)
|
253 |
+
return m
|
254 |
+
|
255 |
+
|
256 |
+
def autotune_warmup_fused(model):
|
257 |
+
"""
|
258 |
+
Pre-tunes the quantized kernel
|
259 |
+
"""
|
260 |
+
from tqdm import tqdm
|
261 |
+
|
262 |
+
kn_values = {}
|
263 |
+
|
264 |
+
for _, m in model.named_modules():
|
265 |
+
if not isinstance(m, QuantLlamaMLP):
|
266 |
+
continue
|
267 |
+
|
268 |
+
k = m.infeatures
|
269 |
+
n = m.intermediate_size
|
270 |
+
|
271 |
+
m.fused2cuda()
|
272 |
+
if (k, n) not in kn_values:
|
273 |
+
kn_values[(k, n)] = m
|
274 |
+
|
275 |
+
print(f'Found {len(kn_values)} unique fused mlp KN values.')
|
276 |
+
|
277 |
+
print('Warming up autotune cache ...')
|
278 |
+
with torch.no_grad():
|
279 |
+
for m in tqdm(range(0, 12)):
|
280 |
+
m = 2**m # [1, 2048]
|
281 |
+
for (k, n), (modules) in kn_values.items():
|
282 |
+
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
|
283 |
+
modules.triton_llama_mlp(a)
|
284 |
+
|
285 |
+
for (k, n), (modules) in kn_values.items():
|
286 |
+
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
|
287 |
+
modules.fused2cpu()
|
288 |
+
del kn_values
|
quant/quant_linear.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
6 |
+
|
7 |
+
try:
|
8 |
+
import triton
|
9 |
+
import triton.language as tl
|
10 |
+
from . import custom_autotune
|
11 |
+
|
12 |
+
# code based https://github.com/fpgaminer/GPTQ-triton
|
13 |
+
@custom_autotune.autotune(
|
14 |
+
configs=[
|
15 |
+
triton.Config({
|
16 |
+
'BLOCK_SIZE_M': 64,
|
17 |
+
'BLOCK_SIZE_N': 256,
|
18 |
+
'BLOCK_SIZE_K': 32,
|
19 |
+
'GROUP_SIZE_M': 8
|
20 |
+
}, num_stages=4, num_warps=4),
|
21 |
+
triton.Config({
|
22 |
+
'BLOCK_SIZE_M': 128,
|
23 |
+
'BLOCK_SIZE_N': 128,
|
24 |
+
'BLOCK_SIZE_K': 32,
|
25 |
+
'GROUP_SIZE_M': 8
|
26 |
+
}, num_stages=4, num_warps=4),
|
27 |
+
triton.Config({
|
28 |
+
'BLOCK_SIZE_M': 64,
|
29 |
+
'BLOCK_SIZE_N': 128,
|
30 |
+
'BLOCK_SIZE_K': 32,
|
31 |
+
'GROUP_SIZE_M': 8
|
32 |
+
}, num_stages=4, num_warps=4),
|
33 |
+
triton.Config({
|
34 |
+
'BLOCK_SIZE_M': 128,
|
35 |
+
'BLOCK_SIZE_N': 32,
|
36 |
+
'BLOCK_SIZE_K': 32,
|
37 |
+
'GROUP_SIZE_M': 8
|
38 |
+
}, num_stages=4, num_warps=4),
|
39 |
+
triton.Config({
|
40 |
+
'BLOCK_SIZE_M': 64,
|
41 |
+
'BLOCK_SIZE_N': 64,
|
42 |
+
'BLOCK_SIZE_K': 32,
|
43 |
+
'GROUP_SIZE_M': 8
|
44 |
+
}, num_stages=4, num_warps=4),
|
45 |
+
triton.Config({
|
46 |
+
'BLOCK_SIZE_M': 64,
|
47 |
+
'BLOCK_SIZE_N': 128,
|
48 |
+
'BLOCK_SIZE_K': 32,
|
49 |
+
'GROUP_SIZE_M': 8
|
50 |
+
}, num_stages=2, num_warps=8),
|
51 |
+
triton.Config({
|
52 |
+
'BLOCK_SIZE_M': 64,
|
53 |
+
'BLOCK_SIZE_N': 64,
|
54 |
+
'BLOCK_SIZE_K': 64,
|
55 |
+
'GROUP_SIZE_M': 8
|
56 |
+
}, num_stages=3, num_warps=8),
|
57 |
+
triton.Config({
|
58 |
+
'BLOCK_SIZE_M': 32,
|
59 |
+
'BLOCK_SIZE_N': 32,
|
60 |
+
'BLOCK_SIZE_K': 128,
|
61 |
+
'GROUP_SIZE_M': 8
|
62 |
+
}, num_stages=2, num_warps=4),
|
63 |
+
],
|
64 |
+
key=['M', 'N', 'K'],
|
65 |
+
nearest_power_of_two=True,
|
66 |
+
prune_configs_by={
|
67 |
+
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
68 |
+
'perf_model': None,
|
69 |
+
'top_k': None,
|
70 |
+
},
|
71 |
+
)
|
72 |
+
@triton.jit
|
73 |
+
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
|
74 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
75 |
+
"""
|
76 |
+
Compute the matrix multiplication C = A x B.
|
77 |
+
A is of shape (M, K) float16
|
78 |
+
B is of shape (K//8, N) int32
|
79 |
+
C is of shape (M, N) float16
|
80 |
+
scales is of shape (G, N) float16
|
81 |
+
zeros is of shape (G, N) float16
|
82 |
+
g_ptr is of shape (K) int32
|
83 |
+
"""
|
84 |
+
infearure_per_bits = 32 // bits
|
85 |
+
|
86 |
+
pid = tl.program_id(axis=0)
|
87 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
88 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
89 |
+
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
90 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
91 |
+
group_id = pid // num_pid_in_group
|
92 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
93 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
94 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
95 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
96 |
+
|
97 |
+
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
98 |
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
99 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
100 |
+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
101 |
+
a_mask = (offs_am[:, None] < M)
|
102 |
+
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
103 |
+
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
104 |
+
g_ptrs = g_ptr + offs_k
|
105 |
+
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
106 |
+
scales_ptrs = scales_ptr + offs_bn[None, :]
|
107 |
+
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
108 |
+
|
109 |
+
shifter = (offs_k % infearure_per_bits) * bits
|
110 |
+
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
111 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
112 |
+
|
113 |
+
for k in range(0, num_pid_k):
|
114 |
+
g_idx = tl.load(g_ptrs)
|
115 |
+
|
116 |
+
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
117 |
+
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
118 |
+
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
119 |
+
|
120 |
+
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
121 |
+
zeros = (zeros + 1)
|
122 |
+
|
123 |
+
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
124 |
+
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
125 |
+
|
126 |
+
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
127 |
+
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
128 |
+
b = (b - zeros) * scales # Scale and shift
|
129 |
+
|
130 |
+
accumulator += tl.dot(a, b)
|
131 |
+
a_ptrs += BLOCK_SIZE_K
|
132 |
+
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
133 |
+
g_ptrs += BLOCK_SIZE_K
|
134 |
+
|
135 |
+
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
136 |
+
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
137 |
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
138 |
+
|
139 |
+
@custom_autotune.autotune(configs=[
|
140 |
+
triton.Config({
|
141 |
+
'BLOCK_SIZE_M': 64,
|
142 |
+
'BLOCK_SIZE_N': 32,
|
143 |
+
'BLOCK_SIZE_K': 256,
|
144 |
+
'GROUP_SIZE_M': 8
|
145 |
+
}, num_stages=4, num_warps=4),
|
146 |
+
triton.Config({
|
147 |
+
'BLOCK_SIZE_M': 128,
|
148 |
+
'BLOCK_SIZE_N': 32,
|
149 |
+
'BLOCK_SIZE_K': 128,
|
150 |
+
'GROUP_SIZE_M': 8
|
151 |
+
}, num_stages=4, num_warps=4),
|
152 |
+
triton.Config({
|
153 |
+
'BLOCK_SIZE_M': 64,
|
154 |
+
'BLOCK_SIZE_N': 32,
|
155 |
+
'BLOCK_SIZE_K': 128,
|
156 |
+
'GROUP_SIZE_M': 8
|
157 |
+
}, num_stages=4, num_warps=4),
|
158 |
+
triton.Config({
|
159 |
+
'BLOCK_SIZE_M': 128,
|
160 |
+
'BLOCK_SIZE_N': 32,
|
161 |
+
'BLOCK_SIZE_K': 32,
|
162 |
+
'GROUP_SIZE_M': 8
|
163 |
+
}, num_stages=4, num_warps=4),
|
164 |
+
triton.Config({
|
165 |
+
'BLOCK_SIZE_M': 64,
|
166 |
+
'BLOCK_SIZE_N': 32,
|
167 |
+
'BLOCK_SIZE_K': 64,
|
168 |
+
'GROUP_SIZE_M': 8
|
169 |
+
}, num_stages=4, num_warps=4),
|
170 |
+
triton.Config({
|
171 |
+
'BLOCK_SIZE_M': 64,
|
172 |
+
'BLOCK_SIZE_N': 32,
|
173 |
+
'BLOCK_SIZE_K': 128,
|
174 |
+
'GROUP_SIZE_M': 8
|
175 |
+
}, num_stages=2, num_warps=8),
|
176 |
+
triton.Config({
|
177 |
+
'BLOCK_SIZE_M': 64,
|
178 |
+
'BLOCK_SIZE_N': 64,
|
179 |
+
'BLOCK_SIZE_K': 64,
|
180 |
+
'GROUP_SIZE_M': 8
|
181 |
+
}, num_stages=3, num_warps=8),
|
182 |
+
triton.Config({
|
183 |
+
'BLOCK_SIZE_M': 32,
|
184 |
+
'BLOCK_SIZE_N': 128,
|
185 |
+
'BLOCK_SIZE_K': 32,
|
186 |
+
'GROUP_SIZE_M': 8
|
187 |
+
}, num_stages=2, num_warps=4),
|
188 |
+
],
|
189 |
+
key=['M', 'N', 'K'],
|
190 |
+
nearest_power_of_two=True)
|
191 |
+
@triton.jit
|
192 |
+
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
|
193 |
+
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
194 |
+
"""
|
195 |
+
Compute the matrix multiplication C = A x B.
|
196 |
+
A is of shape (M, N) float16
|
197 |
+
B is of shape (K//8, N) int32
|
198 |
+
C is of shape (M, K) float16
|
199 |
+
scales is of shape (G, N) float16
|
200 |
+
zeros is of shape (G, N) float16
|
201 |
+
g_ptr is of shape (K) int32
|
202 |
+
"""
|
203 |
+
infearure_per_bits = 32 // bits
|
204 |
+
|
205 |
+
pid = tl.program_id(axis=0)
|
206 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
207 |
+
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
208 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
209 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_k
|
210 |
+
group_id = pid // num_pid_in_group
|
211 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
212 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
213 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
214 |
+
pid_k = (pid % num_pid_in_group) // group_size_m
|
215 |
+
|
216 |
+
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
217 |
+
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
218 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
219 |
+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
220 |
+
a_mask = (offs_am[:, None] < M)
|
221 |
+
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
222 |
+
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
223 |
+
g_ptrs = g_ptr + offs_bk
|
224 |
+
g_idx = tl.load(g_ptrs)
|
225 |
+
|
226 |
+
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
227 |
+
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
228 |
+
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
229 |
+
|
230 |
+
shifter = (offs_bk % infearure_per_bits) * bits
|
231 |
+
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
232 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
233 |
+
|
234 |
+
for n in range(0, num_pid_n):
|
235 |
+
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
236 |
+
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
237 |
+
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
238 |
+
|
239 |
+
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
240 |
+
zeros = (zeros + 1)
|
241 |
+
|
242 |
+
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
243 |
+
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
244 |
+
|
245 |
+
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
246 |
+
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
247 |
+
b = (b - zeros) * scales # Scale and shift
|
248 |
+
b = tl.trans(b)
|
249 |
+
|
250 |
+
accumulator += tl.dot(a, b)
|
251 |
+
a_ptrs += BLOCK_SIZE_N
|
252 |
+
b_ptrs += BLOCK_SIZE_N
|
253 |
+
scales_ptrs += BLOCK_SIZE_N
|
254 |
+
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
|
255 |
+
|
256 |
+
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
257 |
+
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
258 |
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
259 |
+
except:
|
260 |
+
print('trioton not installed.')
|
261 |
+
|
262 |
+
|
263 |
+
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
264 |
+
with torch.cuda.device(input.device):
|
265 |
+
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
|
266 |
+
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
|
267 |
+
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
268 |
+
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
269 |
+
return output
|
270 |
+
|
271 |
+
|
272 |
+
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
273 |
+
with torch.cuda.device(input.device):
|
274 |
+
output_dim = (qweight.shape[0] * 32) // bits
|
275 |
+
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
|
276 |
+
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
|
277 |
+
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
278 |
+
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
279 |
+
return output
|
280 |
+
|
281 |
+
|
282 |
+
class QuantLinearFunction(torch.autograd.Function):
|
283 |
+
|
284 |
+
@staticmethod
|
285 |
+
@custom_fwd(cast_inputs=torch.float16)
|
286 |
+
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
287 |
+
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
288 |
+
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
289 |
+
ctx.bits, ctx.maxq = bits, maxq
|
290 |
+
return output
|
291 |
+
|
292 |
+
@staticmethod
|
293 |
+
@custom_bwd
|
294 |
+
def backward(ctx, grad_output):
|
295 |
+
qweight, scales, qzeros, g_idx = ctx.saved_tensors
|
296 |
+
bits, maxq = ctx.bits, ctx.maxq
|
297 |
+
grad_input = None
|
298 |
+
|
299 |
+
if ctx.needs_input_grad[0]:
|
300 |
+
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
301 |
+
return grad_input, None, None, None, None, None, None
|
302 |
+
|
303 |
+
|
304 |
+
class QuantLinear(nn.Module):
|
305 |
+
|
306 |
+
def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
|
307 |
+
super().__init__()
|
308 |
+
if bits not in [2, 4, 8]:
|
309 |
+
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
310 |
+
self.infeatures = infeatures
|
311 |
+
self.outfeatures = outfeatures
|
312 |
+
self.bits = bits
|
313 |
+
self.maxq = 2**self.bits - 1
|
314 |
+
self.groupsize = groupsize if groupsize != -1 else infeatures
|
315 |
+
|
316 |
+
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
317 |
+
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
|
318 |
+
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
319 |
+
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
320 |
+
if bias:
|
321 |
+
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
322 |
+
else:
|
323 |
+
self.bias = None
|
324 |
+
|
325 |
+
def pack(self, linear, scales, zeros, g_idx=None):
|
326 |
+
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
327 |
+
|
328 |
+
scales = scales.t().contiguous()
|
329 |
+
zeros = zeros.t().contiguous()
|
330 |
+
scale_zeros = zeros * scales
|
331 |
+
self.scales = scales.clone().half()
|
332 |
+
if linear.bias is not None:
|
333 |
+
self.bias = linear.bias.clone().half()
|
334 |
+
|
335 |
+
intweight = []
|
336 |
+
for idx in range(self.infeatures):
|
337 |
+
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
|
338 |
+
intweight = torch.cat(intweight, dim=1)
|
339 |
+
intweight = intweight.t().contiguous()
|
340 |
+
intweight = intweight.numpy().astype(np.uint32)
|
341 |
+
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
|
342 |
+
i = 0
|
343 |
+
row = 0
|
344 |
+
while row < qweight.shape[0]:
|
345 |
+
if self.bits in [2, 4, 8]:
|
346 |
+
for j in range(i, i + (32 // self.bits)):
|
347 |
+
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
348 |
+
i += 32 // self.bits
|
349 |
+
row += 1
|
350 |
+
else:
|
351 |
+
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
352 |
+
|
353 |
+
qweight = qweight.astype(np.int32)
|
354 |
+
self.qweight = torch.from_numpy(qweight)
|
355 |
+
|
356 |
+
zeros -= 1
|
357 |
+
zeros = zeros.numpy().astype(np.uint32)
|
358 |
+
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
359 |
+
i = 0
|
360 |
+
col = 0
|
361 |
+
while col < qzeros.shape[1]:
|
362 |
+
if self.bits in [2, 4, 8]:
|
363 |
+
for j in range(i, i + (32 // self.bits)):
|
364 |
+
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
365 |
+
i += 32 // self.bits
|
366 |
+
col += 1
|
367 |
+
else:
|
368 |
+
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
369 |
+
|
370 |
+
qzeros = qzeros.astype(np.int32)
|
371 |
+
self.qzeros = torch.from_numpy(qzeros)
|
372 |
+
|
373 |
+
def forward(self, x):
|
374 |
+
out_shape = x.shape[:-1] + (self.outfeatures, )
|
375 |
+
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
|
376 |
+
out = out + self.bias if self.bias is not None else out
|
377 |
+
return out.reshape(out_shape)
|
378 |
+
|
379 |
+
|
380 |
+
def make_quant_linear(module, names, bits, groupsize, name=''):
|
381 |
+
if isinstance(module, QuantLinear):
|
382 |
+
return
|
383 |
+
for attr in dir(module):
|
384 |
+
tmp = getattr(module, attr)
|
385 |
+
name1 = name + '.' + attr if name != '' else attr
|
386 |
+
if name1 in names:
|
387 |
+
delattr(module, attr)
|
388 |
+
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
|
389 |
+
for name1, child in module.named_children():
|
390 |
+
make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
391 |
+
|
392 |
+
|
393 |
+
def autotune_warmup_linear(model, transpose=False):
|
394 |
+
"""
|
395 |
+
Pre-tunes the quantized kernel
|
396 |
+
"""
|
397 |
+
from tqdm import tqdm
|
398 |
+
|
399 |
+
kn_values = {}
|
400 |
+
|
401 |
+
for _, m in model.named_modules():
|
402 |
+
if not isinstance(m, QuantLinear):
|
403 |
+
continue
|
404 |
+
|
405 |
+
k = m.infeatures
|
406 |
+
n = m.outfeatures
|
407 |
+
|
408 |
+
if (k, n) not in kn_values:
|
409 |
+
kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
|
410 |
+
|
411 |
+
print(f'Found {len(kn_values)} unique KN Linear values.')
|
412 |
+
|
413 |
+
print('Warming up autotune cache ...')
|
414 |
+
with torch.no_grad():
|
415 |
+
for m in tqdm(range(0, 12)):
|
416 |
+
m = 2**m # [1, 2048]
|
417 |
+
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
|
418 |
+
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
|
419 |
+
matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
420 |
+
if transpose:
|
421 |
+
a = torch.randn(m, n, dtype=torch.float16, device='cuda')
|
422 |
+
transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
423 |
+
del kn_values
|
quant/quantizer.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
class Quantizer(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, shape=1):
|
10 |
+
super(Quantizer, self).__init__()
|
11 |
+
self.register_buffer('maxq', torch.tensor(0))
|
12 |
+
self.register_buffer('scale', torch.zeros(shape))
|
13 |
+
self.register_buffer('zero', torch.zeros(shape))
|
14 |
+
|
15 |
+
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
|
16 |
+
|
17 |
+
self.maxq = torch.tensor(2**bits - 1)
|
18 |
+
self.perchannel = perchannel
|
19 |
+
self.sym = sym
|
20 |
+
self.mse = mse
|
21 |
+
self.norm = norm
|
22 |
+
self.grid = grid
|
23 |
+
self.maxshrink = maxshrink
|
24 |
+
if trits:
|
25 |
+
self.maxq = torch.tensor(-1)
|
26 |
+
self.scale = torch.zeros_like(self.scale)
|
27 |
+
|
28 |
+
def _quantize(self, x, scale, zero, maxq):
|
29 |
+
if maxq < 0:
|
30 |
+
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
|
31 |
+
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
32 |
+
return scale * (q - zero)
|
33 |
+
|
34 |
+
def find_params(self, x, weight=False):
|
35 |
+
dev = x.device
|
36 |
+
self.maxq = self.maxq.to(dev)
|
37 |
+
|
38 |
+
shape = x.shape
|
39 |
+
if self.perchannel:
|
40 |
+
if weight:
|
41 |
+
x = x.flatten(1)
|
42 |
+
else:
|
43 |
+
if len(shape) == 4:
|
44 |
+
x = x.permute([1, 0, 2, 3])
|
45 |
+
x = x.flatten(1)
|
46 |
+
if len(shape) == 3:
|
47 |
+
x = x.reshape((-1, shape[-1])).t()
|
48 |
+
if len(shape) == 2:
|
49 |
+
x = x.t()
|
50 |
+
else:
|
51 |
+
x = x.flatten().unsqueeze(0)
|
52 |
+
|
53 |
+
tmp = torch.zeros(x.shape[0], device=dev)
|
54 |
+
xmin = torch.minimum(x.min(1)[0], tmp)
|
55 |
+
xmax = torch.maximum(x.max(1)[0], tmp)
|
56 |
+
|
57 |
+
if self.sym:
|
58 |
+
xmax = torch.maximum(torch.abs(xmin), xmax)
|
59 |
+
tmp = xmin < 0
|
60 |
+
if torch.any(tmp):
|
61 |
+
xmin[tmp] = -xmax[tmp]
|
62 |
+
tmp = (xmin == 0) & (xmax == 0)
|
63 |
+
xmin[tmp] = -1
|
64 |
+
xmax[tmp] = +1
|
65 |
+
|
66 |
+
if self.maxq < 0:
|
67 |
+
self.scale = xmax
|
68 |
+
self.zero = xmin
|
69 |
+
else:
|
70 |
+
self.scale = (xmax - xmin) / self.maxq
|
71 |
+
if self.sym:
|
72 |
+
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
73 |
+
else:
|
74 |
+
self.zero = torch.round(-xmin / self.scale)
|
75 |
+
|
76 |
+
if self.mse:
|
77 |
+
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
78 |
+
for i in range(int(self.maxshrink * self.grid)):
|
79 |
+
p = 1 - i / self.grid
|
80 |
+
xmin1 = p * xmin
|
81 |
+
xmax1 = p * xmax
|
82 |
+
scale1 = (xmax1 - xmin1) / self.maxq
|
83 |
+
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
84 |
+
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
85 |
+
q -= x
|
86 |
+
q.abs_()
|
87 |
+
q.pow_(self.norm)
|
88 |
+
err = torch.sum(q, 1)
|
89 |
+
tmp = err < best
|
90 |
+
if torch.any(tmp):
|
91 |
+
best[tmp] = err[tmp]
|
92 |
+
self.scale[tmp] = scale1[tmp]
|
93 |
+
self.zero[tmp] = zero1[tmp]
|
94 |
+
if not self.perchannel:
|
95 |
+
if weight:
|
96 |
+
tmp = shape[0]
|
97 |
+
else:
|
98 |
+
tmp = shape[1] if len(shape) != 3 else shape[2]
|
99 |
+
self.scale = self.scale.repeat(tmp)
|
100 |
+
self.zero = self.zero.repeat(tmp)
|
101 |
+
|
102 |
+
if weight:
|
103 |
+
shape = [-1] + [1] * (len(shape) - 1)
|
104 |
+
self.scale = self.scale.reshape(shape)
|
105 |
+
self.zero = self.zero.reshape(shape)
|
106 |
+
return
|
107 |
+
if len(shape) == 4:
|
108 |
+
self.scale = self.scale.reshape((1, -1, 1, 1))
|
109 |
+
self.zero = self.zero.reshape((1, -1, 1, 1))
|
110 |
+
if len(shape) == 3:
|
111 |
+
self.scale = self.scale.reshape((1, 1, -1))
|
112 |
+
self.zero = self.zero.reshape((1, 1, -1))
|
113 |
+
if len(shape) == 2:
|
114 |
+
self.scale = self.scale.unsqueeze(0)
|
115 |
+
self.zero = self.zero.unsqueeze(0)
|
116 |
+
|
117 |
+
def quantize(self, x):
|
118 |
+
if self.ready():
|
119 |
+
return self._quantize(x, self.scale, self.zero, self.maxq)
|
120 |
+
|
121 |
+
return x
|
122 |
+
|
123 |
+
def enabled(self):
|
124 |
+
return self.maxq > 0
|
125 |
+
|
126 |
+
def ready(self):
|
127 |
+
return torch.all(self.scale != 0)
|
quant/triton_norm.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import triton
|
4 |
+
import triton.language as tl
|
5 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
6 |
+
|
7 |
+
@triton.jit
|
8 |
+
def rms_norm_fwd_fused(
|
9 |
+
X, # pointer to the input
|
10 |
+
Y, # pointer to the output
|
11 |
+
W, # pointer to the weights
|
12 |
+
stride, # how much to increase the pointer when moving by 1 row
|
13 |
+
N, # number of columns in X
|
14 |
+
eps, # epsilon to avoid division by zero
|
15 |
+
BLOCK_SIZE: tl.constexpr,
|
16 |
+
):
|
17 |
+
# Map the program id to the row of X and Y it should compute.
|
18 |
+
row = tl.program_id(0)
|
19 |
+
Y += row * stride
|
20 |
+
X += row * stride
|
21 |
+
# Compute variance
|
22 |
+
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
23 |
+
for off in range(0, N, BLOCK_SIZE):
|
24 |
+
cols = off + tl.arange(0, BLOCK_SIZE)
|
25 |
+
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
26 |
+
x = tl.where(cols < N, x, 0.)
|
27 |
+
_var += x * x
|
28 |
+
var = tl.sum(_var, axis=0) / N
|
29 |
+
rstd = 1 / tl.sqrt(var + eps)
|
30 |
+
# Normalize and apply linear transformation
|
31 |
+
for off in range(0, N, BLOCK_SIZE):
|
32 |
+
cols = off + tl.arange(0, BLOCK_SIZE)
|
33 |
+
mask = cols < N
|
34 |
+
w = tl.load(W + cols, mask=mask)
|
35 |
+
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
|
36 |
+
x_hat = x * rstd
|
37 |
+
y = x_hat * w
|
38 |
+
# Write output
|
39 |
+
tl.store(Y + cols, y, mask=mask)
|
40 |
+
|
41 |
+
class TritonLlamaRMSNorm(nn.Module):
|
42 |
+
def __init__(self, weight, eps=1e-6):
|
43 |
+
"""
|
44 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
45 |
+
"""
|
46 |
+
super().__init__()
|
47 |
+
self.weight = weight
|
48 |
+
self.variance_epsilon = eps
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
y = torch.empty_like(x)
|
52 |
+
# reshape input data into 2D tensor
|
53 |
+
x_arg = x.reshape(-1, x.shape[-1])
|
54 |
+
M, N = x_arg.shape
|
55 |
+
# Less than 64KB per feature: enqueue fused kernel
|
56 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
57 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
58 |
+
if N > BLOCK_SIZE:
|
59 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
60 |
+
# heuristics for number of warps
|
61 |
+
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
62 |
+
# enqueue kernel
|
63 |
+
rms_norm_fwd_fused[(M,)](x_arg, y, self.weight,
|
64 |
+
x_arg.stride(0), N, self.variance_epsilon,
|
65 |
+
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
66 |
+
return y
|
67 |
+
|
68 |
+
|
69 |
+
def make_quant_norm(model):
|
70 |
+
"""
|
71 |
+
Replace all LlamaRMSNorm modules with TritonLlamaRMSNorm modules
|
72 |
+
"""
|
73 |
+
|
74 |
+
for name, m in model.named_modules():
|
75 |
+
if not isinstance(m, LlamaRMSNorm):
|
76 |
+
continue
|
77 |
+
|
78 |
+
norm = TritonLlamaRMSNorm(m.weight, m.variance_epsilon)
|
79 |
+
|
80 |
+
if '.' in name:
|
81 |
+
parent_name = name.rsplit('.', 1)[0]
|
82 |
+
child_name = name[len(parent_name) + 1:]
|
83 |
+
parent = model.get_submodule(parent_name)
|
84 |
+
else:
|
85 |
+
parent_name = ''
|
86 |
+
parent = model
|
87 |
+
child_name = name
|
88 |
+
|
89 |
+
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
|
90 |
+
|
91 |
+
setattr(parent, child_name, norm)
|