vijul.shah commited on
Commit
51ba5d6
·
1 Parent(s): dc32a0b

Added models and supporting files

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
config.yml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 42
2
+
3
+ feature_extraction_configs:
4
+ blink_detection: true
5
+ extraction_library: "mediapipe"
6
+ show_features: ['full_imgs', 'faces', 'eyes', 'blinks', 'iris']
7
+
8
+ model_configs:
9
+ models_path: "pre_trained_models"
10
+ registered_model_names: ["ResNet18", "ResNet50"]
11
+ labels: ["left_eye", "right_eye"]
12
+ targets: ["left_pupil", "right_pupil"]
13
+ num_classes: 1
14
+
15
+ xai_configs:
16
+ attribution_methods: [
17
+ "IntegratedGradients",
18
+ "Saliency",
19
+ "InputXGradient",
20
+ "GuidedBackprop",
21
+ "Deconvolution",
22
+ "GuidedGradCam",
23
+ "LayerGradCam",
24
+ "LayerGradientXActivation",
25
+ ]
26
+ cam_methods: [
27
+ "CAM",
28
+ "GradCAM",
29
+ "GradCAMpp",
30
+ "SmoothGradCAMpp",
31
+ "ScoreCAM",
32
+ "SSCAM",
33
+ "ISCAM",
34
+ "XGradCAM",
35
+ "LayerCAM",
36
+ ]
37
+
38
+ use_sr: false
39
+
40
+ upscale_configs:
41
+ upscale: [1, 2, 3, 4]
42
+ upscale_method_configs:
43
+ size: [16, 32]
44
+ antialias: true
45
+ interpolation: ["bicubic"]
46
+
47
+ sr_methods: ["GFPGAN", "RealESRGAN", "SRResNet", "CodeFormer", "HAT"]
48
+ sr_method_configs:
49
+ bg_upsampler_name: "realesrgan"
50
+ prefered_net_in_upsampler: "RRDBNet"
packages.txt ADDED
File without changes
pre_trained_models/ResNet18/right_eye.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68e2928f13900580bcb9b7c1a1f6d4bba863cfcfee2def944b49ef0c09337668
3
+ size 46843194
pre_trained_models/ResNet50/left_eye.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bd4bac728b71dae9e759b86188206a4f38fbc83b9507dd08f2a6abe1568d995
3
+ size 102554624
pre_trained_models/ResNet50/right_eye.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5179f569ea1886c9ad63ca9d047fdf721a9b59a63313cd9da3f2e3fae25de73
3
+ size 102554624
registrations/models.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch.nn as nn
3
+ import os.path as osp
4
+ from torchvision import models
5
+ import torch.nn.functional as F
6
+ from registry import MODEL_REGISTRY
7
+
8
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
9
+ sys.path.append(root_path)
10
+
11
+ # ============================= ResNets =============================
12
+
13
+
14
+ # @MODEL_REGISTRY.register()
15
+ # class ResNet18(nn.Module):
16
+ # def __init__(self, model_args):
17
+ # super(ResNet18, self).__init__()
18
+ # self.num_classes = model_args.get("num_classes", 1)
19
+ # self.resnet = models.resnet18(weights=None, num_classes=self.num_classes)
20
+
21
+ # def forward(self, x, masks=None):
22
+ # return self.resnet(x)
23
+
24
+
25
+ # @MODEL_REGISTRY.register()
26
+ # class ResNet18(nn.Module):
27
+ # def __init__(self, model_args):
28
+ # super(ResNet18, self).__init__()
29
+ # self.num_classes = model_args.get("num_classes", 1)
30
+ # self.resnet = models.resnet18(weights=None, num_classes=self.num_classes)
31
+
32
+ # def forward(self, x, masks=None):
33
+ # # Calculate the padding dynamically based on the input size
34
+ # height, width = x.shape[2], x.shape[3]
35
+ # pad_height = max(0, (224 - height) // 2)
36
+ # pad_width = max(0, (224 - width) // 2)
37
+
38
+ # # Apply padding
39
+ # x = F.pad(
40
+ # x, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0
41
+ # )
42
+ # x = self.resnet(x)
43
+ # return x
44
+
45
+
46
+ @MODEL_REGISTRY.register()
47
+ class ResNet18(nn.Module):
48
+ def __init__(self, model_args):
49
+ super(ResNet18, self).__init__()
50
+ self.num_classes = model_args.get("num_classes", 1)
51
+ self.resnet = models.resnet18(weights=None)
52
+ self.regression_head = nn.Linear(1000, self.num_classes)
53
+
54
+ def forward(self, x, masks=None):
55
+ # Calculate the padding dynamically based on the input size
56
+ height, width = x.shape[2], x.shape[3]
57
+ pad_height = max(0, (224 - height) // 2)
58
+ pad_width = max(0, (224 - width) // 2)
59
+
60
+ # Apply padding
61
+ x = F.pad(
62
+ x, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0
63
+ )
64
+ x = self.resnet(x)
65
+ x = self.regression_head(x)
66
+ return x
67
+
68
+
69
+ # @MODEL_REGISTRY.register()
70
+ # class ResNet50(nn.Module):
71
+ # def __init__(self, model_args):
72
+ # super(ResNet50, self).__init__()
73
+ # self.num_classes = model_args.get("num_classes", 1)
74
+ # self.resnet = models.resnet50(weights=None, num_classes=self.num_classes)
75
+
76
+ # def forward(self, x, masks=None):
77
+ # return self.resnet(x)
78
+
79
+
80
+ # @MODEL_REGISTRY.register()
81
+ # class ResNet50(nn.Module):
82
+ # def __init__(self, model_args):
83
+ # super(ResNet50, self).__init__()
84
+ # self.num_classes = model_args.get("num_classes", 1)
85
+ # self.resnet = models.resnet50(weights=None, num_classes=self.num_classes)
86
+
87
+ # def forward(self, x, masks=None):
88
+ # # Calculate the padding dynamically based on the input size
89
+ # height, width = x.shape[2], x.shape[3]
90
+ # pad_height = max(0, (224 - height) // 2)
91
+ # pad_width = max(0, (224 - width) // 2)
92
+
93
+ # # Apply padding
94
+ # x = F.pad(
95
+ # x, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0
96
+ # )
97
+ # x = self.resnet(x)
98
+ # return x
99
+
100
+
101
+ @MODEL_REGISTRY.register()
102
+ class ResNet50(nn.Module):
103
+ def __init__(self, model_args):
104
+ super(ResNet50, self).__init__()
105
+ self.num_classes = model_args.get("num_classes", 1)
106
+ self.resnet = models.resnet50(weights=None)
107
+ self.regression_head = nn.Linear(1000, self.num_classes)
108
+
109
+ def forward(self, x, masks=None):
110
+ # Calculate the padding dynamically based on the input size
111
+ height, width = x.shape[2], x.shape[3]
112
+ pad_height = max(0, (224 - height) // 2)
113
+ pad_width = max(0, (224 - width) // 2)
114
+
115
+ # Apply padding
116
+ x = F.pad(
117
+ x, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0
118
+ )
119
+ x = self.resnet(x)
120
+ x = self.regression_head(x)
121
+ return x
122
+
123
+
124
+ print("Registered models in MODEL_REGISTRY:", MODEL_REGISTRY.keys())
registry.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
2
+
3
+
4
+ class Registry:
5
+ """
6
+ The registry that provides name -> object mapping, to support third-party
7
+ users' custom modules.
8
+
9
+ To create a registry (e.g. a backbone registry):
10
+
11
+ .. code-block:: python
12
+
13
+ BACKBONE_REGISTRY = Registry('BACKBONE')
14
+
15
+ To register an object:
16
+
17
+ .. code-block:: python
18
+
19
+ @BACKBONE_REGISTRY.register()
20
+ class MyBackbone():
21
+ ...
22
+
23
+ Or:
24
+
25
+ .. code-block:: python
26
+
27
+ BACKBONE_REGISTRY.register(MyBackbone)
28
+ """
29
+
30
+ def __init__(self, name):
31
+ """
32
+ Args:
33
+ name (str): the name of this registry
34
+ """
35
+ self._name = name
36
+ self._obj_map = {}
37
+
38
+ def _do_register(self, name, obj):
39
+ assert name not in self._obj_map, (
40
+ f"An object named '{name}' was already registered "
41
+ f"in '{self._name}' registry!"
42
+ )
43
+ self._obj_map[name] = obj
44
+
45
+ def register(self, obj=None):
46
+ """
47
+ Register the given object under the the name `obj.__name__`.
48
+ Can be used as either a decorator or not.
49
+ See docstring of this class for usage.
50
+ """
51
+ if obj is None:
52
+ # used as a decorator
53
+ def deco(func_or_class):
54
+ name = func_or_class.__name__
55
+ self._do_register(name, func_or_class)
56
+ return func_or_class
57
+
58
+ return deco
59
+
60
+ # used as a function call
61
+ name = obj.__name__
62
+ self._do_register(name, obj)
63
+
64
+ def get(self, name):
65
+ ret = self._obj_map.get(name)
66
+ if ret is None:
67
+ raise KeyError(
68
+ f"No object named '{name}' found in '{self._name}' registry!"
69
+ )
70
+ return ret
71
+
72
+ def __contains__(self, name):
73
+ return name in self._obj_map
74
+
75
+ def __iter__(self):
76
+ return iter(self._obj_map.items())
77
+
78
+ def keys(self):
79
+ return self._obj_map.keys()
80
+
81
+
82
+ MODEL_REGISTRY = Registry("model")
registry_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+ from os import path as osp
4
+
5
+
6
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
7
+ """Scan a directory to find the interested files.
8
+
9
+ Args:
10
+ dir_path (str): Path of the directory.
11
+ suffix (str | tuple(str), optional): File suffix that we are
12
+ interested in. Default: None.
13
+ recursive (bool, optional): If set to True, recursively scan the
14
+ directory. Default: False.
15
+ full_path (bool, optional): If set to True, include the dir_path.
16
+ Default: False.
17
+
18
+ Returns:
19
+ A generator for all the interested files with relative paths.
20
+ """
21
+
22
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
23
+ raise TypeError('"suffix" must be a string or tuple of strings')
24
+
25
+ root = dir_path
26
+
27
+ def _scandir(dir_path, suffix, recursive):
28
+ for entry in os.scandir(dir_path):
29
+ if not entry.name.startswith(".") and entry.is_file():
30
+ if full_path:
31
+ return_path = entry.path
32
+ else:
33
+ return_path = osp.relpath(entry.path, root)
34
+
35
+ if suffix is None:
36
+ yield return_path
37
+ elif return_path.endswith(suffix):
38
+ yield return_path
39
+ else:
40
+ if recursive:
41
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
42
+ else:
43
+ continue
44
+
45
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
46
+
47
+
48
+ def import_registered_modules(registration_folder="registrations"):
49
+ """
50
+ Import all registered modules from the specified folder.
51
+
52
+ This function automatically scans all the files under the specified folder and imports all the required modules for registry.
53
+
54
+ Parameters:
55
+ registration_folder (str, optional): Path to the folder containing registration modules. Default is "registrations".
56
+
57
+ Returns:
58
+ list: List of imported modules.
59
+ """
60
+
61
+ print("\n")
62
+
63
+ registration_modules_folder = (
64
+ osp.dirname(osp.abspath(__file__)) + f"/{registration_folder}"
65
+ )
66
+ print("registration_modules_folder = ", registration_modules_folder)
67
+
68
+ registration_modules_file_names = [
69
+ osp.splitext(osp.basename(v))[0]
70
+ for v in scandir(dir_path=registration_modules_folder)
71
+ ]
72
+ print("registration_modules_file_names = ", registration_modules_file_names)
73
+
74
+ imported_modules = [
75
+ importlib.import_module(f"{registration_folder}.{file_name}")
76
+ for file_name in registration_modules_file_names
77
+ ]
78
+ print("imported_modules = ", imported_modules)
79
+ print("\n")
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ PyYAML
3
+ numpy
4
+ pandas
5
+ matplotlib
6
+ seaborn
7
+ mlflow
8
+ pillow
9
+ scikit_learn
10
+ torch
11
+ captum
12
+ evaluate
13
+ # basicsr
14
+ facexlib
15
+ realesrgan
16
+ opencv_python
17
+ cmake
18
+ dlib
19
+ einops
20
+ transformers
21
+ # gfpgan
22
+ # streamlit
23
+ mediapipe
24
+ imutils
25
+ scipy
26
+ torchvision==0.16.0
27
+ torchcam
utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from registry import MODEL_REGISTRY
2
+
3
+
4
+ def get_model(model_configs):
5
+ registered_model = MODEL_REGISTRY.get(model_configs["registered_model_name"])
6
+ model_configs.pop("registered_model_name")
7
+ if len(model_configs) > 0:
8
+ model = registered_model(model_configs)
9
+ else:
10
+ model = registered_model()
11
+ return model