Fabrice-TIERCELIN commited on
Commit
21377e0
·
verified ·
1 Parent(s): 33782f1

Upload test_attention.py

Browse files
Files changed (1) hide show
  1. test_attention.py +180 -0
test_attention.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import os
4
+ current_dir = os.path.dirname(os.path.abspath(__file__))
5
+ project_root = os.path.dirname(current_dir)
6
+ sys.path.append(project_root)
7
+
8
+ from hyvideo.modules.attenion import attention
9
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
10
+ from xfuser.core.distributed import (
11
+ init_distributed_environment,
12
+ initialize_model_parallel,
13
+ # initialize_runtime_state,
14
+ )
15
+
16
+ def init_dist(backend="nccl"):
17
+ local_rank = int(os.environ["LOCAL_RANK"])
18
+ rank = int(os.environ["RANK"])
19
+ world_size = int(os.environ["WORLD_SIZE"])
20
+
21
+ print(
22
+ f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}"
23
+ )
24
+
25
+ torch.cuda.set_device(local_rank)
26
+ init_distributed_environment(rank=rank, world_size=world_size)
27
+ # dist.init_process_group(backend=backend)
28
+ # construct a hybrid sequence parallel config (ulysses=2, ring = world_size // 2)
29
+
30
+ if world_size > 1:
31
+ ring_degree = world_size // 2
32
+ ulysses_degree = 2
33
+ else:
34
+ ring_degree = 1
35
+ ulysses_degree = 1
36
+ initialize_model_parallel(
37
+ sequence_parallel_degree=world_size,
38
+ ring_degree=ring_degree,
39
+ ulysses_degree=ulysses_degree,
40
+ )
41
+
42
+ return rank, world_size
43
+
44
+ def test_mm_double_stream_block_attention(rank, world_size):
45
+ device = torch.device(f"cuda:{rank}")
46
+ dtype = torch.bfloat16
47
+ batch_size = 1
48
+ seq_len_img = 118800
49
+ seq_len_txt = 256
50
+ heads_num = 24
51
+ head_dim = 128
52
+
53
+ img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
54
+ img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
55
+ img_v = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
56
+ txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
57
+ txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
58
+ txt_v = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
59
+
60
+ with torch.no_grad():
61
+ torch.distributed.broadcast(img_q, src=0)
62
+ torch.distributed.broadcast(img_k, src=0)
63
+ torch.distributed.broadcast(img_v, src=0)
64
+ torch.distributed.broadcast(txt_q, src=0)
65
+ torch.distributed.broadcast(txt_k, src=0)
66
+ torch.distributed.broadcast(txt_v, src=0)
67
+ q = torch.cat((img_q, txt_q), dim=1)
68
+ k = torch.cat((img_k, txt_k), dim=1)
69
+ v = torch.cat((img_v, txt_v), dim=1)
70
+
71
+
72
+ cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
73
+ cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
74
+ max_seqlen_q = 119056
75
+ max_seqlen_kv = 119056
76
+ mode = "torch" # "torch", "vanilla", "flash"
77
+
78
+ original_output = attention(
79
+ q,
80
+ k,
81
+ v,
82
+ mode=mode,
83
+ cu_seqlens_q=cu_seqlens_q,
84
+ cu_seqlens_kv=cu_seqlens_kv,
85
+ max_seqlen_q=max_seqlen_q,
86
+ max_seqlen_kv=max_seqlen_kv,
87
+ batch_size=batch_size
88
+ )
89
+
90
+ hybrid_seq_parallel_attn = xFuserLongContextAttention()
91
+ hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
92
+ None,
93
+ img_q,
94
+ img_k,
95
+ img_v,
96
+ dropout_p=0.0,
97
+ causal=False,
98
+ joint_tensor_query=txt_q,
99
+ joint_tensor_key=txt_k,
100
+ joint_tensor_value=txt_v,
101
+ joint_strategy="rear",
102
+ )
103
+
104
+ b, s, a, d = hybrid_seq_parallel_output.shape
105
+ hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
106
+
107
+ assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
108
+
109
+ torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
110
+ print("test_mm_double_stream_block_attention Passed")
111
+
112
+ def test_mm_single_stream_block_attention(rank, world_size):
113
+ device = torch.device(f"cuda:{rank}")
114
+ dtype = torch.bfloat16
115
+ txt_len = 256
116
+ batch_size = 1
117
+ seq_len_img = 118800
118
+ seq_len_txt = 256
119
+ heads_num = 24
120
+ head_dim = 128
121
+
122
+ with torch.no_grad():
123
+ img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
124
+ img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
125
+ txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
126
+ txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
127
+ v = torch.randn(batch_size, seq_len_img + seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
128
+
129
+ torch.distributed.broadcast(img_q, src=0)
130
+ torch.distributed.broadcast(img_k, src=0)
131
+ torch.distributed.broadcast(txt_q, src=0)
132
+ torch.distributed.broadcast(txt_k, src=0)
133
+ torch.distributed.broadcast(v, src=0)
134
+
135
+ q = torch.cat((img_q, txt_q), dim=1)
136
+ k = torch.cat((img_k, txt_k), dim=1)
137
+
138
+ cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
139
+ cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
140
+ max_seqlen_q = 119056
141
+ max_seqlen_kv = 119056
142
+ mode = "torch" # "torch", "vanilla", "flash"
143
+
144
+ original_output = attention(
145
+ q,
146
+ k,
147
+ v,
148
+ mode=mode,
149
+ cu_seqlens_q=cu_seqlens_q,
150
+ cu_seqlens_kv=cu_seqlens_kv,
151
+ max_seqlen_q=max_seqlen_q,
152
+ max_seqlen_kv=max_seqlen_kv,
153
+ batch_size=batch_size
154
+ )
155
+
156
+ hybrid_seq_parallel_attn = xFuserLongContextAttention()
157
+ hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
158
+ None,
159
+ q[:, :-txt_len, :, :],
160
+ k[:, :-txt_len, :, :],
161
+ v[:, :-txt_len, :, :],
162
+ dropout_p=0.0,
163
+ causal=False,
164
+ joint_tensor_query=q[:, -txt_len:, :, :],
165
+ joint_tensor_key=k[:, -txt_len:, :, :],
166
+ joint_tensor_value=v[:, -txt_len:, :, :],
167
+ joint_strategy="rear",
168
+ )
169
+ b, s, a, d = hybrid_seq_parallel_output.shape
170
+ hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
171
+
172
+ assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
173
+
174
+ torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
175
+ print("test_mm_single_stream_block_attention Passed")
176
+
177
+ if __name__ == "__main__":
178
+ rank, world_size = init_dist()
179
+ test_mm_double_stream_block_attention(rank, world_size)
180
+ test_mm_single_stream_block_attention(rank, world_size)