Satandon1999
commited on
Update triton_flash_blocksparse_attn.py
Browse filesAdding with ```torch.cuda.device(q.device.index)``` at all applicable sections to support multi gpu.
- triton_flash_blocksparse_attn.py +64 -62
triton_flash_blocksparse_attn.py
CHANGED
@@ -992,37 +992,38 @@ def blocksparse_flash_attn_padded_fwd(
|
|
992 |
|
993 |
grid = (len(q_start_sids), n_heads)
|
994 |
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
|
|
|
1026 |
|
1027 |
return out
|
1028 |
|
@@ -1094,37 +1095,38 @@ def blocksparse_flash_attn_varlen_fwd(
|
|
1094 |
|
1095 |
grid = (len(q_start_sids), n_heads)
|
1096 |
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
|
|
1128 |
|
1129 |
return out
|
1130 |
|
|
|
992 |
|
993 |
grid = (len(q_start_sids), n_heads)
|
994 |
|
995 |
+
with torch.cuda.device(q.device.index):
|
996 |
+
_fwd_kernel_batch_inference[grid](
|
997 |
+
q, k, v, out,
|
998 |
+
sm_scale,
|
999 |
+
q_batch_starts,
|
1000 |
+
q_batch_ends,
|
1001 |
+
k_batch_starts,
|
1002 |
+
k_batch_ends,
|
1003 |
+
q_batch_ids,
|
1004 |
+
q_start_sids,
|
1005 |
+
|
1006 |
+
*q.stride(),
|
1007 |
+
*k.stride(),
|
1008 |
+
*v.stride(),
|
1009 |
+
*out.stride(),
|
1010 |
+
|
1011 |
+
layout_crow_indices,
|
1012 |
+
layout_col_indices,
|
1013 |
+
*layout_crow_indices.stride(),
|
1014 |
+
*layout_col_indices.stride(),
|
1015 |
+
|
1016 |
+
q_k_ratio,
|
1017 |
+
HAS_BATCH_DIM = True,
|
1018 |
+
D_HEAD = head_size,
|
1019 |
+
BLOCK_M = block_size,
|
1020 |
+
BLOCK_N = block_size,
|
1021 |
+
BLOCK_D = block_d,
|
1022 |
+
BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
|
1023 |
+
EVEN_D = block_d == head_size,
|
1024 |
+
num_warps = 1 if q_len == 1 else 4,
|
1025 |
+
num_stages = 3
|
1026 |
+
)
|
1027 |
|
1028 |
return out
|
1029 |
|
|
|
1095 |
|
1096 |
grid = (len(q_start_sids), n_heads)
|
1097 |
|
1098 |
+
with torch.cuda.device(q.device.index):
|
1099 |
+
_fwd_kernel_batch_inference[grid](
|
1100 |
+
q, k, v, out,
|
1101 |
+
sm_scale,
|
1102 |
+
cu_seqlens_q[:-1],
|
1103 |
+
cu_seqlens_q[1:],
|
1104 |
+
cu_seqlens_k[:-1],
|
1105 |
+
cu_seqlens_k[1:],
|
1106 |
+
q_batch_ids,
|
1107 |
+
q_start_sids,
|
1108 |
+
|
1109 |
+
0, *q.stride(),
|
1110 |
+
0, *k.stride(),
|
1111 |
+
0, *v.stride(),
|
1112 |
+
0, *out.stride(),
|
1113 |
+
|
1114 |
+
layout_crow_indices,
|
1115 |
+
layout_col_indices,
|
1116 |
+
*layout_crow_indices.stride(),
|
1117 |
+
*layout_col_indices.stride(),
|
1118 |
+
|
1119 |
+
q_k_ratio,
|
1120 |
+
HAS_BATCH_DIM = False,
|
1121 |
+
D_HEAD = head_size,
|
1122 |
+
BLOCK_M = block_size,
|
1123 |
+
BLOCK_N = block_size,
|
1124 |
+
BLOCK_D = block_d,
|
1125 |
+
BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
|
1126 |
+
EVEN_D = block_d == head_size,
|
1127 |
+
num_warps = 1 if decoding_only else 4,
|
1128 |
+
num_stages = 3
|
1129 |
+
)
|
1130 |
|
1131 |
return out
|
1132 |
|