Phi-3-small-128k-instruct / triton_flash_blocksparse_attn.py
moidhassan's picture
Resolve - 196 [rank0]: triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 180224, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
794ffcf verified
raw
history blame
82.5 kB
"""
Author: Eric Lin (xihlin)
"""
"""
... note(bapatra)::
This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module
imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal.
In the future, would be really good to revisit this and refactor into a more readable file structure.
"""
from typing import TypeVar
from functools import lru_cache
import math
import pytest
import torch
import numpy as np
import triton
import triton.language as tl
import os
import dataclasses
Phi3SmallConfig = TypeVar('Phi3SmallConfig')
# triton 2.0.0: fail at backward on A100, for the examples, if h_dim=128.
# Done
# 1. strided of qkv
# 2. seq len not power of 2
# 3. bf16 with Triton May, 2023
# TODO:
# 1. wip: support non-contiguous backward, also help reduce memory allocation in training (q, k, v split)
# 2. block sparse with different BLOCK_M, BLOCK_N?
# 3. for Lq not divided by BLOCK_M, BLOCK_N, only apply mask to K/V on last batch, still need to apply mask on Q.
# Attempt, fail to compile
# 4. For 2nd iter of inference, BLOCK_M=1, how to make things work? K/V maynot divided by BLOCK_N.
# 5. The inner loop can also be paralled via bigger num_stage(better) or on different thread-block (via m/L and atomic update, but this no-comm/sync between blocks)
###########################################################
################### Kernel Parameters #####################
###########################################################
@dataclasses.dataclass
class BlockSparseParams(object):
block_size: int
kernel_block_size: int
num_local_blocks: int
vert_stride: int
homo_head_pattern: bool = False
@classmethod
def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams":
return cls(
block_size=config.blocksparse_block_size,
kernel_block_size=config.blocksparse_triton_kernel_block_size,
num_local_blocks=config.blocksparse_num_local_blocks,
vert_stride=config.blocksparse_vert_stride,
homo_head_pattern=config.blocksparse_homo_head_pattern,
)
###########################################################
###########################################################
###########################################################
################### Utility Functions #####################
###########################################################
# helper functions for 3D sparse pattern
# these function are not optimized and very inefficient. Avoid calling them too frequent.
# currently, it is only called within `get_local_strided_sparse_attention_op`, which is cached.
def dense_to_crow_col(x):
''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
param:
TODO:
1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it?
NOTE: col_indices padded -1
'''
pad = -1
dim = x.dim()
assert x.dim() in (2, 3)
if x.dim() == 2:
x = x[None]
x = [xi.to_sparse_csr() for xi in x]
crows = torch.vstack([xi.crow_indices() for xi in x])
cols = [xi.col_indices() for xi in x]
max_cols = max(len(xi) for xi in cols)
cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols]
cols = torch.vstack(cols)
if dim == 2:
crows = crows[0]
cols = cols[0]
return crows, cols
def crow_col_to_dense(crows, cols, dtype=torch.float16):
dim = crows.dim()
if dim == 1:
crows = crows[None]
cols = cols[None]
device = crows.device
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
x = torch.zeros(shape, dtype=dtype)
for i in range(shape[0]):
for j in range(shape[1]):
x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1
if dim == 1:
x = x[0]
return x.to(device)
def dense_to_ccol_row(x):
'''Similar, but to CSC format
'''
x = x.transpose(-2, -1)
return dense_to_crow_col(x)
def ccol_row_to_dense(ccol, rows, dtype=torch.float16):
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False):
'''
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation of CSR format.
- block dense mask
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
'''
with torch.no_grad():
N_BLOCK = triton.cdiv(N_CTX, BLOCK)
q_pos = torch.arange(N_BLOCK)[:, None]
k_pos = torch.arange(N_BLOCK)[None]
mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr()
if return_dense:
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense
else:
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None
def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False):
'''
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation of CSR format.
- block dense mask
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
'''
if homo_head:
with torch.no_grad():
(crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense)
crow = crow[None].expand(n_heads, crow.shape[0])
col = col[None].expand(n_heads, col.shape[0])
if return_dense:
mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape)
return (crow, col), block_mask_dense, mask_dense
with torch.no_grad():
N_BLOCK = triton.cdiv(N_CTX, BLOCK)
q_pos = torch.arange(N_BLOCK)[None, :, None]
k_pos = torch.arange(N_BLOCK)[None, None]
head_sliding_step = max(1, int(vert_stride / n_heads)) # if vert_stride <= n_heads, rotating the heads
mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)]
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:]
if return_dense:
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None]
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense
else:
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None
def get_sparse_attn_mask(q, N_CTX, *args, **kwargs):
return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs)
###########################################################
###########################################################
###########################################################
###################### Training Kernels ###################
###########################################################
# TODO: only apply loading/saving mask on the last iteration for EVEN_N_BLOCK, useful for 1st iteration of inference.
# Experiment failed inside loop.
# Another idea: only on saving? load even out of boundary(will it causes illegal access error)?
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
layout_crow_ptr,
layout_col_ptr,
layout_crow_stride_h, layout_crow_stride_m,
layout_col_stride_h, layout_col_stride_m,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug. TMP, L, M are assumed to have contiguous layouts
Out,
stride_qz, stride_qh, stride_qm, stride_qd,
stride_kz, stride_kh, stride_kn, stride_kd,
stride_vz, stride_vh, stride_vn, stride_vd,
stride_oz, stride_oh, stride_om, stride_od,
Z, H, N_CTX,
PAST_LEN,
Q_ROUNDED_LEN,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
EVEN_M_BLOCK: tl.constexpr,
EVEN_N_BLOCK: tl.constexpr,
INFERENCE: tl.constexpr,
NUM_DBLOCKS: tl.constexpr,
):
Q_LEN = N_CTX - PAST_LEN
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_h = off_hz % H
off_z = off_hz // H
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
# off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if NUM_DBLOCKS >= 2:
acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
if EVEN_M_BLOCK:
q = tl.load(q_ptrs)
if NUM_DBLOCKS >= 2:
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
else:
q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)
if NUM_DBLOCKS >= 2:
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN)
layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m
start_l = tl.load(layout_ptr).to(tl.int32)
end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32)
# loop over k, v and update accumulator
for col_idx_idx in range(start_l, end_l):
col_idx = tl.load(layout_col_ptr + off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32)
start_n = col_idx * BLOCK_N
# -- compute qk ----
if EVEN_N_BLOCK:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
if NUM_DBLOCKS >= 2:
if EVEN_N_BLOCK:
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd)
else:
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX)
qk += tl.dot(q2, k)
qk *= sm_scale
qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf'))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
# tl.store(t_ptrs, acc_scale)
# acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
if NUM_DBLOCKS >= 2:
acc2 = acc2 * acc_scale[:, None]
p = p.to(Q.dtype.element_ty)
# update acc
if EVEN_N_BLOCK:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX)
acc += tl.dot(p, v)
if NUM_DBLOCKS >= 2:
if EVEN_N_BLOCK:
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd)
else:
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX)
acc2 += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
# start_m = tl.program_id(0)
# offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
if not INFERENCE:
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
if EVEN_M_BLOCK:
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
else:
tl.store(l_ptrs, l_i, mask=offs_m < Q_LEN)
tl.store(m_ptrs, m_i, mask=offs_m < Q_LEN)
# initialize pointers to output
# offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < Q_LEN)
if NUM_DBLOCKS >= 2:
tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2, mask=offs_m[:, None] < Q_LEN)
## backward
@triton.heuristics(
{
'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
}
)
@triton.jit
def _bwd_preprocess(
Out, DO, L, # assume contiguous for Out, DO, L, NewDO, Delta layout.
NewDO, Delta,
N_CTX,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
EVEN_M_BLOCK: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_d = tl.arange(0, D_HEAD)
# load
if EVEN_M_BLOCK:
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
else:
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
if EVEN_M_BLOCK:
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do)
else:
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do, mask=off_m[:, None] < N_CTX)
tl.store(Delta + off_m, delta)
# Does not suuport unequal seqlen(q) and seqlen(k)
@triton.heuristics(
{
'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
'EVEN_N_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_N'] == 0,
}
)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale,
layout_ccol_ptr,
layout_row_ptr,
layout_ccol_stride_h, layout_ccol_stride_m,
layout_row_stride_h, layout_row_stride_m,
Out, DO, # assume contigous: Out, Do, DQ, DK, DV, L, M, D, seq(q) == seq(k), with stride_oz, stride_oh, stride_om, stride_od,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qd,
stride_kz, stride_kh, stride_kn, stride_kd,
stride_vz, stride_vh, stride_vn, stride_vd,
stride_oz, stride_oh, stride_om, stride_od,
# stride_dz, stride_dh, stride_dm, stride_dd,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
EVEN_M_BLOCK: tl.constexpr,
EVEN_N_BLOCK: tl.constexpr,
NUM_DBLOCKS: tl.constexpr,
):
start_n = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_oz + off_h * stride_oh
DQ += off_z * stride_oz + off_h * stride_oh
DK += off_z * stride_oz + off_h * stride_oh
DV += off_z * stride_oz + off_h * stride_oh
# Look like this loop can be parallelled
# for start_n in range(0, num_block):
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd)
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
if EVEN_N_BLOCK:
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
else:
k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX)
v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX)
if NUM_DBLOCKS >= 2:
dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
if EVEN_N_BLOCK:
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd)
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd)
else:
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX)
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX)
# loop over rows
layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m
start_l = tl.load(layout_ptr).to(tl.int32)
end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32)
for row_idx_idx in range(start_l, end_l):
row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32)
start_m = row_idx * BLOCK_M
# offs_qm = start_m + tl.arange(0, BLOCK_M)
offs_m_curr = start_m + offs_m
q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd)
do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
# load q, k, v, do on-chip
if EVEN_M_BLOCK:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX)
# re-compute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
if NUM_DBLOCKS >= 2:
if EVEN_M_BLOCK:
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
else:
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX)
qk += tl.dot(q2, tl.trans(k2))
qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf'))
if EVEN_M_BLOCK:
m = tl.load(m_ptrs + offs_m_curr)
else:
m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
if EVEN_M_BLOCK:
do = tl.load(do_ptrs)
else:
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX)
if NUM_DBLOCKS >= 2:
if EVEN_M_BLOCK:
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od)
else:
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
if NUM_DBLOCKS >= 2:
dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2)
# compute dp = dot(v, do)
if EVEN_M_BLOCK:
Di = tl.load(D_ptrs + offs_m_curr)
else:
Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
if NUM_DBLOCKS >= 2:
dp += tl.dot(do2, tl.trans(v2))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
if NUM_DBLOCKS >= 2:
dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2)
# # compute dq
dq = tl.dot(ds.to(Q.dtype.element_ty), k)
if EVEN_M_BLOCK:
tl.atomic_add(dq_ptrs, dq)
else:
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX)
if NUM_DBLOCKS >= 2:
dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2)
dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od
if EVEN_M_BLOCK:
tl.atomic_add(dq_ptrs2, dq2)
else:
tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX)
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
if EVEN_N_BLOCK:
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
else:
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX)
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX)
if NUM_DBLOCKS >= 2:
dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od
dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od
if EVEN_N_BLOCK:
tl.store(dv_ptrs2, dv2)
tl.store(dk_ptrs2, dk2)
else:
tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX)
tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX)
def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None):
'''
:param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v.
:param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor.
Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all..
'''
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous()
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
q_rounded_len = grid[0] * BLOCK_M
tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
if inference is None:
inference = (not q.requires_grad) and (not k.requires_grad) and (not v.requires_grad)
if inference:
L, m = tmp, tmp # no need to use create new tensor
else:
L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
if layout_col_indices.dim() == 1:
layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1)
layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1)
assert q.shape[-1] in [64, 128]
BLOCK_DMODEL = 64
if num_warps is None:
MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL)
num_warps = max(1, 2 ** int(math.log2(MIN_D / 16)))
# print(f'> {BLOCK_M=}, {BLOCK_N=}, {BLOCK_DMODEL=}, {num_warps=}, {num_stages=}')
else:
assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.'''
## For debugging:
# print(f'>> {q.shape=}, {k.shape=}, {BLOCK_M=}, {BLOCK_N=}, {num_warps=}, {BLOCK_DMODEL=}, {q.stride()=}, {k.stride()=}')
# print(f'>> {layout_crow_indices=}\n{layout_col_indices=}\n {layout_crow_indices.stride()=}, {layout_crow_indices.stride()=}')
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
_fwd_kernel[grid](
q, k, v, sm_scale,
layout_crow_indices,
layout_col_indices,
layout_crow_indices.stride(0), layout_crow_indices.stride(1),
layout_col_indices.stride(0), layout_col_indices.stride(1),
tmp, L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], k.shape[2],
k.shape[2] - q.shape[2],
q_rounded_len,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
BLOCK_DMODEL=BLOCK_DMODEL,
EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
INFERENCE=inference,
NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
num_warps=num_warps,
num_stages=num_stages,
)
if inference:
L, m = None, None
ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices, layout_col_indices)
ctx.BLOCK_M = BLOCK_M
ctx.BLOCK_N = BLOCK_N
ctx.BLOCK_DMODEL = BLOCK_DMODEL
# ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.num_warps = num_warps
ctx.num_stages = num_stages
return o
def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None):
# q, k, v, o, l, m = ctx.saved_tensors
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
## this following too slow to do online, so get it from inputs, which is cached.
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
if not do.is_contiguous():
do = do.contiguous()
## for debugging
# print(f'----> do is not contiguous: {do.stride()=}')
# raise ValueError(f'>>>> output grad is not contiguous: {do.stride()=}')
if not o.is_contiguous():
# TODO: currently only work with contiguous q/k/v.
raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.')
if layout_ccol_indices.dim() == 1:
layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1)
layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1)
# do = do.contiguous()
dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32)
dk = dk if dk is not None else torch.empty_like(k)
dv =dv if dv is not None else torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride()
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
k.shape[2],
BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1],
)
grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1])
_bwd_kernel[grid](
q, k, v, ctx.sm_scale,
layout_ccol_indices,
layout_row_indices,
layout_ccol_indices.stride(0), layout_ccol_indices.stride(1),
layout_row_indices.stride(0), layout_row_indices.stride(1),
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK_M,
BLOCK_N=ctx.BLOCK_N,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL,
num_warps=ctx.num_warps,
num_stages=1,
)
return dq, dk, dv, None, None, None
class _sparse_attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
BLOCK = 128
# shape constraints
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK)
@staticmethod
def backward(ctx, do):
# q, k, v, o, l, m = ctx.saved_tensors
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
# TODO: the following is very inefficient.
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
return _backward(ctx, do, layout_ccol_indices, layout_row_indices)
# suppressed
class _sparse_attention_inference(_sparse_attention):
# TODO: does not work now, as BLOCK_M cannot be <1, as shape for tl.dot cannot be smaller than 16.
@staticmethod
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
BLOCK = 128
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK)
def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs):
class _sparse_attention_config(_sparse_attention):
@staticmethod
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
# shape constraints
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
**kwargs
)
return _sparse_attention_config.apply
@lru_cache(maxsize=8)
def get_local_strided_sparse_attention_op(
n_heads: int,
max_seq_len:int,
sparse_block_size: int=128,
local_blocks: int=4,
vert_stride: int=4,
homo_head: bool=False,
dtype=torch.bfloat16,
device='cuda',
active_head_range=None,
verbose=True,
**kwargs):
'''
:param n_heads: total number of attention heads (regardless of tensor/model parallel)
:param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences.
:param sparse_block_size: sparse block size. Default to 128
:param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens.
:param vert_stride: Default to 4. Meaning
:param homo_head: if all head shared the same pattern.
:param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads.
Mainly for tensor/model parallelization where heads are splitted to different GPUs.
'''
if verbose:
print((f'> new block_sparse_attn op constructed with config: '
f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, '
f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}'))
# assert math.log2(max_seq_len) % 2 == 0, f"max_seq_len should be power of 2 to be more efficient"
_, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device,
BLOCK=sparse_block_size, local_blocks=local_blocks,
vert_stride=vert_stride, homo_head=homo_head,
return_dense=False)
if (not homo_head) and (active_head_range is not None):
assert isinstance(active_head_range, tuple)
assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.'
h_start, h_end = active_head_range
block_sparse_pattern = block_sparse_pattern[h_start:h_end]
# print(block_sparse_pattern)
return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs)
def get_sparse_attn_op(
sparse_pattern: torch.tensor,
sparse_block_size: int=128,
kernel_block_size=128,
qkv_format='q,k,v',
**kwargs):
'''
Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime,
which is very inefficient (use python loops on CPU. PyTorch 1.13 supports CSR->CSC, may help.)
:param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`.
This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention
:param sparse_block_size: sparse block size. Default to 128
:param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size`
:param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported.
:param kwargs: keyward arguments passed to `_forward`
'''
# assert qkv_format in ('q,k,v', 'q, kv', 'qkv') # to save from running `concat` at forward/backward
assert qkv_format == 'q,k,v'
if kernel_block_size is None:
kernel_block_size = sparse_block_size
else:
assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}."
assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given"
# print(f'>> {sparse_pattern.shape=}')
# print(f'{sparse_pattern=}')
if sparse_block_size // kernel_block_size > 1:
_mul = sparse_block_size // kernel_block_size
# need to consider if block_m and block_n are different
sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul))
num_sparse_blocks = sparse_pattern.size(-1)
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
sparse_pattern *= block_causal_mask.type_as(sparse_pattern)
# print(f'>> after: {sparse_pattern.shape=}')
# print(f'{sparse_pattern=}')
BLOCK_N = kernel_block_size
NUM_BLOCK = sparse_pattern.size(-1)
MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK
grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern)
# sparse csc layout for backward
grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern)
# cache GPU backward layout. limit the size to avoid OOM as time goes.
# For inference, one only needs to cache one block as sequence length always increases
# Therefore, this cache needs to be reconstructed per every `block_size`-steps.
# For training/finetune, set to 8 to increase cache hit.
# Given an input, the block_len will be the same for all layers, so cache is very helpful.
max_cache_size = 1 if kwargs.get('inference', False) else 8
@lru_cache(maxsize=max_cache_size)
def get_backward_layout_by_block_len(block_len):
assert block_len <= NUM_BLOCK
if block_len == NUM_BLOCK:
return (grand_layout_ccol_indices, grand_layout_row_indices)
return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len])
# for debugging
# if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# print(f'> {sparse_pattern.cpu().tolist()=}')
# print('----')
# print(f'> {grand_layout_crow_indices.cpu().tolist()=}\n{grand_layout_col_indices.cpu().tolist()=}')
# q, k, v separated
class _q_k_v_sparse_attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
# assert q.shape[2] == 1 or q.shape[2] == k.shape[2]
# shape constraints
MIN_BLOCK_SIZE = 16
assert BLOCK_N >= MIN_BLOCK_SIZE
BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N # BLOCK_M has to be power of 2
# this following code only works for causal attention
K_BLOCKS = triton.cdiv(k.shape[2], kernel_block_size)
# Q_START_BLOCKS = K_BLOCKS - 1 if q.shape[2] == 1 else 0
Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N)
# print(Q_START_BLOCKS, K_BLOCKS)
layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1]
layout_col_indices = grand_layout_col_indices
# print(BLOCK_M, BLOCK_N, Q_START_BLOCKS, K_BLOCKS+1, layout_crow_indices, layout_col_indices)
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
**kwargs
)
@staticmethod
def backward(ctx, do):
q, k = ctx.saved_tensors[:2]
assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.'
# assume q, k have same length
block_len = triton.cdiv(do.shape[2], kernel_block_size)
backward_layout = get_backward_layout_by_block_len(block_len)
return _backward(ctx, do, *backward_layout)[:4]
def _q_k_v_sparse_attention_fn(*args):
return _q_k_v_sparse_attention.apply(*args)
_q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern
_q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices
_q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices
_q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices
_q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices
return _q_k_v_sparse_attention_fn
###########################################################
###########################################################
###########################################################
################ Inference Kernels ########################
###########################################################
def blocksparse_flash_attn_padded_fwd(
q, k, v, # (batch, tokens, n_heads, head_size)
sm_scale,
sparse_layout,
*,
left_paddings = None,
seqlens = None,
block_size = 64,
max_seqlen = None
):
'''
q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size)
left_paddings: (batch, ), number of left paddings for each sample.
seqlens: can be used to specify right padding. No need to specify if left_paddings is used.
'''
batches, q_len, n_heads, head_size = q.shape
_, k_len, n_kv_heads, _ = k.shape
assert q.dim() == k.dim() == v.dim() == 4
assert q.size(2) % k.size(2) == 0
assert q.size(0) == k.size(0) and q.size(3) == k.size(3)
assert k.shape == v.shape # TODO: allow diff head_size for k, v
assert q_len == 1 or q_len == k_len, \
f'q length can only 1 for decoding for same as k length for prefilling.'
q_k_ratio = q.size(2) // k.size(2)
if max_seqlen:
assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.'
# paddings always has zero output, a little slower than using empty
out = q.new_zeros(q.shape)
layout_crow_indices, layout_col_indices = sparse_layout
block_d = triton.next_power_of_2(head_size)
if left_paddings is not None:
assert left_paddings.shape == (batches,)
k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous()
else:
k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device)
if seqlens is not None:
k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts)
assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.'
else:
k_batch_ends = torch.zeros_like(k_batch_starts) + k_len
if q_len == 1:
q_batch_starts = torch.zeros_like(k_batch_starts)
q_batch_ends = q_batch_starts + 1
else:
q_batch_starts = k_batch_starts
q_batch_ends = k_batch_ends
# switch to use cpu to avoid too many kernel lauch when iterate over
q_lens = (q_batch_ends - q_batch_starts).cpu()
n_blocks = (q_lens + block_size - 1) // block_size
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
dtype=q_batch_starts.dtype,
device=q_batch_starts.device)
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
dtype=q_batch_starts.dtype,
device=q_batch_starts.device)
grid = (len(q_start_sids), n_heads)
_fwd_kernel_batch_inference[grid](
q, k, v, out,
sm_scale,
q_batch_starts,
q_batch_ends,
k_batch_starts,
k_batch_ends,
q_batch_ids,
q_start_sids,
*q.stride(),
*k.stride(),
*v.stride(),
*out.stride(),
layout_crow_indices,
layout_col_indices,
*layout_crow_indices.stride(),
*layout_col_indices.stride(),
q_k_ratio,
HAS_BATCH_DIM = True,
D_HEAD = head_size,
BLOCK_M = block_size,
BLOCK_N = block_size,
BLOCK_D = block_d,
BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
EVEN_D = block_d == head_size,
num_warps = 1 if q_len == 1 else 4,
num_stages = 1 # <---- instead of 3
)
return out
def blocksparse_flash_attn_varlen_fwd(
q, k, v, # (#tokens, n_heads, head_size)
cu_seqlens_k,
cu_seqlens_q,
sm_scale,
sparse_layout,
*,
block_size=64,
max_seqlen = None
):
# split q to blocks
_, n_heads, head_size = q.shape
batch_size = cu_seqlens_k.size(0) - 1
# print(f'> {q.shape=}, {k.shape=}')
assert q.dim() == k.dim() == v.dim() == 3
assert q.size(1) % k.size(1) == 0
assert q.size(2) == k.size(2)
assert k.shape == v.shape # TODO: allow diff head_size for k, v
assert cu_seqlens_k.dim() == 1
q_k_ratio = q.size(1) // k.size(1)
if cu_seqlens_q is None:
if q.size(0) == batch_size: # decoding only
cu_seqlens_q = torch.arange(0, batch_size + 1,
dtype=cu_seqlens_k.dtype,
device=cu_seqlens_k.device)
elif q.size(0) == k.size(0):
cu_seqlens_q = cu_seqlens_k
else:
raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.')
else:
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
# switch to use cpu to avoid too many kernel lauch when iterate over
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \
'length of q should either be 1 (decoding) or same as k (prefilling).'
if max_seqlen:
assert k_lens.max() <= max_seqlen
n_blocks = (q_lens + block_size - 1) // block_size
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
dtype=cu_seqlens_q.dtype,
device=cu_seqlens_q.device)
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
dtype=cu_seqlens_q.dtype,
device=cu_seqlens_q.device)
out = q.new_empty(q.shape)
cu_seqlens_q = cu_seqlens_q.contiguous()
cu_seqlens_k = cu_seqlens_k.contiguous()
layout_crow_indices, layout_col_indices = sparse_layout
block_d = triton.next_power_of_2(head_size)
decoding_only = (q_lens == 1).all()
grid = (len(q_start_sids), n_heads)
_fwd_kernel_batch_inference[grid](
q, k, v, out,
sm_scale,
cu_seqlens_q[:-1],
cu_seqlens_q[1:],
cu_seqlens_k[:-1],
cu_seqlens_k[1:],
q_batch_ids,
q_start_sids,
0, *q.stride(),
0, *k.stride(),
0, *v.stride(),
0, *out.stride(),
layout_crow_indices,
layout_col_indices,
*layout_crow_indices.stride(),
*layout_col_indices.stride(),
q_k_ratio,
HAS_BATCH_DIM = False,
D_HEAD = head_size,
BLOCK_M = block_size,
BLOCK_N = block_size,
BLOCK_D = block_d,
BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
EVEN_D = block_d == head_size,
num_warps = 1 if decoding_only else 4,
num_stages = 3
)
return out
@triton.jit
def _fwd_kernel_inner(
acc, l_i, m_i,
q, Q,
k_block_col_idx,
layout_col_ptr,
layout_col_stride_h, layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h, offs_m, offs_n, offs_d,
stride_kt, stride_vt,
sm_scale,
k_seqlen,
past_len,
LAST_K_BLOCK: tl.constexpr,
BLOCK_M_LOADING: tl.constexpr,
BLOCK_N: tl.constexpr,
D_HEAD: tl.constexpr,
EVEN_D: tl.constexpr,
M_LT_N: tl.constexpr
):
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32)
start_n = k_block_id * BLOCK_N
# -- compute qk ----
if LAST_K_BLOCK:
if EVEN_D:
k = tl.load(k_ptrs + start_n * stride_kt,
mask=offs_n[None, :] + start_n < k_seqlen)
else:
# mask = mask & (offs_d[:, ])
k = tl.load(k_ptrs + start_n * stride_kt,
mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD))
else:
if EVEN_D:
k = tl.load(k_ptrs + start_n * stride_kt)
else:
k = tl.load(k_ptrs + start_n * stride_kt,
mask=offs_d[:, None] < D_HEAD)
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK | M_LT_N:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf'))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
p = p.to(Q.dtype.element_ty)
# update acc
if LAST_K_BLOCK:
if EVEN_D:
v = tl.load(v_ptrs + start_n * stride_vt,
mask=offs_n[:, None] + start_n < k_seqlen)
else:
v = tl.load(v_ptrs + start_n * stride_vt,
mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD))
else:
if EVEN_D:
v = tl.load(v_ptrs + start_n * stride_vt)
else:
v = tl.load(v_ptrs + start_n * stride_vt,
mask=offs_d[None, :] < D_HEAD)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
return acc, l_i, m_i
@triton.heuristics(
{
'M_LT_N': lambda kwargs: kwargs['BLOCK_M'] < kwargs['BLOCK_N'],
}
)
@triton.jit
def _fwd_kernel_batch_inference(
Q, K, V, Out,
sm_scale,
q_batch_starts,
q_batch_ends,
k_batch_starts,
k_batch_ends,
q_batch_ids,
q_start_sids,
stride_qb, stride_qt, stride_qh, stride_qd,
stride_kb, stride_kt, stride_kh, stride_kd,
stride_vb, stride_vt, stride_vh, stride_vd,
stride_ob, stride_ot, stride_oh, stride_od,
layout_crow_ptr,
layout_col_ptr,
layout_crow_stride_h, layout_crow_stride_m,
layout_col_stride_h, layout_col_stride_m,
q_k_ratio,
HAS_BATCH_DIM: tl.constexpr,
D_HEAD: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_M_LOADING: tl.constexpr,
EVEN_D: tl.constexpr,
M_LT_N: tl.constexpr
):
'''
NOTATION:
pid: position id
sid: storage id
sbid: storage block id
pbid: position block id
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
q and blocks in KV needs to be contiguous
Arguments:
kv_seq_lens: for compute past_len
kv_storage_offsets: similar to block_tables in vllm, except it is dynamic.
TODO: fix this
TODO:
Optimize grouped-attn
CUDA graph support issue
1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...)
since we mix prompt and decoing phase here, it can be more complex.
need to set up diff cuda-graph for diff (off_zm, off_z)
# indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding
therefore, cu_seqlens_q, kv_seq_lens
'''
off_zm = tl.program_id(0)
off_h = tl.program_id(1)
off_h_for_kv = off_h // q_k_ratio
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
q_start_sid = tl.load(q_start_sids + off_zm)
start_m = q_start_sid // BLOCK_M
if HAS_BATCH_DIM:
Q += off_z * stride_qb
K += off_z * stride_kb
V += off_z * stride_vb
Out += off_z * stride_ob
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
past_len = k_seqlen - q_seqlen
Q += q_cu_start * stride_qt + off_h * stride_qh
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
Out += q_cu_start * stride_ot + off_h * stride_oh
q_pbid = (past_len + q_start_sid) // BLOCK_M
if EVEN_D:
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=offs_m[:, None] < q_seqlen)
else:
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
other=0)
sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m
# TODO: load at once, supported in new Triton
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
for k_block_col_idx in range(k_block_start, k_block_end - 1):
acc, l_i, m_i = _fwd_kernel_inner(
acc, l_i, m_i,
q, Q,
k_block_col_idx,
layout_col_ptr,
layout_col_stride_h, layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h, offs_m, offs_n, offs_d,
stride_kt, stride_vt,
sm_scale,
k_seqlen,
past_len,
False,
BLOCK_M_LOADING,
BLOCK_N,
D_HEAD,
EVEN_D,
M_LT_N
)
acc, l_i, m_i = _fwd_kernel_inner(
acc, l_i, m_i,
q, Q,
k_block_end - 1,
layout_col_ptr,
layout_col_stride_h, layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h, offs_m, offs_n, offs_d,
stride_kt, stride_vt,
sm_scale,
k_seqlen,
past_len,
True,
BLOCK_M_LOADING,
BLOCK_N,
D_HEAD,
EVEN_D,
M_LT_N
)
# write output
if EVEN_D:
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
mask=offs_m[:, None] < q_seqlen)
else:
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD))
###########################################################
###########################################################
###########################################################
################## Testing Utilities ######################
###########################################################
def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None):
'''
q, k, v: shape=(batch, n_heads, seq, dim)
'''
# for verification
if sm_scale is None:
sm_scale = math.sqrt(float(q.size(-1)))
if block_attn_mask is not None:
assert attn_mask is None
outs = []
for s in range(0, q.size(2), block_size):
e = min(s + block_size, q.size(2))
q_block = q[:, :, s:e]
attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale
mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)]
mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device))
mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0)
attn = attn.masked_fill((1 - mask).bool(), float('-inf'))
attn = attn.softmax(-1)
out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e])
outs.append(out)
torch_output = torch.cat(outs, dim=2)
else:
attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale
# import ipdb; ipdb.set_trace()
if attn_mask is not None:
attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf'))
# print(f'> torch attn: {attn.exp().sum(-1)=}')
attn = attn.softmax(-1)
if do is not None:
dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do)
print(f'> torch_attn computed dv: {dv=}')
torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v)
return torch_output
###########################################################
###########################################################
###########################################################
#################### Unit Tests ###########################
###########################################################
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 8, 2048, 128), (1, 4, 4096, 64)])
def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True,
sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None):
Q_LEN = Q_LEN or N_CTX
torch.manual_seed(20)
q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
if sm_scale is None:
sm_scale = 1. / math.sqrt(D_HEAD)
# for debugging
# print(f'>> {q.shape=}, {k.shape=}, {v.shape=}, {homo_head=}, {kernel_block_size=}, {sparse_block_size=}, {local_blocks=}, {vert_stride=}')
sm_scale = 0.0078125
if backward:
q.requires_grad_(), k.requires_grad_(), v.requires_grad_()
# qkv = torch.empty((Z, N_CTX, 3*H*D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5)
# q = qkv[..., :H*D_HEAD]
# k = qkv[..., H*D_HEAD:2*H*D_HEAD]
# v = qkv[..., 2*H*D_HEAD:]
# q = q.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
# k = k.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
# v = v.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
# if Q_LEN and Q_LEN < N_CTX:
# q = q[:, :, -Q_LEN:] # .contiguous()
# q = q.requires_grad_()
# k = k.requires_grad_()
# v = v.requires_grad_()
dout = torch.randn_like(q).contiguous()
# dout = torch.eye(N_CTX)[:, :D_HEAD][None, None].expand_as(q).type_as(q).contiguous()
# print(dout)
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size,
local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True)
if sparse_attention_fn is None:
sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX,
sparse_block_size=sparse_block_size,
local_blocks=local_blocks,
vert_stride=vert_stride,
homo_head=homo_head,
device=q.device,
dtype=q.dtype,
kernel_block_size=kernel_block_size)
# reference implementation
ref_out = torch_attention(q, k, v, mask_dense, sm_scale)
# lengths = torch.full((Z,), fill_value=N_CTX, device='cuda')
# cu_seqlens = torch.zeros((Z + 1,), device='cuda', dtype=torch.int32)
# cu_seqlens[1:] = lengths.cumsum(0)
# # qkv = torch.randn((Z * N_CTX, 3, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
# qkv_list = list(map(lambda x: x.permute(0, 2, 1, 3).contiguous().view(Z * N_CTX, 1, H, D_HEAD), [q, k, v]))
# qkv = torch.cat(qkv_list, dim=1)
# ref_out0 = flash_attn_func(qkv, cu_seqlens, dropout_p=0, max_s=N_CTX, softmax_scale=sm_scale, causal=True)
# ref_out = ref_out0.view(Z, N_CTX, H, D_HEAD).permute(0, 2, 1, 3).contiguous()
if backward:
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
tri_out = sparse_attention_fn(q, k, v, sm_scale)
decimal = 1 if dtype == torch.bfloat16 else 2
assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}'
if backward:
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
if backward:
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}')
###########################################################
if __name__ == '__main__':
GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip()
# print(GPU_TYPE)
support_backward = True # 'A100' in GPU_TYPE. Wasn't supportted in consumer A1000.
###############
# benchmarking
HAS_DENSE_TRITON_FLASH = False
# try:
# from triton.ops.flash_attention import attention as triton_attention
# HAS_DENSE_TRITON_FLASH = True
# except:
# HAS_DENSE_TRITON_FLASH = False
# print('> cannot import Trition flash attn')
try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
print('> cannot import flash_attn')
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128 # 6.7B model, with 4k len
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 16, 4096, 128 # 204m model
BLOCK_SIZE = 64
LOCAl_BLOCKS = 8 # 4
VERT_STRIDE = 1 # 16 # 8
HOMO_HEAD = False
sparse_type = 'home' if HOMO_HEAD else 'hetero'
dtype = torch.bfloat16
modes = ['fwd', 'bwd'] if support_backward else ['fwd']
configs = [triton.testing.Benchmark(
x_names=['SEQ_LEN'],
x_vals=[2**i for i in range(8, 16)],
line_arg='provider',
line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'],
line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else []) + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'],
styles=[('red', '-'), ('blue', '-'), ('green', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode}
) for mode in modes]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
N_CTX = SEQ_LEN
if provider == 'triton':
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
sm_scale = 1.3
fn = lambda: triton_attention(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == 'triton_sparse':
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
sm_scale = 1.3
# q_pos = torch.arange(N_CTX // BLOCK, device='cuda')[:, None]
# k_pos = torch.arange(N_CTX // BLOCK, device='cuda')[None]
# local_blocks = 4 # num_block per attn, block_size is tied to BLOCK
# vert_stride =N_CTX + 1 # 4
# mask_vert_strided = torch.arange(N_CTX // BLOCK, device='cuda') % vert_stride == vert_stride - 1
# mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).type_as(q)
# mask = mask_dense.to_sparse_csr()
# mask_csr, _ = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK, local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=HOMO_HEAD)
if sparse_attention_fn is None:
# sparse_attention_fn = sparse_attention
sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN,
local_blocks=LOCAl_BLOCKS,
vert_stride=VERT_STRIDE,
homo_head=HOMO_HEAD,
sparse_block_size=BLOCK_SIZE,
kernel_block_size=BLOCK_SIZE,
device=q.device)
# sparse_attention_fn = sparse_attention_factory(128, 128, num_warps=8)
# fn = lambda: sparse_attention_fn(q, k, v, mask_csr[0], mask_csr[1], sm_scale)
fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == 'flash':
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
# if provider == 'torch':
# q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
# k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
# v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
# sm_scale = 1.3
# causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(q)
# fn = lambda: torch_attention(q, k, v, causal_mask, sm_scale)
# ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
# return ms
BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1 # 6.7B model, with 4k len
BLOCK_SIZE = 64
LOCAl_BLOCKS = 8 # 4
VERT_STRIDE = 16 # 8
HOMO_HEAD = False
sparse_type = 'home' if HOMO_HEAD else 'hetero'
dtype = torch.bfloat16
MAX_N_CTX = 8192
configs = [triton.testing.Benchmark(
x_names=['PAST_LEN'],
x_vals=[2**i - 1 for i in range(8, 14)],
line_arg='provider',
line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'],
line_names=['Torch'] + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'],
styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')],
ylabel='ms',
plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode}
) for mode in ['fwd']]
@triton.testing.perf_report(configs)
def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'):
assert mode in ['fwd']
warmup = 25
rep = 100
N_CTX = PAST_LEN + Q_LEN
if provider == 'torch':
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
sm_scale = 1.3
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE,
local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True)
fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == 'triton_sparse':
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
sm_scale = 1.3
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
local_blocks=LOCAl_BLOCKS,
vert_stride=VERT_STRIDE,
homo_head=HOMO_HEAD,
sparse_block_size=BLOCK_SIZE,
kernel_block_size=BLOCK_SIZE,
device=q.device,
inference=True)
fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == 'triton_dense':
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
sm_scale = 1.3
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
local_blocks=1,
vert_stride=1,
homo_head=True,
sparse_block_size=BLOCK_SIZE,
kernel_block_size=BLOCK_SIZE,
device=q.device,
inference=True)
fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == 'flash':
assert Q_LEN == 1
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32)
# (total_q, nheads, headdim),
q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
# bench_flash_attention.run(save_path='.', print_data=True)
bench_flash_attention_inference.run(save_path='.', print_data=True)
exit()
# head_dim=64
test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64,
dtype=torch.bfloat16, homo_head=False, backward=support_backward)
# uneven length, bf16
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128,
kernel_block_size=64, local_blocks=8, vert_stride=8)
test_op(3, 2, 2047, 128, homo_head=False, backward=False)
# diff kernel/sparse block size
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64)
# inference
# test_op(1, 4, 512 + 256, 128, Q_LEN=1, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
# dense flash attn
test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False,
backward=support_backward, local_blocks=1, vert_stride=1)
# fp16
test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
# longer sequence
test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward)
test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
# homo head
test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False)
test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward)
# sparse_attention_fn = sparse_attention_factory(16, 128, num_warps=1, INFERENCE=True)
# test_op(8, 1, 2047, 128, 1, backward=False, sparse_attention_fn=None)
# test_op_inference(3, 2, 2048, 128, 2048)
# test_op_inference(3, 2, 2047, 64, 2047)
# test_op_inference(3, 2, 256, 64, 128)
# test_op_inference(3, 2, 2048, 64, 1)
bench_flash_attention.run(save_path='.', print_data=True)
# bench_flash_attention_inference.run(save_path='.', print_data=True)
# ========================
# Some Benchmark Results #
# ========================
# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-fwd
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.057184 0.069646 0.052567
# 1 512.0 0.131688 0.187658 0.110212
# 2 1024.0 0.391844 0.524990 0.247875
# 3 2048.0 1.305190 1.456685 0.596506
# 4 4096.0 4.623019 4.968653 1.600277
# 5 8192.0 17.513062 18.332262 4.802458
# 6 16384.0 68.453377 70.337540 16.052908
# 7 32768.0 270.655487 276.020233 57.938946
# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-bwd (num_warp=8):
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.190120 0.150313 0.181451
# 1 512.0 0.406348 0.391767 0.391177
# 2 1024.0 1.029704 1.182967 0.885741
# 3 2048.0 2.985456 3.843399 2.040469
# 4 4096.0 9.808897 13.073701 5.069609
# 5 8192.0 34.995201 47.863808 13.948782
# 6 16384.0 132.740097 182.579193 42.816513
# 7 32768.0 542.223389 714.820618 147.053574
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.050949 0.032357 0.107513
# 1 512.0 0.073624 0.050651 0.199086
# 2 1024.0 0.107472 0.080379 0.245445
# 3 2048.0 0.178423 0.129448 0.338259
# 4 4096.0 0.327647 0.223106 0.517048
# 5 8192.0 0.588423 0.411263 0.884606
# 6 16384.0 1.098898 0.798941 1.611809
# 7 32768.0 2.094537 1.594726 3.044160
# 6.7B
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-fwd:
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.069208 0.082156 0.065097
# 1 512.0 0.138271 0.201393 0.144467
# 2 1024.0 0.391521 0.624614 0.322382
# 3 2048.0 1.268443 2.406325 0.784367
# 4 4096.0 4.455703 9.139097 2.100856
# 5 8192.0 16.764315 35.289600 6.328320
# 6 16384.0 65.221634 138.401794 21.069057
# 7 32768.0 257.251343 548.085754 76.111870
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-bwd:
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.297118 0.266469 0.255255
# 1 512.0 0.672826 0.613685 0.552954
# 2 1024.0 1.718434 1.705066 1.251953
# 3 2048.0 4.936755 5.403875 2.927895
# 4 4096.0 15.911594 18.959362 7.436288
# 5 8192.0 55.357441 70.808578 21.140224
# 6 16384.0 208.188416 273.617920 68.018173
# 7 32768.0 806.037476 1081.453613 218.720261
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.050151 0.032337 0.107593
# 1 512.0 0.073409 0.051737 0.200200
# 2 1024.0 0.107533 0.082099 0.247067
# 3 2048.0 0.177259 0.128891 0.338510
# 4 4096.0 0.325866 0.223621 0.524842
# 5 8192.0 0.586926 0.408913 0.885490
# 6 16384.0 1.100834 0.793277 1.612271
# 7 32768.0 2.098851 1.595831 3.064544
# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-fwd:
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.066673 0.082037 0.065085
# 1 512.0 0.137379 0.201880 0.143473
# 2 1024.0 0.390675 0.624234 0.312046
# 3 2048.0 1.267739 2.406950 0.696045
# 4 4096.0 4.445138 9.136333 1.665788
# 5 8192.0 16.768614 35.265533 4.380486
# 6 16384.0 65.235970 138.393600 12.997633
# 7 32768.0 257.317902 550.442993 42.821121
# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-bwd:
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.296461 0.266581 0.254022
# 1 512.0 0.671427 0.613643 0.551283
# 2 1024.0 1.719918 1.704295 1.229982
# 3 2048.0 4.945305 5.403364 2.721906
# 4 4096.0 15.934293 18.960999 6.259371
# 5 8192.0 55.406593 70.832130 15.676929
# 6 16384.0 208.750595 275.004425 44.837891
# 7 32768.0 808.057861 1080.647705 141.856766
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert8-hetero:
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.050739 0.032886 0.107837
# 1 512.0 0.073507 0.051996 0.200293
# 2 1024.0 0.106394 0.080679 0.240610
# 3 2048.0 0.177659 0.127660 0.287625
# 4 4096.0 0.326326 0.226971 0.377500
# 5 8192.0 0.586339 0.407367 0.559266
# 6 16384.0 1.102279 0.786221 0.920976
# 7 32768.0 2.097370 1.545090 1.644288
################
##### fp16 #####
################
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.032518 0.035472 0.029939
# 1 512.0 0.054266 0.087841 0.054320
# 2 1024.0 0.133447 0.263090 0.102045
# 3 2048.0 0.384615 1.023293 0.201763
# 4 4096.0 1.300890 4.023936 0.449555
# 5 8192.0 4.774144 15.816704 1.150854
# 6 16384.0 18.220032 62.771198 3.356001
# 7 32768.0 71.405571 250.273788 10.976142
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.083342 0.069742 0.079496
# 1 512.0 0.159894 0.170995 0.151705
# 2 1024.0 0.386071 0.522407 0.331443
# 3 2048.0 1.067715 1.737333 0.715248
# 4 4096.0 3.382731 6.219520 1.597457
# 5 8192.0 11.857793 23.560448 3.879035
# 6 16384.0 44.422142 91.251709 10.626843
# 7 32768.0 175.011841 359.473145 32.340992
################
##### bf16 #####
################
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.037636 0.035902 0.031512
# 1 512.0 0.058591 0.087229 0.058125
# 2 1024.0 0.143337 0.263919 0.108443
# 3 2048.0 0.414458 1.025985 0.214114
# 4 4096.0 1.390841 4.020010 0.480550
# 5 8192.0 5.067938 15.808171 1.230874
# 6 16384.0 19.442280 62.765057 3.597274
# 7 32768.0 75.501572 250.443771 11.768959
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
# 0 256.0 0.084404 0.070663 0.082613
# 1 512.0 0.161510 0.172882 0.157661
# 2 1024.0 0.388954 0.526047 0.339855
# 3 2048.0 1.075814 1.736057 0.732420
# 4 4096.0 3.401622 6.221376 1.636039
# 5 8192.0 11.915136 23.483391 3.968725
# 6 16384.0 44.660225 91.302910 10.857130
# 7 32768.0 175.038467 359.048187 32.778240