Fix for FlashAttention RuntimeError & Triton Multi GPU fix.
#17
by
Satandon1999
- opened
- positional_embedding.py +2 -2
- triton_flash_blocksparse_attn.py +89 -86
positional_embedding.py
CHANGED
@@ -269,10 +269,10 @@ class RotaryEmbedding(torch.nn.Module):
|
|
269 |
return (
|
270 |
apply_rotary_pos_emb(
|
271 |
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
272 |
-
),
|
273 |
apply_rotary_pos_emb(
|
274 |
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
275 |
-
),
|
276 |
)
|
277 |
|
278 |
@classmethod
|
|
|
269 |
return (
|
270 |
apply_rotary_pos_emb(
|
271 |
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
272 |
+
).to(q.dtype),
|
273 |
apply_rotary_pos_emb(
|
274 |
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
275 |
+
).to(q.dtype),
|
276 |
)
|
277 |
|
278 |
@classmethod
|
triton_flash_blocksparse_attn.py
CHANGED
@@ -611,30 +611,31 @@ def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BL
|
|
611 |
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
612 |
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
613 |
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
|
|
638 |
if inference:
|
639 |
L, m = None, None
|
640 |
|
@@ -991,37 +992,38 @@ def blocksparse_flash_attn_padded_fwd(
|
|
991 |
|
992 |
grid = (len(q_start_sids), n_heads)
|
993 |
|
994 |
-
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
|
|
1025 |
|
1026 |
return out
|
1027 |
|
@@ -1093,37 +1095,38 @@ def blocksparse_flash_attn_varlen_fwd(
|
|
1093 |
|
1094 |
grid = (len(q_start_sids), n_heads)
|
1095 |
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
|
|
1127 |
|
1128 |
return out
|
1129 |
|
|
|
611 |
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
612 |
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
613 |
|
614 |
+
with torch.cuda.device(q.device.index):
|
615 |
+
_fwd_kernel[grid](
|
616 |
+
q, k, v, sm_scale,
|
617 |
+
layout_crow_indices,
|
618 |
+
layout_col_indices,
|
619 |
+
layout_crow_indices.stride(0), layout_crow_indices.stride(1),
|
620 |
+
layout_col_indices.stride(0), layout_col_indices.stride(1),
|
621 |
+
tmp, L, m,
|
622 |
+
o,
|
623 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
624 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
625 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
626 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
627 |
+
q.shape[0], q.shape[1], k.shape[2],
|
628 |
+
k.shape[2] - q.shape[2],
|
629 |
+
q_rounded_len,
|
630 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
631 |
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
632 |
+
EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
|
633 |
+
EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
|
634 |
+
INFERENCE=inference,
|
635 |
+
NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
|
636 |
+
num_warps=num_warps,
|
637 |
+
num_stages=num_stages,
|
638 |
+
)
|
639 |
if inference:
|
640 |
L, m = None, None
|
641 |
|
|
|
992 |
|
993 |
grid = (len(q_start_sids), n_heads)
|
994 |
|
995 |
+
with torch.cuda.device(q.device.index):
|
996 |
+
_fwd_kernel_batch_inference[grid](
|
997 |
+
q, k, v, out,
|
998 |
+
sm_scale,
|
999 |
+
q_batch_starts,
|
1000 |
+
q_batch_ends,
|
1001 |
+
k_batch_starts,
|
1002 |
+
k_batch_ends,
|
1003 |
+
q_batch_ids,
|
1004 |
+
q_start_sids,
|
1005 |
+
|
1006 |
+
*q.stride(),
|
1007 |
+
*k.stride(),
|
1008 |
+
*v.stride(),
|
1009 |
+
*out.stride(),
|
1010 |
+
|
1011 |
+
layout_crow_indices,
|
1012 |
+
layout_col_indices,
|
1013 |
+
*layout_crow_indices.stride(),
|
1014 |
+
*layout_col_indices.stride(),
|
1015 |
+
|
1016 |
+
q_k_ratio,
|
1017 |
+
HAS_BATCH_DIM = True,
|
1018 |
+
D_HEAD = head_size,
|
1019 |
+
BLOCK_M = block_size,
|
1020 |
+
BLOCK_N = block_size,
|
1021 |
+
BLOCK_D = block_d,
|
1022 |
+
BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
|
1023 |
+
EVEN_D = block_d == head_size,
|
1024 |
+
num_warps = 1 if q_len == 1 else 4,
|
1025 |
+
num_stages = 3
|
1026 |
+
)
|
1027 |
|
1028 |
return out
|
1029 |
|
|
|
1095 |
|
1096 |
grid = (len(q_start_sids), n_heads)
|
1097 |
|
1098 |
+
with torch.cuda.device(q.device.index):
|
1099 |
+
_fwd_kernel_batch_inference[grid](
|
1100 |
+
q, k, v, out,
|
1101 |
+
sm_scale,
|
1102 |
+
cu_seqlens_q[:-1],
|
1103 |
+
cu_seqlens_q[1:],
|
1104 |
+
cu_seqlens_k[:-1],
|
1105 |
+
cu_seqlens_k[1:],
|
1106 |
+
q_batch_ids,
|
1107 |
+
q_start_sids,
|
1108 |
+
|
1109 |
+
0, *q.stride(),
|
1110 |
+
0, *k.stride(),
|
1111 |
+
0, *v.stride(),
|
1112 |
+
0, *out.stride(),
|
1113 |
+
|
1114 |
+
layout_crow_indices,
|
1115 |
+
layout_col_indices,
|
1116 |
+
*layout_crow_indices.stride(),
|
1117 |
+
*layout_col_indices.stride(),
|
1118 |
+
|
1119 |
+
q_k_ratio,
|
1120 |
+
HAS_BATCH_DIM = False,
|
1121 |
+
D_HEAD = head_size,
|
1122 |
+
BLOCK_M = block_size,
|
1123 |
+
BLOCK_N = block_size,
|
1124 |
+
BLOCK_D = block_d,
|
1125 |
+
BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
|
1126 |
+
EVEN_D = block_d == head_size,
|
1127 |
+
num_warps = 1 if decoding_only else 4,
|
1128 |
+
num_stages = 3
|
1129 |
+
)
|
1130 |
|
1131 |
return out
|
1132 |
|