Satandon1999
commited on
Update triton_flash_blocksparse_attn.py
Browse filesAdd suggestion similar to https://huggingface.co/THUDM/cogagent-chat-hf/blob/d519da3b191401234f4bd86ce1c287c61bc276a3/util.py#L210 to avoid error
```ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)```
- triton_flash_blocksparse_attn.py +25 -24
triton_flash_blocksparse_attn.py
CHANGED
@@ -611,30 +611,31 @@ def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BL
|
|
611 |
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
612 |
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
613 |
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
|
|
638 |
if inference:
|
639 |
L, m = None, None
|
640 |
|
|
|
611 |
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
612 |
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
613 |
|
614 |
+
with torch.cuda.device(q.device.index):
|
615 |
+
_fwd_kernel[grid](
|
616 |
+
q, k, v, sm_scale,
|
617 |
+
layout_crow_indices,
|
618 |
+
layout_col_indices,
|
619 |
+
layout_crow_indices.stride(0), layout_crow_indices.stride(1),
|
620 |
+
layout_col_indices.stride(0), layout_col_indices.stride(1),
|
621 |
+
tmp, L, m,
|
622 |
+
o,
|
623 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
624 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
625 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
626 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
627 |
+
q.shape[0], q.shape[1], k.shape[2],
|
628 |
+
k.shape[2] - q.shape[2],
|
629 |
+
q_rounded_len,
|
630 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
631 |
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
632 |
+
EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
|
633 |
+
EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
|
634 |
+
INFERENCE=inference,
|
635 |
+
NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
|
636 |
+
num_warps=num_warps,
|
637 |
+
num_stages=num_stages,
|
638 |
+
)
|
639 |
if inference:
|
640 |
L, m = None, None
|
641 |
|