Resolve - 196 [rank0]: triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 180224, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
794ffcf
verified
""" | |
Author: Eric Lin (xihlin) | |
""" | |
""" | |
... note(bapatra):: | |
This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module | |
imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal. | |
In the future, would be really good to revisit this and refactor into a more readable file structure. | |
""" | |
from typing import TypeVar | |
from functools import lru_cache | |
import math | |
import pytest | |
import torch | |
import numpy as np | |
import triton | |
import triton.language as tl | |
import os | |
import dataclasses | |
Phi3SmallConfig = TypeVar('Phi3SmallConfig') | |
# triton 2.0.0: fail at backward on A100, for the examples, if h_dim=128. | |
# Done | |
# 1. strided of qkv | |
# 2. seq len not power of 2 | |
# 3. bf16 with Triton May, 2023 | |
# TODO: | |
# 1. wip: support non-contiguous backward, also help reduce memory allocation in training (q, k, v split) | |
# 2. block sparse with different BLOCK_M, BLOCK_N? | |
# 3. for Lq not divided by BLOCK_M, BLOCK_N, only apply mask to K/V on last batch, still need to apply mask on Q. | |
# Attempt, fail to compile | |
# 4. For 2nd iter of inference, BLOCK_M=1, how to make things work? K/V maynot divided by BLOCK_N. | |
# 5. The inner loop can also be paralled via bigger num_stage(better) or on different thread-block (via m/L and atomic update, but this no-comm/sync between blocks) | |
########################################################### | |
################### Kernel Parameters ##################### | |
########################################################### | |
class BlockSparseParams(object): | |
block_size: int | |
kernel_block_size: int | |
num_local_blocks: int | |
vert_stride: int | |
homo_head_pattern: bool = False | |
def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams": | |
return cls( | |
block_size=config.blocksparse_block_size, | |
kernel_block_size=config.blocksparse_triton_kernel_block_size, | |
num_local_blocks=config.blocksparse_num_local_blocks, | |
vert_stride=config.blocksparse_vert_stride, | |
homo_head_pattern=config.blocksparse_homo_head_pattern, | |
) | |
########################################################### | |
########################################################### | |
########################################################### | |
################### Utility Functions ##################### | |
########################################################### | |
# helper functions for 3D sparse pattern | |
# these function are not optimized and very inefficient. Avoid calling them too frequent. | |
# currently, it is only called within `get_local_strided_sparse_attention_op`, which is cached. | |
def dense_to_crow_col(x): | |
''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. | |
param: | |
TODO: | |
1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it? | |
NOTE: col_indices padded -1 | |
''' | |
pad = -1 | |
dim = x.dim() | |
assert x.dim() in (2, 3) | |
if x.dim() == 2: | |
x = x[None] | |
x = [xi.to_sparse_csr() for xi in x] | |
crows = torch.vstack([xi.crow_indices() for xi in x]) | |
cols = [xi.col_indices() for xi in x] | |
max_cols = max(len(xi) for xi in cols) | |
cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols] | |
cols = torch.vstack(cols) | |
if dim == 2: | |
crows = crows[0] | |
cols = cols[0] | |
return crows, cols | |
def crow_col_to_dense(crows, cols, dtype=torch.float16): | |
dim = crows.dim() | |
if dim == 1: | |
crows = crows[None] | |
cols = cols[None] | |
device = crows.device | |
crows, cols = crows.cpu(), cols.cpu() # faster in cpu | |
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) | |
x = torch.zeros(shape, dtype=dtype) | |
for i in range(shape[0]): | |
for j in range(shape[1]): | |
x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1 | |
if dim == 1: | |
x = x[0] | |
return x.to(device) | |
def dense_to_ccol_row(x): | |
'''Similar, but to CSC format | |
''' | |
x = x.transpose(-2, -1) | |
return dense_to_crow_col(x) | |
def ccol_row_to_dense(ccol, rows, dtype=torch.float16): | |
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() | |
def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False): | |
''' | |
:return: a tuple of 3: | |
- tuple of crow_indices, col_indices representation of CSR format. | |
- block dense mask | |
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None | |
''' | |
with torch.no_grad(): | |
N_BLOCK = triton.cdiv(N_CTX, BLOCK) | |
q_pos = torch.arange(N_BLOCK)[:, None] | |
k_pos = torch.arange(N_BLOCK)[None] | |
mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0 | |
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype) | |
N_BLOCK_Q = triton.cdiv(q_len, BLOCK) | |
block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr() | |
if return_dense: | |
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK))) | |
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:] | |
mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask | |
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense | |
else: | |
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None | |
def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False): | |
''' | |
:return: a tuple of 3: | |
- tuple of crow_indices, col_indices representation of CSR format. | |
- block dense mask | |
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None | |
''' | |
if homo_head: | |
with torch.no_grad(): | |
(crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense) | |
crow = crow[None].expand(n_heads, crow.shape[0]) | |
col = col[None].expand(n_heads, col.shape[0]) | |
if return_dense: | |
mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape) | |
return (crow, col), block_mask_dense, mask_dense | |
with torch.no_grad(): | |
N_BLOCK = triton.cdiv(N_CTX, BLOCK) | |
q_pos = torch.arange(N_BLOCK)[None, :, None] | |
k_pos = torch.arange(N_BLOCK)[None, None] | |
head_sliding_step = max(1, int(vert_stride / n_heads)) # if vert_stride <= n_heads, rotating the heads | |
mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)] | |
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) | |
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype) | |
N_BLOCK_Q = triton.cdiv(q_len, BLOCK) | |
block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:] | |
if return_dense: | |
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK))) | |
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:] | |
mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None] | |
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense | |
else: | |
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None | |
def get_sparse_attn_mask(q, N_CTX, *args, **kwargs): | |
return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs) | |
########################################################### | |
########################################################### | |
########################################################### | |
###################### Training Kernels ################### | |
########################################################### | |
# TODO: only apply loading/saving mask on the last iteration for EVEN_N_BLOCK, useful for 1st iteration of inference. | |
# Experiment failed inside loop. | |
# Another idea: only on saving? load even out of boundary(will it causes illegal access error)? | |
def _fwd_kernel( | |
Q, K, V, sm_scale, | |
layout_crow_ptr, | |
layout_col_ptr, | |
layout_crow_stride_h, layout_crow_stride_m, | |
layout_col_stride_h, layout_col_stride_m, | |
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug. TMP, L, M are assumed to have contiguous layouts | |
Out, | |
stride_qz, stride_qh, stride_qm, stride_qd, | |
stride_kz, stride_kh, stride_kn, stride_kd, | |
stride_vz, stride_vh, stride_vn, stride_vd, | |
stride_oz, stride_oh, stride_om, stride_od, | |
Z, H, N_CTX, | |
PAST_LEN, | |
Q_ROUNDED_LEN, | |
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
EVEN_M_BLOCK: tl.constexpr, | |
EVEN_N_BLOCK: tl.constexpr, | |
INFERENCE: tl.constexpr, | |
NUM_DBLOCKS: tl.constexpr, | |
): | |
Q_LEN = N_CTX - PAST_LEN | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
off_h = off_hz % H | |
off_z = off_hz // H | |
Q += off_z * stride_qz + off_h * stride_qh | |
K += off_z * stride_kz + off_h * stride_kh | |
V += off_z * stride_vz + off_h * stride_vh | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
offs_d = tl.arange(0, BLOCK_DMODEL) | |
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd | |
# off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd | |
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd | |
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd | |
# Initialize pointers to Q, K, V | |
q_ptrs = Q + off_q | |
k_ptrs = K + off_k | |
v_ptrs = V + off_v | |
# initialize pointer to m and l | |
t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) | |
if NUM_DBLOCKS >= 2: | |
acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) | |
# load q: it will stay in SRAM throughout | |
if EVEN_M_BLOCK: | |
q = tl.load(q_ptrs) | |
if NUM_DBLOCKS >= 2: | |
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd) | |
else: | |
q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) | |
if NUM_DBLOCKS >= 2: | |
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN) | |
layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m | |
start_l = tl.load(layout_ptr).to(tl.int32) | |
end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32) | |
# loop over k, v and update accumulator | |
for col_idx_idx in range(start_l, end_l): | |
col_idx = tl.load(layout_col_ptr + off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32) | |
start_n = col_idx * BLOCK_N | |
# -- compute qk ---- | |
if EVEN_N_BLOCK: | |
k = tl.load(k_ptrs + start_n * stride_kn) | |
else: | |
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX) | |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) | |
qk += tl.dot(q, k) | |
if NUM_DBLOCKS >= 2: | |
if EVEN_N_BLOCK: | |
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd) | |
else: | |
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX) | |
qk += tl.dot(q2, k) | |
qk *= sm_scale | |
qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf')) | |
# -- compute m_ij, p, l_ij | |
m_ij = tl.max(qk, 1) | |
p = tl.exp(qk - m_ij[:, None]) | |
l_ij = tl.sum(p, 1) | |
# -- update m_i and l_i | |
m_i_new = tl.maximum(m_i, m_ij) | |
alpha = tl.exp(m_i - m_i_new) | |
beta = tl.exp(m_ij - m_i_new) | |
l_i_new = alpha * l_i + beta * l_ij | |
# -- update output accumulator -- | |
# scale p | |
p_scale = beta / l_i_new | |
p = p * p_scale[:, None] | |
# scale acc | |
acc_scale = l_i / l_i_new * alpha | |
# tl.store(t_ptrs, acc_scale) | |
# acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load | |
acc = acc * acc_scale[:, None] | |
if NUM_DBLOCKS >= 2: | |
acc2 = acc2 * acc_scale[:, None] | |
p = p.to(Q.dtype.element_ty) | |
# update acc | |
if EVEN_N_BLOCK: | |
v = tl.load(v_ptrs + start_n * stride_vn) | |
else: | |
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX) | |
acc += tl.dot(p, v) | |
if NUM_DBLOCKS >= 2: | |
if EVEN_N_BLOCK: | |
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd) | |
else: | |
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX) | |
acc2 += tl.dot(p, v) | |
# update m_i and l_i | |
l_i = l_i_new | |
m_i = m_i_new | |
# rematerialize offsets to save registers | |
# start_m = tl.program_id(0) | |
# offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
# write back l and m | |
if not INFERENCE: | |
l_ptrs = L + off_hz * N_CTX + offs_m | |
m_ptrs = M + off_hz * N_CTX + offs_m | |
if EVEN_M_BLOCK: | |
tl.store(l_ptrs, l_i) | |
tl.store(m_ptrs, m_i) | |
else: | |
tl.store(l_ptrs, l_i, mask=offs_m < Q_LEN) | |
tl.store(m_ptrs, m_i, mask=offs_m < Q_LEN) | |
# initialize pointers to output | |
# offs_n = tl.arange(0, BLOCK_DMODEL) | |
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od | |
out_ptrs = Out + off_o | |
tl.store(out_ptrs, acc, mask=offs_m[:, None] < Q_LEN) | |
if NUM_DBLOCKS >= 2: | |
tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2, mask=offs_m[:, None] < Q_LEN) | |
## backward | |
def _bwd_preprocess( | |
Out, DO, L, # assume contiguous for Out, DO, L, NewDO, Delta layout. | |
NewDO, Delta, | |
N_CTX, | |
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, | |
EVEN_M_BLOCK: tl.constexpr, | |
): | |
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) | |
off_d = tl.arange(0, D_HEAD) | |
# load | |
if EVEN_M_BLOCK: | |
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32) | |
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32) | |
else: | |
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32) | |
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32) | |
denom = tl.load(L + off_m).to(tl.float32) | |
# compute | |
do = do / denom[:, None] | |
delta = tl.sum(o * do, axis=1) | |
# write-back | |
if EVEN_M_BLOCK: | |
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do) | |
else: | |
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do, mask=off_m[:, None] < N_CTX) | |
tl.store(Delta + off_m, delta) | |
# Does not suuport unequal seqlen(q) and seqlen(k) | |
def _bwd_kernel( | |
Q, K, V, sm_scale, | |
layout_ccol_ptr, | |
layout_row_ptr, | |
layout_ccol_stride_h, layout_ccol_stride_m, | |
layout_row_stride_h, layout_row_stride_m, | |
Out, DO, # assume contigous: Out, Do, DQ, DK, DV, L, M, D, seq(q) == seq(k), with stride_oz, stride_oh, stride_om, stride_od, | |
DQ, DK, DV, | |
L, M, | |
D, | |
stride_qz, stride_qh, stride_qm, stride_qd, | |
stride_kz, stride_kh, stride_kn, stride_kd, | |
stride_vz, stride_vh, stride_vn, stride_vd, | |
stride_oz, stride_oh, stride_om, stride_od, | |
# stride_dz, stride_dh, stride_dm, stride_dd, | |
Z, H, N_CTX, | |
num_block, | |
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
EVEN_M_BLOCK: tl.constexpr, | |
EVEN_N_BLOCK: tl.constexpr, | |
NUM_DBLOCKS: tl.constexpr, | |
): | |
start_n = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
off_z = off_hz // H | |
off_h = off_hz % H | |
# offset pointers for batch/head | |
Q += off_z * stride_qz + off_h * stride_qh | |
K += off_z * stride_kz + off_h * stride_kh | |
V += off_z * stride_vz + off_h * stride_vh | |
DO += off_z * stride_oz + off_h * stride_oh | |
DQ += off_z * stride_oz + off_h * stride_oh | |
DK += off_z * stride_oz + off_h * stride_oh | |
DV += off_z * stride_oz + off_h * stride_oh | |
# Look like this loop can be parallelled | |
# for start_n in range(0, num_block): | |
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
offs_m = tl.arange(0, BLOCK_M) | |
offs_d = tl.arange(0, BLOCK_DMODEL) | |
# initialize pointers to value-like data | |
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd) | |
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd) | |
# pointer to row-wise quantities in value-like data | |
D_ptrs = D + off_hz * N_CTX | |
m_ptrs = M + off_hz * N_CTX | |
# initialize dv amd dk | |
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | |
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | |
# k and v stay in SRAM throughout | |
if EVEN_N_BLOCK: | |
k = tl.load(k_ptrs) | |
v = tl.load(v_ptrs) | |
else: | |
k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX) | |
v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX) | |
if NUM_DBLOCKS >= 2: | |
dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | |
dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | |
if EVEN_N_BLOCK: | |
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd) | |
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd) | |
else: | |
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX) | |
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX) | |
# loop over rows | |
layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m | |
start_l = tl.load(layout_ptr).to(tl.int32) | |
end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32) | |
for row_idx_idx in range(start_l, end_l): | |
row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32) | |
start_m = row_idx * BLOCK_M | |
# offs_qm = start_m + tl.arange(0, BLOCK_M) | |
offs_m_curr = start_m + offs_m | |
q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd) | |
do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od) | |
dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od) | |
# load q, k, v, do on-chip | |
if EVEN_M_BLOCK: | |
q = tl.load(q_ptrs) | |
else: | |
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX) | |
# re-compute p = softmax(qk, dim=-1).T | |
# NOTE: `do` is pre-divided by `l`; no normalization here | |
qk = tl.dot(q, tl.trans(k)) | |
if NUM_DBLOCKS >= 2: | |
if EVEN_M_BLOCK: | |
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd) | |
else: | |
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX) | |
qk += tl.dot(q2, tl.trans(k2)) | |
qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf')) | |
if EVEN_M_BLOCK: | |
m = tl.load(m_ptrs + offs_m_curr) | |
else: | |
m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX) | |
p = tl.exp(qk * sm_scale - m[:, None]) | |
# compute dv | |
if EVEN_M_BLOCK: | |
do = tl.load(do_ptrs) | |
else: | |
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX) | |
if NUM_DBLOCKS >= 2: | |
if EVEN_M_BLOCK: | |
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od) | |
else: | |
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX) | |
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) | |
if NUM_DBLOCKS >= 2: | |
dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2) | |
# compute dp = dot(v, do) | |
if EVEN_M_BLOCK: | |
Di = tl.load(D_ptrs + offs_m_curr) | |
else: | |
Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX) | |
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] | |
dp += tl.dot(do, tl.trans(v)) | |
if NUM_DBLOCKS >= 2: | |
dp += tl.dot(do2, tl.trans(v2)) | |
# compute ds = p * (dp - delta[:, None]) | |
ds = p * dp * sm_scale | |
# compute dk = dot(ds.T, q) | |
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) | |
if NUM_DBLOCKS >= 2: | |
dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2) | |
# # compute dq | |
dq = tl.dot(ds.to(Q.dtype.element_ty), k) | |
if EVEN_M_BLOCK: | |
tl.atomic_add(dq_ptrs, dq) | |
else: | |
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX) | |
if NUM_DBLOCKS >= 2: | |
dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2) | |
dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od | |
if EVEN_M_BLOCK: | |
tl.atomic_add(dq_ptrs2, dq2) | |
else: | |
tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX) | |
# write-back | |
dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od) | |
dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od) | |
if EVEN_N_BLOCK: | |
tl.store(dv_ptrs, dv) | |
tl.store(dk_ptrs, dk) | |
else: | |
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX) | |
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX) | |
if NUM_DBLOCKS >= 2: | |
dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od | |
dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od | |
if EVEN_N_BLOCK: | |
tl.store(dv_ptrs2, dv2) | |
tl.store(dk_ptrs2, dk2) | |
else: | |
tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX) | |
tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX) | |
def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None): | |
''' | |
:param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v. | |
:param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor. | |
Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all.. | |
''' | |
assert q.shape[-1] == k.shape[-1] == v.shape[-1] | |
assert k.shape[2] == v.shape[2] | |
o = out if out is not None else torch.empty_like(q).contiguous() | |
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) | |
q_rounded_len = grid[0] * BLOCK_M | |
tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) | |
if inference is None: | |
inference = (not q.requires_grad) and (not k.requires_grad) and (not v.requires_grad) | |
if inference: | |
L, m = tmp, tmp # no need to use create new tensor | |
else: | |
L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) | |
m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) | |
if layout_col_indices.dim() == 1: | |
layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1) | |
layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1) | |
assert q.shape[-1] in [64, 128] | |
BLOCK_DMODEL = 64 | |
if num_warps is None: | |
MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL) | |
num_warps = max(1, 2 ** int(math.log2(MIN_D / 16))) | |
# print(f'> {BLOCK_M=}, {BLOCK_N=}, {BLOCK_DMODEL=}, {num_warps=}, {num_stages=}') | |
else: | |
assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.''' | |
## For debugging: | |
# print(f'>> {q.shape=}, {k.shape=}, {BLOCK_M=}, {BLOCK_N=}, {num_warps=}, {BLOCK_DMODEL=}, {q.stride()=}, {k.stride()=}') | |
# print(f'>> {layout_crow_indices=}\n{layout_col_indices=}\n {layout_crow_indices.stride()=}, {layout_crow_indices.stride()=}') | |
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \ | |
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}') | |
_fwd_kernel[grid]( | |
q, k, v, sm_scale, | |
layout_crow_indices, | |
layout_col_indices, | |
layout_crow_indices.stride(0), layout_crow_indices.stride(1), | |
layout_col_indices.stride(0), layout_col_indices.stride(1), | |
tmp, L, m, | |
o, | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), | |
k.stride(0), k.stride(1), k.stride(2), k.stride(3), | |
v.stride(0), v.stride(1), v.stride(2), v.stride(3), | |
o.stride(0), o.stride(1), o.stride(2), o.stride(3), | |
q.shape[0], q.shape[1], k.shape[2], | |
k.shape[2] - q.shape[2], | |
q_rounded_len, | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, | |
BLOCK_DMODEL=BLOCK_DMODEL, | |
EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0, | |
EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 , | |
INFERENCE=inference, | |
NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL, | |
num_warps=num_warps, | |
num_stages=num_stages, | |
) | |
if inference: | |
L, m = None, None | |
ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices, layout_col_indices) | |
ctx.BLOCK_M = BLOCK_M | |
ctx.BLOCK_N = BLOCK_N | |
ctx.BLOCK_DMODEL = BLOCK_DMODEL | |
# ctx.BLOCK = BLOCK | |
ctx.grid = grid | |
ctx.sm_scale = sm_scale | |
ctx.num_warps = num_warps | |
ctx.num_stages = num_stages | |
return o | |
def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None): | |
# q, k, v, o, l, m = ctx.saved_tensors | |
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors | |
## this following too slow to do online, so get it from inputs, which is cached. | |
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices)) | |
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices)) | |
if not do.is_contiguous(): | |
do = do.contiguous() | |
## for debugging | |
# print(f'----> do is not contiguous: {do.stride()=}') | |
# raise ValueError(f'>>>> output grad is not contiguous: {do.stride()=}') | |
if not o.is_contiguous(): | |
# TODO: currently only work with contiguous q/k/v. | |
raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.') | |
if layout_ccol_indices.dim() == 1: | |
layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1) | |
layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1) | |
# do = do.contiguous() | |
dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32) | |
dk = dk if dk is not None else torch.empty_like(k) | |
dv =dv if dv is not None else torch.empty_like(v) | |
do_scaled = torch.empty_like(do) | |
delta = torch.empty_like(l) | |
assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride() | |
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( | |
o, do, l, | |
do_scaled, delta, | |
k.shape[2], | |
BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1], | |
) | |
grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1]) | |
_bwd_kernel[grid]( | |
q, k, v, ctx.sm_scale, | |
layout_ccol_indices, | |
layout_row_indices, | |
layout_ccol_indices.stride(0), layout_ccol_indices.stride(1), | |
layout_row_indices.stride(0), layout_row_indices.stride(1), | |
o, do_scaled, | |
dq, dk, dv, | |
l, m, | |
delta, | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), | |
k.stride(0), k.stride(1), k.stride(2), k.stride(3), | |
v.stride(0), v.stride(1), v.stride(2), v.stride(3), | |
o.stride(0), o.stride(1), o.stride(2), o.stride(3), | |
q.shape[0], q.shape[1], q.shape[2], | |
ctx.grid[0], | |
BLOCK_M=ctx.BLOCK_M, | |
BLOCK_N=ctx.BLOCK_N, | |
BLOCK_DMODEL=ctx.BLOCK_DMODEL, | |
NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL, | |
num_warps=ctx.num_warps, | |
num_stages=1, | |
) | |
return dq, dk, dv, None, None, None | |
class _sparse_attention(torch.autograd.Function): | |
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): | |
BLOCK = 128 | |
# shape constraints | |
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK) | |
def backward(ctx, do): | |
# q, k, v, o, l, m = ctx.saved_tensors | |
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors | |
# TODO: the following is very inefficient. | |
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices)) | |
layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices)) | |
return _backward(ctx, do, layout_ccol_indices, layout_row_indices) | |
# suppressed | |
class _sparse_attention_inference(_sparse_attention): | |
# TODO: does not work now, as BLOCK_M cannot be <1, as shape for tl.dot cannot be smaller than 16. | |
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): | |
BLOCK = 128 | |
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK) | |
def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs): | |
class _sparse_attention_config(_sparse_attention): | |
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): | |
# shape constraints | |
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, | |
**kwargs | |
) | |
return _sparse_attention_config.apply | |
def get_local_strided_sparse_attention_op( | |
n_heads: int, | |
max_seq_len:int, | |
sparse_block_size: int=128, | |
local_blocks: int=4, | |
vert_stride: int=4, | |
homo_head: bool=False, | |
dtype=torch.bfloat16, | |
device='cuda', | |
active_head_range=None, | |
verbose=True, | |
**kwargs): | |
''' | |
:param n_heads: total number of attention heads (regardless of tensor/model parallel) | |
:param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences. | |
:param sparse_block_size: sparse block size. Default to 128 | |
:param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens. | |
:param vert_stride: Default to 4. Meaning | |
:param homo_head: if all head shared the same pattern. | |
:param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads. | |
Mainly for tensor/model parallelization where heads are splitted to different GPUs. | |
''' | |
if verbose: | |
print((f'> new block_sparse_attn op constructed with config: ' | |
f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, ' | |
f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}')) | |
# assert math.log2(max_seq_len) % 2 == 0, f"max_seq_len should be power of 2 to be more efficient" | |
_, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device, | |
BLOCK=sparse_block_size, local_blocks=local_blocks, | |
vert_stride=vert_stride, homo_head=homo_head, | |
return_dense=False) | |
if (not homo_head) and (active_head_range is not None): | |
assert isinstance(active_head_range, tuple) | |
assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.' | |
h_start, h_end = active_head_range | |
block_sparse_pattern = block_sparse_pattern[h_start:h_end] | |
# print(block_sparse_pattern) | |
return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs) | |
def get_sparse_attn_op( | |
sparse_pattern: torch.tensor, | |
sparse_block_size: int=128, | |
kernel_block_size=128, | |
qkv_format='q,k,v', | |
**kwargs): | |
''' | |
Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime, | |
which is very inefficient (use python loops on CPU. PyTorch 1.13 supports CSR->CSC, may help.) | |
:param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`. | |
This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention | |
:param sparse_block_size: sparse block size. Default to 128 | |
:param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size` | |
:param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported. | |
:param kwargs: keyward arguments passed to `_forward` | |
''' | |
# assert qkv_format in ('q,k,v', 'q, kv', 'qkv') # to save from running `concat` at forward/backward | |
assert qkv_format == 'q,k,v' | |
if kernel_block_size is None: | |
kernel_block_size = sparse_block_size | |
else: | |
assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}." | |
assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given" | |
# print(f'>> {sparse_pattern.shape=}') | |
# print(f'{sparse_pattern=}') | |
if sparse_block_size // kernel_block_size > 1: | |
_mul = sparse_block_size // kernel_block_size | |
# need to consider if block_m and block_n are different | |
sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul)) | |
num_sparse_blocks = sparse_pattern.size(-1) | |
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None] | |
sparse_pattern *= block_causal_mask.type_as(sparse_pattern) | |
# print(f'>> after: {sparse_pattern.shape=}') | |
# print(f'{sparse_pattern=}') | |
BLOCK_N = kernel_block_size | |
NUM_BLOCK = sparse_pattern.size(-1) | |
MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK | |
grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern) | |
# sparse csc layout for backward | |
grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern) | |
# cache GPU backward layout. limit the size to avoid OOM as time goes. | |
# For inference, one only needs to cache one block as sequence length always increases | |
# Therefore, this cache needs to be reconstructed per every `block_size`-steps. | |
# For training/finetune, set to 8 to increase cache hit. | |
# Given an input, the block_len will be the same for all layers, so cache is very helpful. | |
max_cache_size = 1 if kwargs.get('inference', False) else 8 | |
def get_backward_layout_by_block_len(block_len): | |
assert block_len <= NUM_BLOCK | |
if block_len == NUM_BLOCK: | |
return (grand_layout_ccol_indices, grand_layout_row_indices) | |
return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len]) | |
# for debugging | |
# if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: | |
# print(f'> {sparse_pattern.cpu().tolist()=}') | |
# print('----') | |
# print(f'> {grand_layout_crow_indices.cpu().tolist()=}\n{grand_layout_col_indices.cpu().tolist()=}') | |
# q, k, v separated | |
class _q_k_v_sparse_attention(torch.autograd.Function): | |
def forward(ctx, q, k, v, sm_scale): | |
# assert q.shape[2] == 1 or q.shape[2] == k.shape[2] | |
# shape constraints | |
MIN_BLOCK_SIZE = 16 | |
assert BLOCK_N >= MIN_BLOCK_SIZE | |
BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N # BLOCK_M has to be power of 2 | |
# this following code only works for causal attention | |
K_BLOCKS = triton.cdiv(k.shape[2], kernel_block_size) | |
# Q_START_BLOCKS = K_BLOCKS - 1 if q.shape[2] == 1 else 0 | |
Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N) | |
# print(Q_START_BLOCKS, K_BLOCKS) | |
layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1] | |
layout_col_indices = grand_layout_col_indices | |
# print(BLOCK_M, BLOCK_N, Q_START_BLOCKS, K_BLOCKS+1, layout_crow_indices, layout_col_indices) | |
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, | |
**kwargs | |
) | |
def backward(ctx, do): | |
q, k = ctx.saved_tensors[:2] | |
assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.' | |
# assume q, k have same length | |
block_len = triton.cdiv(do.shape[2], kernel_block_size) | |
backward_layout = get_backward_layout_by_block_len(block_len) | |
return _backward(ctx, do, *backward_layout)[:4] | |
def _q_k_v_sparse_attention_fn(*args): | |
return _q_k_v_sparse_attention.apply(*args) | |
_q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern | |
_q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices | |
_q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices | |
_q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices | |
_q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices | |
return _q_k_v_sparse_attention_fn | |
########################################################### | |
########################################################### | |
########################################################### | |
################ Inference Kernels ######################## | |
########################################################### | |
def blocksparse_flash_attn_padded_fwd( | |
q, k, v, # (batch, tokens, n_heads, head_size) | |
sm_scale, | |
sparse_layout, | |
*, | |
left_paddings = None, | |
seqlens = None, | |
block_size = 64, | |
max_seqlen = None | |
): | |
''' | |
q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size) | |
left_paddings: (batch, ), number of left paddings for each sample. | |
seqlens: can be used to specify right padding. No need to specify if left_paddings is used. | |
''' | |
batches, q_len, n_heads, head_size = q.shape | |
_, k_len, n_kv_heads, _ = k.shape | |
assert q.dim() == k.dim() == v.dim() == 4 | |
assert q.size(2) % k.size(2) == 0 | |
assert q.size(0) == k.size(0) and q.size(3) == k.size(3) | |
assert k.shape == v.shape # TODO: allow diff head_size for k, v | |
assert q_len == 1 or q_len == k_len, \ | |
f'q length can only 1 for decoding for same as k length for prefilling.' | |
q_k_ratio = q.size(2) // k.size(2) | |
if max_seqlen: | |
assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.' | |
# paddings always has zero output, a little slower than using empty | |
out = q.new_zeros(q.shape) | |
layout_crow_indices, layout_col_indices = sparse_layout | |
block_d = triton.next_power_of_2(head_size) | |
if left_paddings is not None: | |
assert left_paddings.shape == (batches,) | |
k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous() | |
else: | |
k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device) | |
if seqlens is not None: | |
k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts) | |
assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.' | |
else: | |
k_batch_ends = torch.zeros_like(k_batch_starts) + k_len | |
if q_len == 1: | |
q_batch_starts = torch.zeros_like(k_batch_starts) | |
q_batch_ends = q_batch_starts + 1 | |
else: | |
q_batch_starts = k_batch_starts | |
q_batch_ends = k_batch_ends | |
# switch to use cpu to avoid too many kernel lauch when iterate over | |
q_lens = (q_batch_ends - q_batch_starts).cpu() | |
n_blocks = (q_lens + block_size - 1) // block_size | |
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)], | |
dtype=q_batch_starts.dtype, | |
device=q_batch_starts.device) | |
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)], | |
dtype=q_batch_starts.dtype, | |
device=q_batch_starts.device) | |
grid = (len(q_start_sids), n_heads) | |
_fwd_kernel_batch_inference[grid]( | |
q, k, v, out, | |
sm_scale, | |
q_batch_starts, | |
q_batch_ends, | |
k_batch_starts, | |
k_batch_ends, | |
q_batch_ids, | |
q_start_sids, | |
*q.stride(), | |
*k.stride(), | |
*v.stride(), | |
*out.stride(), | |
layout_crow_indices, | |
layout_col_indices, | |
*layout_crow_indices.stride(), | |
*layout_col_indices.stride(), | |
q_k_ratio, | |
HAS_BATCH_DIM = True, | |
D_HEAD = head_size, | |
BLOCK_M = block_size, | |
BLOCK_N = block_size, | |
BLOCK_D = block_d, | |
BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding | |
EVEN_D = block_d == head_size, | |
num_warps = 1 if q_len == 1 else 4, | |
num_stages = 1 # <---- instead of 3 | |
) | |
return out | |
def blocksparse_flash_attn_varlen_fwd( | |
q, k, v, # (#tokens, n_heads, head_size) | |
cu_seqlens_k, | |
cu_seqlens_q, | |
sm_scale, | |
sparse_layout, | |
*, | |
block_size=64, | |
max_seqlen = None | |
): | |
# split q to blocks | |
_, n_heads, head_size = q.shape | |
batch_size = cu_seqlens_k.size(0) - 1 | |
# print(f'> {q.shape=}, {k.shape=}') | |
assert q.dim() == k.dim() == v.dim() == 3 | |
assert q.size(1) % k.size(1) == 0 | |
assert q.size(2) == k.size(2) | |
assert k.shape == v.shape # TODO: allow diff head_size for k, v | |
assert cu_seqlens_k.dim() == 1 | |
q_k_ratio = q.size(1) // k.size(1) | |
if cu_seqlens_q is None: | |
if q.size(0) == batch_size: # decoding only | |
cu_seqlens_q = torch.arange(0, batch_size + 1, | |
dtype=cu_seqlens_k.dtype, | |
device=cu_seqlens_k.device) | |
elif q.size(0) == k.size(0): | |
cu_seqlens_q = cu_seqlens_k | |
else: | |
raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.') | |
else: | |
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) | |
# switch to use cpu to avoid too many kernel lauch when iterate over | |
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() | |
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() | |
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \ | |
'length of q should either be 1 (decoding) or same as k (prefilling).' | |
if max_seqlen: | |
assert k_lens.max() <= max_seqlen | |
n_blocks = (q_lens + block_size - 1) // block_size | |
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)], | |
dtype=cu_seqlens_q.dtype, | |
device=cu_seqlens_q.device) | |
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)], | |
dtype=cu_seqlens_q.dtype, | |
device=cu_seqlens_q.device) | |
out = q.new_empty(q.shape) | |
cu_seqlens_q = cu_seqlens_q.contiguous() | |
cu_seqlens_k = cu_seqlens_k.contiguous() | |
layout_crow_indices, layout_col_indices = sparse_layout | |
block_d = triton.next_power_of_2(head_size) | |
decoding_only = (q_lens == 1).all() | |
grid = (len(q_start_sids), n_heads) | |
_fwd_kernel_batch_inference[grid]( | |
q, k, v, out, | |
sm_scale, | |
cu_seqlens_q[:-1], | |
cu_seqlens_q[1:], | |
cu_seqlens_k[:-1], | |
cu_seqlens_k[1:], | |
q_batch_ids, | |
q_start_sids, | |
0, *q.stride(), | |
0, *k.stride(), | |
0, *v.stride(), | |
0, *out.stride(), | |
layout_crow_indices, | |
layout_col_indices, | |
*layout_crow_indices.stride(), | |
*layout_col_indices.stride(), | |
q_k_ratio, | |
HAS_BATCH_DIM = False, | |
D_HEAD = head_size, | |
BLOCK_M = block_size, | |
BLOCK_N = block_size, | |
BLOCK_D = block_d, | |
BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding | |
EVEN_D = block_d == head_size, | |
num_warps = 1 if decoding_only else 4, | |
num_stages = 3 | |
) | |
return out | |
def _fwd_kernel_inner( | |
acc, l_i, m_i, | |
q, Q, | |
k_block_col_idx, | |
layout_col_ptr, | |
layout_col_stride_h, layout_col_stride_m, | |
k_ptrs, | |
v_ptrs, | |
off_h, offs_m, offs_n, offs_d, | |
stride_kt, stride_vt, | |
sm_scale, | |
k_seqlen, | |
past_len, | |
LAST_K_BLOCK: tl.constexpr, | |
BLOCK_M_LOADING: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
D_HEAD: tl.constexpr, | |
EVEN_D: tl.constexpr, | |
M_LT_N: tl.constexpr | |
): | |
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32) | |
start_n = k_block_id * BLOCK_N | |
# -- compute qk ---- | |
if LAST_K_BLOCK: | |
if EVEN_D: | |
k = tl.load(k_ptrs + start_n * stride_kt, | |
mask=offs_n[None, :] + start_n < k_seqlen) | |
else: | |
# mask = mask & (offs_d[:, ]) | |
k = tl.load(k_ptrs + start_n * stride_kt, | |
mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD)) | |
else: | |
if EVEN_D: | |
k = tl.load(k_ptrs + start_n * stride_kt) | |
else: | |
k = tl.load(k_ptrs + start_n * stride_kt, | |
mask=offs_d[:, None] < D_HEAD) | |
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) | |
qk += tl.dot(q, k) | |
qk *= sm_scale | |
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N | |
if LAST_K_BLOCK | M_LT_N: | |
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) | |
# -- compute m_ij, p, l_ij | |
m_ij = tl.max(qk, 1) | |
p = tl.exp(qk - m_ij[:, None]) | |
l_ij = tl.sum(p, 1) | |
# -- update m_i and l_i | |
m_i_new = tl.maximum(m_i, m_ij) | |
alpha = tl.exp(m_i - m_i_new) | |
beta = tl.exp(m_ij - m_i_new) | |
l_i_new = alpha * l_i + beta * l_ij | |
# -- update output accumulator -- | |
# scale p | |
p_scale = beta / l_i_new | |
p = p * p_scale[:, None] | |
# scale acc | |
acc_scale = l_i / l_i_new * alpha | |
acc = acc * acc_scale[:, None] | |
p = p.to(Q.dtype.element_ty) | |
# update acc | |
if LAST_K_BLOCK: | |
if EVEN_D: | |
v = tl.load(v_ptrs + start_n * stride_vt, | |
mask=offs_n[:, None] + start_n < k_seqlen) | |
else: | |
v = tl.load(v_ptrs + start_n * stride_vt, | |
mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD)) | |
else: | |
if EVEN_D: | |
v = tl.load(v_ptrs + start_n * stride_vt) | |
else: | |
v = tl.load(v_ptrs + start_n * stride_vt, | |
mask=offs_d[None, :] < D_HEAD) | |
acc += tl.dot(p, v) | |
# update m_i and l_i | |
l_i = l_i_new | |
m_i = m_i_new | |
return acc, l_i, m_i | |
def _fwd_kernel_batch_inference( | |
Q, K, V, Out, | |
sm_scale, | |
q_batch_starts, | |
q_batch_ends, | |
k_batch_starts, | |
k_batch_ends, | |
q_batch_ids, | |
q_start_sids, | |
stride_qb, stride_qt, stride_qh, stride_qd, | |
stride_kb, stride_kt, stride_kh, stride_kd, | |
stride_vb, stride_vt, stride_vh, stride_vd, | |
stride_ob, stride_ot, stride_oh, stride_od, | |
layout_crow_ptr, | |
layout_col_ptr, | |
layout_crow_stride_h, layout_crow_stride_m, | |
layout_col_stride_h, layout_col_stride_m, | |
q_k_ratio, | |
HAS_BATCH_DIM: tl.constexpr, | |
D_HEAD: tl.constexpr, | |
BLOCK_M: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
BLOCK_D: tl.constexpr, | |
BLOCK_M_LOADING: tl.constexpr, | |
EVEN_D: tl.constexpr, | |
M_LT_N: tl.constexpr | |
): | |
''' | |
NOTATION: | |
pid: position id | |
sid: storage id | |
sbid: storage block id | |
pbid: position block id | |
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) | |
q and blocks in KV needs to be contiguous | |
Arguments: | |
kv_seq_lens: for compute past_len | |
kv_storage_offsets: similar to block_tables in vllm, except it is dynamic. | |
TODO: fix this | |
TODO: | |
Optimize grouped-attn | |
CUDA graph support issue | |
1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...) | |
since we mix prompt and decoing phase here, it can be more complex. | |
need to set up diff cuda-graph for diff (off_zm, off_z) | |
# indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding | |
therefore, cu_seqlens_q, kv_seq_lens | |
''' | |
off_zm = tl.program_id(0) | |
off_h = tl.program_id(1) | |
off_h_for_kv = off_h // q_k_ratio | |
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] | |
q_start_sid = tl.load(q_start_sids + off_zm) | |
start_m = q_start_sid // BLOCK_M | |
if HAS_BATCH_DIM: | |
Q += off_z * stride_qb | |
K += off_z * stride_kb | |
V += off_z * stride_vb | |
Out += off_z * stride_ob | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) | |
offs_n = tl.arange(0, BLOCK_N) | |
offs_d = tl.arange(0, BLOCK_D) | |
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) | |
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start | |
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) | |
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start | |
past_len = k_seqlen - q_seqlen | |
Q += q_cu_start * stride_qt + off_h * stride_qh | |
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh | |
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh | |
Out += q_cu_start * stride_ot + off_h * stride_oh | |
q_pbid = (past_len + q_start_sid) // BLOCK_M | |
if EVEN_D: | |
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, | |
mask=offs_m[:, None] < q_seqlen) | |
else: | |
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, | |
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), | |
other=0) | |
sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m | |
# TODO: load at once, supported in new Triton | |
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) | |
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) | |
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf') | |
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) | |
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) | |
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd | |
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd | |
for k_block_col_idx in range(k_block_start, k_block_end - 1): | |
acc, l_i, m_i = _fwd_kernel_inner( | |
acc, l_i, m_i, | |
q, Q, | |
k_block_col_idx, | |
layout_col_ptr, | |
layout_col_stride_h, layout_col_stride_m, | |
k_ptrs, | |
v_ptrs, | |
off_h, offs_m, offs_n, offs_d, | |
stride_kt, stride_vt, | |
sm_scale, | |
k_seqlen, | |
past_len, | |
False, | |
BLOCK_M_LOADING, | |
BLOCK_N, | |
D_HEAD, | |
EVEN_D, | |
M_LT_N | |
) | |
acc, l_i, m_i = _fwd_kernel_inner( | |
acc, l_i, m_i, | |
q, Q, | |
k_block_end - 1, | |
layout_col_ptr, | |
layout_col_stride_h, layout_col_stride_m, | |
k_ptrs, | |
v_ptrs, | |
off_h, offs_m, offs_n, offs_d, | |
stride_kt, stride_vt, | |
sm_scale, | |
k_seqlen, | |
past_len, | |
True, | |
BLOCK_M_LOADING, | |
BLOCK_N, | |
D_HEAD, | |
EVEN_D, | |
M_LT_N | |
) | |
# write output | |
if EVEN_D: | |
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc, | |
mask=offs_m[:, None] < q_seqlen) | |
else: | |
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc, | |
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD)) | |
########################################################### | |
########################################################### | |
########################################################### | |
################## Testing Utilities ###################### | |
########################################################### | |
def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None): | |
''' | |
q, k, v: shape=(batch, n_heads, seq, dim) | |
''' | |
# for verification | |
if sm_scale is None: | |
sm_scale = math.sqrt(float(q.size(-1))) | |
if block_attn_mask is not None: | |
assert attn_mask is None | |
outs = [] | |
for s in range(0, q.size(2), block_size): | |
e = min(s + block_size, q.size(2)) | |
q_block = q[:, :, s:e] | |
attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale | |
mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)] | |
mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device)) | |
mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0) | |
attn = attn.masked_fill((1 - mask).bool(), float('-inf')) | |
attn = attn.softmax(-1) | |
out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e]) | |
outs.append(out) | |
torch_output = torch.cat(outs, dim=2) | |
else: | |
attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale | |
# import ipdb; ipdb.set_trace() | |
if attn_mask is not None: | |
attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf')) | |
# print(f'> torch attn: {attn.exp().sum(-1)=}') | |
attn = attn.softmax(-1) | |
if do is not None: | |
dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do) | |
print(f'> torch_attn computed dv: {dv=}') | |
torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v) | |
return torch_output | |
########################################################### | |
########################################################### | |
########################################################### | |
#################### Unit Tests ########################### | |
########################################################### | |
def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True, | |
sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None): | |
Q_LEN = Q_LEN or N_CTX | |
torch.manual_seed(20) | |
q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_() | |
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_() | |
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_() | |
if sm_scale is None: | |
sm_scale = 1. / math.sqrt(D_HEAD) | |
# for debugging | |
# print(f'>> {q.shape=}, {k.shape=}, {v.shape=}, {homo_head=}, {kernel_block_size=}, {sparse_block_size=}, {local_blocks=}, {vert_stride=}') | |
sm_scale = 0.0078125 | |
if backward: | |
q.requires_grad_(), k.requires_grad_(), v.requires_grad_() | |
# qkv = torch.empty((Z, N_CTX, 3*H*D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) | |
# q = qkv[..., :H*D_HEAD] | |
# k = qkv[..., H*D_HEAD:2*H*D_HEAD] | |
# v = qkv[..., 2*H*D_HEAD:] | |
# q = q.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3) | |
# k = k.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3) | |
# v = v.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3) | |
# if Q_LEN and Q_LEN < N_CTX: | |
# q = q[:, :, -Q_LEN:] # .contiguous() | |
# q = q.requires_grad_() | |
# k = k.requires_grad_() | |
# v = v.requires_grad_() | |
dout = torch.randn_like(q).contiguous() | |
# dout = torch.eye(N_CTX)[:, :D_HEAD][None, None].expand_as(q).type_as(q).contiguous() | |
# print(dout) | |
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size, | |
local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True) | |
if sparse_attention_fn is None: | |
sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX, | |
sparse_block_size=sparse_block_size, | |
local_blocks=local_blocks, | |
vert_stride=vert_stride, | |
homo_head=homo_head, | |
device=q.device, | |
dtype=q.dtype, | |
kernel_block_size=kernel_block_size) | |
# reference implementation | |
ref_out = torch_attention(q, k, v, mask_dense, sm_scale) | |
# lengths = torch.full((Z,), fill_value=N_CTX, device='cuda') | |
# cu_seqlens = torch.zeros((Z + 1,), device='cuda', dtype=torch.int32) | |
# cu_seqlens[1:] = lengths.cumsum(0) | |
# # qkv = torch.randn((Z * N_CTX, 3, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
# qkv_list = list(map(lambda x: x.permute(0, 2, 1, 3).contiguous().view(Z * N_CTX, 1, H, D_HEAD), [q, k, v])) | |
# qkv = torch.cat(qkv_list, dim=1) | |
# ref_out0 = flash_attn_func(qkv, cu_seqlens, dropout_p=0, max_s=N_CTX, softmax_scale=sm_scale, causal=True) | |
# ref_out = ref_out0.view(Z, N_CTX, H, D_HEAD).permute(0, 2, 1, 3).contiguous() | |
if backward: | |
ref_out.backward(dout) | |
ref_dv, v.grad = v.grad.clone(), None | |
ref_dk, k.grad = k.grad.clone(), None | |
ref_dq, q.grad = q.grad.clone(), None | |
tri_out = sparse_attention_fn(q, k, v, sm_scale) | |
decimal = 1 if dtype == torch.bfloat16 else 2 | |
assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}' | |
if backward: | |
tri_out.backward(dout) | |
tri_dv, v.grad = v.grad.clone(), None | |
tri_dk, k.grad = k.grad.clone(), None | |
tri_dq, q.grad = q.grad.clone(), None | |
if backward: | |
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2) | |
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) | |
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0) | |
print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}') | |
########################################################### | |
if __name__ == '__main__': | |
GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip() | |
# print(GPU_TYPE) | |
support_backward = True # 'A100' in GPU_TYPE. Wasn't supportted in consumer A1000. | |
############### | |
# benchmarking | |
HAS_DENSE_TRITON_FLASH = False | |
# try: | |
# from triton.ops.flash_attention import attention as triton_attention | |
# HAS_DENSE_TRITON_FLASH = True | |
# except: | |
# HAS_DENSE_TRITON_FLASH = False | |
# print('> cannot import Trition flash attn') | |
try: | |
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func | |
HAS_FLASH = True | |
except BaseException: | |
HAS_FLASH = False | |
print('> cannot import flash_attn') | |
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 | |
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128 # 6.7B model, with 4k len | |
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 16, 4096, 128 # 204m model | |
BLOCK_SIZE = 64 | |
LOCAl_BLOCKS = 8 # 4 | |
VERT_STRIDE = 1 # 16 # 8 | |
HOMO_HEAD = False | |
sparse_type = 'home' if HOMO_HEAD else 'hetero' | |
dtype = torch.bfloat16 | |
modes = ['fwd', 'bwd'] if support_backward else ['fwd'] | |
configs = [triton.testing.Benchmark( | |
x_names=['SEQ_LEN'], | |
x_vals=[2**i for i in range(8, 16)], | |
line_arg='provider', | |
line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'], | |
line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else []) + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'], | |
styles=[('red', '-'), ('blue', '-'), ('green', '-')], | |
ylabel='ms', | |
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}', | |
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode} | |
) for mode in modes] | |
def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None): | |
assert mode in ['fwd', 'bwd'] | |
warmup = 25 | |
rep = 100 | |
N_CTX = SEQ_LEN | |
if provider == 'triton': | |
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
sm_scale = 1.3 | |
fn = lambda: triton_attention(q, k, v, sm_scale) | |
if mode == 'bwd': | |
o = fn() | |
do = torch.randn_like(o) | |
fn = lambda: o.backward(do, retain_graph=True) | |
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | |
return ms | |
if provider == 'triton_sparse': | |
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
sm_scale = 1.3 | |
# q_pos = torch.arange(N_CTX // BLOCK, device='cuda')[:, None] | |
# k_pos = torch.arange(N_CTX // BLOCK, device='cuda')[None] | |
# local_blocks = 4 # num_block per attn, block_size is tied to BLOCK | |
# vert_stride =N_CTX + 1 # 4 | |
# mask_vert_strided = torch.arange(N_CTX // BLOCK, device='cuda') % vert_stride == vert_stride - 1 | |
# mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).type_as(q) | |
# mask = mask_dense.to_sparse_csr() | |
# mask_csr, _ = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK, local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=HOMO_HEAD) | |
if sparse_attention_fn is None: | |
# sparse_attention_fn = sparse_attention | |
sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN, | |
local_blocks=LOCAl_BLOCKS, | |
vert_stride=VERT_STRIDE, | |
homo_head=HOMO_HEAD, | |
sparse_block_size=BLOCK_SIZE, | |
kernel_block_size=BLOCK_SIZE, | |
device=q.device) | |
# sparse_attention_fn = sparse_attention_factory(128, 128, num_warps=8) | |
# fn = lambda: sparse_attention_fn(q, k, v, mask_csr[0], mask_csr[1], sm_scale) | |
fn = lambda: sparse_attention_fn(q, k, v, sm_scale) | |
if mode == 'bwd': | |
o = fn() | |
do = torch.randn_like(o) | |
fn = lambda: o.backward(do, retain_graph=True) | |
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | |
return ms | |
if provider == 'flash': | |
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) | |
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) | |
cu_seqlens[1:] = lengths.cumsum(0) | |
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) | |
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) | |
if mode == 'bwd': | |
o = fn() | |
do = torch.randn_like(o) | |
fn = lambda: o.backward(do, retain_graph=True) | |
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | |
return ms | |
# if provider == 'torch': | |
# q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
# k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
# v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | |
# sm_scale = 1.3 | |
# causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(q) | |
# fn = lambda: torch_attention(q, k, v, causal_mask, sm_scale) | |
# ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) | |
# return ms | |
BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1 # 6.7B model, with 4k len | |
BLOCK_SIZE = 64 | |
LOCAl_BLOCKS = 8 # 4 | |
VERT_STRIDE = 16 # 8 | |
HOMO_HEAD = False | |
sparse_type = 'home' if HOMO_HEAD else 'hetero' | |
dtype = torch.bfloat16 | |
MAX_N_CTX = 8192 | |
configs = [triton.testing.Benchmark( | |
x_names=['PAST_LEN'], | |
x_vals=[2**i - 1 for i in range(8, 14)], | |
line_arg='provider', | |
line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'], | |
line_names=['Torch'] + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'], | |
styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')], | |
ylabel='ms', | |
plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}', | |
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode} | |
) for mode in ['fwd']] | |
def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'): | |
assert mode in ['fwd'] | |
warmup = 25 | |
rep = 100 | |
N_CTX = PAST_LEN + Q_LEN | |
if provider == 'torch': | |
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
sm_scale = 1.3 | |
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE, | |
local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True) | |
fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048) | |
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | |
return ms | |
if provider == 'triton_sparse': | |
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
sm_scale = 1.3 | |
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX, | |
local_blocks=LOCAl_BLOCKS, | |
vert_stride=VERT_STRIDE, | |
homo_head=HOMO_HEAD, | |
sparse_block_size=BLOCK_SIZE, | |
kernel_block_size=BLOCK_SIZE, | |
device=q.device, | |
inference=True) | |
fn = lambda: sparse_attention_fn(q, k, v, sm_scale) | |
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | |
return ms | |
if provider == 'triton_dense': | |
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
sm_scale = 1.3 | |
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX, | |
local_blocks=1, | |
vert_stride=1, | |
homo_head=True, | |
sparse_block_size=BLOCK_SIZE, | |
kernel_block_size=BLOCK_SIZE, | |
device=q.device, | |
inference=True) | |
fn = lambda: sparse_attention_fn(q, k, v, sm_scale) | |
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | |
return ms | |
if provider == 'flash': | |
assert Q_LEN == 1 | |
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) | |
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) | |
cu_seqlens[1:] = lengths.cumsum(0) | |
cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32) | |
# (total_q, nheads, headdim), | |
q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | |
fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False) | |
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | |
return ms | |
test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward) | |
# bench_flash_attention.run(save_path='.', print_data=True) | |
bench_flash_attention_inference.run(save_path='.', print_data=True) | |
exit() | |
# head_dim=64 | |
test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64, | |
dtype=torch.bfloat16, homo_head=False, backward=support_backward) | |
# uneven length, bf16 | |
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128, | |
kernel_block_size=64, local_blocks=8, vert_stride=8) | |
test_op(3, 2, 2047, 128, homo_head=False, backward=False) | |
# diff kernel/sparse block size | |
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64) | |
# inference | |
# test_op(1, 4, 512 + 256, 128, Q_LEN=1, dtype=torch.bfloat16, homo_head=False, backward=support_backward) | |
# dense flash attn | |
test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False, | |
backward=support_backward, local_blocks=1, vert_stride=1) | |
# fp16 | |
test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward) | |
# longer sequence | |
test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward) | |
test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward) | |
# homo head | |
test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False) | |
test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward) | |
# sparse_attention_fn = sparse_attention_factory(16, 128, num_warps=1, INFERENCE=True) | |
# test_op(8, 1, 2047, 128, 1, backward=False, sparse_attention_fn=None) | |
# test_op_inference(3, 2, 2048, 128, 2048) | |
# test_op_inference(3, 2, 2047, 64, 2047) | |
# test_op_inference(3, 2, 256, 64, 128) | |
# test_op_inference(3, 2, 2048, 64, 1) | |
bench_flash_attention.run(save_path='.', print_data=True) | |
# bench_flash_attention_inference.run(save_path='.', print_data=True) | |
# ======================== | |
# Some Benchmark Results # | |
# ======================== | |
# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-fwd | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.057184 0.069646 0.052567 | |
# 1 512.0 0.131688 0.187658 0.110212 | |
# 2 1024.0 0.391844 0.524990 0.247875 | |
# 3 2048.0 1.305190 1.456685 0.596506 | |
# 4 4096.0 4.623019 4.968653 1.600277 | |
# 5 8192.0 17.513062 18.332262 4.802458 | |
# 6 16384.0 68.453377 70.337540 16.052908 | |
# 7 32768.0 270.655487 276.020233 57.938946 | |
# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-bwd (num_warp=8): | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.190120 0.150313 0.181451 | |
# 1 512.0 0.406348 0.391767 0.391177 | |
# 2 1024.0 1.029704 1.182967 0.885741 | |
# 3 2048.0 2.985456 3.843399 2.040469 | |
# 4 4096.0 9.808897 13.073701 5.069609 | |
# 5 8192.0 34.995201 47.863808 13.948782 | |
# 6 16384.0 132.740097 182.579193 42.816513 | |
# 7 32768.0 542.223389 714.820618 147.053574 | |
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero: | |
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.050949 0.032357 0.107513 | |
# 1 512.0 0.073624 0.050651 0.199086 | |
# 2 1024.0 0.107472 0.080379 0.245445 | |
# 3 2048.0 0.178423 0.129448 0.338259 | |
# 4 4096.0 0.327647 0.223106 0.517048 | |
# 5 8192.0 0.588423 0.411263 0.884606 | |
# 6 16384.0 1.098898 0.798941 1.611809 | |
# 7 32768.0 2.094537 1.594726 3.044160 | |
# 6.7B | |
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-fwd: | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.069208 0.082156 0.065097 | |
# 1 512.0 0.138271 0.201393 0.144467 | |
# 2 1024.0 0.391521 0.624614 0.322382 | |
# 3 2048.0 1.268443 2.406325 0.784367 | |
# 4 4096.0 4.455703 9.139097 2.100856 | |
# 5 8192.0 16.764315 35.289600 6.328320 | |
# 6 16384.0 65.221634 138.401794 21.069057 | |
# 7 32768.0 257.251343 548.085754 76.111870 | |
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-bwd: | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.297118 0.266469 0.255255 | |
# 1 512.0 0.672826 0.613685 0.552954 | |
# 2 1024.0 1.718434 1.705066 1.251953 | |
# 3 2048.0 4.936755 5.403875 2.927895 | |
# 4 4096.0 15.911594 18.959362 7.436288 | |
# 5 8192.0 55.357441 70.808578 21.140224 | |
# 6 16384.0 208.188416 273.617920 68.018173 | |
# 7 32768.0 806.037476 1081.453613 218.720261 | |
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero: | |
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.050151 0.032337 0.107593 | |
# 1 512.0 0.073409 0.051737 0.200200 | |
# 2 1024.0 0.107533 0.082099 0.247067 | |
# 3 2048.0 0.177259 0.128891 0.338510 | |
# 4 4096.0 0.325866 0.223621 0.524842 | |
# 5 8192.0 0.586926 0.408913 0.885490 | |
# 6 16384.0 1.100834 0.793277 1.612271 | |
# 7 32768.0 2.098851 1.595831 3.064544 | |
# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-fwd: | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.066673 0.082037 0.065085 | |
# 1 512.0 0.137379 0.201880 0.143473 | |
# 2 1024.0 0.390675 0.624234 0.312046 | |
# 3 2048.0 1.267739 2.406950 0.696045 | |
# 4 4096.0 4.445138 9.136333 1.665788 | |
# 5 8192.0 16.768614 35.265533 4.380486 | |
# 6 16384.0 65.235970 138.393600 12.997633 | |
# 7 32768.0 257.317902 550.442993 42.821121 | |
# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-bwd: | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.296461 0.266581 0.254022 | |
# 1 512.0 0.671427 0.613643 0.551283 | |
# 2 1024.0 1.719918 1.704295 1.229982 | |
# 3 2048.0 4.945305 5.403364 2.721906 | |
# 4 4096.0 15.934293 18.960999 6.259371 | |
# 5 8192.0 55.406593 70.832130 15.676929 | |
# 6 16384.0 208.750595 275.004425 44.837891 | |
# 7 32768.0 808.057861 1080.647705 141.856766 | |
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert8-hetero: | |
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.050739 0.032886 0.107837 | |
# 1 512.0 0.073507 0.051996 0.200293 | |
# 2 1024.0 0.106394 0.080679 0.240610 | |
# 3 2048.0 0.177659 0.127660 0.287625 | |
# 4 4096.0 0.326326 0.226971 0.377500 | |
# 5 8192.0 0.586339 0.407367 0.559266 | |
# 6 16384.0 1.102279 0.786221 0.920976 | |
# 7 32768.0 2.097370 1.545090 1.644288 | |
################ | |
##### fp16 ##### | |
################ | |
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd: | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.032518 0.035472 0.029939 | |
# 1 512.0 0.054266 0.087841 0.054320 | |
# 2 1024.0 0.133447 0.263090 0.102045 | |
# 3 2048.0 0.384615 1.023293 0.201763 | |
# 4 4096.0 1.300890 4.023936 0.449555 | |
# 5 8192.0 4.774144 15.816704 1.150854 | |
# 6 16384.0 18.220032 62.771198 3.356001 | |
# 7 32768.0 71.405571 250.273788 10.976142 | |
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd: | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.083342 0.069742 0.079496 | |
# 1 512.0 0.159894 0.170995 0.151705 | |
# 2 1024.0 0.386071 0.522407 0.331443 | |
# 3 2048.0 1.067715 1.737333 0.715248 | |
# 4 4096.0 3.382731 6.219520 1.597457 | |
# 5 8192.0 11.857793 23.560448 3.879035 | |
# 6 16384.0 44.422142 91.251709 10.626843 | |
# 7 32768.0 175.011841 359.473145 32.340992 | |
################ | |
##### bf16 ##### | |
################ | |
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd: | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.037636 0.035902 0.031512 | |
# 1 512.0 0.058591 0.087229 0.058125 | |
# 2 1024.0 0.143337 0.263919 0.108443 | |
# 3 2048.0 0.414458 1.025985 0.214114 | |
# 4 4096.0 1.390841 4.020010 0.480550 | |
# 5 8192.0 5.067938 15.808171 1.230874 | |
# 6 16384.0 19.442280 62.765057 3.597274 | |
# 7 32768.0 75.501572 250.443771 11.768959 | |
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd: | |
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse | |
# 0 256.0 0.084404 0.070663 0.082613 | |
# 1 512.0 0.161510 0.172882 0.157661 | |
# 2 1024.0 0.388954 0.526047 0.339855 | |
# 3 2048.0 1.075814 1.736057 0.732420 | |
# 4 4096.0 3.401622 6.221376 1.636039 | |
# 5 8192.0 11.915136 23.483391 3.968725 | |
# 6 16384.0 44.660225 91.302910 10.857130 | |
# 7 32768.0 175.038467 359.048187 32.778240 | |