Files changed (1) hide show
  1. hf_model.py +0 -128
hf_model.py CHANGED
@@ -295,131 +295,3 @@ class HFTextEncoder(nn.Module):
295
  def init_parameters(self):
296
  pass
297
 
298
-
299
- """
300
- HF vision model
301
- """
302
-
303
-
304
- class HFVisionEncoder(nn.Module):
305
- output_tokens: torch.jit.Final[bool]
306
-
307
- def __init__(
308
- self,
309
- model_name_or_path: str,
310
- image_size: int,
311
- output_dim: int,
312
- config: PretrainedConfig = None,
313
- pool_type: str = 'tok',
314
- proj_type: Optional[str] = None,
315
- proj_bias: bool = False,
316
- attn_drop: float = 0.0,
317
- hidden_drop: float = 0.0,
318
- drop_path: Optional[float] = None,
319
- pretrained: bool = True,
320
- output_tokens: bool = False,
321
- trust_remote_code: bool = False,
322
- ):
323
- super().__init__()
324
- self.output_tokens = output_tokens
325
- self.output_dim = output_dim
326
- self.image_size = (image_size, image_size)
327
-
328
- if config is None:
329
- self.config = AutoConfig.from_pretrained(
330
- model_name_or_path,
331
- trust_remote_code=trust_remote_code,
332
- hidden_dropout_prob=hidden_drop,
333
- attention_probs_dropout_prob=attn_drop,
334
- drop_path_rate=drop_path,
335
- )
336
- create_func, model_args = (
337
- (AutoModel.from_pretrained, model_name_or_path)
338
- if pretrained
339
- else (AutoModel.from_config, self.config)
340
- )
341
- self.transformer = create_func(
342
- model_args,
343
- trust_remote_code=trust_remote_code,
344
- hidden_dropout_prob=hidden_drop,
345
- attention_probs_dropout_prob=attn_drop,
346
- )
347
- else:
348
- self.config = config
349
- self.transformer = AutoModel.from_config(config)
350
-
351
- if 'dinov2' in model_name_or_path:
352
- self.transformer.embeddings.mask_token.requires_grad = False
353
-
354
- assert pool_type in ('tok', 'avg', 'none')
355
- self.pool_type = pool_type
356
-
357
- d_model = self.config.hidden_size
358
- if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
359
- self.proj = nn.Identity()
360
- elif proj_type == 'linear':
361
- self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
362
- elif proj_type == 'mlp':
363
- hidden_size = (d_model + output_dim) // 2
364
- self.proj = nn.Sequential(
365
- nn.Linear(d_model, hidden_size, bias=proj_bias),
366
- nn.GELU(),
367
- nn.Linear(hidden_size, output_dim, bias=proj_bias),
368
- )
369
-
370
- def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
371
- if self.pool_type == 'avg':
372
- pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
373
- elif self.pool_type == 'tok':
374
- pooled, tokens = x[:, 0], x[:, 1:]
375
- else:
376
- pooled = tokens = x
377
-
378
- return pooled, tokens
379
-
380
- def forward(self, x: torch.Tensor):
381
- # returns a tuple of (final hidden states, token pooled outputs)
382
- x = self.transformer(x)[0]
383
- pooled, tokens = self._global_pool(x)
384
- projected = self.proj(pooled)
385
-
386
- return projected
387
-
388
- def lock(self, unlocked_layers: int = 0, freeze_bn_stats: bool = True):
389
- if not unlocked_layers: # full freezing
390
- for n, p in self.transformer.named_parameters():
391
- p.requires_grad = (
392
- (not freeze_bn_stats) if 'LayerNorm' in n.split('.') else False
393
- )
394
- return
395
-
396
- # TODO: make it work if unlocked_layers !=0
397
- encoder = (
398
- self.transformer.encoder
399
- if hasattr(self.transformer, 'encoder')
400
- else self.transformer
401
- )
402
- layer_list = getattr(
403
- encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
404
- )
405
- print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
406
- embeddings = getattr(
407
- self.transformer,
408
- _HF_ARCH_DICT[self.config.model_type]['config_names'][
409
- 'token_embeddings_attr'
410
- ],
411
- )
412
- modules = [embeddings, *layer_list][:-unlocked_layers]
413
- # freeze layers
414
- for module in modules:
415
- for n, p in module.named_parameters():
416
- p.requires_grad = (
417
- (not freeze_bn_stats) if 'LayerNorm' in n.split('.') else False
418
- )
419
-
420
- @torch.jit.ignore
421
- def set_grad_checkpointing(self, *_, **__):
422
- self.transformer.gradient_checkpointing_enable()
423
-
424
- def init_parameters(self):
425
- pass
 
295
  def init_parameters(self):
296
  pass
297