michaelryoo
commited on
Upload model
Browse files- config.json +1 -2
- modeling_xgenmm.py +93 -7
config.json
CHANGED
@@ -18,8 +18,7 @@
|
|
18 |
"vision_encoder_config": {
|
19 |
"anyres_patch_sampling": false,
|
20 |
"image_aspect_ratio": "pad",
|
21 |
-
"model_type": "xgenmm_vision_encoder"
|
22 |
-
"temporal_encoder_mode": "gttm"
|
23 |
},
|
24 |
"vision_tokenizer_config": {
|
25 |
"model_type": "xgenmm_vision_tokenizer"
|
|
|
18 |
"vision_encoder_config": {
|
19 |
"anyres_patch_sampling": false,
|
20 |
"image_aspect_ratio": "pad",
|
21 |
+
"model_type": "xgenmm_vision_encoder"
|
|
|
22 |
},
|
23 |
"vision_tokenizer_config": {
|
24 |
"model_type": "xgenmm_vision_tokenizer"
|
modeling_xgenmm.py
CHANGED
@@ -78,14 +78,15 @@ class XGenMMConfig(PretrainedConfig):
|
|
78 |
vision_encoder_config = {
|
79 |
"image_aspect_ratio": "pad",
|
80 |
"anyres_patch_sampling": False,
|
81 |
-
"temporal_encoder_mode": "gttm",
|
82 |
}
|
83 |
logger.info(
|
84 |
"vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values."
|
85 |
)
|
86 |
|
87 |
if vision_tokenizer_config is None:
|
88 |
-
vision_tokenizer_config = {
|
|
|
|
|
89 |
logger.info(
|
90 |
"vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values."
|
91 |
)
|
@@ -843,6 +844,64 @@ class TokenTuringMachineUnit(nn.Module):
|
|
843 |
return (mem_out_tokens, output_tokens)
|
844 |
|
845 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
class GroupedTokenTuringMachine4(nn.Module):
|
847 |
def __init__(
|
848 |
self,
|
@@ -1000,10 +1059,7 @@ class TokenTuringMachine(nn.Module):
|
|
1000 |
pos = pos.unsqueeze(1)
|
1001 |
step_tokens = step_tokens + pos
|
1002 |
|
1003 |
-
# print(step_tokens.shape)
|
1004 |
memory_tokens, output_tokens = self.ttm_unit(memory_tokens, step_tokens)
|
1005 |
-
# print(f'memory_tokens shape: {memory_tokens.shape}')
|
1006 |
-
# print(f'output_tokens shape: {output_tokens.shape}')
|
1007 |
output_tokens_list.append(output_tokens)
|
1008 |
|
1009 |
if self.final_output_only:
|
@@ -1016,6 +1072,31 @@ class TokenTuringMachine(nn.Module):
|
|
1016 |
return output_tokens
|
1017 |
|
1018 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1019 |
def num_params(module, filter_to_trainable=False):
|
1020 |
"""Returns the number of parameters in the module, or optionally only the trainable parameters"""
|
1021 |
if filter_to_trainable:
|
@@ -1094,8 +1175,13 @@ class PerceiverResampler(VisionTokenizer):
|
|
1094 |
if self.temporal_encoder_mode=='gttm':
|
1095 |
# self.ttm = TokenTuringMachine(dim=dim, memory_size=128, memory_out_mode=True)
|
1096 |
self.temporal_encoder = GroupedTokenTuringMachine(dim=dim, process_size=128, memory_size_per_group=4)
|
1097 |
-
elif self.temporal_encoder_mode=='
|
1098 |
self.temporal_encoder = GroupedTokenTuringMachine4(dim=dim, process_size=128, memory_size_per_group=4, output_size=32)
|
|
|
|
|
|
|
|
|
|
|
1099 |
|
1100 |
def forward(self, x, vision_attn_masks):
|
1101 |
"""
|
@@ -1126,7 +1212,7 @@ class PerceiverResampler(VisionTokenizer):
|
|
1126 |
latents = attn(x, latents, vision_attn_masks) + latents
|
1127 |
latents = ff(latents) + latents
|
1128 |
|
1129 |
-
if self.
|
1130 |
latents = self.temporal_encoder(latents)
|
1131 |
|
1132 |
if exists(self.projection):
|
|
|
78 |
vision_encoder_config = {
|
79 |
"image_aspect_ratio": "pad",
|
80 |
"anyres_patch_sampling": False,
|
|
|
81 |
}
|
82 |
logger.info(
|
83 |
"vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values."
|
84 |
)
|
85 |
|
86 |
if vision_tokenizer_config is None:
|
87 |
+
vision_tokenizer_config = {
|
88 |
+
"temporal_encoder_mode": "gttm",
|
89 |
+
}
|
90 |
logger.info(
|
91 |
"vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values."
|
92 |
)
|
|
|
844 |
return (mem_out_tokens, output_tokens)
|
845 |
|
846 |
|
847 |
+
class GroupedTokenTuringMachine7(nn.Module):
|
848 |
+
def __init__(
|
849 |
+
self,
|
850 |
+
*,
|
851 |
+
dim,
|
852 |
+
output_size=32,
|
853 |
+
memory_size_per_group=4,
|
854 |
+
num_layers=4,
|
855 |
+
num_heads=8,
|
856 |
+
):
|
857 |
+
super().__init__()
|
858 |
+
|
859 |
+
self.ttm_unit = GroupedTokenTuringMachineUnit(
|
860 |
+
dim=dim,
|
861 |
+
process_size=output_size,
|
862 |
+
memory_size_per_group=memory_size_per_group,
|
863 |
+
num_layers=num_layers,
|
864 |
+
num_heads=num_heads)
|
865 |
+
|
866 |
+
self.initial_memory = nn.Parameter(torch.randn(output_size, memory_size_per_group, dim))
|
867 |
+
|
868 |
+
self.pos_emb = PositionalEncoding1D(dim)
|
869 |
+
|
870 |
+
self.initial_reduction = TokenLearnerAttentionModule(dim=dim, num_target_tokens=output_size)
|
871 |
+
|
872 |
+
def forward(self, x):
|
873 |
+
"""
|
874 |
+
Args:
|
875 |
+
x (torch.Tensor):
|
876 |
+
shape (b, T, n, D)
|
877 |
+
"""
|
878 |
+
b, T, n, D = x.shape
|
879 |
+
|
880 |
+
memory_tokens = repeat(self.initial_memory, "n g d -> b n g d", b=b)
|
881 |
+
|
882 |
+
mean_x = torch.mean(x, dim=-2, keepdim=False)
|
883 |
+
positional_embeddings = self.pos_emb(mean_x) # (b, T, d)
|
884 |
+
|
885 |
+
for i in range(T):
|
886 |
+
step_tokens = x[:, i, :, :]
|
887 |
+
|
888 |
+
pos = positional_embeddings[:, i, :]
|
889 |
+
pos = pos.unsqueeze(1)
|
890 |
+
step_tokens = step_tokens + pos
|
891 |
+
|
892 |
+
step_tokens = self.initial_reduction(step_tokens)
|
893 |
+
|
894 |
+
# print(memory_tokens.shape)
|
895 |
+
# print(step_tokens.shape)
|
896 |
+
|
897 |
+
memory_tokens = self.ttm_unit(memory_tokens, step_tokens)
|
898 |
+
|
899 |
+
memory_tokens = torch.mean(memory_tokens, dim=-2, keepdim=False)
|
900 |
+
# memory_tokens = torch.amax(memory_tokens, dim=-2, keepdim=False)
|
901 |
+
|
902 |
+
return memory_tokens.unsqueeze(1)
|
903 |
+
|
904 |
+
|
905 |
class GroupedTokenTuringMachine4(nn.Module):
|
906 |
def __init__(
|
907 |
self,
|
|
|
1059 |
pos = pos.unsqueeze(1)
|
1060 |
step_tokens = step_tokens + pos
|
1061 |
|
|
|
1062 |
memory_tokens, output_tokens = self.ttm_unit(memory_tokens, step_tokens)
|
|
|
|
|
1063 |
output_tokens_list.append(output_tokens)
|
1064 |
|
1065 |
if self.final_output_only:
|
|
|
1072 |
return output_tokens
|
1073 |
|
1074 |
|
1075 |
+
class TokenLearner(nn.Module):
|
1076 |
+
def __init__(
|
1077 |
+
self,
|
1078 |
+
*,
|
1079 |
+
dim,
|
1080 |
+
output_size=128,
|
1081 |
+
):
|
1082 |
+
super().__init__()
|
1083 |
+
|
1084 |
+
self.final_output = TokenLearnerAttentionModule(dim=dim, num_target_tokens=output_size)
|
1085 |
+
|
1086 |
+
def forward(self, x):
|
1087 |
+
"""
|
1088 |
+
Args:
|
1089 |
+
x (torch.Tensor):
|
1090 |
+
shape (b, T, n, D)
|
1091 |
+
"""
|
1092 |
+
b, T, n, D = x.shape
|
1093 |
+
|
1094 |
+
output_tokens = x.view(b, -1, D)
|
1095 |
+
output_tokens = self.final_output(output_tokens)
|
1096 |
+
|
1097 |
+
return output_tokens.unsqueeze(1)
|
1098 |
+
|
1099 |
+
|
1100 |
def num_params(module, filter_to_trainable=False):
|
1101 |
"""Returns the number of parameters in the module, or optionally only the trainable parameters"""
|
1102 |
if filter_to_trainable:
|
|
|
1175 |
if self.temporal_encoder_mode=='gttm':
|
1176 |
# self.ttm = TokenTuringMachine(dim=dim, memory_size=128, memory_out_mode=True)
|
1177 |
self.temporal_encoder = GroupedTokenTuringMachine(dim=dim, process_size=128, memory_size_per_group=4)
|
1178 |
+
elif self.temporal_encoder_mode=='gttm4':
|
1179 |
self.temporal_encoder = GroupedTokenTuringMachine4(dim=dim, process_size=128, memory_size_per_group=4, output_size=32)
|
1180 |
+
elif self.temporal_encoder_mode=='tokenlearner':
|
1181 |
+
self.temporal_encoder = TokenLearner(dim=dim, output_size=32)
|
1182 |
+
elif self.temporal_encoder_mode=='gttm7':
|
1183 |
+
self.temporal_encoder = GroupedTokenTuringMachine7(dim=dim, memory_size_per_group=4, output_size=32)
|
1184 |
+
|
1185 |
|
1186 |
def forward(self, x, vision_attn_masks):
|
1187 |
"""
|
|
|
1212 |
latents = attn(x, latents, vision_attn_masks) + latents
|
1213 |
latents = ff(latents) + latents
|
1214 |
|
1215 |
+
if self.temporal_encoder_mode is not None:
|
1216 |
latents = self.temporal_encoder(latents)
|
1217 |
|
1218 |
if exists(self.projection):
|