README.md CHANGED
@@ -67,15 +67,14 @@ import torch
67
 
68
  from huggingface_mae import MAEModel
69
 
70
- huggingface_openphenom_model_dir = "."
71
- # huggingface_modelpath = "recursionpharma/OpenPhenom"
72
 
73
 
74
  @pytest.fixture
75
  def huggingface_model():
76
- # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/OpenPhenom to this directory
77
- # huggingface-cli download recursionpharma/OpenPhenom --local-dir=.
78
- huggingface_model = MAEModel.from_pretrained(huggingface_openphenom_model_dir)
79
  huggingface_model.eval()
80
  return huggingface_model
81
 
 
67
 
68
  from huggingface_mae import MAEModel
69
 
70
+ # huggingface_openphenom_model_dir = "."
71
+ huggingface_modelpath = "recursionpharma/OpenPhenom"
72
 
73
 
74
  @pytest.fixture
75
  def huggingface_model():
76
+ # This step downloads the model to a local cache, takes a bit to run
77
+ huggingface_model = MAEModel.from_pretrained(huggingface_modelpath)
 
78
  huggingface_model.eval()
79
  return huggingface_model
80
 
__init__.py ADDED
File without changes
generate_reconstructions.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
huggingface_mae.py CHANGED
@@ -4,12 +4,13 @@ import torch
4
  import torch.nn as nn
5
 
6
  from transformers import PretrainedConfig, PreTrainedModel
 
7
 
8
- from loss import FourierLoss
9
- from normalizer import Normalizer
10
- from mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder
11
- from mae_utils import flatten_images
12
- from vit import (
13
  generate_2d_sincos_pos_embeddings,
14
  sincos_positional_encoding_vit,
15
  vit_small_patch16_256,
@@ -285,8 +286,8 @@ class MAEModel(PreTrainedModel):
285
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
286
  filename = kwargs.pop("filename", "model.safetensors")
287
 
288
- modelpath = f"{pretrained_model_name_or_path}/{filename}"
289
  config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
 
290
  state_dict = torch.load(modelpath, map_location="cpu")
291
  model = cls(config)
292
  model.load_state_dict(state_dict["state_dict"])
 
4
  import torch.nn as nn
5
 
6
  from transformers import PretrainedConfig, PreTrainedModel
7
+ from transformers.utils import cached_file
8
 
9
+ from .loss import FourierLoss
10
+ from .normalizer import Normalizer
11
+ from .mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder
12
+ from .mae_utils import flatten_images
13
+ from .vit import (
14
  generate_2d_sincos_pos_embeddings,
15
  sincos_positional_encoding_vit,
16
  vit_small_patch16_256,
 
286
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
287
  filename = kwargs.pop("filename", "model.safetensors")
288
 
 
289
  config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
290
+ modelpath = cached_file(pretrained_model_name_or_path, filename=filename)
291
  state_dict = torch.load(modelpath, map_location="cpu")
292
  model = cls(config)
293
  model.load_state_dict(state_dict["state_dict"])
mae_modules.py CHANGED
@@ -7,8 +7,8 @@ import torch.nn as nn
7
  from timm.models.helpers import checkpoint_seq
8
  from timm.models.vision_transformer import Block, Mlp, VisionTransformer
9
 
10
- from masking import transformer_random_masking
11
- from vit import channel_agnostic_vit
12
 
13
  # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
14
  # leverage the flattening and unflattening utilities as needed from mae_utils.py.
 
7
  from timm.models.helpers import checkpoint_seq
8
  from timm.models.vision_transformer import Block, Mlp, VisionTransformer
9
 
10
+ from .masking import transformer_random_masking
11
+ from .vit import channel_agnostic_vit
12
 
13
  # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
14
  # leverage the flattening and unflattening utilities as needed from mae_utils.py.
test_huggingface_mae.py CHANGED
@@ -1,17 +1,16 @@
1
  import pytest
2
  import torch
3
 
4
- from huggingface_mae import MAEModel
 
5
 
6
- huggingface_openphenom_model_dir = "."
7
- # huggingface_modelpath = "recursionpharma/test-pb-model"
8
 
9
 
10
  @pytest.fixture
11
  def huggingface_model():
12
- # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
13
- # huggingface-cli download recursionpharma/test-pb-model --local-dir=.
14
- huggingface_model = MAEModel.from_pretrained(huggingface_openphenom_model_dir)
15
  huggingface_model.eval()
16
  return huggingface_model
17
 
 
1
  import pytest
2
  import torch
3
 
4
+ # huggingface_openphenom_model_dir = "."
5
+ huggingface_modelpath = "recursionpharma/OpenPhenom"
6
 
7
+ from .huggingface_mae import MAEModel
 
8
 
9
 
10
  @pytest.fixture
11
  def huggingface_model():
12
+ # This step downloads the model to a local cache, takes a bit to run
13
+ huggingface_model = MAEModel.from_pretrained(huggingface_modelpath)
 
14
  huggingface_model.eval()
15
  return huggingface_model
16