import torch import sys import os current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(current_dir) sys.path.append(project_root) from hyvideo.modules.attenion import attention from xfuser.core.long_ctx_attention import xFuserLongContextAttention from xfuser.core.distributed import ( init_distributed_environment, initialize_model_parallel, # initialize_runtime_state, ) def init_dist(backend="nccl"): local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) print( f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}" ) torch.cuda.set_device(local_rank) init_distributed_environment(rank=rank, world_size=world_size) # dist.init_process_group(backend=backend) # construct a hybrid sequence parallel config (ulysses=2, ring = world_size // 2) if world_size > 1: ring_degree = world_size // 2 ulysses_degree = 2 else: ring_degree = 1 ulysses_degree = 1 initialize_model_parallel( sequence_parallel_degree=world_size, ring_degree=ring_degree, ulysses_degree=ulysses_degree, ) return rank, world_size def test_mm_double_stream_block_attention(rank, world_size): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 batch_size = 1 seq_len_img = 118800 seq_len_txt = 256 heads_num = 24 head_dim = 128 img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) img_v = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) txt_v = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) with torch.no_grad(): torch.distributed.broadcast(img_q, src=0) torch.distributed.broadcast(img_k, src=0) torch.distributed.broadcast(img_v, src=0) torch.distributed.broadcast(txt_q, src=0) torch.distributed.broadcast(txt_k, src=0) torch.distributed.broadcast(txt_v, src=0) q = torch.cat((img_q, txt_q), dim=1) k = torch.cat((img_k, txt_k), dim=1) v = torch.cat((img_v, txt_v), dim=1) cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32) cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32) max_seqlen_q = 119056 max_seqlen_kv = 119056 mode = "torch" # "torch", "vanilla", "flash" original_output = attention( q, k, v, mode=mode, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, batch_size=batch_size ) hybrid_seq_parallel_attn = xFuserLongContextAttention() hybrid_seq_parallel_output = hybrid_seq_parallel_attn( None, img_q, img_k, img_v, dropout_p=0.0, causal=False, joint_tensor_query=txt_q, joint_tensor_key=txt_k, joint_tensor_value=txt_v, joint_strategy="rear", ) b, s, a, d = hybrid_seq_parallel_output.shape hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1) assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}" torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3) print("test_mm_double_stream_block_attention Passed") def test_mm_single_stream_block_attention(rank, world_size): device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 txt_len = 256 batch_size = 1 seq_len_img = 118800 seq_len_txt = 256 heads_num = 24 head_dim = 128 with torch.no_grad(): img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) v = torch.randn(batch_size, seq_len_img + seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) torch.distributed.broadcast(img_q, src=0) torch.distributed.broadcast(img_k, src=0) torch.distributed.broadcast(txt_q, src=0) torch.distributed.broadcast(txt_k, src=0) torch.distributed.broadcast(v, src=0) q = torch.cat((img_q, txt_q), dim=1) k = torch.cat((img_k, txt_k), dim=1) cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32) cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32) max_seqlen_q = 119056 max_seqlen_kv = 119056 mode = "torch" # "torch", "vanilla", "flash" original_output = attention( q, k, v, mode=mode, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, batch_size=batch_size ) hybrid_seq_parallel_attn = xFuserLongContextAttention() hybrid_seq_parallel_output = hybrid_seq_parallel_attn( None, q[:, :-txt_len, :, :], k[:, :-txt_len, :, :], v[:, :-txt_len, :, :], dropout_p=0.0, causal=False, joint_tensor_query=q[:, -txt_len:, :, :], joint_tensor_key=k[:, -txt_len:, :, :], joint_tensor_value=v[:, -txt_len:, :, :], joint_strategy="rear", ) b, s, a, d = hybrid_seq_parallel_output.shape hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1) assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}" torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3) print("test_mm_single_stream_block_attention Passed") if __name__ == "__main__": rank, world_size = init_dist() test_mm_double_stream_block_attention(rank, world_size) test_mm_single_stream_block_attention(rank, world_size)