ssaroya commited on
Commit
401522d
·
1 Parent(s): 310935e

Upload 7 files

Browse files
quant/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .quantizer import Quantizer
2
+ from .fused_attn import QuantLlamaAttention, make_quant_attn
3
+ from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused
4
+ from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear
5
+ from .triton_norm import TritonLlamaRMSNorm, make_quant_norm
quant/custom_autotune.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/fpgaminer/GPTQ-triton
2
+ """
3
+ Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
4
+ """
5
+
6
+ import builtins
7
+ import math
8
+ import time
9
+ from typing import Dict
10
+
11
+ import triton
12
+
13
+
14
+ class Autotuner(triton.KernelInterface):
15
+
16
+ def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
17
+ '''
18
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
19
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
20
+ 'top_k': number of configs to bench
21
+ 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
22
+ 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
23
+ '''
24
+ if not configs:
25
+ self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
26
+ else:
27
+ self.configs = configs
28
+ self.key_idx = [arg_names.index(k) for k in key]
29
+ self.nearest_power_of_two = nearest_power_of_two
30
+ self.cache = {}
31
+ # hook to reset all required tensor to zeros before relaunching a kernel
32
+ self.hook = lambda args: 0
33
+ if reset_to_zero is not None:
34
+ self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
35
+
36
+ def _hook(args):
37
+ for i in self.reset_idx:
38
+ args[i].zero_()
39
+
40
+ self.hook = _hook
41
+ self.arg_names = arg_names
42
+ # prune configs
43
+ if prune_configs_by:
44
+ perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
45
+ if 'early_config_prune' in prune_configs_by:
46
+ early_config_prune = prune_configs_by['early_config_prune']
47
+ else:
48
+ perf_model, top_k, early_config_prune = None, None, None
49
+ self.perf_model, self.configs_top_k = perf_model, top_k
50
+ self.early_config_prune = early_config_prune
51
+ self.fn = fn
52
+
53
+ def _bench(self, *args, config, **meta):
54
+ # check for conflicts, i.e. meta-parameters both provided
55
+ # as kwargs and by the autotuner
56
+ conflicts = meta.keys() & config.kwargs.keys()
57
+ if conflicts:
58
+ raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
59
+ " Make sure that you don't re-define auto-tuned symbols.")
60
+ # augment meta-parameters with tunable ones
61
+ current = dict(meta, **config.kwargs)
62
+
63
+ def kernel_call():
64
+ if config.pre_hook:
65
+ config.pre_hook(self.nargs)
66
+ self.hook(args)
67
+ self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
68
+
69
+ try:
70
+ # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
71
+ # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
72
+ return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
73
+ except triton.compiler.OutOfResources:
74
+ return (float('inf'), float('inf'), float('inf'))
75
+
76
+ def run(self, *args, **kwargs):
77
+ self.nargs = dict(zip(self.arg_names, args))
78
+ if len(self.configs) > 1:
79
+ key = tuple(args[i] for i in self.key_idx)
80
+
81
+ # This reduces the amount of autotuning by rounding the keys to the nearest power of two
82
+ # In my testing this gives decent results, and greatly reduces the amount of tuning required
83
+ if self.nearest_power_of_two:
84
+ key = tuple([2**int(math.log2(x) + 0.5) for x in key])
85
+
86
+ if key not in self.cache:
87
+ # prune configs
88
+ pruned_configs = self.prune_configs(kwargs)
89
+ bench_start = time.time()
90
+ timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
91
+ bench_end = time.time()
92
+ self.bench_time = bench_end - bench_start
93
+ self.cache[key] = builtins.min(timings, key=timings.get)
94
+ self.hook(args)
95
+ self.configs_timings = timings
96
+ config = self.cache[key]
97
+ else:
98
+ config = self.configs[0]
99
+ self.best_config = config
100
+ if config.pre_hook is not None:
101
+ config.pre_hook(self.nargs)
102
+ return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
103
+
104
+ def prune_configs(self, kwargs):
105
+ pruned_configs = self.configs
106
+ if self.early_config_prune:
107
+ pruned_configs = self.early_config_prune(self.configs, self.nargs)
108
+ if self.perf_model:
109
+ top_k = self.configs_top_k
110
+ if isinstance(top_k, float) and top_k <= 1.0:
111
+ top_k = int(len(self.configs) * top_k)
112
+ if len(pruned_configs) > top_k:
113
+ est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
114
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
115
+ return pruned_configs
116
+
117
+ def warmup(self, *args, **kwargs):
118
+ self.nargs = dict(zip(self.arg_names, args))
119
+ for config in self.prune_configs(kwargs):
120
+ self.fn.warmup(
121
+ *args,
122
+ num_warps=config.num_warps,
123
+ num_stages=config.num_stages,
124
+ **kwargs,
125
+ **config.kwargs,
126
+ )
127
+ self.nargs = None
128
+
129
+
130
+ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
131
+ """
132
+ Decorator for auto-tuning a :code:`triton.jit`'d function.
133
+ .. highlight:: python
134
+ .. code-block:: python
135
+ @triton.autotune(configs=[
136
+ triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
137
+ triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
138
+ ],
139
+ key=['x_size'] # the two above configs will be evaluated anytime
140
+ # the value of x_size changes
141
+ )
142
+ @triton.jit
143
+ def kernel(x_ptr, x_size, **META):
144
+ BLOCK_SIZE = META['BLOCK_SIZE']
145
+ :note: When all the configurations are evaluated, the kernel will run multiple time.
146
+ This means that whatever value the kernel updates will be updated multiple times.
147
+ To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
148
+ reset the value of the provided tensor to `zero` before running any configuration.
149
+ :param configs: a list of :code:`triton.Config` objects
150
+ :type configs: list[triton.Config]
151
+ :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
152
+ :type key: list[str]
153
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
154
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
155
+ 'top_k': number of configs to bench
156
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
157
+ :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
158
+ :type reset_to_zero: list[str]
159
+ """
160
+
161
+ def decorator(fn):
162
+ return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
163
+
164
+ return decorator
165
+
166
+
167
+ def matmul248_kernel_config_pruner(configs, nargs):
168
+ """
169
+ The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
170
+ """
171
+ m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
172
+ n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
173
+ k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
174
+
175
+ used = set()
176
+ for config in configs:
177
+ block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
178
+ block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
179
+ block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
180
+ group_size_m = config.kwargs['GROUP_SIZE_M']
181
+
182
+ if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used:
183
+ continue
184
+
185
+ used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps))
186
+ yield triton.Config({
187
+ 'BLOCK_SIZE_M': block_size_m,
188
+ 'BLOCK_SIZE_N': block_size_n,
189
+ 'BLOCK_SIZE_K': block_size_k,
190
+ 'GROUP_SIZE_M': group_size_m
191
+ },
192
+ num_stages=config.num_stages,
193
+ num_warps=config.num_warps)
quant/fused_attn.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import functional as F
2
+ from transformers.models.llama.modeling_llama import LlamaAttention
3
+ from .quant_linear import *
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def rotate_half_kernel(
10
+ qk_seq_ptr,
11
+ position_ids_ptr,
12
+ qk_seq_stride,
13
+ position_ids_batch_stride,
14
+ seq_len,
15
+ HEAD_DIM: tl.constexpr,
16
+ BLOCK_HEIGHT: tl.constexpr,
17
+ BLOCK_WIDTH: tl.constexpr,
18
+ INV_BASE: tl.constexpr
19
+ ):
20
+ # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension.
21
+ # position ids: (bsz, seq_len) -- must be contiguous in the last dimension.
22
+
23
+ HALF_HEAD: tl.constexpr = HEAD_DIM // 2
24
+ STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH
25
+
26
+ batch_seq = tl.program_id(axis=0)
27
+ row_blk_x_col_blk = tl.program_id(axis=1)
28
+
29
+ row_blk = row_blk_x_col_blk // STEPS_PER_ROW
30
+ row = row_blk * BLOCK_HEIGHT
31
+ if BLOCK_WIDTH < HALF_HEAD:
32
+ col_blk = row_blk_x_col_blk % STEPS_PER_ROW
33
+ col = col_blk * BLOCK_WIDTH
34
+ else:
35
+ col: tl.constexpr = 0
36
+
37
+ # A block will never cross a sequence boundary, which simplifies things a lot.
38
+ batch = batch_seq // seq_len
39
+ seq = batch_seq % seq_len
40
+ position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq)
41
+ # As sometimes happens, just calculating this on the fly is faster than loading it from memory.
42
+ # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate.
43
+ freq = tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE) * position_id
44
+ cos = tl.cos(freq).to(tl.float32)
45
+ sin = tl.sin(freq).to(tl.float32)
46
+
47
+ col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH)
48
+ embed_offsets = (row * HEAD_DIM + col) + col_offsets
49
+ x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets
50
+
51
+ for k in range(0, BLOCK_HEIGHT):
52
+ x = tl.load(x_ptrs).to(tl.float32)
53
+ y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32)
54
+ out_x = x * cos - y * sin
55
+ tl.store(x_ptrs, out_x)
56
+ out_y = x * sin + y * cos
57
+ tl.store(x_ptrs + HALF_HEAD, out_y)
58
+ x_ptrs += HEAD_DIM
59
+
60
+
61
+ def triton_rotate_half_(qk, position_ids, config=None):
62
+ batch_size, seq_len, qandk, num_heads, head_dim = qk.shape
63
+
64
+ # This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective.
65
+ config = config or {'BLOCK_HEIGHT': 1, 'BLOCK_WIDTH': min(128, head_dim // 2), 'num_warps': 1}
66
+ config['BLOCK_HEIGHT'] = min(config['BLOCK_HEIGHT'], 2 * num_heads)
67
+
68
+ assert qk.stride(3) == head_dim
69
+ assert qk.stride(4) == 1
70
+ assert position_ids.shape == (batch_size, seq_len)
71
+ assert position_ids.stride(1) == 1, 'position_ids must be contiguous in the last dimension'
72
+ assert (2 * num_heads) % config['BLOCK_HEIGHT'] == 0, f'number of rows not evenly divisible by {config["BLOCK_HEIGHT"]}'
73
+ assert (head_dim // 2) % config['BLOCK_WIDTH'] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config["BLOCK_WIDTH"]}'
74
+
75
+ qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim)
76
+ grid = (qk_by_seq.shape[0], (2 * num_heads // config['BLOCK_HEIGHT']) * (head_dim // 2 // config['BLOCK_WIDTH']))
77
+
78
+ # Must be the same as the theta of the frequencies used to train the model.
79
+ BASE = 10000.0
80
+
81
+ rotate_half_kernel[grid](
82
+ qk_by_seq,
83
+ position_ids,
84
+ qk_by_seq.stride(0),
85
+ position_ids.stride(0),
86
+ seq_len,
87
+ HEAD_DIM=head_dim,
88
+ BLOCK_HEIGHT=config['BLOCK_HEIGHT'],
89
+ BLOCK_WIDTH=config['BLOCK_WIDTH'],
90
+ INV_BASE=-2.0 * math.log(BASE) / head_dim,
91
+ num_warps=config['num_warps']
92
+ )
93
+
94
+
95
+ class QuantLlamaAttention(nn.Module):
96
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
97
+
98
+ def __init__(
99
+ self,
100
+ hidden_size,
101
+ num_heads,
102
+ qkv_proj,
103
+ o_proj
104
+ ):
105
+ super().__init__()
106
+ self.hidden_size = hidden_size
107
+ self.num_heads = num_heads
108
+ self.head_dim = hidden_size // num_heads
109
+
110
+ if (self.head_dim * num_heads) != self.hidden_size:
111
+ raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
112
+ f" and `num_heads`: {num_heads}).")
113
+ self.qkv_proj = qkv_proj
114
+ self.o_proj = o_proj
115
+
116
+ def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
117
+ """Input shape: Batch x Time x Channel"""
118
+
119
+ bsz, q_len, _ = hidden_states.size()
120
+
121
+ qkv_states = self.qkv_proj(hidden_states)
122
+ qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
123
+
124
+ # This updates the query and key states in-place, saving VRAM.
125
+ triton_rotate_half_(qkv_states[:, :, :2], position_ids)
126
+
127
+ query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
128
+ del qkv_states
129
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
130
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
131
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
132
+
133
+ is_causal = past_key_value is None
134
+
135
+ kv_seq_len = q_len
136
+ if past_key_value is not None:
137
+ kv_seq_len += past_key_value[0].shape[-2]
138
+
139
+ if past_key_value is not None:
140
+ # reuse k, v, self_attention
141
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
142
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
143
+
144
+ if use_cache:
145
+ # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
146
+ # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
147
+ key_states = key_states.contiguous()
148
+ value_states = value_states.contiguous()
149
+ query_states = query_states.contiguous()
150
+
151
+ past_key_value = (key_states, value_states) if use_cache else None
152
+
153
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
154
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
155
+ del query_states, key_states, value_states
156
+
157
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
158
+ attn_output = self.o_proj(attn_output)
159
+
160
+ return attn_output, None, past_key_value
161
+
162
+
163
+ def make_quant_attn(model):
164
+ """
165
+ Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
166
+ """
167
+
168
+ for name, m in model.named_modules():
169
+ if not isinstance(m, LlamaAttention):
170
+ continue
171
+
172
+ q_proj = m.q_proj
173
+ k_proj = m.k_proj
174
+ v_proj = m.v_proj
175
+
176
+ qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
177
+ qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
178
+ scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
179
+ g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
180
+ bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
181
+
182
+ qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False)
183
+ qkv_layer.qweight = qweights
184
+ qkv_layer.qzeros = qzeros
185
+ qkv_layer.scales = scales
186
+ qkv_layer.g_idx = g_idx
187
+ qkv_layer.bias = bias
188
+ # We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch.
189
+
190
+ attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj)
191
+
192
+ if '.' in name:
193
+ parent_name = name.rsplit('.', 1)[0]
194
+ child_name = name[len(parent_name) + 1:]
195
+ parent = model.get_submodule(parent_name)
196
+ else:
197
+ parent_name = ''
198
+ parent = model
199
+ child_name = name
200
+
201
+ #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
202
+
203
+ setattr(parent, child_name, attn)
quant/fused_mlp.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.cuda.amp import custom_bwd, custom_fwd
5
+ from transformers.models.llama.modeling_llama import LlamaMLP
6
+
7
+ try:
8
+ import triton
9
+ import triton.language as tl
10
+ from . import custom_autotune
11
+
12
+ # code based https://github.com/fpgaminer/GPTQ-triton
13
+ @custom_autotune.autotune(
14
+ configs=[
15
+ triton.Config({
16
+ 'BLOCK_SIZE_M': 256,
17
+ 'BLOCK_SIZE_N': 64,
18
+ 'BLOCK_SIZE_K': 32,
19
+ 'GROUP_SIZE_M': 8
20
+ }, num_stages=4, num_warps=4),
21
+ triton.Config({
22
+ 'BLOCK_SIZE_M': 64,
23
+ 'BLOCK_SIZE_N': 256,
24
+ 'BLOCK_SIZE_K': 32,
25
+ 'GROUP_SIZE_M': 8
26
+ }, num_stages=4, num_warps=4),
27
+ triton.Config({
28
+ 'BLOCK_SIZE_M': 128,
29
+ 'BLOCK_SIZE_N': 128,
30
+ 'BLOCK_SIZE_K': 32,
31
+ 'GROUP_SIZE_M': 8
32
+ }, num_stages=4, num_warps=4),
33
+ triton.Config({
34
+ 'BLOCK_SIZE_M': 128,
35
+ 'BLOCK_SIZE_N': 64,
36
+ 'BLOCK_SIZE_K': 32,
37
+ 'GROUP_SIZE_M': 8
38
+ }, num_stages=4, num_warps=4),
39
+ triton.Config({
40
+ 'BLOCK_SIZE_M': 64,
41
+ 'BLOCK_SIZE_N': 128,
42
+ 'BLOCK_SIZE_K': 32,
43
+ 'GROUP_SIZE_M': 8
44
+ }, num_stages=4, num_warps=4),
45
+ triton.Config({
46
+ 'BLOCK_SIZE_M': 128,
47
+ 'BLOCK_SIZE_N': 32,
48
+ 'BLOCK_SIZE_K': 32,
49
+ 'GROUP_SIZE_M': 8
50
+ }, num_stages=4, num_warps=4), # 3090
51
+ triton.Config({
52
+ 'BLOCK_SIZE_M': 128,
53
+ 'BLOCK_SIZE_N': 16,
54
+ 'BLOCK_SIZE_K': 32,
55
+ 'GROUP_SIZE_M': 8
56
+ }, num_stages=4, num_warps=4), # 3090
57
+ triton.Config({
58
+ 'BLOCK_SIZE_M': 32,
59
+ 'BLOCK_SIZE_N': 32,
60
+ 'BLOCK_SIZE_K': 128,
61
+ 'GROUP_SIZE_M': 8
62
+ }, num_stages=2, num_warps=4), # 3090
63
+ triton.Config({
64
+ 'BLOCK_SIZE_M': 64,
65
+ 'BLOCK_SIZE_N': 16,
66
+ 'BLOCK_SIZE_K': 64,
67
+ 'GROUP_SIZE_M': 8
68
+ }, num_stages=4, num_warps=4), # 3090
69
+ triton.Config({
70
+ 'BLOCK_SIZE_M': 64,
71
+ 'BLOCK_SIZE_N': 32,
72
+ 'BLOCK_SIZE_K': 64,
73
+ 'GROUP_SIZE_M': 8
74
+ }, num_stages=4, num_warps=4), # 3090
75
+ ],
76
+ key=['M', 'N', 'K'],
77
+ nearest_power_of_two=True,
78
+ prune_configs_by={
79
+ 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
80
+ 'perf_model': None,
81
+ 'top_k': None,
82
+ },
83
+ )
84
+ @triton.jit
85
+ def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn,
86
+ stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
87
+ """
88
+ Computes: C = silu(A * B1) * (A * B2)
89
+ A is of shape (M, K) float16
90
+ B is of shape (K//8, N) int32
91
+ C is of shape (M, N) float16
92
+ scales is of shape (1, N) float16
93
+ zeros is of shape (1, N//8) int32
94
+ """
95
+ infearure_per_bits = 32 // bits
96
+
97
+ pid = tl.program_id(axis=0)
98
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
99
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
100
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
101
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
102
+ group_id = pid // num_pid_in_group
103
+ first_pid_m = group_id * GROUP_SIZE_M
104
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
105
+ pid_m = first_pid_m + (pid % group_size_m)
106
+ pid_n = (pid % num_pid_in_group) // group_size_m
107
+
108
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
110
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
111
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
112
+ a_mask = (offs_am[:, None] < M)
113
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
114
+ b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
115
+ b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
116
+ g1_ptrs = g1_ptr + offs_k
117
+ g2_ptrs = g2_ptr + offs_k
118
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
119
+ scales1_ptrs = scales1_ptr + offs_bn[None, :]
120
+ scales2_ptrs = scales2_ptr + offs_bn[None, :]
121
+ zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)
122
+ zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)
123
+
124
+ shifter = (offs_k % infearure_per_bits) * bits
125
+ zeros_shifter = (offs_bn % infearure_per_bits) * bits
126
+ accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
127
+ accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
128
+ for k in range(0, num_pid_k):
129
+ g1_idx = tl.load(g1_ptrs)
130
+ g2_idx = tl.load(g2_ptrs)
131
+
132
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
133
+ scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
134
+ scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)
135
+
136
+ zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
137
+ zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq
138
+ zeros1 = (zeros1 + 1)
139
+
140
+ zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
141
+ zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq
142
+ zeros2 = (zeros2 + 1)
143
+
144
+ a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
145
+ b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
146
+ b2 = tl.load(b2_ptrs)
147
+
148
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
149
+ b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values
150
+ b1 = (b1 - zeros1) * scales1 # Scale and shift
151
+ accumulator1 += tl.dot(a, b1)
152
+
153
+ b2 = (b2 >> shifter[:, None]) & maxq
154
+ b2 = (b2 - zeros2) * scales2
155
+ accumulator2 += tl.dot(a, b2)
156
+
157
+ a_ptrs += BLOCK_SIZE_K
158
+ b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
159
+ b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
160
+ g1_ptrs += BLOCK_SIZE_K
161
+ g2_ptrs += BLOCK_SIZE_K
162
+
163
+ accumulator1 = silu(accumulator1)
164
+ c = accumulator1 * accumulator2
165
+ c = c.to(tl.float16)
166
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
167
+ c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
168
+ tl.store(c_ptrs, c, mask=c_mask)
169
+
170
+ @triton.jit
171
+ def silu(x):
172
+ return x * tl.sigmoid(x)
173
+ except:
174
+ print('triton not installed.')
175
+
176
+
177
+ class QuantLlamaMLP(nn.Module):
178
+
179
+ def __init__(
180
+ self,
181
+ gate_proj,
182
+ down_proj,
183
+ up_proj,
184
+ ):
185
+ super().__init__()
186
+ self.register_buffer('gate_proj_qweight', gate_proj.qweight)
187
+ self.register_buffer('gate_proj_scales', gate_proj.scales)
188
+ self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
189
+ self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)
190
+ self.register_buffer('up_proj_qweight', up_proj.qweight)
191
+ self.register_buffer('up_proj_scales', up_proj.scales)
192
+ self.register_buffer('up_proj_qzeros', up_proj.qzeros)
193
+ self.register_buffer('up_proj_g_idx', up_proj.g_idx)
194
+
195
+ self.infeatures = gate_proj.infeatures
196
+ self.intermediate_size = gate_proj.outfeatures
197
+ self.outfeatures = down_proj.outfeatures
198
+ self.bits = gate_proj.bits
199
+ self.maxq = gate_proj.maxq
200
+
201
+ self.down_proj = down_proj
202
+
203
+ def forward(self, x):
204
+ return self.down_proj(self.triton_llama_mlp(x))
205
+
206
+ def triton_llama_mlp(self, x):
207
+ with torch.cuda.device(x.device):
208
+ out_shape = x.shape[:-1] + (self.intermediate_size, )
209
+ x = x.reshape(-1, x.shape[-1])
210
+ M, K = x.shape
211
+ N = self.intermediate_size
212
+ c = torch.empty((M, N), device=x.device, dtype=torch.float16)
213
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
214
+ fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales,
215
+ self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0),
216
+ self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0))
217
+ c = c.reshape(out_shape)
218
+ return c
219
+
220
+ def fused2cuda(self):
221
+ self.gate_proj_qweight = self.gate_proj_qweight.cuda()
222
+ self.gate_proj_scales = self.gate_proj_scales.cuda()
223
+ self.gate_proj_qzeros = self.gate_proj_qzeros.cuda()
224
+ self.gate_proj_g_idx = self.gate_proj_g_idx.cuda()
225
+ self.up_proj_qweight = self.up_proj_qweight.cuda()
226
+ self.up_proj_scales = self.up_proj_scales.cuda()
227
+ self.up_proj_qzeros = self.up_proj_qzeros.cuda()
228
+ self.up_proj_g_idx = self.up_proj_g_idx.cuda()
229
+
230
+ def fused2cpu(self):
231
+ self.gate_proj_qweight = self.gate_proj_qweight.cpu()
232
+ self.gate_proj_scales = self.gate_proj_scales.cpu()
233
+ self.gate_proj_qzeros = self.gate_proj_qzeros.cpu()
234
+ self.gate_proj_g_idx = self.gate_proj_g_idx.cpu()
235
+ self.up_proj_qweight = self.up_proj_qweight.cpu()
236
+ self.up_proj_scales = self.up_proj_scales.cpu()
237
+ self.up_proj_qzeros = self.up_proj_qzeros.cpu()
238
+ self.up_proj_g_idx = self.up_proj_g_idx.cpu()
239
+
240
+
241
+ def make_fused_mlp(m, parent_name=''):
242
+ """
243
+ Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
244
+ """
245
+ if isinstance(m, LlamaMLP):
246
+ return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
247
+
248
+ for name, child in m.named_children():
249
+ child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
250
+
251
+ if isinstance(child, QuantLlamaMLP):
252
+ setattr(m, name, child)
253
+ return m
254
+
255
+
256
+ def autotune_warmup_fused(model):
257
+ """
258
+ Pre-tunes the quantized kernel
259
+ """
260
+ from tqdm import tqdm
261
+
262
+ kn_values = {}
263
+
264
+ for _, m in model.named_modules():
265
+ if not isinstance(m, QuantLlamaMLP):
266
+ continue
267
+
268
+ k = m.infeatures
269
+ n = m.intermediate_size
270
+
271
+ m.fused2cuda()
272
+ if (k, n) not in kn_values:
273
+ kn_values[(k, n)] = m
274
+
275
+ print(f'Found {len(kn_values)} unique fused mlp KN values.')
276
+
277
+ print('Warming up autotune cache ...')
278
+ with torch.no_grad():
279
+ for m in tqdm(range(0, 12)):
280
+ m = 2**m # [1, 2048]
281
+ for (k, n), (modules) in kn_values.items():
282
+ a = torch.randn(m, k, dtype=torch.float16, device='cuda')
283
+ modules.triton_llama_mlp(a)
284
+
285
+ for (k, n), (modules) in kn_values.items():
286
+ a = torch.randn(m, k, dtype=torch.float16, device='cuda')
287
+ modules.fused2cpu()
288
+ del kn_values
quant/quant_linear.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.cuda.amp import custom_bwd, custom_fwd
6
+
7
+ try:
8
+ import triton
9
+ import triton.language as tl
10
+ from . import custom_autotune
11
+
12
+ # code based https://github.com/fpgaminer/GPTQ-triton
13
+ @custom_autotune.autotune(
14
+ configs=[
15
+ triton.Config({
16
+ 'BLOCK_SIZE_M': 64,
17
+ 'BLOCK_SIZE_N': 256,
18
+ 'BLOCK_SIZE_K': 32,
19
+ 'GROUP_SIZE_M': 8
20
+ }, num_stages=4, num_warps=4),
21
+ triton.Config({
22
+ 'BLOCK_SIZE_M': 128,
23
+ 'BLOCK_SIZE_N': 128,
24
+ 'BLOCK_SIZE_K': 32,
25
+ 'GROUP_SIZE_M': 8
26
+ }, num_stages=4, num_warps=4),
27
+ triton.Config({
28
+ 'BLOCK_SIZE_M': 64,
29
+ 'BLOCK_SIZE_N': 128,
30
+ 'BLOCK_SIZE_K': 32,
31
+ 'GROUP_SIZE_M': 8
32
+ }, num_stages=4, num_warps=4),
33
+ triton.Config({
34
+ 'BLOCK_SIZE_M': 128,
35
+ 'BLOCK_SIZE_N': 32,
36
+ 'BLOCK_SIZE_K': 32,
37
+ 'GROUP_SIZE_M': 8
38
+ }, num_stages=4, num_warps=4),
39
+ triton.Config({
40
+ 'BLOCK_SIZE_M': 64,
41
+ 'BLOCK_SIZE_N': 64,
42
+ 'BLOCK_SIZE_K': 32,
43
+ 'GROUP_SIZE_M': 8
44
+ }, num_stages=4, num_warps=4),
45
+ triton.Config({
46
+ 'BLOCK_SIZE_M': 64,
47
+ 'BLOCK_SIZE_N': 128,
48
+ 'BLOCK_SIZE_K': 32,
49
+ 'GROUP_SIZE_M': 8
50
+ }, num_stages=2, num_warps=8),
51
+ triton.Config({
52
+ 'BLOCK_SIZE_M': 64,
53
+ 'BLOCK_SIZE_N': 64,
54
+ 'BLOCK_SIZE_K': 64,
55
+ 'GROUP_SIZE_M': 8
56
+ }, num_stages=3, num_warps=8),
57
+ triton.Config({
58
+ 'BLOCK_SIZE_M': 32,
59
+ 'BLOCK_SIZE_N': 32,
60
+ 'BLOCK_SIZE_K': 128,
61
+ 'GROUP_SIZE_M': 8
62
+ }, num_stages=2, num_warps=4),
63
+ ],
64
+ key=['M', 'N', 'K'],
65
+ nearest_power_of_two=True,
66
+ prune_configs_by={
67
+ 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
68
+ 'perf_model': None,
69
+ 'top_k': None,
70
+ },
71
+ )
72
+ @triton.jit
73
+ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
74
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
75
+ """
76
+ Compute the matrix multiplication C = A x B.
77
+ A is of shape (M, K) float16
78
+ B is of shape (K//8, N) int32
79
+ C is of shape (M, N) float16
80
+ scales is of shape (G, N) float16
81
+ zeros is of shape (G, N) float16
82
+ g_ptr is of shape (K) int32
83
+ """
84
+ infearure_per_bits = 32 // bits
85
+
86
+ pid = tl.program_id(axis=0)
87
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
88
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
89
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
90
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
91
+ group_id = pid // num_pid_in_group
92
+ first_pid_m = group_id * GROUP_SIZE_M
93
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
94
+ pid_m = first_pid_m + (pid % group_size_m)
95
+ pid_n = (pid % num_pid_in_group) // group_size_m
96
+
97
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
98
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
99
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
100
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
101
+ a_mask = (offs_am[:, None] < M)
102
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
103
+ b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
104
+ g_ptrs = g_ptr + offs_k
105
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
106
+ scales_ptrs = scales_ptr + offs_bn[None, :]
107
+ zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
108
+
109
+ shifter = (offs_k % infearure_per_bits) * bits
110
+ zeros_shifter = (offs_bn % infearure_per_bits) * bits
111
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
112
+
113
+ for k in range(0, num_pid_k):
114
+ g_idx = tl.load(g_ptrs)
115
+
116
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
117
+ scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
118
+ zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
119
+
120
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
121
+ zeros = (zeros + 1)
122
+
123
+ a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
124
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
125
+
126
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
127
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
128
+ b = (b - zeros) * scales # Scale and shift
129
+
130
+ accumulator += tl.dot(a, b)
131
+ a_ptrs += BLOCK_SIZE_K
132
+ b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
133
+ g_ptrs += BLOCK_SIZE_K
134
+
135
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
136
+ c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
137
+ tl.store(c_ptrs, accumulator, mask=c_mask)
138
+
139
+ @custom_autotune.autotune(configs=[
140
+ triton.Config({
141
+ 'BLOCK_SIZE_M': 64,
142
+ 'BLOCK_SIZE_N': 32,
143
+ 'BLOCK_SIZE_K': 256,
144
+ 'GROUP_SIZE_M': 8
145
+ }, num_stages=4, num_warps=4),
146
+ triton.Config({
147
+ 'BLOCK_SIZE_M': 128,
148
+ 'BLOCK_SIZE_N': 32,
149
+ 'BLOCK_SIZE_K': 128,
150
+ 'GROUP_SIZE_M': 8
151
+ }, num_stages=4, num_warps=4),
152
+ triton.Config({
153
+ 'BLOCK_SIZE_M': 64,
154
+ 'BLOCK_SIZE_N': 32,
155
+ 'BLOCK_SIZE_K': 128,
156
+ 'GROUP_SIZE_M': 8
157
+ }, num_stages=4, num_warps=4),
158
+ triton.Config({
159
+ 'BLOCK_SIZE_M': 128,
160
+ 'BLOCK_SIZE_N': 32,
161
+ 'BLOCK_SIZE_K': 32,
162
+ 'GROUP_SIZE_M': 8
163
+ }, num_stages=4, num_warps=4),
164
+ triton.Config({
165
+ 'BLOCK_SIZE_M': 64,
166
+ 'BLOCK_SIZE_N': 32,
167
+ 'BLOCK_SIZE_K': 64,
168
+ 'GROUP_SIZE_M': 8
169
+ }, num_stages=4, num_warps=4),
170
+ triton.Config({
171
+ 'BLOCK_SIZE_M': 64,
172
+ 'BLOCK_SIZE_N': 32,
173
+ 'BLOCK_SIZE_K': 128,
174
+ 'GROUP_SIZE_M': 8
175
+ }, num_stages=2, num_warps=8),
176
+ triton.Config({
177
+ 'BLOCK_SIZE_M': 64,
178
+ 'BLOCK_SIZE_N': 64,
179
+ 'BLOCK_SIZE_K': 64,
180
+ 'GROUP_SIZE_M': 8
181
+ }, num_stages=3, num_warps=8),
182
+ triton.Config({
183
+ 'BLOCK_SIZE_M': 32,
184
+ 'BLOCK_SIZE_N': 128,
185
+ 'BLOCK_SIZE_K': 32,
186
+ 'GROUP_SIZE_M': 8
187
+ }, num_stages=2, num_warps=4),
188
+ ],
189
+ key=['M', 'N', 'K'],
190
+ nearest_power_of_two=True)
191
+ @triton.jit
192
+ def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
193
+ stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
194
+ """
195
+ Compute the matrix multiplication C = A x B.
196
+ A is of shape (M, N) float16
197
+ B is of shape (K//8, N) int32
198
+ C is of shape (M, K) float16
199
+ scales is of shape (G, N) float16
200
+ zeros is of shape (G, N) float16
201
+ g_ptr is of shape (K) int32
202
+ """
203
+ infearure_per_bits = 32 // bits
204
+
205
+ pid = tl.program_id(axis=0)
206
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
207
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
208
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
209
+ num_pid_in_group = GROUP_SIZE_M * num_pid_k
210
+ group_id = pid // num_pid_in_group
211
+ first_pid_m = group_id * GROUP_SIZE_M
212
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
213
+ pid_m = first_pid_m + (pid % group_size_m)
214
+ pid_k = (pid % num_pid_in_group) // group_size_m
215
+
216
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
217
+ offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
218
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
219
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
220
+ a_mask = (offs_am[:, None] < M)
221
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
222
+ b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
223
+ g_ptrs = g_ptr + offs_bk
224
+ g_idx = tl.load(g_ptrs)
225
+
226
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
227
+ scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
228
+ zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
229
+
230
+ shifter = (offs_bk % infearure_per_bits) * bits
231
+ zeros_shifter = (offs_n % infearure_per_bits) * bits
232
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
233
+
234
+ for n in range(0, num_pid_n):
235
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
236
+ scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
237
+ zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
238
+
239
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
240
+ zeros = (zeros + 1)
241
+
242
+ a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
243
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
244
+
245
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
246
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
247
+ b = (b - zeros) * scales # Scale and shift
248
+ b = tl.trans(b)
249
+
250
+ accumulator += tl.dot(a, b)
251
+ a_ptrs += BLOCK_SIZE_N
252
+ b_ptrs += BLOCK_SIZE_N
253
+ scales_ptrs += BLOCK_SIZE_N
254
+ zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
255
+
256
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
257
+ c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
258
+ tl.store(c_ptrs, accumulator, mask=c_mask)
259
+ except:
260
+ print('trioton not installed.')
261
+
262
+
263
+ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
264
+ with torch.cuda.device(input.device):
265
+ output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
266
+ grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
267
+ matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
268
+ qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
269
+ return output
270
+
271
+
272
+ def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
273
+ with torch.cuda.device(input.device):
274
+ output_dim = (qweight.shape[0] * 32) // bits
275
+ output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
276
+ grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
277
+ transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
278
+ qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
279
+ return output
280
+
281
+
282
+ class QuantLinearFunction(torch.autograd.Function):
283
+
284
+ @staticmethod
285
+ @custom_fwd(cast_inputs=torch.float16)
286
+ def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
287
+ output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
288
+ ctx.save_for_backward(qweight, scales, qzeros, g_idx)
289
+ ctx.bits, ctx.maxq = bits, maxq
290
+ return output
291
+
292
+ @staticmethod
293
+ @custom_bwd
294
+ def backward(ctx, grad_output):
295
+ qweight, scales, qzeros, g_idx = ctx.saved_tensors
296
+ bits, maxq = ctx.bits, ctx.maxq
297
+ grad_input = None
298
+
299
+ if ctx.needs_input_grad[0]:
300
+ grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
301
+ return grad_input, None, None, None, None, None, None
302
+
303
+
304
+ class QuantLinear(nn.Module):
305
+
306
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
307
+ super().__init__()
308
+ if bits not in [2, 4, 8]:
309
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
310
+ self.infeatures = infeatures
311
+ self.outfeatures = outfeatures
312
+ self.bits = bits
313
+ self.maxq = 2**self.bits - 1
314
+ self.groupsize = groupsize if groupsize != -1 else infeatures
315
+
316
+ self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
317
+ self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
318
+ self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
319
+ self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
320
+ if bias:
321
+ self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
322
+ else:
323
+ self.bias = None
324
+
325
+ def pack(self, linear, scales, zeros, g_idx=None):
326
+ self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
327
+
328
+ scales = scales.t().contiguous()
329
+ zeros = zeros.t().contiguous()
330
+ scale_zeros = zeros * scales
331
+ self.scales = scales.clone().half()
332
+ if linear.bias is not None:
333
+ self.bias = linear.bias.clone().half()
334
+
335
+ intweight = []
336
+ for idx in range(self.infeatures):
337
+ intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
338
+ intweight = torch.cat(intweight, dim=1)
339
+ intweight = intweight.t().contiguous()
340
+ intweight = intweight.numpy().astype(np.uint32)
341
+ qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
342
+ i = 0
343
+ row = 0
344
+ while row < qweight.shape[0]:
345
+ if self.bits in [2, 4, 8]:
346
+ for j in range(i, i + (32 // self.bits)):
347
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
348
+ i += 32 // self.bits
349
+ row += 1
350
+ else:
351
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
352
+
353
+ qweight = qweight.astype(np.int32)
354
+ self.qweight = torch.from_numpy(qweight)
355
+
356
+ zeros -= 1
357
+ zeros = zeros.numpy().astype(np.uint32)
358
+ qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
359
+ i = 0
360
+ col = 0
361
+ while col < qzeros.shape[1]:
362
+ if self.bits in [2, 4, 8]:
363
+ for j in range(i, i + (32 // self.bits)):
364
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
365
+ i += 32 // self.bits
366
+ col += 1
367
+ else:
368
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
369
+
370
+ qzeros = qzeros.astype(np.int32)
371
+ self.qzeros = torch.from_numpy(qzeros)
372
+
373
+ def forward(self, x):
374
+ out_shape = x.shape[:-1] + (self.outfeatures, )
375
+ out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
376
+ out = out + self.bias if self.bias is not None else out
377
+ return out.reshape(out_shape)
378
+
379
+
380
+ def make_quant_linear(module, names, bits, groupsize, name=''):
381
+ if isinstance(module, QuantLinear):
382
+ return
383
+ for attr in dir(module):
384
+ tmp = getattr(module, attr)
385
+ name1 = name + '.' + attr if name != '' else attr
386
+ if name1 in names:
387
+ delattr(module, attr)
388
+ setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
389
+ for name1, child in module.named_children():
390
+ make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
391
+
392
+
393
+ def autotune_warmup_linear(model, transpose=False):
394
+ """
395
+ Pre-tunes the quantized kernel
396
+ """
397
+ from tqdm import tqdm
398
+
399
+ kn_values = {}
400
+
401
+ for _, m in model.named_modules():
402
+ if not isinstance(m, QuantLinear):
403
+ continue
404
+
405
+ k = m.infeatures
406
+ n = m.outfeatures
407
+
408
+ if (k, n) not in kn_values:
409
+ kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
410
+
411
+ print(f'Found {len(kn_values)} unique KN Linear values.')
412
+
413
+ print('Warming up autotune cache ...')
414
+ with torch.no_grad():
415
+ for m in tqdm(range(0, 12)):
416
+ m = 2**m # [1, 2048]
417
+ for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
418
+ a = torch.randn(m, k, dtype=torch.float16, device='cuda')
419
+ matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
420
+ if transpose:
421
+ a = torch.randn(m, n, dtype=torch.float16, device='cuda')
422
+ transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
423
+ del kn_values
quant/quantizer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+
6
+
7
+ class Quantizer(nn.Module):
8
+
9
+ def __init__(self, shape=1):
10
+ super(Quantizer, self).__init__()
11
+ self.register_buffer('maxq', torch.tensor(0))
12
+ self.register_buffer('scale', torch.zeros(shape))
13
+ self.register_buffer('zero', torch.zeros(shape))
14
+
15
+ def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
16
+
17
+ self.maxq = torch.tensor(2**bits - 1)
18
+ self.perchannel = perchannel
19
+ self.sym = sym
20
+ self.mse = mse
21
+ self.norm = norm
22
+ self.grid = grid
23
+ self.maxshrink = maxshrink
24
+ if trits:
25
+ self.maxq = torch.tensor(-1)
26
+ self.scale = torch.zeros_like(self.scale)
27
+
28
+ def _quantize(self, x, scale, zero, maxq):
29
+ if maxq < 0:
30
+ return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
31
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
32
+ return scale * (q - zero)
33
+
34
+ def find_params(self, x, weight=False):
35
+ dev = x.device
36
+ self.maxq = self.maxq.to(dev)
37
+
38
+ shape = x.shape
39
+ if self.perchannel:
40
+ if weight:
41
+ x = x.flatten(1)
42
+ else:
43
+ if len(shape) == 4:
44
+ x = x.permute([1, 0, 2, 3])
45
+ x = x.flatten(1)
46
+ if len(shape) == 3:
47
+ x = x.reshape((-1, shape[-1])).t()
48
+ if len(shape) == 2:
49
+ x = x.t()
50
+ else:
51
+ x = x.flatten().unsqueeze(0)
52
+
53
+ tmp = torch.zeros(x.shape[0], device=dev)
54
+ xmin = torch.minimum(x.min(1)[0], tmp)
55
+ xmax = torch.maximum(x.max(1)[0], tmp)
56
+
57
+ if self.sym:
58
+ xmax = torch.maximum(torch.abs(xmin), xmax)
59
+ tmp = xmin < 0
60
+ if torch.any(tmp):
61
+ xmin[tmp] = -xmax[tmp]
62
+ tmp = (xmin == 0) & (xmax == 0)
63
+ xmin[tmp] = -1
64
+ xmax[tmp] = +1
65
+
66
+ if self.maxq < 0:
67
+ self.scale = xmax
68
+ self.zero = xmin
69
+ else:
70
+ self.scale = (xmax - xmin) / self.maxq
71
+ if self.sym:
72
+ self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
73
+ else:
74
+ self.zero = torch.round(-xmin / self.scale)
75
+
76
+ if self.mse:
77
+ best = torch.full([x.shape[0]], float('inf'), device=dev)
78
+ for i in range(int(self.maxshrink * self.grid)):
79
+ p = 1 - i / self.grid
80
+ xmin1 = p * xmin
81
+ xmax1 = p * xmax
82
+ scale1 = (xmax1 - xmin1) / self.maxq
83
+ zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
84
+ q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
85
+ q -= x
86
+ q.abs_()
87
+ q.pow_(self.norm)
88
+ err = torch.sum(q, 1)
89
+ tmp = err < best
90
+ if torch.any(tmp):
91
+ best[tmp] = err[tmp]
92
+ self.scale[tmp] = scale1[tmp]
93
+ self.zero[tmp] = zero1[tmp]
94
+ if not self.perchannel:
95
+ if weight:
96
+ tmp = shape[0]
97
+ else:
98
+ tmp = shape[1] if len(shape) != 3 else shape[2]
99
+ self.scale = self.scale.repeat(tmp)
100
+ self.zero = self.zero.repeat(tmp)
101
+
102
+ if weight:
103
+ shape = [-1] + [1] * (len(shape) - 1)
104
+ self.scale = self.scale.reshape(shape)
105
+ self.zero = self.zero.reshape(shape)
106
+ return
107
+ if len(shape) == 4:
108
+ self.scale = self.scale.reshape((1, -1, 1, 1))
109
+ self.zero = self.zero.reshape((1, -1, 1, 1))
110
+ if len(shape) == 3:
111
+ self.scale = self.scale.reshape((1, 1, -1))
112
+ self.zero = self.zero.reshape((1, 1, -1))
113
+ if len(shape) == 2:
114
+ self.scale = self.scale.unsqueeze(0)
115
+ self.zero = self.zero.unsqueeze(0)
116
+
117
+ def quantize(self, x):
118
+ if self.ready():
119
+ return self._quantize(x, self.scale, self.zero, self.maxq)
120
+
121
+ return x
122
+
123
+ def enabled(self):
124
+ return self.maxq > 0
125
+
126
+ def ready(self):
127
+ return torch.all(self.scale != 0)
quant/triton_norm.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import triton
4
+ import triton.language as tl
5
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
6
+
7
+ @triton.jit
8
+ def rms_norm_fwd_fused(
9
+ X, # pointer to the input
10
+ Y, # pointer to the output
11
+ W, # pointer to the weights
12
+ stride, # how much to increase the pointer when moving by 1 row
13
+ N, # number of columns in X
14
+ eps, # epsilon to avoid division by zero
15
+ BLOCK_SIZE: tl.constexpr,
16
+ ):
17
+ # Map the program id to the row of X and Y it should compute.
18
+ row = tl.program_id(0)
19
+ Y += row * stride
20
+ X += row * stride
21
+ # Compute variance
22
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
23
+ for off in range(0, N, BLOCK_SIZE):
24
+ cols = off + tl.arange(0, BLOCK_SIZE)
25
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
26
+ x = tl.where(cols < N, x, 0.)
27
+ _var += x * x
28
+ var = tl.sum(_var, axis=0) / N
29
+ rstd = 1 / tl.sqrt(var + eps)
30
+ # Normalize and apply linear transformation
31
+ for off in range(0, N, BLOCK_SIZE):
32
+ cols = off + tl.arange(0, BLOCK_SIZE)
33
+ mask = cols < N
34
+ w = tl.load(W + cols, mask=mask)
35
+ x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
36
+ x_hat = x * rstd
37
+ y = x_hat * w
38
+ # Write output
39
+ tl.store(Y + cols, y, mask=mask)
40
+
41
+ class TritonLlamaRMSNorm(nn.Module):
42
+ def __init__(self, weight, eps=1e-6):
43
+ """
44
+ LlamaRMSNorm is equivalent to T5LayerNorm
45
+ """
46
+ super().__init__()
47
+ self.weight = weight
48
+ self.variance_epsilon = eps
49
+
50
+ def forward(self, x):
51
+ y = torch.empty_like(x)
52
+ # reshape input data into 2D tensor
53
+ x_arg = x.reshape(-1, x.shape[-1])
54
+ M, N = x_arg.shape
55
+ # Less than 64KB per feature: enqueue fused kernel
56
+ MAX_FUSED_SIZE = 65536 // x.element_size()
57
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
58
+ if N > BLOCK_SIZE:
59
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
60
+ # heuristics for number of warps
61
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
62
+ # enqueue kernel
63
+ rms_norm_fwd_fused[(M,)](x_arg, y, self.weight,
64
+ x_arg.stride(0), N, self.variance_epsilon,
65
+ BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
66
+ return y
67
+
68
+
69
+ def make_quant_norm(model):
70
+ """
71
+ Replace all LlamaRMSNorm modules with TritonLlamaRMSNorm modules
72
+ """
73
+
74
+ for name, m in model.named_modules():
75
+ if not isinstance(m, LlamaRMSNorm):
76
+ continue
77
+
78
+ norm = TritonLlamaRMSNorm(m.weight, m.variance_epsilon)
79
+
80
+ if '.' in name:
81
+ parent_name = name.rsplit('.', 1)[0]
82
+ child_name = name[len(parent_name) + 1:]
83
+ parent = model.get_submodule(parent_name)
84
+ else:
85
+ parent_name = ''
86
+ parent = model
87
+ child_name = name
88
+
89
+ #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
90
+
91
+ setattr(parent, child_name, norm)