michaelryoo commited on
Commit
4d46a61
·
verified ·
1 Parent(s): 05ce84e

Upload model

Browse files
Files changed (2) hide show
  1. config.json +1 -2
  2. modeling_xgenmm.py +93 -7
config.json CHANGED
@@ -18,8 +18,7 @@
18
  "vision_encoder_config": {
19
  "anyres_patch_sampling": false,
20
  "image_aspect_ratio": "pad",
21
- "model_type": "xgenmm_vision_encoder",
22
- "temporal_encoder_mode": "gttm"
23
  },
24
  "vision_tokenizer_config": {
25
  "model_type": "xgenmm_vision_tokenizer"
 
18
  "vision_encoder_config": {
19
  "anyres_patch_sampling": false,
20
  "image_aspect_ratio": "pad",
21
+ "model_type": "xgenmm_vision_encoder"
 
22
  },
23
  "vision_tokenizer_config": {
24
  "model_type": "xgenmm_vision_tokenizer"
modeling_xgenmm.py CHANGED
@@ -78,14 +78,15 @@ class XGenMMConfig(PretrainedConfig):
78
  vision_encoder_config = {
79
  "image_aspect_ratio": "pad",
80
  "anyres_patch_sampling": False,
81
- "temporal_encoder_mode": "gttm",
82
  }
83
  logger.info(
84
  "vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values."
85
  )
86
 
87
  if vision_tokenizer_config is None:
88
- vision_tokenizer_config = {}
 
 
89
  logger.info(
90
  "vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values."
91
  )
@@ -843,6 +844,64 @@ class TokenTuringMachineUnit(nn.Module):
843
  return (mem_out_tokens, output_tokens)
844
 
845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
846
  class GroupedTokenTuringMachine4(nn.Module):
847
  def __init__(
848
  self,
@@ -1000,10 +1059,7 @@ class TokenTuringMachine(nn.Module):
1000
  pos = pos.unsqueeze(1)
1001
  step_tokens = step_tokens + pos
1002
 
1003
- # print(step_tokens.shape)
1004
  memory_tokens, output_tokens = self.ttm_unit(memory_tokens, step_tokens)
1005
- # print(f'memory_tokens shape: {memory_tokens.shape}')
1006
- # print(f'output_tokens shape: {output_tokens.shape}')
1007
  output_tokens_list.append(output_tokens)
1008
 
1009
  if self.final_output_only:
@@ -1016,6 +1072,31 @@ class TokenTuringMachine(nn.Module):
1016
  return output_tokens
1017
 
1018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1019
  def num_params(module, filter_to_trainable=False):
1020
  """Returns the number of parameters in the module, or optionally only the trainable parameters"""
1021
  if filter_to_trainable:
@@ -1094,8 +1175,13 @@ class PerceiverResampler(VisionTokenizer):
1094
  if self.temporal_encoder_mode=='gttm':
1095
  # self.ttm = TokenTuringMachine(dim=dim, memory_size=128, memory_out_mode=True)
1096
  self.temporal_encoder = GroupedTokenTuringMachine(dim=dim, process_size=128, memory_size_per_group=4)
1097
- elif self.temporal_encoder_mode=='gttm_pool':
1098
  self.temporal_encoder = GroupedTokenTuringMachine4(dim=dim, process_size=128, memory_size_per_group=4, output_size=32)
 
 
 
 
 
1099
 
1100
  def forward(self, x, vision_attn_masks):
1101
  """
@@ -1126,7 +1212,7 @@ class PerceiverResampler(VisionTokenizer):
1126
  latents = attn(x, latents, vision_attn_masks) + latents
1127
  latents = ff(latents) + latents
1128
 
1129
- if self.video_mode is not None:
1130
  latents = self.temporal_encoder(latents)
1131
 
1132
  if exists(self.projection):
 
78
  vision_encoder_config = {
79
  "image_aspect_ratio": "pad",
80
  "anyres_patch_sampling": False,
 
81
  }
82
  logger.info(
83
  "vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values."
84
  )
85
 
86
  if vision_tokenizer_config is None:
87
+ vision_tokenizer_config = {
88
+ "temporal_encoder_mode": "gttm",
89
+ }
90
  logger.info(
91
  "vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values."
92
  )
 
844
  return (mem_out_tokens, output_tokens)
845
 
846
 
847
+ class GroupedTokenTuringMachine7(nn.Module):
848
+ def __init__(
849
+ self,
850
+ *,
851
+ dim,
852
+ output_size=32,
853
+ memory_size_per_group=4,
854
+ num_layers=4,
855
+ num_heads=8,
856
+ ):
857
+ super().__init__()
858
+
859
+ self.ttm_unit = GroupedTokenTuringMachineUnit(
860
+ dim=dim,
861
+ process_size=output_size,
862
+ memory_size_per_group=memory_size_per_group,
863
+ num_layers=num_layers,
864
+ num_heads=num_heads)
865
+
866
+ self.initial_memory = nn.Parameter(torch.randn(output_size, memory_size_per_group, dim))
867
+
868
+ self.pos_emb = PositionalEncoding1D(dim)
869
+
870
+ self.initial_reduction = TokenLearnerAttentionModule(dim=dim, num_target_tokens=output_size)
871
+
872
+ def forward(self, x):
873
+ """
874
+ Args:
875
+ x (torch.Tensor):
876
+ shape (b, T, n, D)
877
+ """
878
+ b, T, n, D = x.shape
879
+
880
+ memory_tokens = repeat(self.initial_memory, "n g d -> b n g d", b=b)
881
+
882
+ mean_x = torch.mean(x, dim=-2, keepdim=False)
883
+ positional_embeddings = self.pos_emb(mean_x) # (b, T, d)
884
+
885
+ for i in range(T):
886
+ step_tokens = x[:, i, :, :]
887
+
888
+ pos = positional_embeddings[:, i, :]
889
+ pos = pos.unsqueeze(1)
890
+ step_tokens = step_tokens + pos
891
+
892
+ step_tokens = self.initial_reduction(step_tokens)
893
+
894
+ # print(memory_tokens.shape)
895
+ # print(step_tokens.shape)
896
+
897
+ memory_tokens = self.ttm_unit(memory_tokens, step_tokens)
898
+
899
+ memory_tokens = torch.mean(memory_tokens, dim=-2, keepdim=False)
900
+ # memory_tokens = torch.amax(memory_tokens, dim=-2, keepdim=False)
901
+
902
+ return memory_tokens.unsqueeze(1)
903
+
904
+
905
  class GroupedTokenTuringMachine4(nn.Module):
906
  def __init__(
907
  self,
 
1059
  pos = pos.unsqueeze(1)
1060
  step_tokens = step_tokens + pos
1061
 
 
1062
  memory_tokens, output_tokens = self.ttm_unit(memory_tokens, step_tokens)
 
 
1063
  output_tokens_list.append(output_tokens)
1064
 
1065
  if self.final_output_only:
 
1072
  return output_tokens
1073
 
1074
 
1075
+ class TokenLearner(nn.Module):
1076
+ def __init__(
1077
+ self,
1078
+ *,
1079
+ dim,
1080
+ output_size=128,
1081
+ ):
1082
+ super().__init__()
1083
+
1084
+ self.final_output = TokenLearnerAttentionModule(dim=dim, num_target_tokens=output_size)
1085
+
1086
+ def forward(self, x):
1087
+ """
1088
+ Args:
1089
+ x (torch.Tensor):
1090
+ shape (b, T, n, D)
1091
+ """
1092
+ b, T, n, D = x.shape
1093
+
1094
+ output_tokens = x.view(b, -1, D)
1095
+ output_tokens = self.final_output(output_tokens)
1096
+
1097
+ return output_tokens.unsqueeze(1)
1098
+
1099
+
1100
  def num_params(module, filter_to_trainable=False):
1101
  """Returns the number of parameters in the module, or optionally only the trainable parameters"""
1102
  if filter_to_trainable:
 
1175
  if self.temporal_encoder_mode=='gttm':
1176
  # self.ttm = TokenTuringMachine(dim=dim, memory_size=128, memory_out_mode=True)
1177
  self.temporal_encoder = GroupedTokenTuringMachine(dim=dim, process_size=128, memory_size_per_group=4)
1178
+ elif self.temporal_encoder_mode=='gttm4':
1179
  self.temporal_encoder = GroupedTokenTuringMachine4(dim=dim, process_size=128, memory_size_per_group=4, output_size=32)
1180
+ elif self.temporal_encoder_mode=='tokenlearner':
1181
+ self.temporal_encoder = TokenLearner(dim=dim, output_size=32)
1182
+ elif self.temporal_encoder_mode=='gttm7':
1183
+ self.temporal_encoder = GroupedTokenTuringMachine7(dim=dim, memory_size_per_group=4, output_size=32)
1184
+
1185
 
1186
  def forward(self, x, vision_attn_masks):
1187
  """
 
1212
  latents = attn(x, latents, vision_attn_masks) + latents
1213
  latents = ff(latents) + latents
1214
 
1215
+ if self.temporal_encoder_mode is not None:
1216
  latents = self.temporal_encoder(latents)
1217
 
1218
  if exists(self.projection):