File size: 5,250 Bytes
ee124bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import torch
import torch.nn as nn
from typing import Callable, Tuple
def bipartite_soft_matching(
metric: torch.Tensor,
r: int,
) -> Tuple[Callable, Callable]:
"""
Applies ToMe with a balanced matching set (50%, 50%).
Input size is [batch, tokens, channels].
r indicates the number of tokens to remove (max 50% of tokens).
"""
protected = 0
t = metric.shape[1]
r = min(r, (t - protected) // 2)
assert r > 0, r
with torch.no_grad():
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = metric[..., ::2, :], metric[..., 1::2, :]
scores = a @ b.transpose(-1, -2)
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = x[..., ::2, :], x[..., 1::2, :]
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_add(-2, dst_idx.expand(n, r, c), src) # , reduce=mode)
return torch.cat([unm, dst], dim=1)
def unmerge(x: torch.Tensor) -> torch.Tensor:
unm_len = unm_idx.shape[1]
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
n, _, c = unm.shape
src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))
out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)
out[..., 1::2, :] = dst
out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)
return out
return merge, unmerge
def merge_wavg(
merge: Callable, x: torch.Tensor, size: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Applies the merge function by taking a weighted average based on token size.
Returns the merged tensor and the new token sizes.
"""
if size is None:
size = torch.ones_like(x[..., 0, None])
x = merge(x * size, mode="sum")
size = merge(size, mode="sum")
x = x / size
return x, size
class ToMe16_mlp_hd64(nn.Module):
def __init__(self, config, vision_cfg):
super().__init__()
self._config = config
self.mm_hidden_size = config.mm_hidden_size
self.hw = vision_cfg.image_size // vision_cfg.patch_size
self.num_attention_heads = vision_cfg.num_attention_heads
self.mlp = nn.Sequential(nn.Linear(config.mm_hidden_size, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, config.hidden_size))
self.max_pos_hw = self.hw
self.max_pos_num_frames = config.mm_pos_num_frames
self.num_image_patches_per_side = 8
self.num_frame_patches_per_side = 4
def merge_tokens(self, x, target_num_token):
r"""
x = torch.randn(10, 2560, c)
x = merge_tokens(x, r_merge_list=[1280])
"""
size = None
b, p, c = x.shape
tmp_p = p
r_merge_list = []
assert tmp_p > target_num_token, f"{tmp_p} should greater than {target_num_token}"
while tmp_p != target_num_token:
if tmp_p - target_num_token <= (tmp_p // 2):
r_merge_list.append(tmp_p - target_num_token)
break
else:
r_merge_list.append(tmp_p // 2)
tmp_p = tmp_p - (tmp_p // 2)
head = self.num_attention_heads
dim = c // head
for r in r_merge_list:
metric = x.reshape(b, p, head, dim).mean(2) # [b, p, c//head]
merge, _ = bipartite_soft_matching(
metric,
r
)
x, size = merge_wavg(merge, x, size)
_, p, _ = x.shape
return x
def forward(self, x, compress=False, local_num_frames=-1): # 单帧64
height = width = self.hw
assert height * width == x.shape[1]
if local_num_frames != -1 and local_num_frames != 1:
assert compress is True
if compress:
if local_num_frames != -1:
num_frames = local_num_frames
x = x.reshape(x.shape[0] // local_num_frames, -1, x.shape[-1])
else:
num_frames = x.shape[0]
x = x.reshape(1, -1, x.shape[-1])
num_tome_tokens = 16 * num_frames
else:
num_tome_tokens = 64
x = self.merge_tokens(x, target_num_token=num_tome_tokens)
x = self.mlp(x)
return x
@property
def config(self):
return {"mm_projector_type": "tome16_mlp_hd64"}
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, "mm_projector_type", "linear")
if projector_type == 'tome16_mlp_hd64':
return ToMe16_mlp_hd64(config, kwargs["vision_cfg"])
raise ValueError(f"Unknown projector type: {projector_type}")
|