Fabrice-TIERCELIN
commited on
Upload test_attention.py
Browse files- test_attention.py +180 -0
test_attention.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
project_root = os.path.dirname(current_dir)
|
6 |
+
sys.path.append(project_root)
|
7 |
+
|
8 |
+
from hyvideo.modules.attenion import attention
|
9 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
10 |
+
from xfuser.core.distributed import (
|
11 |
+
init_distributed_environment,
|
12 |
+
initialize_model_parallel,
|
13 |
+
# initialize_runtime_state,
|
14 |
+
)
|
15 |
+
|
16 |
+
def init_dist(backend="nccl"):
|
17 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
18 |
+
rank = int(os.environ["RANK"])
|
19 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
20 |
+
|
21 |
+
print(
|
22 |
+
f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}"
|
23 |
+
)
|
24 |
+
|
25 |
+
torch.cuda.set_device(local_rank)
|
26 |
+
init_distributed_environment(rank=rank, world_size=world_size)
|
27 |
+
# dist.init_process_group(backend=backend)
|
28 |
+
# construct a hybrid sequence parallel config (ulysses=2, ring = world_size // 2)
|
29 |
+
|
30 |
+
if world_size > 1:
|
31 |
+
ring_degree = world_size // 2
|
32 |
+
ulysses_degree = 2
|
33 |
+
else:
|
34 |
+
ring_degree = 1
|
35 |
+
ulysses_degree = 1
|
36 |
+
initialize_model_parallel(
|
37 |
+
sequence_parallel_degree=world_size,
|
38 |
+
ring_degree=ring_degree,
|
39 |
+
ulysses_degree=ulysses_degree,
|
40 |
+
)
|
41 |
+
|
42 |
+
return rank, world_size
|
43 |
+
|
44 |
+
def test_mm_double_stream_block_attention(rank, world_size):
|
45 |
+
device = torch.device(f"cuda:{rank}")
|
46 |
+
dtype = torch.bfloat16
|
47 |
+
batch_size = 1
|
48 |
+
seq_len_img = 118800
|
49 |
+
seq_len_txt = 256
|
50 |
+
heads_num = 24
|
51 |
+
head_dim = 128
|
52 |
+
|
53 |
+
img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
|
54 |
+
img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
|
55 |
+
img_v = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
|
56 |
+
txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
|
57 |
+
txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
|
58 |
+
txt_v = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
|
59 |
+
|
60 |
+
with torch.no_grad():
|
61 |
+
torch.distributed.broadcast(img_q, src=0)
|
62 |
+
torch.distributed.broadcast(img_k, src=0)
|
63 |
+
torch.distributed.broadcast(img_v, src=0)
|
64 |
+
torch.distributed.broadcast(txt_q, src=0)
|
65 |
+
torch.distributed.broadcast(txt_k, src=0)
|
66 |
+
torch.distributed.broadcast(txt_v, src=0)
|
67 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
68 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
69 |
+
v = torch.cat((img_v, txt_v), dim=1)
|
70 |
+
|
71 |
+
|
72 |
+
cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
|
73 |
+
cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
|
74 |
+
max_seqlen_q = 119056
|
75 |
+
max_seqlen_kv = 119056
|
76 |
+
mode = "torch" # "torch", "vanilla", "flash"
|
77 |
+
|
78 |
+
original_output = attention(
|
79 |
+
q,
|
80 |
+
k,
|
81 |
+
v,
|
82 |
+
mode=mode,
|
83 |
+
cu_seqlens_q=cu_seqlens_q,
|
84 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
85 |
+
max_seqlen_q=max_seqlen_q,
|
86 |
+
max_seqlen_kv=max_seqlen_kv,
|
87 |
+
batch_size=batch_size
|
88 |
+
)
|
89 |
+
|
90 |
+
hybrid_seq_parallel_attn = xFuserLongContextAttention()
|
91 |
+
hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
|
92 |
+
None,
|
93 |
+
img_q,
|
94 |
+
img_k,
|
95 |
+
img_v,
|
96 |
+
dropout_p=0.0,
|
97 |
+
causal=False,
|
98 |
+
joint_tensor_query=txt_q,
|
99 |
+
joint_tensor_key=txt_k,
|
100 |
+
joint_tensor_value=txt_v,
|
101 |
+
joint_strategy="rear",
|
102 |
+
)
|
103 |
+
|
104 |
+
b, s, a, d = hybrid_seq_parallel_output.shape
|
105 |
+
hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
|
106 |
+
|
107 |
+
assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
|
108 |
+
|
109 |
+
torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
|
110 |
+
print("test_mm_double_stream_block_attention Passed")
|
111 |
+
|
112 |
+
def test_mm_single_stream_block_attention(rank, world_size):
|
113 |
+
device = torch.device(f"cuda:{rank}")
|
114 |
+
dtype = torch.bfloat16
|
115 |
+
txt_len = 256
|
116 |
+
batch_size = 1
|
117 |
+
seq_len_img = 118800
|
118 |
+
seq_len_txt = 256
|
119 |
+
heads_num = 24
|
120 |
+
head_dim = 128
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
|
124 |
+
img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
|
125 |
+
txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
|
126 |
+
txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
|
127 |
+
v = torch.randn(batch_size, seq_len_img + seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
|
128 |
+
|
129 |
+
torch.distributed.broadcast(img_q, src=0)
|
130 |
+
torch.distributed.broadcast(img_k, src=0)
|
131 |
+
torch.distributed.broadcast(txt_q, src=0)
|
132 |
+
torch.distributed.broadcast(txt_k, src=0)
|
133 |
+
torch.distributed.broadcast(v, src=0)
|
134 |
+
|
135 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
136 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
137 |
+
|
138 |
+
cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
|
139 |
+
cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
|
140 |
+
max_seqlen_q = 119056
|
141 |
+
max_seqlen_kv = 119056
|
142 |
+
mode = "torch" # "torch", "vanilla", "flash"
|
143 |
+
|
144 |
+
original_output = attention(
|
145 |
+
q,
|
146 |
+
k,
|
147 |
+
v,
|
148 |
+
mode=mode,
|
149 |
+
cu_seqlens_q=cu_seqlens_q,
|
150 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
151 |
+
max_seqlen_q=max_seqlen_q,
|
152 |
+
max_seqlen_kv=max_seqlen_kv,
|
153 |
+
batch_size=batch_size
|
154 |
+
)
|
155 |
+
|
156 |
+
hybrid_seq_parallel_attn = xFuserLongContextAttention()
|
157 |
+
hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
|
158 |
+
None,
|
159 |
+
q[:, :-txt_len, :, :],
|
160 |
+
k[:, :-txt_len, :, :],
|
161 |
+
v[:, :-txt_len, :, :],
|
162 |
+
dropout_p=0.0,
|
163 |
+
causal=False,
|
164 |
+
joint_tensor_query=q[:, -txt_len:, :, :],
|
165 |
+
joint_tensor_key=k[:, -txt_len:, :, :],
|
166 |
+
joint_tensor_value=v[:, -txt_len:, :, :],
|
167 |
+
joint_strategy="rear",
|
168 |
+
)
|
169 |
+
b, s, a, d = hybrid_seq_parallel_output.shape
|
170 |
+
hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
|
171 |
+
|
172 |
+
assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
|
173 |
+
|
174 |
+
torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
|
175 |
+
print("test_mm_single_stream_block_attention Passed")
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
rank, world_size = init_dist()
|
179 |
+
test_mm_double_stream_block_attention(rank, world_size)
|
180 |
+
test_mm_single_stream_block_attention(rank, world_size)
|