nick_93 commited on
Commit
bcec54e
·
1 Parent(s): 2fb7fdb
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ *.ckpt
7
+ *.pth
8
+ refer/refer/data/
9
+ depth/kitti_dataset/
10
+ depth/nyu_depth_v2/
11
+
12
+ # C extensions
13
+ .so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/#use-with-ide
116
+ .pdm.toml
117
+
118
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119
+ __pypackages__/
120
+
121
+ # Celery stuff
122
+ celerybeat-schedule
123
+ celerybeat.pid
124
+
125
+ # SageMath parsed files
126
+ *.sage.py
127
+
128
+ # Environments
129
+ .env
130
+ .venv
131
+ env/
132
+ venv/
133
+ ENV/
134
+ env.bak/
135
+ venv.bak/
136
+
137
+ # Spyder project settings
138
+ .spyderproject
139
+ .spyproject
140
+
141
+ # Rope project settings
142
+ .ropeproject
143
+
144
+ # mkdocs documentation
145
+ /site
146
+
147
+ # mypy
148
+ .mypy_cache/
149
+ .dmypy.json
150
+ dmypy.json
151
+
152
+ # Pyre type checker
153
+ .pyre/
154
+
155
+ # pytype static type analyzer
156
+ .pytype/
157
+
158
+ # Cython debug symbols
159
+ cython_debug/
160
+
161
+ # PyCharm
162
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
165
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Mykola Lavreniuk
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ depth_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), 'depth'))
5
+ sys.path.append(depth_directory)
6
+ os.chdir(depth_directory)
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import torch.backends.cudnn as cudnn
12
+ from depth.models_depth.model import EVPDepth
13
+ from depth.configs.train_options import TrainOptions
14
+ from depth.configs.test_options import TestOptions
15
+ import glob
16
+ import utils
17
+ import torchvision.transforms as transforms
18
+ from utils_depth.misc import colorize
19
+ from PIL import Image
20
+ import torch.nn.functional as F
21
+ import gradio as gr
22
+ import tempfile
23
+
24
+
25
+ css = """
26
+ #img-display-container {
27
+ max-height: 50vh;
28
+ }
29
+ #img-display-input {
30
+ max-height: 40vh;
31
+ }
32
+ #img-display-output {
33
+ max-height: 40vh;
34
+ }
35
+
36
+ """
37
+
38
+ def create_demo(model, device):
39
+ gr.Markdown("### Depth Prediction demo")
40
+ with gr.Row():
41
+ input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
42
+ depth_image = gr.Image(label="Depth Map", elem_id='img-display-output')
43
+ raw_file = gr.File(label="16-bit raw depth, multiplier:256")
44
+ submit = gr.Button("Submit")
45
+
46
+ def on_submit(image):
47
+ transform = transforms.ToTensor()
48
+ image = transform(image).unsqueeze(0).to(device)
49
+ shape = image.shape
50
+ image = torch.nn.functional.interpolate(image, (440,480), mode='bilinear', align_corners=True)
51
+ image = F.pad(image, (0, 0, 40, 0))
52
+ with torch.no_grad():
53
+ pred = model(image)['pred_d']
54
+
55
+ pred = pred[:,:,40:,:]
56
+ pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
57
+ pred_d_numpy = pred.squeeze().cpu().numpy()
58
+ colored_depth, _, _ = colorize(pred_d_numpy, cmap='gray_r')
59
+
60
+ tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
61
+ raw_depth = Image.fromarray((pred_d_numpy*256).astype('uint16'))
62
+ raw_depth.save(tmp.name)
63
+ return [colored_depth, tmp.name]
64
+
65
+ submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file])
66
+ examples = gr.Examples(examples=["test_img.jpg"],
67
+ inputs=[input_image])
68
+
69
+
70
+ def main():
71
+ opt = TestOptions().initialize()
72
+ opt.add_argument('--img_path', type=str)
73
+ args = opt.parse_args()
74
+
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ model = EVPDepth(args=args, caption_aggregation=True)
77
+ cudnn.benchmark = True
78
+ model.to(device)
79
+ model_weight = torch.load(args.ckpt_dir)['model']
80
+ if 'module' in next(iter(model_weight.items()))[0]:
81
+ model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
82
+ model.load_state_dict(model_weight, strict=False)
83
+ model.eval()
84
+
85
+ title = "# EVP"
86
+ description = """Official demo for **EVP: Enhanced Visual Perception using Inverse Multi-Attentive Feature
87
+ Refinement and Regularized Image-Text Alignment**.
88
+ EVP is a deep learning model for metric depth estimation from a single image.
89
+ Please refer to our [paper](https://arxiv.org/abs/2312.08548) or [github](https://github.com/Lavreniuk/EVP) for more details."""
90
+
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown(title)
93
+ gr.Markdown(description)
94
+ with gr.Tab("Depth Prediction"):
95
+ create_demo(model, device)
96
+ gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/shariqfarooq/ZoeDepth?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br>
97
+ <p><img src="https://visitor-badge.glitch.me/badge?page_id=shariqfarooq.zoedepth_demo_hf" alt="visitors"></p></center>''')
98
+
99
+ demo.queue().launch(share=True)
100
+
101
+
102
+ if __name__ == '__main__':
103
+ main()
depth/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Depth Estimation
2
+ ## Getting Started
3
+
4
+ 1. Install the [mmcv-full](https://github.com/open-mmlab/mmcv) library and some required packages.
5
+
6
+ ```bash
7
+ pip install openmim
8
+ mim install mmcv-full
9
+ pip install -r requirements.txt
10
+ ```
11
+
12
+ 2. Prepare NYUDepthV2 datasets following [GLPDepth](https://github.com/vinvino02/GLPDepth) and [BTS](https://github.com/cleinc/bts/tree/master).
13
+
14
+ ```
15
+ mkdir nyu_depth_v2
16
+ wget http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat
17
+ python extract_official_train_test_set_from_mat.py nyu_depth_v2_labeled.mat splits.mat ./nyu_depth_v2/official_splits/
18
+ ```
19
+
20
+ Download sync.zip provided by the authors of BTS from this [url](https://drive.google.com/file/d/1AysroWpfISmm-yRFGBgFTrLy6FjQwvwP/view) and unzip in `./nyu_depth_v2` folder.
21
+
22
+ Your dataset directory should be:
23
+
24
+ ```
25
+ │nyu_depth_v2/
26
+ ├──official_splits/
27
+ │ ├── test
28
+ │ ├── train
29
+ ├──sync/
30
+ ```
31
+
32
+ ## Results and Fine-tuned Models
33
+
34
+ EVP obtains 0.224 RMSE on NYUv2 depth estimation benchmark, establishing the new state-of-the-art.
35
+
36
+ | | RMSE | d1 | d2 | d3 | REL | log_10 |
37
+ |---------|-------|-------|--------|------|-------|-------|
38
+ | **EVP** | 0.224 | 0.976 | 0.997 | 0.999 | 0.061 | 0.027 |
39
+
40
+ EVP obtains 0.048 REL and 0.136 SqREL on KITTI depth estimation benchmark, establishing the new state-of-the-art.
41
+
42
+ | | REL | SqREL | RMSE | RMSE log | d1 | d2 | d3 |
43
+ |---------|-------|-------|--------|------|-------|-------|-------|
44
+ | **EVP** | 0.048 | 0.136 | 2.015 | 0.073 | 0.980 | 0.998 | 1.000 |
45
+
46
+ ## Training
47
+
48
+ Run the following instuction to train the EVP-Depth model.
49
+
50
+ ```
51
+ bash train.sh <LOG_DIR>
52
+ ```
53
+
54
+ ## Evaluation
55
+ Command format:
56
+ ```
57
+ bash test.sh <CHECKPOINT_PATH>
58
+ ```
59
+
60
+ ## Custom inference
61
+ ```
62
+ PYTHONPATH="../":$PYTHONPATH python inference.py --img_path test_img.jpg --ckpt_dir nyu.ckpt
63
+ ```
depth/configs/base_options.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
3
+ # For non-commercial purpose only (research, evaluation etc).
4
+ # Modified by Zigang Geng ([email protected]).
5
+ # ------------------------------------------------------------------------------
6
+
7
+ import argparse
8
+
9
+
10
+ def str2bool(v):
11
+ if isinstance(v, bool):
12
+ return v
13
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
14
+ return True
15
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
16
+ return False
17
+ else:
18
+ raise argparse.ArgumentTypeError('Boolean value expected.')
19
+
20
+
21
+ class BaseOptions():
22
+ def __init__(self):
23
+ pass
24
+
25
+ def initialize(self):
26
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
27
+ # base configs
28
+ parser.add_argument('--resume_from', type=str, default='')
29
+ parser.add_argument('--exp_name', type=str, default='')
30
+ parser.add_argument('--gpu_or_cpu', type=str, default='gpu')
31
+ parser.add_argument('--data_path', type=str, default='/data/ssd1/')
32
+ parser.add_argument('--dataset', type=str, default='nyudepthv2',
33
+ choices=['nyudepthv2', 'kitti', 'imagepath'])
34
+ parser.add_argument('--batch_size', type=int, default=8)
35
+ parser.add_argument('--workers', type=int, default=8)
36
+
37
+ # depth configs
38
+ parser.add_argument('--max_depth', type=float, default=10.0)
39
+ parser.add_argument('--max_depth_eval', type=float, default=10.0)
40
+ parser.add_argument('--min_depth_eval', type=float, default=1e-3)
41
+ parser.add_argument('--do_kb_crop', type=int, default=1)
42
+ parser.add_argument('--kitti_crop', type=str, default=None,
43
+ choices=['garg_crop', 'eigen_crop'])
44
+
45
+ parser.add_argument('--pretrained', type=str, default='')
46
+ parser.add_argument('--drop_path_rate', type=float, default=0.3)
47
+ parser.add_argument('--use_checkpoint', type=str2bool, default='False')
48
+ parser.add_argument('--num_deconv', type=int, default=3)
49
+ parser.add_argument('--num_filters', nargs='+', type=int, default=[32,32,32])
50
+ parser.add_argument('--deconv_kernels', nargs='+', type=int, default=[2,2,2])
51
+
52
+ parser.add_argument('--shift_window_test', action='store_true')
53
+ parser.add_argument('--shift_size', type=int, default=2)
54
+ parser.add_argument('--flip_test', action='store_true')
55
+
56
+ return parser
depth/configs/test_options.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
3
+ # For non-commercial purpose only (research, evaluation etc).
4
+ # ------------------------------------------------------------------------------
5
+
6
+ from configs.base_options import BaseOptions
7
+
8
+ class TestOptions(BaseOptions):
9
+ def initialize(self):
10
+ parser = BaseOptions.initialize(self)
11
+
12
+ # experiment configs
13
+ parser.add_argument('--ckpt_dir', type=str,
14
+ default='./ckpt/best_model_nyu.ckpt',
15
+ help='load ckpt path')
16
+ parser.add_argument('--result_dir', type=str, default='./results',
17
+ help='save result images into result_dir/exp_name')
18
+ parser.add_argument('--crop_h', type=int, default=448)
19
+ parser.add_argument('--crop_w', type=int, default=576)
20
+
21
+ parser.add_argument('--save_eval_pngs', action='store_true',
22
+ help='save result image into evaluation form')
23
+ parser.add_argument('--save_visualize', action='store_true',
24
+ help='save result image into visulized form')
25
+ return parser
26
+
27
+
depth/configs/train_options.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
3
+ # For non-commercial purpose only (research, evaluation etc).
4
+ # Modified by Zigang Geng ([email protected]).
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from configs.base_options import BaseOptions
8
+ import argparse
9
+
10
+
11
+ def str2bool(v):
12
+ if isinstance(v, bool):
13
+ return v
14
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
15
+ return True
16
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
17
+ return False
18
+ else:
19
+ raise argparse.ArgumentTypeError('Boolean value expected.')
20
+
21
+
22
+ class TrainOptions(BaseOptions):
23
+ def initialize(self):
24
+ parser = BaseOptions.initialize(self)
25
+
26
+ # experiment configs
27
+ parser.add_argument('--epochs', type=int, default=25)
28
+ parser.add_argument('--max_lr', type=float, default=5e-4)
29
+ parser.add_argument('--min_lr', type=float, default=3e-5)
30
+ parser.add_argument('--weight_decay', type=float, default=5e-2)
31
+ parser.add_argument('--layer_decay', type=float, default=0.9)
32
+
33
+ parser.add_argument('--crop_h', type=int, default=448)
34
+ parser.add_argument('--crop_w', type=int, default=576)
35
+ parser.add_argument('--log_dir', type=str, default='./logs')
36
+
37
+ # logging options
38
+ parser.add_argument('--val_freq', type=int, default=1)
39
+ parser.add_argument('--pro_bar', type=str2bool, default='False')
40
+ parser.add_argument('--save_freq', type=int, default=1)
41
+ parser.add_argument('--print_freq', type=int, default=100)
42
+ parser.add_argument('--save_model', action='store_true')
43
+ parser.add_argument(
44
+ '--resume-from', help='the checkpoint file to resume from')
45
+ parser.add_argument('--auto_resume', action='store_true')
46
+ parser.add_argument('--save_result', action='store_true')
47
+
48
+
49
+
50
+ return parser
depth/inference.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.backends.cudnn as cudnn
6
+ from models_depth.model import EVPDepth
7
+ from configs.train_options import TrainOptions
8
+ from configs.test_options import TestOptions
9
+ import glob
10
+ import utils
11
+ import torchvision.transforms as transforms
12
+ from utils_depth.misc import colorize
13
+ from PIL import Image
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def main():
18
+ opt = TestOptions().initialize()
19
+ opt.add_argument('--img_path', type=str)
20
+ args = opt.parse_args()
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model = EVPDepth(args=args, caption_aggregation=True)
24
+ cudnn.benchmark = True
25
+ model.to(device)
26
+ model_weight = torch.load(args.ckpt_dir)['model']
27
+ if 'module' in next(iter(model_weight.items()))[0]:
28
+ model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
29
+ model.load_state_dict(model_weight, strict=False)
30
+ model.eval()
31
+
32
+ img_path = args.img_path
33
+ image = cv2.imread(img_path)
34
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35
+ transform = transforms.ToTensor()
36
+ image = transform(image).unsqueeze(0).to(device)
37
+ shape = image.shape
38
+ image = torch.nn.functional.interpolate(image, (440,480), mode='bilinear', align_corners=True)
39
+ image = F.pad(image, (0, 0, 40, 0))
40
+
41
+ with torch.no_grad():
42
+ pred = model(image)['pred_d']
43
+
44
+ pred = pred[:,:,40:,:]
45
+ pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
46
+ pred_d_numpy = pred.squeeze().cpu().numpy()
47
+ pred_d_color, _, _ = colorize(pred_d_numpy, cmap='gray_r')
48
+ Image.fromarray(pred_d_color).save('res.png')
49
+
50
+ return 0
51
+
52
+ if __name__ == '__main__':
53
+ main()
depth/models_depth/attractor.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ @torch.jit.script
30
+ def exp_attractor(dx, alpha: float = 300, gamma: int = 2):
31
+ """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor
32
+
33
+ Args:
34
+ dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
35
+ alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
36
+ gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
37
+
38
+ Returns:
39
+ torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc
40
+ """
41
+ return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx)
42
+
43
+
44
+ @torch.jit.script
45
+ def inv_attractor(dx, alpha: float = 300, gamma: int = 2):
46
+ """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center
47
+ This is the default one according to the accompanying paper.
48
+
49
+ Args:
50
+ dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
51
+ alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
52
+ gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
53
+
54
+ Returns:
55
+ torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc
56
+ """
57
+ return dx.div(1+alpha*dx.pow(gamma))
58
+
59
+
60
+ class AttractorLayer(nn.Module):
61
+ def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
62
+ alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
63
+ """
64
+ Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth)
65
+ """
66
+ super().__init__()
67
+
68
+ self.n_attractors = n_attractors
69
+ self.n_bins = n_bins
70
+ self.min_depth = min_depth
71
+ self.max_depth = max_depth
72
+ self.alpha = alpha
73
+ self.gamma = gamma
74
+ self.kind = kind
75
+ self.attractor_type = attractor_type
76
+ self.memory_efficient = memory_efficient
77
+
78
+ self._net = nn.Sequential(
79
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
80
+ nn.ReLU(inplace=True),
81
+ nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm
82
+ nn.ReLU(inplace=True)
83
+ )
84
+
85
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
86
+ """
87
+ Args:
88
+ x (torch.Tensor) : feature block; shape - n, c, h, w
89
+ b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
90
+
91
+ Returns:
92
+ tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w
93
+ """
94
+ if prev_b_embedding is not None:
95
+ if interpolate:
96
+ prev_b_embedding = nn.functional.interpolate(
97
+ prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
98
+ x = x + prev_b_embedding
99
+
100
+ A = self._net(x)
101
+ eps = 1e-3
102
+ A = A + eps
103
+ n, c, h, w = A.shape
104
+ A = A.view(n, self.n_attractors, 2, h, w)
105
+ A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w
106
+ A_normed = A[:, :, 0, ...] # n, na, h, w
107
+
108
+ b_prev = nn.functional.interpolate(
109
+ b_prev, (h, w), mode='bilinear', align_corners=True)
110
+ b_centers = b_prev
111
+
112
+ if self.attractor_type == 'exp':
113
+ dist = exp_attractor
114
+ else:
115
+ dist = inv_attractor
116
+
117
+ if not self.memory_efficient:
118
+ func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
119
+ # .shape N, nbins, h, w
120
+ delta_c = func(dist(A_normed.unsqueeze(
121
+ 2) - b_centers.unsqueeze(1)), dim=1)
122
+ else:
123
+ delta_c = torch.zeros_like(b_centers, device=b_centers.device)
124
+ for i in range(self.n_attractors):
125
+ # .shape N, nbins, h, w
126
+ delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers)
127
+
128
+ if self.kind == 'mean':
129
+ delta_c = delta_c / self.n_attractors
130
+
131
+ b_new_centers = b_centers + delta_c
132
+ B_centers = (self.max_depth - self.min_depth) * \
133
+ b_new_centers + self.min_depth
134
+ B_centers, _ = torch.sort(B_centers, dim=1)
135
+ B_centers = torch.clip(B_centers, self.min_depth, self.max_depth)
136
+ return b_new_centers, B_centers
137
+
138
+
139
+ class AttractorLayerUnnormed(nn.Module):
140
+ def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
141
+ alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
142
+ """
143
+ Attractor layer for bin centers. Bin centers are unbounded
144
+ """
145
+ super().__init__()
146
+
147
+ self.n_attractors = n_attractors
148
+ self.n_bins = n_bins
149
+ self.min_depth = min_depth
150
+ self.max_depth = max_depth
151
+ self.alpha = alpha
152
+ self.gamma = gamma
153
+ self.kind = kind
154
+ self.attractor_type = attractor_type
155
+ self.memory_efficient = memory_efficient
156
+
157
+ self._net = nn.Sequential(
158
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
159
+ nn.ReLU(inplace=True),
160
+ nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0),
161
+ nn.Softplus()
162
+ )
163
+
164
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
165
+ """
166
+ Args:
167
+ x (torch.Tensor) : feature block; shape - n, c, h, w
168
+ b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
169
+
170
+ Returns:
171
+ tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version
172
+ """
173
+ if prev_b_embedding is not None:
174
+ if interpolate:
175
+ prev_b_embedding = nn.functional.interpolate(
176
+ prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
177
+ x = x + prev_b_embedding
178
+
179
+ A = self._net(x)
180
+ n, c, h, w = A.shape
181
+
182
+ b_prev = nn.functional.interpolate(
183
+ b_prev, (h, w), mode='bilinear', align_corners=True)
184
+ b_centers = b_prev
185
+
186
+ if self.attractor_type == 'exp':
187
+ dist = exp_attractor
188
+ else:
189
+ dist = inv_attractor
190
+
191
+ if not self.memory_efficient:
192
+ func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
193
+ # .shape N, nbins, h, w
194
+ delta_c = func(
195
+ dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1)
196
+ else:
197
+ delta_c = torch.zeros_like(b_centers, device=b_centers.device)
198
+ for i in range(self.n_attractors):
199
+ delta_c += dist(A[:, i, ...].unsqueeze(1) -
200
+ b_centers) # .shape N, nbins, h, w
201
+
202
+ if self.kind == 'mean':
203
+ delta_c = delta_c / self.n_attractors
204
+
205
+ b_new_centers = b_centers + delta_c
206
+ B_centers = b_new_centers
207
+
208
+ return b_new_centers, B_centers
depth/models_depth/checkpoint.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # The code is from Swin Transformer.
5
+ # (https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmcv_custom/checkpoint.py)
6
+ # ------------------------------------------------------------------------------
7
+
8
+ import io
9
+ import os
10
+ import os.path as osp
11
+ import pkgutil
12
+ import time
13
+ import warnings
14
+ import numpy as np
15
+ from collections import OrderedDict
16
+ from importlib import import_module
17
+ from tempfile import TemporaryDirectory
18
+ from scipy import interpolate
19
+
20
+ import torch
21
+ import torchvision
22
+ import torch.distributed as dist
23
+ from torch.optim import Optimizer
24
+ from torch.utils import model_zoo
25
+ from torch.nn import functional as F
26
+
27
+ import mmcv
28
+ from mmcv.fileio import FileClient
29
+ from mmcv.fileio import load as load_file
30
+ from mmcv.parallel import is_module_wrapper
31
+ from mmcv.utils import mkdir_or_exist
32
+ from mmcv.runner import get_dist_info
33
+ from mmcv.utils import get_logger
34
+
35
+ import logging
36
+
37
+
38
+ def get_root_logger(log_file=None, log_level=logging.INFO):
39
+ """Get the root logger.
40
+
41
+ The logger will be initialized if it has not been initialized. By default a
42
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
43
+ also be added. The name of the root logger is the top-level package name,
44
+ e.g., "mmseg".
45
+
46
+ Args:
47
+ log_file (str | None): The log filename. If specified, a FileHandler
48
+ will be added to the root logger.
49
+ log_level (int): The root logger level. Note that only the process of
50
+ rank 0 is affected, while other processes will set the level to
51
+ "Error" and be silent most of the time.
52
+
53
+ Returns:
54
+ logging.Logger: The root logger.
55
+ """
56
+
57
+ logger = get_logger(name='mmpose', log_file=log_file, log_level=log_level)
58
+
59
+ return logger
60
+
61
+
62
+ def _get_mmcv_home():
63
+ mmcv_home = os.path.expanduser(
64
+ os.getenv(
65
+ ENV_MMCV_HOME,
66
+ os.path.join(
67
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
68
+
69
+ mkdir_or_exist(mmcv_home)
70
+ return mmcv_home
71
+
72
+
73
+ def load_state_dict(module, state_dict, strict=False, logger=None):
74
+ """Load state_dict to a module.
75
+
76
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
77
+ Default value for ``strict`` is set to ``False`` and the message for
78
+ param mismatch will be shown even if strict is False.
79
+
80
+ Args:
81
+ module (Module): Module that receives the state_dict.
82
+ state_dict (OrderedDict): Weights.
83
+ strict (bool): whether to strictly enforce that the keys
84
+ in :attr:`state_dict` match the keys returned by this module's
85
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
86
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
87
+ message. If not specified, print function will be used.
88
+ """
89
+ unexpected_keys = []
90
+ all_missing_keys = []
91
+ err_msg = []
92
+
93
+ metadata = getattr(state_dict, '_metadata', None)
94
+ state_dict = state_dict.copy()
95
+ if metadata is not None:
96
+ state_dict._metadata = metadata
97
+
98
+ # use _load_from_state_dict to enable checkpoint version control
99
+ def load(module, prefix=''):
100
+ # recursively check parallel module in case that the model has a
101
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
102
+ if is_module_wrapper(module):
103
+ module = module.module
104
+ local_metadata = {} if metadata is None else metadata.get(
105
+ prefix[:-1], {})
106
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
107
+ all_missing_keys, unexpected_keys,
108
+ err_msg)
109
+ for name, child in module._modules.items():
110
+ if child is not None:
111
+ load(child, prefix + name + '.')
112
+
113
+ load(module)
114
+ load = None # break load->load reference cycle
115
+
116
+ # ignore "num_batches_tracked" of BN layers
117
+ missing_keys = [
118
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
119
+ ]
120
+
121
+ if unexpected_keys:
122
+ err_msg.append('unexpected key in source '
123
+ f'state_dict: {", ".join(unexpected_keys)}\n')
124
+ if missing_keys:
125
+ err_msg.append(
126
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
127
+
128
+ rank, _ = get_dist_info()
129
+ if len(err_msg) > 0 and rank == 0:
130
+ err_msg.insert(
131
+ 0, 'The model and loaded state dict do not match exactly\n')
132
+ err_msg = '\n'.join(err_msg)
133
+ if strict:
134
+ raise RuntimeError(err_msg)
135
+ elif logger is not None:
136
+ logger.warning(err_msg)
137
+ else:
138
+ print(err_msg)
139
+
140
+
141
+ def load_url_dist(url, model_dir=None):
142
+ """In distributed setting, this function only download checkpoint at local
143
+ rank 0."""
144
+ rank, world_size = get_dist_info()
145
+ rank = int(os.environ.get('LOCAL_RANK', rank))
146
+ if rank == 0:
147
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
148
+ if world_size > 1:
149
+ torch.distributed.barrier()
150
+ if rank > 0:
151
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
152
+ return checkpoint
153
+
154
+
155
+ def load_pavimodel_dist(model_path, map_location=None):
156
+ """In distributed setting, this function only download checkpoint at local
157
+ rank 0."""
158
+ try:
159
+ from pavi import modelcloud
160
+ except ImportError:
161
+ raise ImportError(
162
+ 'Please install pavi to load checkpoint from modelcloud.')
163
+ rank, world_size = get_dist_info()
164
+ rank = int(os.environ.get('LOCAL_RANK', rank))
165
+ if rank == 0:
166
+ model = modelcloud.get(model_path)
167
+ with TemporaryDirectory() as tmp_dir:
168
+ downloaded_file = osp.join(tmp_dir, model.name)
169
+ model.download(downloaded_file)
170
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
171
+ if world_size > 1:
172
+ torch.distributed.barrier()
173
+ if rank > 0:
174
+ model = modelcloud.get(model_path)
175
+ with TemporaryDirectory() as tmp_dir:
176
+ downloaded_file = osp.join(tmp_dir, model.name)
177
+ model.download(downloaded_file)
178
+ checkpoint = torch.load(
179
+ downloaded_file, map_location=map_location)
180
+ return checkpoint
181
+
182
+
183
+ def load_fileclient_dist(filename, backend, map_location):
184
+ """In distributed setting, this function only download checkpoint at local
185
+ rank 0."""
186
+ rank, world_size = get_dist_info()
187
+ rank = int(os.environ.get('LOCAL_RANK', rank))
188
+ allowed_backends = ['ceph']
189
+ if backend not in allowed_backends:
190
+ raise ValueError(f'Load from Backend {backend} is not supported.')
191
+ if rank == 0:
192
+ fileclient = FileClient(backend=backend)
193
+ buffer = io.BytesIO(fileclient.get(filename))
194
+ checkpoint = torch.load(buffer, map_location=map_location)
195
+ if world_size > 1:
196
+ torch.distributed.barrier()
197
+ if rank > 0:
198
+ fileclient = FileClient(backend=backend)
199
+ buffer = io.BytesIO(fileclient.get(filename))
200
+ checkpoint = torch.load(buffer, map_location=map_location)
201
+ return checkpoint
202
+
203
+
204
+ def get_torchvision_models():
205
+ model_urls = dict()
206
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
207
+ if ispkg:
208
+ continue
209
+ _zoo = import_module(f'torchvision.models.{name}')
210
+ if hasattr(_zoo, 'model_urls'):
211
+ _urls = getattr(_zoo, 'model_urls')
212
+ model_urls.update(_urls)
213
+ return model_urls
214
+
215
+
216
+ def get_external_models():
217
+ mmcv_home = _get_mmcv_home()
218
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
219
+ default_urls = load_file(default_json_path)
220
+ assert isinstance(default_urls, dict)
221
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
222
+ if osp.exists(external_json_path):
223
+ external_urls = load_file(external_json_path)
224
+ assert isinstance(external_urls, dict)
225
+ default_urls.update(external_urls)
226
+
227
+ return default_urls
228
+
229
+
230
+ def get_mmcls_models():
231
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
232
+ mmcls_urls = load_file(mmcls_json_path)
233
+
234
+ return mmcls_urls
235
+
236
+
237
+ def get_deprecated_model_names():
238
+ deprecate_json_path = osp.join(mmcv.__path__[0],
239
+ 'model_zoo/deprecated.json')
240
+ deprecate_urls = load_file(deprecate_json_path)
241
+ assert isinstance(deprecate_urls, dict)
242
+
243
+ return deprecate_urls
244
+
245
+
246
+ def _process_mmcls_checkpoint(checkpoint):
247
+ state_dict = checkpoint['state_dict']
248
+ new_state_dict = OrderedDict()
249
+ for k, v in state_dict.items():
250
+ if k.startswith('backbone.'):
251
+ new_state_dict[k[9:]] = v
252
+ new_checkpoint = dict(state_dict=new_state_dict)
253
+
254
+ return new_checkpoint
255
+
256
+
257
+ def _load_checkpoint(filename, map_location=None):
258
+ """Load checkpoint from somewhere (modelzoo, file, url).
259
+
260
+ Args:
261
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
262
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
263
+ details.
264
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
265
+
266
+ Returns:
267
+ dict | OrderedDict: The loaded checkpoint. It can be either an
268
+ OrderedDict storing model weights or a dict containing other
269
+ information, which depends on the checkpoint.
270
+ """
271
+ if filename.startswith('modelzoo://'):
272
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
273
+ 'use "torchvision://" instead')
274
+ model_urls = get_torchvision_models()
275
+ model_name = filename[11:]
276
+ checkpoint = load_url_dist(model_urls[model_name])
277
+ elif filename.startswith('torchvision://'):
278
+ model_urls = get_torchvision_models()
279
+ model_name = filename[14:]
280
+ checkpoint = load_url_dist(model_urls[model_name])
281
+ elif filename.startswith('open-mmlab://'):
282
+ model_urls = get_external_models()
283
+ model_name = filename[13:]
284
+ deprecated_urls = get_deprecated_model_names()
285
+ if model_name in deprecated_urls:
286
+ warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
287
+ f'of open-mmlab://{deprecated_urls[model_name]}')
288
+ model_name = deprecated_urls[model_name]
289
+ model_url = model_urls[model_name]
290
+ # check if is url
291
+ if model_url.startswith(('http://', 'https://')):
292
+ checkpoint = load_url_dist(model_url)
293
+ else:
294
+ filename = osp.join(_get_mmcv_home(), model_url)
295
+ if not osp.isfile(filename):
296
+ raise IOError(f'{filename} is not a checkpoint file')
297
+ checkpoint = torch.load(filename, map_location=map_location)
298
+ elif filename.startswith('mmcls://'):
299
+ model_urls = get_mmcls_models()
300
+ model_name = filename[8:]
301
+ checkpoint = load_url_dist(model_urls[model_name])
302
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
303
+ elif filename.startswith(('http://', 'https://')):
304
+ checkpoint = load_url_dist(filename)
305
+ elif filename.startswith('pavi://'):
306
+ model_path = filename[7:]
307
+ checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
308
+ elif filename.startswith('s3://'):
309
+ checkpoint = load_fileclient_dist(
310
+ filename, backend='ceph', map_location=map_location)
311
+ else:
312
+ if not osp.isfile(filename):
313
+ raise IOError(f'{filename} is not a checkpoint file')
314
+ checkpoint = torch.load(filename, map_location=map_location)
315
+ return checkpoint
316
+
317
+
318
+ def load_checkpoint_swin(model,
319
+ filename,
320
+ map_location='cpu',
321
+ strict=False,
322
+ rpe_interpolation='outer_mask',
323
+ logger=None):
324
+ """Load checkpoint from a file or URI.
325
+
326
+ Args:
327
+ model (Module): Module to load checkpoint.
328
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
329
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
330
+ details.
331
+ map_location (str): Same as :func:`torch.load`.
332
+ strict (bool): Whether to allow different params for the model and
333
+ checkpoint.
334
+ logger (:mod:`logging.Logger` or None): The logger for error message.
335
+
336
+ Returns:
337
+ dict or OrderedDict: The loaded checkpoint.
338
+ """
339
+ checkpoint = _load_checkpoint(filename, map_location)
340
+ # OrderedDict is a subclass of dict
341
+ if not isinstance(checkpoint, dict):
342
+ raise RuntimeError(
343
+ f'No state_dict found in checkpoint file {filename}')
344
+ # get state_dict from checkpoint
345
+ if 'state_dict' in checkpoint:
346
+ state_dict = checkpoint['state_dict']
347
+ elif 'model' in checkpoint:
348
+ state_dict = checkpoint['model']
349
+ elif 'module' in checkpoint:
350
+ state_dict = checkpoint['module']
351
+ else:
352
+ state_dict = checkpoint
353
+ # strip prefix of state_dict
354
+ if list(state_dict.keys())[0].startswith('module.'):
355
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
356
+
357
+ # for MoBY, load model of online branch
358
+ if sorted(list(state_dict.keys()))[2].startswith('encoder'):
359
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
360
+
361
+ # reshape absolute position embedding for Swin
362
+ if state_dict.get('absolute_pos_embed') is not None:
363
+ absolute_pos_embed = state_dict['absolute_pos_embed']
364
+ N1, L, C1 = absolute_pos_embed.size()
365
+ N2, C2, H, W = model.absolute_pos_embed.size()
366
+ if N1 != N2 or C1 != C2 or L != H * W:
367
+ logger.warning("Error in loading absolute_pos_embed, pass")
368
+ else:
369
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
370
+
371
+ # interpolate position bias table if needed
372
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
373
+ for k in relative_position_bias_table_keys:
374
+ table_pretrained = state_dict[k]
375
+ table_current = model.state_dict()[k]
376
+ L1, nH1 = table_pretrained.size()
377
+ L2, nH2 = table_current.size()
378
+ if nH1 != nH2:
379
+ logger.warning(f"Error in loading {k}, pass")
380
+ else:
381
+ if L1 != L2:
382
+ if rpe_interpolation in ['bicubic', 'bilinear', 'nearest']:
383
+ logger.info(f"Interpolate relative_position_bias_table using {rpe_interpolation}")
384
+ S1 = int(L1 ** 0.5)
385
+ S2 = int(L2 ** 0.5)
386
+ table_pretrained_resized = F.interpolate(
387
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
388
+ size=(S2, S2), mode=rpe_interpolation)
389
+ state_dict[k] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
390
+ elif rpe_interpolation == 'geo':
391
+ logger.info("Interpolate relative_position_bias_table using geo.")
392
+ src_size = int(L1 ** 0.5)
393
+ dst_size = int(L2 ** 0.5)
394
+
395
+ def geometric_progression(a, r, n):
396
+ return a * (1.0 - r ** n) / (1.0 - r)
397
+
398
+ left, right = 1.01, 1.5
399
+ while right - left > 1e-6:
400
+ q = (left + right) / 2.0
401
+ gp = geometric_progression(1, q, src_size // 2)
402
+ if gp > dst_size // 2:
403
+ right = q
404
+ else:
405
+ left = q
406
+
407
+ # if q > 1.13492:
408
+ # q = 1.13492
409
+
410
+ dis = []
411
+ cur = 1
412
+ for i in range(src_size // 2):
413
+ dis.append(cur)
414
+ cur += q ** (i + 1)
415
+
416
+ r_ids = [-_ for _ in reversed(dis)]
417
+
418
+ x = r_ids + [0] + dis
419
+ y = r_ids + [0] + dis
420
+
421
+ t = dst_size // 2.0
422
+ dx = np.arange(-t, t + 0.1, 1.0)
423
+ dy = np.arange(-t, t + 0.1, 1.0)
424
+
425
+ logger.info("Original positions = %s" % str(x))
426
+ logger.info("Target positions = %s" % str(dx))
427
+
428
+ all_rel_pos_bias = []
429
+
430
+ for i in range(nH1):
431
+ z = table_pretrained[:, i].view(src_size, src_size).float().numpy()
432
+ f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
433
+ all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to(
434
+ table_pretrained.device))
435
+
436
+ new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
437
+ state_dict[k] = new_rel_pos_bias
438
+
439
+ if 'pos_embed' in state_dict:
440
+ pos_embed_checkpoint = state_dict['pos_embed']
441
+ embedding_size = pos_embed_checkpoint.shape[-1]
442
+ num_patches = model.patch_embed.num_patches
443
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
444
+ # height (== width) for the checkpoint position embedding
445
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
446
+ # height (== width) for the new position embedding
447
+ new_size = int(num_patches ** 0.5)
448
+ # class_token and dist_token are kept unchanged
449
+ if orig_size != new_size:
450
+ if dist.get_rank() == 0:
451
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
452
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
453
+ # only the position tokens are interpolated
454
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
455
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
456
+ pos_tokens = torch.nn.functional.interpolate(
457
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
458
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
459
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
460
+ state_dict['pos_embed'] = new_pos_embed
461
+
462
+ # load state_dict
463
+ load_state_dict(model, state_dict, strict, logger)
464
+ return checkpoint
465
+
466
+
467
+ def weights_to_cpu(state_dict):
468
+ """Copy a model state_dict to cpu.
469
+
470
+ Args:
471
+ state_dict (OrderedDict): Model weights on GPU.
472
+
473
+ Returns:
474
+ OrderedDict: Model weights on GPU.
475
+ """
476
+ state_dict_cpu = OrderedDict()
477
+ for key, val in state_dict.items():
478
+ state_dict_cpu[key] = val.cpu()
479
+ return state_dict_cpu
480
+
481
+
482
+ def _save_to_state_dict(module, destination, prefix, keep_vars):
483
+ """Saves module state to `destination` dictionary.
484
+
485
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
486
+
487
+ Args:
488
+ module (nn.Module): The module to generate state_dict.
489
+ destination (dict): A dict where state will be stored.
490
+ prefix (str): The prefix for parameters and buffers used in this
491
+ module.
492
+ """
493
+ for name, param in module._parameters.items():
494
+ if param is not None:
495
+ destination[prefix + name] = param if keep_vars else param.detach()
496
+ for name, buf in module._buffers.items():
497
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
498
+ if buf is not None:
499
+ destination[prefix + name] = buf if keep_vars else buf.detach()
500
+
501
+
502
+ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
503
+ """Returns a dictionary containing a whole state of the module.
504
+
505
+ Both parameters and persistent buffers (e.g. running averages) are
506
+ included. Keys are corresponding parameter and buffer names.
507
+
508
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
509
+ recursively check parallel module in case that the model has a complicated
510
+ structure, e.g., nn.Module(nn.Module(DDP)).
511
+
512
+ Args:
513
+ module (nn.Module): The module to generate state_dict.
514
+ destination (OrderedDict): Returned dict for the state of the
515
+ module.
516
+ prefix (str): Prefix of the key.
517
+ keep_vars (bool): Whether to keep the variable property of the
518
+ parameters. Default: False.
519
+
520
+ Returns:
521
+ dict: A dictionary containing a whole state of the module.
522
+ """
523
+ # recursively check parallel module in case that the model has a
524
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
525
+ if is_module_wrapper(module):
526
+ module = module.module
527
+
528
+ # below is the same as torch.nn.Module.state_dict()
529
+ if destination is None:
530
+ destination = OrderedDict()
531
+ destination._metadata = OrderedDict()
532
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
533
+ version=module._version)
534
+ _save_to_state_dict(module, destination, prefix, keep_vars)
535
+ for name, child in module._modules.items():
536
+ if child is not None:
537
+ get_state_dict(
538
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
539
+ for hook in module._state_dict_hooks.values():
540
+ hook_result = hook(module, destination, prefix, local_metadata)
541
+ if hook_result is not None:
542
+ destination = hook_result
543
+ return destination
544
+
545
+
546
+ def save_checkpoint(model, filename, optimizer=None, meta=None):
547
+ """Save checkpoint to file.
548
+
549
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
550
+ ``optimizer``. By default ``meta`` will contain version and time info.
551
+
552
+ Args:
553
+ model (Module): Module whose params are to be saved.
554
+ filename (str): Checkpoint filename.
555
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
556
+ meta (dict, optional): Metadata to be saved in checkpoint.
557
+ """
558
+ if meta is None:
559
+ meta = {}
560
+ elif not isinstance(meta, dict):
561
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
562
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
563
+
564
+ if is_module_wrapper(model):
565
+ model = model.module
566
+
567
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
568
+ # save class name to the meta
569
+ meta.update(CLASSES=model.CLASSES)
570
+
571
+ checkpoint = {
572
+ 'meta': meta,
573
+ 'state_dict': weights_to_cpu(get_state_dict(model))
574
+ }
575
+ # save optimizer state dict in the checkpoint
576
+ if isinstance(optimizer, Optimizer):
577
+ checkpoint['optimizer'] = optimizer.state_dict()
578
+ elif isinstance(optimizer, dict):
579
+ checkpoint['optimizer'] = {}
580
+ for name, optim in optimizer.items():
581
+ checkpoint['optimizer'][name] = optim.state_dict()
582
+
583
+ if filename.startswith('pavi://'):
584
+ try:
585
+ from pavi import modelcloud
586
+ from pavi.exception import NodeNotFoundError
587
+ except ImportError:
588
+ raise ImportError(
589
+ 'Please install pavi to load checkpoint from modelcloud.')
590
+ model_path = filename[7:]
591
+ root = modelcloud.Folder()
592
+ model_dir, model_name = osp.split(model_path)
593
+ try:
594
+ model = modelcloud.get(model_dir)
595
+ except NodeNotFoundError:
596
+ model = root.create_training_model(model_dir)
597
+ with TemporaryDirectory() as tmp_dir:
598
+ checkpoint_file = osp.join(tmp_dir, model_name)
599
+ with open(checkpoint_file, 'wb') as f:
600
+ torch.save(checkpoint, f)
601
+ f.flush()
602
+ model.create_file(checkpoint_file, name=model_name)
603
+ else:
604
+ mmcv.mkdir_or_exist(osp.dirname(filename))
605
+ # immediately flush buffer
606
+ with open(filename, 'wb') as f:
607
+ torch.save(checkpoint, f)
608
+ f.flush()
depth/models_depth/dist_layers.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ def log_binom(n, k, eps=1e-7):
30
+ """ log(nCk) using stirling approximation """
31
+ n = n + eps
32
+ k = k + eps
33
+ return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps)
34
+
35
+
36
+ class LogBinomial(nn.Module):
37
+ def __init__(self, n_classes=256, act=torch.softmax):
38
+ """Compute log binomial distribution for n_classes
39
+
40
+ Args:
41
+ n_classes (int, optional): number of output classes. Defaults to 256.
42
+ """
43
+ super().__init__()
44
+ self.K = n_classes
45
+ self.act = act
46
+ self.register_buffer('k_idx', torch.arange(
47
+ 0, n_classes).view(1, -1, 1, 1))
48
+ self.register_buffer('K_minus_1', torch.Tensor(
49
+ [self.K-1]).view(1, -1, 1, 1))
50
+
51
+ def forward(self, x, t=1., eps=1e-4):
52
+ """Compute log binomial distribution for x
53
+
54
+ Args:
55
+ x (torch.Tensor - NCHW): probabilities
56
+ t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1..
57
+ eps (float, optional): Small number for numerical stability. Defaults to 1e-4.
58
+
59
+ Returns:
60
+ torch.Tensor -NCHW: log binomial distribution logbinomial(p;t)
61
+ """
62
+ if x.ndim == 3:
63
+ x = x.unsqueeze(1) # make it nchw
64
+
65
+ one_minus_x = torch.clamp(1 - x, eps, 1)
66
+ x = torch.clamp(x, eps, 1)
67
+ y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \
68
+ torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x)
69
+ return self.act(y/t, dim=1)
70
+
71
+
72
+ class ConditionalLogBinomial(nn.Module):
73
+ def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax):
74
+ """Conditional Log Binomial distribution
75
+
76
+ Args:
77
+ in_features (int): number of input channels in main feature
78
+ condition_dim (int): number of input channels in condition feature
79
+ n_classes (int, optional): Number of classes. Defaults to 256.
80
+ bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2.
81
+ p_eps (float, optional): small eps value. Defaults to 1e-4.
82
+ max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50.
83
+ min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7.
84
+ """
85
+ super().__init__()
86
+ self.p_eps = p_eps
87
+ self.max_temp = max_temp
88
+ self.min_temp = min_temp
89
+ self.log_binomial_transform = LogBinomial(n_classes, act=act)
90
+ bottleneck = (in_features + condition_dim) // bottleneck_factor
91
+ self.mlp = nn.Sequential(
92
+ nn.Conv2d(in_features + condition_dim, bottleneck,
93
+ kernel_size=1, stride=1, padding=0),
94
+ nn.GELU(),
95
+ # 2 for p linear norm, 2 for t linear norm
96
+ nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0),
97
+ nn.Softplus()
98
+ )
99
+
100
+ def forward(self, x, cond):
101
+ """Forward pass
102
+
103
+ Args:
104
+ x (torch.Tensor - NCHW): Main feature
105
+ cond (torch.Tensor - NCHW): condition feature
106
+
107
+ Returns:
108
+ torch.Tensor: Output log binomial distribution
109
+ """
110
+ pt = self.mlp(torch.concat((x, cond), dim=1))
111
+ p, t = pt[:, :2, ...], pt[:, 2:, ...]
112
+
113
+ p = p + self.p_eps
114
+ p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...])
115
+
116
+ t = t + self.p_eps
117
+ t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...])
118
+ t = t.unsqueeze(1)
119
+ t = (self.max_temp - self.min_temp) * t + self.min_temp
120
+
121
+ return self.log_binomial_transform(p, t)
depth/models_depth/layers.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class PatchTransformerEncoder(nn.Module):
6
+ def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4):
7
+ super(PatchTransformerEncoder, self).__init__()
8
+ encoder_layers = nn.TransformerEncoderLayer(embedding_dim, num_heads, dim_feedforward=1024)
9
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=4) # takes shape S,N,E
10
+
11
+ self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim,
12
+ kernel_size=patch_size, stride=patch_size, padding=0)
13
+
14
+ self.positional_encodings = nn.Parameter(torch.rand(900, embedding_dim), requires_grad=True)
15
+
16
+ def forward(self, x):
17
+ embeddings = self.embedding_convPxP(x).flatten(2) # .shape = n,c,s = n, embedding_dim, s
18
+ # embeddings = nn.functional.pad(embeddings, (1,0)) # extra special token at start ?
19
+ embeddings = embeddings + self.positional_encodings[:embeddings.shape[2], :].T.unsqueeze(0)
20
+
21
+ # change to S,N,E format required by transformer
22
+ embeddings = embeddings.permute(2, 0, 1)
23
+ x = self.transformer_encoder(embeddings) # .shape = S, N, E
24
+ return x
25
+
26
+
27
+ class PixelWiseDotProduct(nn.Module):
28
+ def __init__(self):
29
+ super(PixelWiseDotProduct, self).__init__()
30
+
31
+ def forward(self, x, K):
32
+ n, c, h, w = x.size()
33
+ _, cout, ck = K.size()
34
+ assert c == ck, "Number of channels in x and Embedding dimension (at dim 2) of K matrix must match"
35
+ y = torch.matmul(x.view(n, c, h * w).permute(0, 2, 1), K.permute(0, 2, 1)) # .shape = n, hw, cout
36
+ return y.permute(0, 2, 1).view(n, cout, h, w)
depth/models_depth/localbins_layers.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ class SeedBinRegressor(nn.Module):
30
+ def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
31
+ """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval.
32
+
33
+ Args:
34
+ in_features (int): input channels
35
+ n_bins (int, optional): Number of bin centers. Defaults to 16.
36
+ mlp_dim (int, optional): Hidden dimension. Defaults to 256.
37
+ min_depth (float, optional): Min depth value. Defaults to 1e-3.
38
+ max_depth (float, optional): Max depth value. Defaults to 10.
39
+ """
40
+ super().__init__()
41
+ self.version = "1_1"
42
+ self.min_depth = min_depth
43
+ self.max_depth = max_depth
44
+
45
+ self._net = nn.Sequential(
46
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
49
+ nn.ReLU(inplace=True)
50
+ )
51
+
52
+ def forward(self, x):
53
+ """
54
+ Returns tensor of bin_width vectors (centers). One vector b for every pixel
55
+ """
56
+ B = self._net(x)
57
+ eps = 1e-3
58
+ B = B + eps
59
+ B_widths_normed = B / B.sum(dim=1, keepdim=True)
60
+ B_widths = (self.max_depth - self.min_depth) * \
61
+ B_widths_normed # .shape NCHW
62
+ # pad has the form (left, right, top, bottom, front, back)
63
+ B_widths = nn.functional.pad(
64
+ B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth)
65
+ B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
66
+
67
+ B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...])
68
+ return B_widths_normed, B_centers
69
+
70
+
71
+ class SeedBinRegressorUnnormed(nn.Module):
72
+ def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
73
+ """Bin center regressor network. Bin centers are unbounded
74
+
75
+ Args:
76
+ in_features (int): input channels
77
+ n_bins (int, optional): Number of bin centers. Defaults to 16.
78
+ mlp_dim (int, optional): Hidden dimension. Defaults to 256.
79
+ min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
80
+ max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
81
+ """
82
+ super().__init__()
83
+ self.version = "1_1"
84
+ self._net = nn.Sequential(
85
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
86
+ nn.ReLU(inplace=True),
87
+ nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
88
+ nn.Softplus()
89
+ )
90
+
91
+ def forward(self, x):
92
+ """
93
+ Returns tensor of bin_width vectors (centers). One vector b for every pixel
94
+ """
95
+ B_centers = self._net(x)
96
+ return B_centers, B_centers
97
+
98
+
99
+ class Projector(nn.Module):
100
+ def __init__(self, in_features, out_features, mlp_dim=128):
101
+ """Projector MLP
102
+
103
+ Args:
104
+ in_features (int): input channels
105
+ out_features (int): output channels
106
+ mlp_dim (int, optional): hidden dimension. Defaults to 128.
107
+ """
108
+ super().__init__()
109
+
110
+ self._net = nn.Sequential(
111
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
112
+ nn.ReLU(inplace=True),
113
+ nn.Conv2d(mlp_dim, out_features, 1, 1, 0),
114
+ )
115
+
116
+ def forward(self, x):
117
+ return self._net(x)
118
+
119
+
120
+
121
+ class LinearSplitter(nn.Module):
122
+ def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10):
123
+ super().__init__()
124
+
125
+ self.prev_nbins = prev_nbins
126
+ self.split_factor = split_factor
127
+ self.min_depth = min_depth
128
+ self.max_depth = max_depth
129
+
130
+ self._net = nn.Sequential(
131
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
132
+ nn.GELU(),
133
+ nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0),
134
+ nn.ReLU()
135
+ )
136
+
137
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
138
+ """
139
+ x : feature block; shape - n, c, h, w
140
+ b_prev : previous bin widths normed; shape - n, prev_nbins, h, w
141
+ """
142
+ if prev_b_embedding is not None:
143
+ if interpolate:
144
+ prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
145
+ x = x + prev_b_embedding
146
+ S = self._net(x)
147
+ eps = 1e-3
148
+ S = S + eps
149
+ n, c, h, w = S.shape
150
+ S = S.view(n, self.prev_nbins, self.split_factor, h, w)
151
+ S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits
152
+
153
+ b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True)
154
+
155
+
156
+ b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees
157
+ # print(b_prev.shape, S_normed.shape)
158
+ # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat?
159
+ b = b_prev.unsqueeze(2) * S_normed
160
+ b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w
161
+
162
+ # calculate bin centers for loss calculation
163
+ B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W
164
+ # pad has the form (left, right, top, bottom, front, back)
165
+ B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth)
166
+ B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
167
+
168
+ B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...])
169
+ return b, B_centers
depth/models_depth/miniViT.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .layers import PatchTransformerEncoder, PixelWiseDotProduct
5
+
6
+
7
+ class mViT(nn.Module):
8
+ def __init__(self, in_channels, n_query_channels=128, patch_size=16, dim_out=256,
9
+ embedding_dim=128, num_heads=4, norm='linear'):
10
+ super(mViT, self).__init__()
11
+ self.norm = norm
12
+ self.n_query_channels = n_query_channels
13
+ self.patch_transformer = PatchTransformerEncoder(in_channels, patch_size, embedding_dim, num_heads)
14
+ self.dot_product_layer = PixelWiseDotProduct()
15
+
16
+ self.conv3x3 = nn.Conv2d(in_channels, embedding_dim, kernel_size=3, stride=1, padding=1)
17
+ self.regressor = nn.Sequential(nn.Linear(embedding_dim, 256),
18
+ nn.LeakyReLU(),
19
+ nn.Linear(256, 256),
20
+ nn.LeakyReLU(),
21
+ nn.Linear(256, dim_out))
22
+
23
+ def forward(self, x):
24
+ # n, c, h, w = x.size()
25
+ tgt = self.patch_transformer(x.clone()) # .shape = S, N, E
26
+
27
+ x = self.conv3x3(x)
28
+
29
+ regression_head, queries = tgt[0, ...], tgt[1:self.n_query_channels + 1, ...]
30
+
31
+ # Change from S, N, E to N, S, E
32
+ queries = queries.permute(1, 0, 2)
33
+ range_attention_maps = self.dot_product_layer(x, queries) # .shape = n, n_query_channels, h, w
34
+
35
+ y = self.regressor(regression_head) # .shape = N, dim_out
36
+ if self.norm == 'linear':
37
+ y = torch.relu(y)
38
+ eps = 0.1
39
+ y = y + eps
40
+ elif self.norm == 'softmax':
41
+ return torch.softmax(y, dim=1), range_attention_maps
42
+ else:
43
+ y = torch.sigmoid(y)
44
+ y = y / y.sum(dim=1, keepdim=True)
45
+ return y, range_attention_maps
depth/models_depth/model.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # The deconvolution code is based on Simple Baseline.
5
+ # (https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py)
6
+ # Modified by Zigang Geng ([email protected]).
7
+ # ------------------------------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.models.layers import trunc_normal_, DropPath
12
+ from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
13
+ constant_init, normal_init)
14
+ from omegaconf import OmegaConf
15
+ from ldm.util import instantiate_from_config
16
+ import torch.nn.functional as F
17
+
18
+ from evp.models import UNetWrapper, TextAdapterRefer, FrozenCLIPEmbedder
19
+ from .miniViT import mViT
20
+ from .attractor import AttractorLayer, AttractorLayerUnnormed
21
+ from .dist_layers import ConditionalLogBinomial
22
+ from .localbins_layers import (Projector, SeedBinRegressor, SeedBinRegressorUnnormed)
23
+ import os
24
+
25
+
26
+ def icnr(x, scale=2, init=nn.init.kaiming_normal_):
27
+ """
28
+ Checkerboard artifact free sub-pixel convolution
29
+ https://arxiv.org/abs/1707.02937
30
+ """
31
+ ni,nf,h,w = x.shape
32
+ ni2 = int(ni/(scale**2))
33
+ k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
34
+ k = k.contiguous().view(ni2, nf, -1)
35
+ k = k.repeat(1, 1, scale**2)
36
+ k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
37
+ x.data.copy_(k)
38
+
39
+
40
+ class PixelShuffle(nn.Module):
41
+ """
42
+ Real-Time Single Image and Video Super-Resolution
43
+ https://arxiv.org/abs/1609.05158
44
+ """
45
+ def __init__(self, n_channels, scale):
46
+ super(PixelShuffle, self).__init__()
47
+ self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
48
+ icnr(self.conv.weight)
49
+ self.shuf = nn.PixelShuffle(scale)
50
+ self.relu = nn.ReLU()
51
+
52
+ def forward(self,x):
53
+ x = self.shuf(self.relu(self.conv(x)))
54
+ return x
55
+
56
+
57
+ class AttentionModule(nn.Module):
58
+ def __init__(self, in_channels, out_channels):
59
+ super(AttentionModule, self).__init__()
60
+
61
+ # Convolutional Layers
62
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
63
+
64
+ # Group Normalization
65
+ self.group_norm = nn.GroupNorm(20, out_channels)
66
+
67
+ # ReLU Activation
68
+ self.relu = nn.ReLU()
69
+
70
+ # Spatial Attention
71
+ self.spatial_attention = nn.Sequential(
72
+ nn.Conv2d(in_channels, 1, kernel_size=1),
73
+ nn.Sigmoid()
74
+ )
75
+
76
+ def forward(self, x):
77
+ # Apply spatial attention
78
+ spatial_attention = self.spatial_attention(x)
79
+ x = x * spatial_attention
80
+
81
+ # Apply convolutional layer
82
+ x = self.conv1(x)
83
+ x = self.group_norm(x)
84
+ x = self.relu(x)
85
+
86
+ return x
87
+
88
+
89
+ class AttentionDownsamplingModule(nn.Module):
90
+ def __init__(self, in_channels, out_channels, scale_factor=2):
91
+ super(AttentionDownsamplingModule, self).__init__()
92
+
93
+ # Spatial Attention
94
+ self.spatial_attention = nn.Sequential(
95
+ nn.Conv2d(in_channels, 1, kernel_size=1),
96
+ nn.Sigmoid()
97
+ )
98
+
99
+ # Channel Attention
100
+ self.channel_attention = nn.Sequential(
101
+ nn.AdaptiveAvgPool2d(1),
102
+ nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
103
+ nn.ReLU(inplace=True),
104
+ nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
105
+ nn.Sigmoid()
106
+ )
107
+
108
+ # Convolutional Layers
109
+ if scale_factor == 2:
110
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
111
+ elif scale_factor == 4:
112
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
113
+
114
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
115
+
116
+ # Group Normalization
117
+ self.group_norm = nn.GroupNorm(20, out_channels)
118
+
119
+ # ReLU Activation
120
+ self.relu = nn.ReLU(inplace=True)
121
+
122
+ def forward(self, x):
123
+ # Apply spatial attention
124
+ spatial_attention = self.spatial_attention(x)
125
+ x = x * spatial_attention
126
+
127
+ # Apply channel attention
128
+ channel_attention = self.channel_attention(x)
129
+ x = x * channel_attention
130
+
131
+ # Apply convolutional layers
132
+ x = self.conv1(x)
133
+ x = self.group_norm(x)
134
+ x = self.relu(x)
135
+ x = self.conv2(x)
136
+ x = self.group_norm(x)
137
+ x = self.relu(x)
138
+
139
+ return x
140
+
141
+
142
+ class AttentionUpsamplingModule(nn.Module):
143
+ def __init__(self, in_channels, out_channels):
144
+ super(AttentionUpsamplingModule, self).__init__()
145
+
146
+ # Spatial Attention for outs[2]
147
+ self.spatial_attention = nn.Sequential(
148
+ nn.Conv2d(in_channels, 1, kernel_size=1),
149
+ nn.Sigmoid()
150
+ )
151
+
152
+ # Channel Attention for outs[2]
153
+ self.channel_attention = nn.Sequential(
154
+ nn.AdaptiveAvgPool2d(1),
155
+ nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
156
+ nn.ReLU(),
157
+ nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
158
+ nn.Sigmoid()
159
+ )
160
+
161
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
162
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
+
164
+ # Group Normalization
165
+ self.group_norm = nn.GroupNorm(20, out_channels)
166
+
167
+ # ReLU Activation
168
+ self.relu = nn.ReLU()
169
+ self.upscale = PixelShuffle(in_channels, 2)
170
+
171
+ def forward(self, x):
172
+ # Apply spatial attention
173
+ spatial_attention = self.spatial_attention(x)
174
+ x = x * spatial_attention
175
+
176
+ # Apply channel attention
177
+ channel_attention = self.channel_attention(x)
178
+ x = x * channel_attention
179
+
180
+ # Apply convolutional layers
181
+ x = self.conv1(x)
182
+ x = self.group_norm(x)
183
+ x = self.relu(x)
184
+ x = self.conv2(x)
185
+ x = self.group_norm(x)
186
+ x = self.relu(x)
187
+
188
+ # Upsample
189
+ x = self.upscale(x)
190
+
191
+ return x
192
+
193
+
194
+ class ConvLayer(nn.Module):
195
+ def __init__(self, in_channels, out_channels):
196
+ super(ConvLayer, self).__init__()
197
+
198
+ self.conv1 = nn.Sequential(
199
+ nn.Conv2d(in_channels, out_channels, 1),
200
+ nn.GroupNorm(20, out_channels),
201
+ nn.ReLU(),
202
+ )
203
+
204
+ def forward(self, x):
205
+ x = self.conv1(x)
206
+
207
+ return x
208
+
209
+
210
+ class InverseMultiAttentiveFeatureRefinement(nn.Module):
211
+ def __init__(self, in_channels_list):
212
+ super(InverseMultiAttentiveFeatureRefinement, self).__init__()
213
+
214
+ self.layer1 = AttentionModule(in_channels_list[0], in_channels_list[0])
215
+ self.layer2 = AttentionDownsamplingModule(in_channels_list[0], in_channels_list[0]//2, scale_factor = 2)
216
+ self.layer3 = ConvLayer(in_channels_list[0]//2 + in_channels_list[1], in_channels_list[1])
217
+ self.layer4 = AttentionDownsamplingModule(in_channels_list[1], in_channels_list[1]//2, scale_factor = 2)
218
+ self.layer5 = ConvLayer(in_channels_list[1]//2 + in_channels_list[2], in_channels_list[2])
219
+ self.layer6 = AttentionDownsamplingModule(in_channels_list[2], in_channels_list[2]//2, scale_factor = 2)
220
+ self.layer7 = ConvLayer(in_channels_list[2]//2 + in_channels_list[3], in_channels_list[3])
221
+
222
+ '''
223
+ self.layer8 = AttentionUpsamplingModule(in_channels_list[3], in_channels_list[3])
224
+ self.layer9 = ConvLayer(in_channels_list[2] + in_channels_list[3], in_channels_list[2])
225
+ self.layer10 = AttentionUpsamplingModule(in_channels_list[2], in_channels_list[2])
226
+ self.layer11 = ConvLayer(in_channels_list[1] + in_channels_list[2], in_channels_list[1])
227
+ self.layer12 = AttentionUpsamplingModule(in_channels_list[1], in_channels_list[1])
228
+ self.layer13 = ConvLayer(in_channels_list[0] + in_channels_list[1], in_channels_list[0])
229
+ '''
230
+ def forward(self, inputs):
231
+ x_c4, x_c3, x_c2, x_c1 = inputs
232
+ x_c4 = self.layer1(x_c4)
233
+ x_c4_3 = self.layer2(x_c4)
234
+ x_c3 = torch.cat([x_c4_3, x_c3], dim=1)
235
+ x_c3 = self.layer3(x_c3)
236
+ x_c3_2 = self.layer4(x_c3)
237
+ x_c2 = torch.cat([x_c3_2, x_c2], dim=1)
238
+ x_c2 = self.layer5(x_c2)
239
+ x_c2_1 = self.layer6(x_c2)
240
+ x_c1 = torch.cat([x_c2_1, x_c1], dim=1)
241
+ x_c1 = self.layer7(x_c1)
242
+ '''
243
+ x_c1_2 = self.layer8(x_c1)
244
+ x_c2 = torch.cat([x_c1_2, x_c2], dim=1)
245
+ x_c2 = self.layer9(x_c2)
246
+ x_c2_3 = self.layer10(x_c2)
247
+ x_c3 = torch.cat([x_c2_3, x_c3], dim=1)
248
+ x_c3 = self.layer11(x_c3)
249
+ x_c3_4 = self.layer12(x_c3)
250
+ x_c4 = torch.cat([x_c3_4, x_c4], dim=1)
251
+ x_c4 = self.layer13(x_c4)
252
+ '''
253
+ return [x_c4, x_c3, x_c2, x_c1]
254
+
255
+
256
+ class EVPDepthEncoder(nn.Module):
257
+ def __init__(self, out_dim=1024, ldm_prior=[320, 680, 1320+1280], sd_path=None, text_dim=768,
258
+ dataset='nyu', caption_aggregation=False
259
+ ):
260
+ super().__init__()
261
+
262
+
263
+ self.layer1 = nn.Sequential(
264
+ nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
265
+ nn.GroupNorm(16, ldm_prior[0]),
266
+ nn.ReLU(),
267
+ nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
268
+ )
269
+
270
+ self.layer2 = nn.Sequential(
271
+ nn.Conv2d(ldm_prior[1], ldm_prior[1], 3, stride=2, padding=1),
272
+ )
273
+
274
+ self.out_layer = nn.Sequential(
275
+ nn.Conv2d(sum(ldm_prior), out_dim, 1),
276
+ nn.GroupNorm(16, out_dim),
277
+ nn.ReLU(),
278
+ )
279
+
280
+ self.aggregation = InverseMultiAttentiveFeatureRefinement([320, 680, 1320, 1280])
281
+
282
+ self.apply(self._init_weights)
283
+
284
+ ### stable diffusion layers
285
+
286
+ config = OmegaConf.load('./v1-inference.yaml')
287
+ if sd_path is None:
288
+ if os.path.exists('../checkpoints/v1-5-pruned-emaonly.ckpt'):
289
+ config.model.params.ckpt_path = '../checkpoints/v1-5-pruned-emaonly.ckpt'
290
+ else:
291
+ config.model.params.ckpt_path = None
292
+ else:
293
+ config.model.params.ckpt_path = f'../{sd_path}'
294
+
295
+ sd_model = instantiate_from_config(config.model)
296
+ self.encoder_vq = sd_model.first_stage_model
297
+
298
+ self.unet = UNetWrapper(sd_model.model, use_attn=True)
299
+ if dataset == 'kitti':
300
+ self.unet = UNetWrapper(sd_model.model, use_attn=True, base_size=384)
301
+
302
+ del sd_model.cond_stage_model
303
+ del self.encoder_vq.decoder
304
+ del self.unet.unet.diffusion_model.out
305
+ del self.encoder_vq.post_quant_conv.weight
306
+ del self.encoder_vq.post_quant_conv.bias
307
+
308
+ for param in self.encoder_vq.parameters():
309
+ param.requires_grad = True
310
+
311
+ self.text_adapter = TextAdapterRefer(text_dim=text_dim)
312
+ self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4)
313
+
314
+ if caption_aggregation:
315
+ class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth')
316
+ #class_embeddings_list = [value['class_embeddings'] for key, value in class_embeddings.items()]
317
+ #stacked_embeddings = torch.stack(class_embeddings_list, dim=0)
318
+ #class_embeddings = torch.mean(stacked_embeddings, dim=0).unsqueeze(0)
319
+
320
+ if 'aggregated' in class_embeddings:
321
+ class_embeddings = class_embeddings['aggregated']
322
+ else:
323
+ clip_model = FrozenCLIPEmbedder(max_length=40,pool=False).cuda()
324
+ class_embeddings_new = [clip_model.encode(value['caption'][0]) for key, value in class_embeddings.items()]
325
+ class_embeddings_new = torch.mean(torch.stack(class_embeddings_new, dim=0), dim=0)
326
+ class_embeddings['aggregated'] = class_embeddings_new
327
+ torch.save(class_embeddings, f'{dataset}_class_embeddings_my_captions.pth')
328
+ class_embeddings = class_embeddings['aggregated']
329
+ self.register_buffer('class_embeddings', class_embeddings)
330
+ else:
331
+ self.class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth')
332
+
333
+ self.clip_model = FrozenCLIPEmbedder(max_length=40,pool=False)
334
+ for param in self.clip_model.parameters():
335
+ param.requires_grad = True
336
+
337
+ #if dataset == 'kitti':
338
+ # self.text_adapter_ = TextAdapterRefer(text_dim=text_dim)
339
+ # self.gamma_ = nn.Parameter(torch.ones(text_dim) * 1e-4)
340
+
341
+ self.caption_aggregation = caption_aggregation
342
+ self.dataset = dataset
343
+
344
+ def _init_weights(self, m):
345
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
346
+ trunc_normal_(m.weight, std=.02)
347
+ nn.init.constant_(m.bias, 0)
348
+
349
+ def forward_features(self, feats):
350
+ x = self.ldm_to_net[0](feats[0])
351
+ for i in range(3):
352
+ if i > 0:
353
+ x = x + self.ldm_to_net[i](feats[i])
354
+ x = self.layers[i](x)
355
+ x = self.upsample_layers[i](x)
356
+ return self.out_conv(x)
357
+
358
+ def forward(self, x, class_ids=None, img_paths=None):
359
+ latents = self.encoder_vq.encode(x).mode()
360
+
361
+ # add division by std
362
+ if self.dataset == 'nyu':
363
+ latents = latents / 5.07543
364
+ elif self.dataset == 'kitti':
365
+ latents = latents / 4.6211
366
+ else:
367
+ print('Please calculate the STD for the dataset!')
368
+
369
+ if class_ids is not None:
370
+ if self.caption_aggregation:
371
+ class_embeddings = self.class_embeddings[[0]*len(class_ids.tolist())]#[class_ids.tolist()]
372
+ else:
373
+ class_embeddings = []
374
+
375
+ for img_path in img_paths:
376
+ class_embeddings.extend([value['caption'][0] for key, value in self.class_embeddings.items() if key in img_path.replace('//', '/')])
377
+
378
+ class_embeddings = self.clip_model.encode(class_embeddings)
379
+ else:
380
+ class_embeddings = self.class_embeddings
381
+
382
+ c_crossattn = self.text_adapter(latents, class_embeddings, self.gamma)
383
+ t = torch.ones((x.shape[0],), device=x.device).long()
384
+
385
+ #if self.dataset == 'kitti':
386
+ # c_crossattn_last = self.text_adapter_(latents, class_embeddings, self.gamma_)
387
+ # outs = self.unet(latents, t, c_crossattn=[c_crossattn, c_crossattn_last])
388
+ #else:
389
+ outs = self.unet(latents, t, c_crossattn=[c_crossattn])
390
+ outs = self.aggregation(outs)
391
+
392
+ feats = [outs[0], outs[1], torch.cat([outs[2], F.interpolate(outs[3], scale_factor=2)], dim=1)]
393
+ x = torch.cat([self.layer1(feats[0]), self.layer2(feats[1]), feats[2]], dim=1)
394
+ return self.out_layer(x)
395
+
396
+ def get_latent(self, x):
397
+ return self.encoder_vq.encode(x).mode()
398
+
399
+
400
+ class EVPDepth(nn.Module):
401
+ def __init__(self, args=None, caption_aggregation=False):
402
+ super().__init__()
403
+ self.max_depth = args.max_depth
404
+ self.min_depth = args.min_depth_eval
405
+
406
+ embed_dim = 192
407
+
408
+ channels_in = embed_dim*8
409
+ channels_out = embed_dim
410
+
411
+ if args.dataset == 'nyudepthv2':
412
+ self.encoder = EVPDepthEncoder(out_dim=channels_in, dataset='nyu', caption_aggregation=caption_aggregation)
413
+ else:
414
+ self.encoder = EVPDepthEncoder(out_dim=channels_in, dataset='kitti', caption_aggregation=caption_aggregation)
415
+
416
+ self.decoder = Decoder(channels_in, channels_out, args)
417
+ self.decoder.init_weights()
418
+ self.mViT = False
419
+ self.custom = False
420
+
421
+
422
+ if not self.mViT and not self.custom:
423
+ n_bins = 64
424
+ bin_embedding_dim = 128
425
+ num_out_features = [32, 32, 32, 192]
426
+ min_temp = 0.0212
427
+ max_temp = 50
428
+ btlnck_features = 256
429
+ n_attractors = [16, 8, 4, 1]
430
+ attractor_alpha = 1000
431
+ attractor_gamma = 2
432
+ attractor_kind = "mean"
433
+ attractor_type = "inv"
434
+ self.bin_centers_type = "softplus"
435
+
436
+ self.bottle_neck = nn.Sequential(
437
+ nn.Conv2d(channels_in, btlnck_features, kernel_size=3, stride=1, padding=1),
438
+ nn.ReLU(inplace=False),
439
+ nn.Conv2d(btlnck_features, btlnck_features, kernel_size=3, stride=1, padding=1))
440
+
441
+
442
+ for m in self.bottle_neck.modules():
443
+ if isinstance(m, nn.Conv2d):
444
+ normal_init(m, std=0.001, bias=0)
445
+
446
+
447
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
448
+ Attractor = AttractorLayerUnnormed
449
+ self.seed_bin_regressor = SeedBinRegressorLayer(
450
+ btlnck_features, n_bins=n_bins, min_depth=self.min_depth, max_depth=self.max_depth)
451
+ self.seed_projector = Projector(btlnck_features, bin_embedding_dim)
452
+ self.projectors = nn.ModuleList([
453
+ Projector(num_out, bin_embedding_dim)
454
+ for num_out in num_out_features
455
+ ])
456
+ self.attractors = nn.ModuleList([
457
+ Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=self.min_depth, max_depth=self.max_depth,
458
+ alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type)
459
+ for i in range(len(num_out_features))
460
+ ])
461
+
462
+ last_in = 192 + 1
463
+ self.conditional_log_binomial = ConditionalLogBinomial(
464
+ last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp)
465
+ elif self.mViT and not self.custom:
466
+ n_bins = 256
467
+ self.adaptive_bins_layer = mViT(192, n_query_channels=192, patch_size=16,
468
+ dim_out=n_bins,
469
+ embedding_dim=192, norm='linear')
470
+ self.conv_out = nn.Sequential(nn.Conv2d(192, n_bins, kernel_size=1, stride=1, padding=0),
471
+ nn.Softmax(dim=1))
472
+
473
+
474
+ def forward(self, x, class_ids=None, img_paths=None):
475
+ b, c, h, w = x.shape
476
+ x = x*2.0 - 1.0 # normalize to [-1, 1]
477
+ if h == 480 and w == 480:
478
+ new_x = torch.zeros(b, c, 512, 512, device=x.device)
479
+ new_x[:, :, 0:480, 0:480] = x
480
+ x = new_x
481
+ elif h==352 and w==352:
482
+ new_x = torch.zeros(b, c, 384, 384, device=x.device)
483
+ new_x[:, :, 0:352, 0:352] = x
484
+ x = new_x
485
+ elif h == 512 and w == 512:
486
+ pass
487
+ else:
488
+ print(h,w)
489
+ raise NotImplementedError
490
+ conv_feats = self.encoder(x, class_ids, img_paths)
491
+
492
+ if h == 480 or h == 352:
493
+ conv_feats = conv_feats[:, :, :-1, :-1]
494
+
495
+ self.decoder.remove_hooks()
496
+ out_depth, out, x_blocks = self.decoder([conv_feats])
497
+
498
+ if not self.mViT and not self.custom:
499
+ x = self.bottle_neck(conv_feats)
500
+ _, seed_b_centers = self.seed_bin_regressor(x)
501
+
502
+ if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
503
+ b_prev = (seed_b_centers - self.min_depth) / \
504
+ (self.max_depth - self.min_depth)
505
+ else:
506
+ b_prev = seed_b_centers
507
+
508
+ prev_b_embedding = self.seed_projector(x)
509
+
510
+ for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks):
511
+ b_embedding = projector(x)
512
+ b, b_centers = attractor(
513
+ b_embedding, b_prev, prev_b_embedding, interpolate=True)
514
+ b_prev = b.clone()
515
+ prev_b_embedding = b_embedding.clone()
516
+
517
+ rel_cond = torch.sigmoid(out_depth) * self.max_depth
518
+
519
+ # concat rel depth with last. First interpolate rel depth to last size
520
+ rel_cond = nn.functional.interpolate(
521
+ rel_cond, size=out.shape[2:], mode='bilinear', align_corners=True)
522
+ last = torch.cat([out, rel_cond], dim=1)
523
+
524
+ b_embedding = nn.functional.interpolate(
525
+ b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
526
+ x = self.conditional_log_binomial(last, b_embedding)
527
+
528
+ # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
529
+ b_centers = nn.functional.interpolate(
530
+ b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
531
+ out_depth = torch.sum(x * b_centers, dim=1, keepdim=True)
532
+
533
+ elif self.mViT and not self.custom:
534
+ bin_widths_normed, range_attention_maps = self.adaptive_bins_layer(out)
535
+ out = self.conv_out(range_attention_maps)
536
+
537
+ bin_widths = (self.max_depth - self.min_depth) * bin_widths_normed # .shape = N, dim_out
538
+ bin_widths = nn.functional.pad(bin_widths, (1, 0), mode='constant', value=self.min_depth)
539
+ bin_edges = torch.cumsum(bin_widths, dim=1)
540
+
541
+ centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
542
+ n, dout = centers.size()
543
+ centers = centers.view(n, dout, 1, 1)
544
+
545
+ out_depth = torch.sum(out * centers, dim=1, keepdim=True)
546
+ else:
547
+ out_depth = torch.sigmoid(out_depth) * self.max_depth
548
+
549
+ return {'pred_d': out_depth}
550
+
551
+
552
+ class Decoder(nn.Module):
553
+ def __init__(self, in_channels, out_channels, args):
554
+ super().__init__()
555
+ self.deconv = args.num_deconv
556
+ self.in_channels = in_channels
557
+
558
+ embed_dim = 192
559
+
560
+ channels_in = embed_dim*8
561
+ channels_out = embed_dim
562
+
563
+ self.deconv_layers, self.intermediate_results = self._make_deconv_layer(
564
+ args.num_deconv,
565
+ args.num_filters,
566
+ args.deconv_kernels,
567
+ )
568
+ self.last_layer_depth = nn.Sequential(
569
+ nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1),
570
+ nn.ReLU(inplace=False),
571
+ nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1))
572
+
573
+ for m in self.last_layer_depth.modules():
574
+ if isinstance(m, nn.Conv2d):
575
+ normal_init(m, std=0.001, bias=0)
576
+
577
+ conv_layers = []
578
+ conv_layers.append(
579
+ build_conv_layer(
580
+ dict(type='Conv2d'),
581
+ in_channels=args.num_filters[-1],
582
+ out_channels=out_channels,
583
+ kernel_size=3,
584
+ stride=1,
585
+ padding=1))
586
+ conv_layers.append(
587
+ build_norm_layer(dict(type='BN'), out_channels)[1])
588
+ conv_layers.append(nn.ReLU())
589
+ self.conv_layers = nn.Sequential(*conv_layers)
590
+
591
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
592
+
593
+ def forward(self, conv_feats):
594
+ out = self.deconv_layers(conv_feats[0])
595
+ out = self.conv_layers(out)
596
+ out = self.up(out)
597
+ self.intermediate_results.append(out)
598
+ out = self.up(out)
599
+ out_depth = self.last_layer_depth(out)
600
+
601
+ return out_depth, out, self.intermediate_results
602
+
603
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
604
+ """Make deconv layers."""
605
+
606
+ layers = []
607
+ in_planes = self.in_channels
608
+ intermediate_results = [] # List to store intermediate feature maps
609
+
610
+ for i in range(num_layers):
611
+ kernel, padding, output_padding = \
612
+ self._get_deconv_cfg(num_kernels[i])
613
+
614
+ planes = num_filters[i]
615
+ layers.append(
616
+ build_upsample_layer(
617
+ dict(type='deconv'),
618
+ in_channels=in_planes,
619
+ out_channels=planes,
620
+ kernel_size=kernel,
621
+ stride=2,
622
+ padding=padding,
623
+ output_padding=output_padding,
624
+ bias=False))
625
+ layers.append(nn.BatchNorm2d(planes))
626
+ layers.append(nn.ReLU())
627
+ in_planes = planes
628
+
629
+ # Add a hook to store the intermediate result
630
+ layers[-1].register_forward_hook(self._hook_fn(intermediate_results))
631
+
632
+ return nn.Sequential(*layers), intermediate_results
633
+
634
+ def _hook_fn(self, intermediate_results):
635
+ def hook(module, input, output):
636
+ intermediate_results.append(output)
637
+ return hook
638
+
639
+ def remove_hooks(self):
640
+ self.intermediate_results.clear()
641
+
642
+ def _get_deconv_cfg(self, deconv_kernel):
643
+ """Get configurations for deconv layers."""
644
+ if deconv_kernel == 4:
645
+ padding = 1
646
+ output_padding = 0
647
+ elif deconv_kernel == 3:
648
+ padding = 1
649
+ output_padding = 1
650
+ elif deconv_kernel == 2:
651
+ padding = 0
652
+ output_padding = 0
653
+ else:
654
+ raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
655
+
656
+ return deconv_kernel, padding, output_padding
657
+
658
+ def init_weights(self):
659
+ """Initialize model weights."""
660
+ for m in self.modules():
661
+ if isinstance(m, nn.Conv2d):
662
+ normal_init(m, std=0.001, bias=0)
663
+ elif isinstance(m, nn.BatchNorm2d):
664
+ constant_init(m, 1)
665
+ elif isinstance(m, nn.ConvTranspose2d):
666
+ normal_init(m, std=0.001)
depth/models_depth/model_vpd.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # The deconvolution code is based on Simple Baseline.
5
+ # (https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py)
6
+ # Modified by Zigang Geng ([email protected]).
7
+ # ------------------------------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.models.layers import trunc_normal_, DropPath
12
+ from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
13
+ constant_init, normal_init)
14
+ from omegaconf import OmegaConf
15
+ from ldm.util import instantiate_from_config
16
+ import torch.nn.functional as F
17
+
18
+ from evp.models import UNetWrapper, TextAdapterDepth
19
+
20
+ class VPDDepthEncoder(nn.Module):
21
+ def __init__(self, out_dim=1024, ldm_prior=[320, 640, 1280+1280], sd_path=None, text_dim=768,
22
+ dataset='nyu'
23
+ ):
24
+ super().__init__()
25
+
26
+
27
+ self.layer1 = nn.Sequential(
28
+ nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
29
+ nn.GroupNorm(16, ldm_prior[0]),
30
+ nn.ReLU(),
31
+ nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
32
+ )
33
+
34
+ self.layer2 = nn.Sequential(
35
+ nn.Conv2d(ldm_prior[1], ldm_prior[1], 3, stride=2, padding=1),
36
+ )
37
+
38
+ self.out_layer = nn.Sequential(
39
+ nn.Conv2d(sum(ldm_prior), out_dim, 1),
40
+ nn.GroupNorm(16, out_dim),
41
+ nn.ReLU(),
42
+ )
43
+
44
+ self.apply(self._init_weights)
45
+
46
+ ### stable diffusion layers
47
+
48
+ config = OmegaConf.load('./v1-inference.yaml')
49
+ if sd_path is None:
50
+ config.model.params.ckpt_path = '../checkpoints/v1-5-pruned-emaonly.ckpt'
51
+ else:
52
+ config.model.params.ckpt_path = f'../{sd_path}'
53
+
54
+ sd_model = instantiate_from_config(config.model)
55
+ self.encoder_vq = sd_model.first_stage_model
56
+
57
+ self.unet = UNetWrapper(sd_model.model, use_attn=False)
58
+
59
+ del sd_model.cond_stage_model
60
+ del self.encoder_vq.decoder
61
+ del self.unet.unet.diffusion_model.out
62
+
63
+ for param in self.encoder_vq.parameters():
64
+ param.requires_grad = False
65
+
66
+ if dataset == 'nyu':
67
+ self.text_adapter = TextAdapterDepth(text_dim=text_dim)
68
+ class_embeddings = torch.load('nyu_class_embeddings.pth')
69
+ else:
70
+ raise NotImplementedError
71
+
72
+ self.register_buffer('class_embeddings', class_embeddings)
73
+ self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4)
74
+
75
+
76
+ def _init_weights(self, m):
77
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
78
+ trunc_normal_(m.weight, std=.02)
79
+ nn.init.constant_(m.bias, 0)
80
+
81
+ def forward_features(self, feats):
82
+ x = self.ldm_to_net[0](feats[0])
83
+ for i in range(3):
84
+ if i > 0:
85
+ x = x + self.ldm_to_net[i](feats[i])
86
+ x = self.layers[i](x)
87
+ x = self.upsample_layers[i](x)
88
+ return self.out_conv(x)
89
+
90
+ def forward(self, x, class_ids=None,img_paths=None):
91
+ with torch.no_grad():
92
+ latents = self.encoder_vq.encode(x).mode().detach()
93
+
94
+ if class_ids is not None:
95
+ class_embeddings = self.class_embeddings[class_ids.tolist()]
96
+ else:
97
+ class_embeddings = self.class_embeddings
98
+
99
+ c_crossattn = self.text_adapter(latents, class_embeddings, self.gamma) # NOTE: here the c_crossattn should be expand_dim as latents
100
+ t = torch.ones((x.shape[0],), device=x.device).long()
101
+ # import pdb; pdb.set_trace()
102
+ outs = self.unet(latents, t, c_crossattn=[c_crossattn])
103
+ feats = [outs[0], outs[1], torch.cat([outs[2], F.interpolate(outs[3], scale_factor=2)], dim=1)]
104
+ x = torch.cat([self.layer1(feats[0]), self.layer2(feats[1]), feats[2]], dim=1)
105
+ return self.out_layer(x)
106
+
107
+ class VPDDepth(nn.Module):
108
+ def __init__(self, args=None):
109
+ super().__init__()
110
+ self.max_depth = args.max_depth
111
+
112
+ embed_dim = 192
113
+
114
+ channels_in = embed_dim*8
115
+ channels_out = embed_dim
116
+
117
+ if args.dataset == 'nyudepthv2':
118
+ self.encoder = VPDDepthEncoder(out_dim=channels_in, dataset='nyu')
119
+ else:
120
+ raise NotImplementedError
121
+
122
+ self.decoder = Decoder(channels_in, channels_out, args)
123
+ self.decoder.init_weights()
124
+
125
+ self.last_layer_depth = nn.Sequential(
126
+ nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1),
127
+ nn.ReLU(inplace=False),
128
+ nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1))
129
+
130
+ for m in self.last_layer_depth.modules():
131
+ if isinstance(m, nn.Conv2d):
132
+ normal_init(m, std=0.001, bias=0)
133
+
134
+ def forward(self, x, class_ids=None,img_paths=None):
135
+ # import pdb; pdb.set_trace()
136
+ b, c, h, w = x.shape
137
+ x = x*2.0 - 1.0 # normalize to [-1, 1]
138
+ if h == 480 and w == 480:
139
+ new_x = torch.zeros(b, c, 512, 512, device=x.device)
140
+ new_x[:, :, 0:480, 0:480] = x
141
+ x = new_x
142
+ elif h==352 and w==352:
143
+ new_x = torch.zeros(b, c, 384, 384, device=x.device)
144
+ new_x[:, :, 0:352, 0:352] = x
145
+ x = new_x
146
+ elif h == 512 and w == 512:
147
+ pass
148
+ else:
149
+ raise NotImplementedError
150
+ conv_feats = self.encoder(x, class_ids)
151
+
152
+ if h == 480 or h == 352:
153
+ conv_feats = conv_feats[:, :, :-1, :-1]
154
+
155
+ out = self.decoder([conv_feats])
156
+ out_depth = self.last_layer_depth(out)
157
+ out_depth = torch.sigmoid(out_depth) * self.max_depth
158
+
159
+ return {'pred_d': out_depth}
160
+
161
+
162
+ class Decoder(nn.Module):
163
+ def __init__(self, in_channels, out_channels, args):
164
+ super().__init__()
165
+ self.deconv = args.num_deconv
166
+ self.in_channels = in_channels
167
+
168
+ # import pdb; pdb.set_trace()
169
+
170
+ self.deconv_layers = self._make_deconv_layer(
171
+ args.num_deconv,
172
+ args.num_filters,
173
+ args.deconv_kernels,
174
+ )
175
+
176
+ conv_layers = []
177
+ conv_layers.append(
178
+ build_conv_layer(
179
+ dict(type='Conv2d'),
180
+ in_channels=args.num_filters[-1],
181
+ out_channels=out_channels,
182
+ kernel_size=3,
183
+ stride=1,
184
+ padding=1))
185
+ conv_layers.append(
186
+ build_norm_layer(dict(type='BN'), out_channels)[1])
187
+ conv_layers.append(nn.ReLU(inplace=True))
188
+ self.conv_layers = nn.Sequential(*conv_layers)
189
+
190
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
191
+
192
+ def forward(self, conv_feats):
193
+ # import pdb; pdb.set_trace()
194
+ out = self.deconv_layers(conv_feats[0])
195
+ out = self.conv_layers(out)
196
+
197
+ out = self.up(out)
198
+ out = self.up(out)
199
+
200
+ return out
201
+
202
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
203
+ """Make deconv layers."""
204
+
205
+ layers = []
206
+ in_planes = self.in_channels
207
+ for i in range(num_layers):
208
+ kernel, padding, output_padding = \
209
+ self._get_deconv_cfg(num_kernels[i])
210
+
211
+ planes = num_filters[i]
212
+ layers.append(
213
+ build_upsample_layer(
214
+ dict(type='deconv'),
215
+ in_channels=in_planes,
216
+ out_channels=planes,
217
+ kernel_size=kernel,
218
+ stride=2,
219
+ padding=padding,
220
+ output_padding=output_padding,
221
+ bias=False))
222
+ layers.append(nn.BatchNorm2d(planes))
223
+ layers.append(nn.ReLU(inplace=True))
224
+ in_planes = planes
225
+
226
+ return nn.Sequential(*layers)
227
+
228
+ def _get_deconv_cfg(self, deconv_kernel):
229
+ """Get configurations for deconv layers."""
230
+ if deconv_kernel == 4:
231
+ padding = 1
232
+ output_padding = 0
233
+ elif deconv_kernel == 3:
234
+ padding = 1
235
+ output_padding = 1
236
+ elif deconv_kernel == 2:
237
+ padding = 0
238
+ output_padding = 0
239
+ else:
240
+ raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
241
+
242
+ return deconv_kernel, padding, output_padding
243
+
244
+ def init_weights(self):
245
+ """Initialize model weights."""
246
+ for m in self.modules():
247
+ if isinstance(m, nn.Conv2d):
248
+ normal_init(m, std=0.001, bias=0)
249
+ elif isinstance(m, nn.BatchNorm2d):
250
+ constant_init(m, 1)
251
+ elif isinstance(m, nn.ConvTranspose2d):
252
+ normal_init(m, std=0.001)
depth/models_depth/optimizer.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # The code is from SimMIM.
5
+ # (https://github.com/microsoft/SimMIM)
6
+ # ------------------------------------------------------------------------------
7
+
8
+ import json
9
+ from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
10
+ from mmcv.runner import build_optimizer
11
+ from mmcv.runner import get_dist_info
12
+
13
+
14
+ def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
15
+ var_name = var_name.replace('encoder', 'backbone') if var_name.startswith('encoder') else var_name
16
+ if var_name in ("backbone.cls_token", "backbone.mask_token",
17
+ "backbone.pos_embed", "backbone.absolute_pos_embed"):
18
+ return 0
19
+ elif var_name.startswith("backbone.patch_embed"):
20
+ return 0
21
+ elif var_name.startswith("backbone.layers"):
22
+ if var_name.split('.')[3] == "blocks":
23
+ stage_id = int(var_name.split('.')[2])
24
+ layer_id = int(var_name.split('.')[4]) \
25
+ + sum(layers_per_stage[:stage_id])
26
+ return layer_id + 1
27
+ elif var_name.split('.')[3] == "downsample":
28
+ stage_id = int(var_name.split('.')[2])
29
+ layer_id = sum(layers_per_stage[:stage_id + 1])
30
+ return layer_id
31
+ else:
32
+ return num_max_layer - 1
33
+
34
+ @OPTIMIZER_BUILDERS.register_module()
35
+ class LDMOptimizerConstructor(DefaultOptimizerConstructor):
36
+ def add_params(self, params, module, prefix='', is_dcn_module=None):
37
+ """Add all parameters of module to the params list.
38
+ The parameters of the given module will be added to the list of param
39
+ groups, with specific rules defined by paramwise_cfg.
40
+ Args:
41
+ params (list[dict]): A list of param groups, it will be modified
42
+ in place.
43
+ module (nn.Module): The module to be added.
44
+ prefix (str): The prefix of the module
45
+ is_dcn_module (int|float|None): If the current module is a
46
+ submodule of DCN, `is_dcn_module` will be passed to
47
+ control conv_offset layer's learning rate. Defaults to None.
48
+ """
49
+ parameter_groups = {}
50
+ no_decay_names = self.paramwise_cfg.get('no_decay_names', [])
51
+ print("Build LDMOptimizerConstructor")
52
+ weight_decay = self.base_wd
53
+
54
+ for name, param in module.named_parameters():
55
+ if not param.requires_grad:
56
+ continue # frozen weights
57
+ if len(param.shape) == 1 or name.endswith(".bias") or name in ('absolute_pos_embed'):
58
+ group_name = "no_decay"
59
+ this_weight_decay = 0.
60
+ else:
61
+ group_name = "decay"
62
+ this_weight_decay = weight_decay
63
+
64
+ for nd_name in no_decay_names:
65
+ if nd_name in name:
66
+ group_name = "no_decay"
67
+ this_weight_decay = 0.
68
+ break
69
+
70
+ if 'unet' in name or 'cond_stage_model' in name or 'encoder_vq' in name or 'clip_model' in name:
71
+ layer_id = 0
72
+ else:
73
+ layer_id = 1
74
+ group_name = "layer_%d_%s" % (layer_id, group_name)
75
+
76
+ if group_name not in parameter_groups:
77
+ if layer_id == 0:
78
+ scale = 0.01
79
+ else:
80
+ scale = 1.0
81
+
82
+ parameter_groups[group_name] = {
83
+ "weight_decay": this_weight_decay,
84
+ "params": [],
85
+ "param_names": [],
86
+ "lr_scale": scale,
87
+ "group_name": group_name,
88
+ "lr": scale * self.base_lr,
89
+ }
90
+
91
+ parameter_groups[group_name]["params"].append(param)
92
+ parameter_groups[group_name]["param_names"].append(name)
93
+ rank, _ = get_dist_info()
94
+ if rank == 0:
95
+ to_display = {}
96
+ for key in parameter_groups:
97
+ to_display[key] = {
98
+ "param_names": parameter_groups[key]["param_names"],
99
+ "lr_scale": parameter_groups[key]["lr_scale"],
100
+ "lr": parameter_groups[key]["lr"],
101
+ "weight_decay": parameter_groups[key]["weight_decay"],
102
+ }
103
+
104
+ params.extend(parameter_groups.values())
105
+
106
+ def build_optimizers(model, cfgs):
107
+ """Build multiple optimizers from configs.
108
+
109
+ If `cfgs` contains several dicts for optimizers, then a dict for each
110
+ constructed optimizers will be returned.
111
+ If `cfgs` only contains one optimizer config, the constructed optimizer
112
+ itself will be returned.
113
+
114
+ For example,
115
+
116
+ 1) Multiple optimizer configs:
117
+
118
+ .. code-block:: python
119
+
120
+ optimizer_cfg = dict(
121
+ model1=dict(type='SGD', lr=lr),
122
+ model2=dict(type='SGD', lr=lr))
123
+
124
+ The return dict is
125
+ ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
126
+
127
+ 2) Single optimizer config:
128
+
129
+ .. code-block:: python
130
+
131
+ optimizer_cfg = dict(type='SGD', lr=lr)
132
+
133
+ The return is ``torch.optim.Optimizer``.
134
+
135
+ Args:
136
+ model (:obj:`nn.Module`): The model with parameters to be optimized.
137
+ cfgs (dict): The config dict of the optimizer.
138
+
139
+ Returns:
140
+ dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
141
+ The initialized optimizers.
142
+ """
143
+ optimizers = {}
144
+ if hasattr(model, 'module'):
145
+ model = model.module
146
+ # determine whether 'cfgs' has several dicts for optimizers
147
+ if all(isinstance(v, dict) for v in cfgs.values()):
148
+ for key, cfg in cfgs.items():
149
+ cfg_ = cfg.copy()
150
+ module = getattr(model, key)
151
+ optimizers[key] = build_optimizer(module, cfg_)
152
+ return optimizers
153
+
154
+ return build_optimizer(model, cfgs)
depth/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=1.6.0
2
+ h5py>=3.6.0
3
+ scipy>=1.7.3
4
+ opencv-python>=4.5.5
5
+ timm>=0.5.4
6
+ albumentations>=1.1.0
7
+ tensorboardX>=2.4.1
8
+ gdown>=4.2.1
depth/test_img.jpg ADDED
depth/utils.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import os
10
+ import math
11
+ import time
12
+ from collections import defaultdict, deque
13
+ import datetime
14
+ import numpy as np
15
+ from timm.utils import get_state_dict
16
+
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ from torch._six import inf
22
+
23
+ from tensorboardX import SummaryWriter
24
+
25
+ class SmoothedValue(object):
26
+ """Track a series of values and provide access to smoothed values over a
27
+ window or the global series average.
28
+ """
29
+
30
+ def __init__(self, window_size=20, fmt=None):
31
+ if fmt is None:
32
+ fmt = "{median:.4f} ({global_avg:.4f})"
33
+ self.deque = deque(maxlen=window_size)
34
+ self.total = 0.0
35
+ self.count = 0
36
+ self.fmt = fmt
37
+
38
+ def update(self, value, n=1):
39
+ self.deque.append(value)
40
+ self.count += n
41
+ self.total += value * n
42
+
43
+ def synchronize_between_processes(self):
44
+ """
45
+ Warning: does not synchronize the deque!
46
+ """
47
+ if not is_dist_avail_and_initialized():
48
+ return
49
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
50
+ dist.barrier()
51
+ dist.all_reduce(t)
52
+ t = t.tolist()
53
+ self.count = int(t[0])
54
+ self.total = t[1]
55
+
56
+ @property
57
+ def median(self):
58
+ d = torch.tensor(list(self.deque))
59
+ return d.median().item()
60
+
61
+ @property
62
+ def avg(self):
63
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
64
+ return d.mean().item()
65
+
66
+ @property
67
+ def global_avg(self):
68
+ return self.total / self.count
69
+
70
+ @property
71
+ def max(self):
72
+ return max(self.deque)
73
+
74
+ @property
75
+ def value(self):
76
+ return self.deque[-1]
77
+
78
+ def __str__(self):
79
+ return self.fmt.format(
80
+ median=self.median,
81
+ avg=self.avg,
82
+ global_avg=self.global_avg,
83
+ max=self.max,
84
+ value=self.value)
85
+
86
+
87
+ class MetricLogger(object):
88
+ def __init__(self, delimiter="\t"):
89
+ self.meters = defaultdict(SmoothedValue)
90
+ self.delimiter = delimiter
91
+
92
+ def update(self, **kwargs):
93
+ for k, v in kwargs.items():
94
+ if v is None:
95
+ continue
96
+ if isinstance(v, torch.Tensor):
97
+ v = v.item()
98
+ assert isinstance(v, (float, int))
99
+ self.meters[k].update(v)
100
+
101
+ def __getattr__(self, attr):
102
+ if attr in self.meters:
103
+ return self.meters[attr]
104
+ if attr in self.__dict__:
105
+ return self.__dict__[attr]
106
+ raise AttributeError("'{}' object has no attribute '{}'".format(
107
+ type(self).__name__, attr))
108
+
109
+ def __str__(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append(
113
+ "{}: {}".format(name, str(meter))
114
+ )
115
+ return self.delimiter.join(loss_str)
116
+
117
+ def synchronize_between_processes(self):
118
+ for meter in self.meters.values():
119
+ meter.synchronize_between_processes()
120
+
121
+ def add_meter(self, name, meter):
122
+ self.meters[name] = meter
123
+
124
+ def log_every(self, iterable, print_freq, header=None):
125
+ i = 0
126
+ if not header:
127
+ header = ''
128
+ start_time = time.time()
129
+ end = time.time()
130
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
131
+ data_time = SmoothedValue(fmt='{avg:.4f}')
132
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
133
+ log_msg = [
134
+ header,
135
+ '[{0' + space_fmt + '}/{1}]',
136
+ 'eta: {eta}',
137
+ '{meters}',
138
+ 'time: {time}',
139
+ 'data: {data}'
140
+ ]
141
+ if torch.cuda.is_available():
142
+ log_msg.append('max mem: {memory:.0f}')
143
+ log_msg = self.delimiter.join(log_msg)
144
+ MB = 1024.0 * 1024.0
145
+ for obj in iterable:
146
+ data_time.update(time.time() - end)
147
+ yield obj
148
+ iter_time.update(time.time() - end)
149
+ if i % print_freq == 0 or i == len(iterable) - 1:
150
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
151
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
152
+ if torch.cuda.is_available():
153
+ print(log_msg.format(
154
+ i, len(iterable), eta=eta_string,
155
+ meters=str(self),
156
+ time=str(iter_time), data=str(data_time),
157
+ memory=torch.cuda.max_memory_allocated() / MB))
158
+ else:
159
+ print(log_msg.format(
160
+ i, len(iterable), eta=eta_string,
161
+ meters=str(self),
162
+ time=str(iter_time), data=str(data_time)))
163
+ i += 1
164
+ end = time.time()
165
+ total_time = time.time() - start_time
166
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
167
+ print('{} Total time: {} ({:.4f} s / it)'.format(
168
+ header, total_time_str, total_time / len(iterable)))
169
+
170
+
171
+ class TensorboardLogger(object):
172
+ def __init__(self, log_dir):
173
+ self.writer = SummaryWriter(logdir=log_dir)
174
+ self.step = 0
175
+
176
+ def set_step(self, step=None):
177
+ if step is not None:
178
+ self.step = step
179
+ else:
180
+ self.step += 1
181
+
182
+ def update(self, head='scalar', step=None, **kwargs):
183
+ for k, v in kwargs.items():
184
+ if v is None:
185
+ continue
186
+ if isinstance(v, torch.Tensor):
187
+ v = v.item()
188
+ assert isinstance(v, (float, int))
189
+ self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
190
+
191
+ def flush(self):
192
+ self.writer.flush()
193
+
194
+
195
+ class WandbLogger(object):
196
+ def __init__(self, args):
197
+ self.args = args
198
+
199
+ try:
200
+ import wandb
201
+ self._wandb = wandb
202
+ except ImportError:
203
+ raise ImportError(
204
+ "To use the Weights and Biases Logger please install wandb."
205
+ "Run `pip install wandb` to install it."
206
+ )
207
+
208
+ # Initialize a W&B run
209
+ if self._wandb.run is None:
210
+ self._wandb.init(
211
+ project=args.project,
212
+ config=args
213
+ )
214
+
215
+ def log_epoch_metrics(self, metrics, commit=True):
216
+ """
217
+ Log train/test metrics onto W&B.
218
+ """
219
+ # Log number of model parameters as W&B summary
220
+ self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
221
+ metrics.pop('n_parameters', None)
222
+
223
+ # Log current epoch
224
+ self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
225
+ metrics.pop('epoch')
226
+
227
+ for k, v in metrics.items():
228
+ if 'train' in k:
229
+ self._wandb.log({f'Global Train/{k}': v}, commit=False)
230
+ elif 'test' in k:
231
+ self._wandb.log({f'Global Test/{k}': v}, commit=False)
232
+
233
+ self._wandb.log({})
234
+
235
+ def log_checkpoints(self):
236
+ output_dir = self.args.output_dir
237
+ model_artifact = self._wandb.Artifact(
238
+ self._wandb.run.id + "_model", type="model"
239
+ )
240
+
241
+ model_artifact.add_dir(output_dir)
242
+ self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
243
+
244
+ def set_steps(self):
245
+ # Set global training step
246
+ self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
247
+ # Set epoch-wise step
248
+ self._wandb.define_metric('Global Train/*', step_metric='epoch')
249
+ self._wandb.define_metric('Global Test/*', step_metric='epoch')
250
+
251
+
252
+ def setup_for_distributed(is_master):
253
+ """
254
+ This function disables printing when not in master process
255
+ """
256
+ import builtins as __builtin__
257
+ builtin_print = __builtin__.print
258
+
259
+ def print(*args, **kwargs):
260
+ force = kwargs.pop('force', False)
261
+ if is_master or force:
262
+ builtin_print(*args, **kwargs)
263
+
264
+ __builtin__.print = print
265
+
266
+
267
+ def is_dist_avail_and_initialized():
268
+ if not dist.is_available():
269
+ return False
270
+ if not dist.is_initialized():
271
+ return False
272
+ return True
273
+
274
+
275
+ def get_world_size():
276
+ if not is_dist_avail_and_initialized():
277
+ return 1
278
+ return dist.get_world_size()
279
+
280
+
281
+ def get_rank():
282
+ if not is_dist_avail_and_initialized():
283
+ return 0
284
+ return dist.get_rank()
285
+
286
+
287
+ def is_main_process():
288
+ return get_rank() == 0
289
+
290
+
291
+ def save_on_master(*args, **kwargs):
292
+ if is_main_process():
293
+ torch.save(*args, **kwargs)
294
+
295
+
296
+ def init_distributed_mode(args):
297
+
298
+ if args.dist_on_itp:
299
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
300
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
301
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
302
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
303
+ os.environ['LOCAL_RANK'] = str(args.gpu)
304
+ os.environ['RANK'] = str(args.rank)
305
+ os.environ['WORLD_SIZE'] = str(args.world_size)
306
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
307
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
308
+ args.rank = int(os.environ["RANK"])
309
+ args.world_size = int(os.environ['WORLD_SIZE'])
310
+ args.gpu = int(os.environ['LOCAL_RANK'])
311
+ elif 'SLURM_PROCID' in os.environ:
312
+ args.rank = int(os.environ['SLURM_PROCID'])
313
+ args.gpu = args.rank % torch.cuda.device_count()
314
+
315
+ os.environ['RANK'] = str(args.rank)
316
+ os.environ['LOCAL_RANK'] = str(args.gpu)
317
+ os.environ['WORLD_SIZE'] = str(args.world_size)
318
+ else:
319
+ print('Not using distributed mode')
320
+ args.distributed = False
321
+ return
322
+
323
+ args.distributed = True
324
+
325
+ torch.cuda.set_device(args.gpu)
326
+ args.dist_backend = 'nccl'
327
+ print('| distributed init (rank {}): {}, gpu {}'.format(
328
+ args.rank, args.dist_url, args.gpu), flush=True)
329
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
330
+ world_size=args.world_size, rank=args.rank)
331
+ torch.distributed.barrier()
332
+ setup_for_distributed(args.rank == 0)
333
+
334
+
335
+ def init_distributed_mode_simple(args):
336
+
337
+ args.rank = int(os.environ["RANK"])
338
+ args.world_size = int(os.environ['WORLD_SIZE'])
339
+ args.gpu = int(os.environ['LOCAL_RANK'])
340
+ args.dist_url = 'env://'
341
+
342
+ args.distributed = True
343
+
344
+ torch.cuda.set_device(args.gpu)
345
+ args.dist_backend = 'nccl'
346
+ print('| distributed init (rank {}): {}, gpu {}'.format(
347
+ args.rank, args.dist_url, args.gpu), flush=True)
348
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
349
+ world_size=args.world_size, rank=args.rank)
350
+ torch.distributed.barrier()
351
+ setup_for_distributed(args.rank == 0)
352
+
353
+ def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
354
+ missing_keys = []
355
+ unexpected_keys = []
356
+ error_msgs = []
357
+ # copy state_dict so _load_from_state_dict can modify it
358
+ metadata = getattr(state_dict, '_metadata', None)
359
+ state_dict = state_dict.copy()
360
+ if metadata is not None:
361
+ state_dict._metadata = metadata
362
+
363
+ def load(module, prefix=''):
364
+ local_metadata = {} if metadata is None else metadata.get(
365
+ prefix[:-1], {})
366
+ module._load_from_state_dict(
367
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
368
+ for name, child in module._modules.items():
369
+ if child is not None:
370
+ load(child, prefix + name + '.')
371
+
372
+ load(model, prefix=prefix)
373
+
374
+ warn_missing_keys = []
375
+ ignore_missing_keys = []
376
+ for key in missing_keys:
377
+ keep_flag = True
378
+ for ignore_key in ignore_missing.split('|'):
379
+ if ignore_key in key:
380
+ keep_flag = False
381
+ break
382
+ if keep_flag:
383
+ warn_missing_keys.append(key)
384
+ else:
385
+ ignore_missing_keys.append(key)
386
+
387
+ missing_keys = warn_missing_keys
388
+
389
+ if len(missing_keys) > 0:
390
+ print("Weights of {} not initialized from pretrained model: {}".format(
391
+ model.__class__.__name__, missing_keys))
392
+ if len(unexpected_keys) > 0:
393
+ print("Weights from pretrained model not used in {}: {}".format(
394
+ model.__class__.__name__, unexpected_keys))
395
+ if len(ignore_missing_keys) > 0:
396
+ print("Ignored weights of {} not initialized from pretrained model: {}".format(
397
+ model.__class__.__name__, ignore_missing_keys))
398
+ if len(error_msgs) > 0:
399
+ print('\n'.join(error_msgs))
400
+
401
+
402
+ class NativeScalerWithGradNormCount:
403
+ state_dict_key = "amp_scaler"
404
+
405
+ def __init__(self):
406
+ self._scaler = torch.cuda.amp.GradScaler()
407
+
408
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
409
+ self._scaler.scale(loss).backward(create_graph=create_graph)
410
+ if update_grad:
411
+ if clip_grad is not None:
412
+ assert parameters is not None
413
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
414
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
415
+ else:
416
+ self._scaler.unscale_(optimizer)
417
+ norm = get_grad_norm_(parameters)
418
+ self._scaler.step(optimizer)
419
+ self._scaler.update()
420
+ else:
421
+ norm = None
422
+ return norm
423
+
424
+ def state_dict(self):
425
+ return self._scaler.state_dict()
426
+
427
+ def load_state_dict(self, state_dict):
428
+ self._scaler.load_state_dict(state_dict)
429
+
430
+
431
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
432
+ if isinstance(parameters, torch.Tensor):
433
+ parameters = [parameters]
434
+ parameters = [p for p in parameters if p.grad is not None]
435
+ norm_type = float(norm_type)
436
+ if len(parameters) == 0:
437
+ return torch.tensor(0.)
438
+ device = parameters[0].grad.device
439
+ if norm_type == inf:
440
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
441
+ else:
442
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
443
+ return total_norm
444
+
445
+
446
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
447
+ start_warmup_value=0, warmup_steps=-1):
448
+ warmup_schedule = np.array([])
449
+ warmup_iters = warmup_epochs * niter_per_ep
450
+ if warmup_steps > 0:
451
+ warmup_iters = warmup_steps
452
+ print("Set warmup steps = %d" % warmup_iters)
453
+ if warmup_epochs > 0:
454
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
455
+
456
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
457
+ schedule = np.array(
458
+ [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
459
+
460
+ schedule = np.concatenate((warmup_schedule, schedule))
461
+
462
+ assert len(schedule) == epochs * niter_per_ep
463
+ return schedule
464
+
465
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
466
+ output_dir = Path(args.output_dir)
467
+ epoch_name = str(epoch)
468
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
469
+ for checkpoint_path in checkpoint_paths:
470
+ to_save = {
471
+ 'model': model_without_ddp.state_dict(),
472
+ 'optimizer': optimizer.state_dict(),
473
+ 'epoch': epoch,
474
+ 'scaler': loss_scaler.state_dict(),
475
+ 'args': args,
476
+ }
477
+
478
+ if model_ema is not None:
479
+ to_save['model_ema'] = get_state_dict(model_ema)
480
+
481
+ save_on_master(to_save, checkpoint_path)
482
+
483
+ if is_main_process() and isinstance(epoch, int):
484
+ to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
485
+ old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
486
+ if os.path.exists(old_ckpt):
487
+ os.remove(old_ckpt)
488
+
489
+
490
+ def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
491
+ output_dir = Path(args.output_dir)
492
+ if args.auto_resume and len(args.resume) == 0:
493
+ import glob
494
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
495
+ latest_ckpt = -1
496
+ for ckpt in all_checkpoints:
497
+ t = ckpt.split('-')[-1].split('.')[0]
498
+ if t.isdigit():
499
+ latest_ckpt = max(int(t), latest_ckpt)
500
+ if latest_ckpt >= 0:
501
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
502
+ print("Auto resume checkpoint: %s" % args.resume)
503
+
504
+ if args.resume:
505
+ if args.resume.startswith('https'):
506
+ checkpoint = torch.hub.load_state_dict_from_url(
507
+ args.resume, map_location='cpu', check_hash=True)
508
+ else:
509
+ checkpoint = torch.load(args.resume, map_location='cpu')
510
+ model_without_ddp.load_state_dict(checkpoint['model'])
511
+ print("Resume checkpoint %s" % args.resume)
512
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
513
+ optimizer.load_state_dict(checkpoint['optimizer'])
514
+ if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
515
+ args.start_epoch = checkpoint['epoch'] + 1
516
+ else:
517
+ assert args.eval, 'Does not support resuming with checkpoint-best'
518
+ if hasattr(args, 'model_ema') and args.model_ema:
519
+ if 'model_ema' in checkpoint.keys():
520
+ model_ema.ema.load_state_dict(checkpoint['model_ema'])
521
+ else:
522
+ model_ema.ema.load_state_dict(checkpoint['model'])
523
+ if 'scaler' in checkpoint:
524
+ loss_scaler.load_state_dict(checkpoint['scaler'])
525
+ print("With optim & sched!")
depth/utils_depth/criterion.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
3
+ # For non-commercial purpose only (research, evaluation etc).
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class SiLogLoss(nn.Module):
11
+ def __init__(self, lambd=0.5):
12
+ super().__init__()
13
+ self.lambd = lambd
14
+
15
+ def forward(self, pred, target):
16
+ valid_mask = (target > 0).detach()
17
+ diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
18
+ loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
19
+ self.lambd * torch.pow(diff_log.mean(), 2))
20
+
21
+ return loss
22
+
depth/utils_depth/logging.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
3
+ # For non-commercial purpose only (research, evaluation etc).
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import os
7
+ import cv2
8
+ import sys
9
+ import time
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+
15
+ TOTAL_BAR_LENGTH = 30.
16
+ last_time = time.time()
17
+ begin_time = last_time
18
+
19
+
20
+ def progress_bar(current, total, epochs, cur_epoch, msg=None):
21
+ _, term_width = os.popen('stty size', 'r').read().split()
22
+ term_width = int(term_width)
23
+ global last_time, begin_time
24
+ if current == 0:
25
+ begin_time = time.time() # Reset for new bar.
26
+
27
+ cur_len = int(TOTAL_BAR_LENGTH * current / total)
28
+ rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
29
+
30
+ sys.stdout.write(' [')
31
+ for i in range(cur_len):
32
+ sys.stdout.write('=')
33
+ sys.stdout.write('>')
34
+ for i in range(rest_len):
35
+ sys.stdout.write('.')
36
+ sys.stdout.write(']')
37
+
38
+ cur_time = time.time()
39
+ step_time = cur_time - last_time
40
+ last_time = cur_time
41
+ tot_time = cur_time - begin_time
42
+ remain_time = step_time * (total - current) + \
43
+ (epochs - cur_epoch) * step_time * total
44
+
45
+ L = []
46
+ L.append(' Step: %s' % format_time(step_time))
47
+ L.append(' | Tot: %s' % format_time(tot_time))
48
+ L.append(' | Rem: %s' % format_time(remain_time))
49
+ if msg:
50
+ L.append(' | ' + msg)
51
+
52
+ msg = ''.join(L)
53
+ sys.stdout.write(msg)
54
+ for i in range(157 - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
55
+ sys.stdout.write(' ')
56
+
57
+ # Go back to the center of the bar.
58
+ for i in range(157 - int(TOTAL_BAR_LENGTH / 2) + 2):
59
+ sys.stdout.write('\b')
60
+ sys.stdout.write(' %d/%d ' % (current + 1, total))
61
+
62
+ if current < total - 1:
63
+ sys.stdout.write('\r')
64
+ else:
65
+ sys.stdout.write('\n')
66
+ sys.stdout.flush()
67
+
68
+
69
+ class AverageMeter():
70
+ """Computes and stores the average and current value"""
71
+
72
+ def __init__(self):
73
+ self.reset()
74
+
75
+ def reset(self):
76
+ self.val = 0
77
+ self.avg = 0
78
+ self.sum = 0
79
+ self.count = 0
80
+
81
+ def update(self, val, n=1):
82
+ self.val = val
83
+ self.sum += val * n
84
+ self.count += n
85
+ self.avg = self.sum / self.count
86
+
87
+
88
+ def format_time(seconds):
89
+ days = int(seconds / 3600 / 24)
90
+ seconds = seconds - days * 3600 * 24
91
+ hours = int(seconds / 3600)
92
+ seconds = seconds - hours * 3600
93
+ minutes = int(seconds / 60)
94
+ seconds = seconds - minutes * 60
95
+ secondsf = int(seconds)
96
+ seconds = seconds - secondsf
97
+ millis = int(seconds * 1000)
98
+
99
+ f = ''
100
+ i = 1
101
+ if days > 0:
102
+ f += str(days) + 'D'
103
+ i += 1
104
+ if hours > 0 and i <= 2:
105
+ f += str(hours) + 'h'
106
+ i += 1
107
+ if minutes > 0 and i <= 2:
108
+ f += str(minutes).zfill(2) + 'm'
109
+ i += 1
110
+ if secondsf > 0 and i <= 2:
111
+ f += str(secondsf).zfill(2) + 's'
112
+ i += 1
113
+ if millis > 0 and i <= 2:
114
+ f += str(millis).zfill(3) + 'ms'
115
+ i += 1
116
+ if f == '':
117
+ f = '0ms'
118
+ return f
119
+
120
+
121
+ def display_result(result_dict):
122
+ line = "\n"
123
+ line += "=" * 100 + '\n'
124
+ for metric, value in result_dict.items():
125
+ line += "{:>10} ".format(metric)
126
+ line += "\n"
127
+ for metric, value in result_dict.items():
128
+ line += "{:10.4f} ".format(value)
129
+ line += "\n"
130
+ line += "=" * 100 + '\n'
131
+
132
+ return line
133
+
134
+
135
+ def save_images(pred, save_path):
136
+ if len(pred.shape) > 3:
137
+ pred = pred.squeeze()
138
+
139
+ if isinstance(pred, torch.Tensor):
140
+ pred = pred.cpu().numpy().astype(np.uint8)
141
+
142
+ if pred.shape[0] < 4:
143
+ pred = np.transpose(pred, (1, 2, 0))
144
+ cv2.imwrite(save_path, pred, [cv2.IMWRITE_PNG_COMPRESSION, 0])
145
+
146
+
147
+ def check_and_make_dirs(paths):
148
+ if not isinstance(paths, list):
149
+ paths = [paths]
150
+ for path in paths:
151
+ if not os.path.exists(path):
152
+ os.makedirs(path)
153
+
154
+ def log_args_to_txt(log_txt, args):
155
+ if not os.path.exists(log_txt):
156
+ with open(log_txt, 'w') as txtfile:
157
+ args_ = vars(args)
158
+ args_str = ''
159
+ for k, v in args_.items():
160
+ args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
161
+ txtfile.write(args_str + '\n')
depth/utils_depth/metrics.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
3
+ # For non-commercial purpose only (research, evaluation etc).
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import torch
7
+
8
+
9
+ def eval_depth(pred, target):
10
+ assert pred.shape == target.shape
11
+
12
+ thresh = torch.max((target / pred), (pred / target))
13
+
14
+ d1 = torch.sum(thresh < 1.25).float() / len(thresh)
15
+ d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh)
16
+ d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh)
17
+
18
+ diff = pred - target
19
+ diff_log = torch.log(pred) - torch.log(target)
20
+
21
+ abs_rel = torch.mean(torch.abs(diff) / target)
22
+ sq_rel = torch.mean(torch.pow(diff, 2) / target)
23
+
24
+ rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
25
+
26
+ rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2)))
27
+
28
+ log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target)))
29
+ silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2))
30
+
31
+ return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(),
32
+ 'sq_rel': sq_rel.item(), 'rmse': rmse.item(), 'rmse_log': rmse_log.item(),
33
+ 'log10':log10.item(), 'silog':silog.item()}
34
+
35
+
36
+ def cropping_img(args, pred, gt_depth):
37
+ min_depth_eval = args.min_depth_eval
38
+
39
+ max_depth_eval = args.max_depth_eval
40
+
41
+ pred[torch.isinf(pred)] = max_depth_eval
42
+ pred[torch.isnan(pred)] = min_depth_eval
43
+
44
+ valid_mask = torch.logical_and(
45
+ gt_depth > min_depth_eval, gt_depth < max_depth_eval)
46
+
47
+ if args.dataset == 'kitti':
48
+ if args.do_kb_crop:
49
+ height, width = gt_depth.shape
50
+ top_margin = int(height - 352)
51
+ left_margin = int((width - 1216) / 2)
52
+ gt_depth = gt_depth[top_margin:top_margin +
53
+ 352, left_margin:left_margin + 1216]
54
+
55
+ if args.kitti_crop:
56
+ gt_height, gt_width = gt_depth.shape
57
+ eval_mask = torch.zeros(valid_mask.shape).to(
58
+ device=valid_mask.device)
59
+
60
+ if args.kitti_crop == 'garg_crop':
61
+ eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height),
62
+ int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
63
+
64
+ elif args.kitti_crop == 'eigen_crop':
65
+ eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height),
66
+ int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
67
+ else:
68
+ eval_mask = valid_mask
69
+
70
+ elif args.dataset == 'nyudepthv2':
71
+ eval_mask = torch.zeros(valid_mask.shape).to(device=valid_mask.device)
72
+ eval_mask[45:471, 41:601] = 1
73
+ else:
74
+ eval_mask = valid_mask
75
+
76
+ valid_mask = torch.logical_and(valid_mask, eval_mask)
77
+
78
+ return pred[valid_mask], gt_depth[valid_mask]
79
+
depth/utils_depth/misc.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # The code is from ZoeDepth (https://github.com/isl-org/ZoeDepth).
3
+ # For non-commercial purpose only (research, evaluation etc).
4
+ # ------------------------------------------------------------------------------
5
+ from scipy import ndimage
6
+
7
+ import math
8
+
9
+ import matplotlib
10
+ import matplotlib.cm
11
+ import numpy as np
12
+ import requests
13
+ import torch
14
+ from PIL import Image
15
+ from torchvision.transforms import ToTensor
16
+
17
+
18
+ def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
19
+ """Converts a depth map to a color image.
20
+
21
+ Args:
22
+ value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
23
+ vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
24
+ vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
25
+ cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
26
+ invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
27
+ invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
28
+ background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
29
+ gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
30
+ value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
31
+
32
+ Returns:
33
+ numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
34
+ """
35
+ if isinstance(value, torch.Tensor):
36
+ value = value.detach().cpu().numpy()
37
+
38
+ value = value.squeeze()
39
+ if invalid_mask is None:
40
+ invalid_mask = value == invalid_val
41
+ mask = np.logical_not(invalid_mask)
42
+
43
+ # normalize
44
+ vmin = np.percentile(value[mask],2) if vmin is None else vmin
45
+ vmax = np.percentile(value[mask],85) if vmax is None else vmax
46
+ if vmin != vmax:
47
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
48
+ else:
49
+ # Avoid 0-division
50
+ value = value * 0.
51
+
52
+ # squeeze last dim if it exists
53
+ # grey out the invalid values
54
+
55
+ value[invalid_mask] = np.nan
56
+ cmapper = matplotlib.colormaps.get_cmap(cmap)
57
+ if value_transform:
58
+ value = value_transform(value)
59
+ # value = value / value.max()
60
+ value = cmapper(value, bytes=True) # (nxmx4)
61
+
62
+ # img = value[:, :, :]
63
+ img = value[...]
64
+ img[invalid_mask] = background_color
65
+
66
+ # return img.transpose((2, 0, 1))
67
+ if gamma_corrected:
68
+ # gamma correction
69
+ img = img / 255
70
+ img = np.power(img, 2.2)
71
+ img = img * 255
72
+ img = img.astype(np.uint8)
73
+ return img, vmin, vmax
depth/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
evp/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .models import UNetWrapper, TextAdapter
evp/models.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+
3
+ import torch as th
4
+ import torch
5
+ import math
6
+ import abc
7
+
8
+ from torch import nn, einsum
9
+
10
+ from einops import rearrange, repeat
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+ from transformers import CLIPTokenizer
13
+ from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPTextTransformer#, _expand_mask
14
+ from inspect import isfunction
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def uniq(arr):
22
+ return{el: True for el in arr}.keys()
23
+
24
+
25
+ def default(val, d):
26
+ if exists(val):
27
+ return val
28
+ return d() if isfunction(d) else d
29
+
30
+
31
+
32
+ def register_attention_control(model, controller):
33
+ def ca_forward(self, place_in_unet):
34
+ def forward(x, context=None, mask=None):
35
+ h = self.heads
36
+
37
+ q = self.to_q(x)
38
+ is_cross = context is not None
39
+ context = default(context, x)
40
+ k = self.to_k(context)
41
+ v = self.to_v(context)
42
+
43
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
44
+
45
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
46
+
47
+ if exists(mask):
48
+ mask = rearrange(mask, 'b ... -> b (...)')
49
+ max_neg_value = -torch.finfo(sim.dtype).max
50
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
51
+ sim.masked_fill_(~mask, max_neg_value)
52
+
53
+ # attention, what we cannot get enough of
54
+ attn = sim.softmax(dim=-1)
55
+
56
+ attn2 = rearrange(attn, '(b h) k c -> h b k c', h=h).mean(0)
57
+ controller(attn2, is_cross, place_in_unet)
58
+
59
+ out = einsum('b i j, b j d -> b i d', attn, v)
60
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
61
+ return self.to_out(out)
62
+
63
+ return forward
64
+
65
+ class DummyController:
66
+ def __call__(self, *args):
67
+ return args[0]
68
+
69
+ def __init__(self):
70
+ self.num_att_layers = 0
71
+
72
+ if controller is None:
73
+ controller = DummyController()
74
+
75
+ def register_recr(net_, count, place_in_unet):
76
+ if net_.__class__.__name__ == 'CrossAttention':
77
+ net_.forward = ca_forward(net_, place_in_unet)
78
+ return count + 1
79
+ elif hasattr(net_, 'children'):
80
+ for net__ in net_.children():
81
+ count = register_recr(net__, count, place_in_unet)
82
+ return count
83
+
84
+ cross_att_count = 0
85
+ sub_nets = model.diffusion_model.named_children()
86
+
87
+ for net in sub_nets:
88
+ if "input_blocks" in net[0]:
89
+ cross_att_count += register_recr(net[1], 0, "down")
90
+ elif "output_blocks" in net[0]:
91
+ cross_att_count += register_recr(net[1], 0, "up")
92
+ elif "middle_block" in net[0]:
93
+ cross_att_count += register_recr(net[1], 0, "mid")
94
+
95
+ controller.num_att_layers = cross_att_count
96
+
97
+
98
+ class AttentionControl(abc.ABC):
99
+
100
+ def step_callback(self, x_t):
101
+ return x_t
102
+
103
+ def between_steps(self):
104
+ return
105
+
106
+ @property
107
+ def num_uncond_att_layers(self):
108
+ return 0
109
+
110
+ @abc.abstractmethod
111
+ def forward (self, attn, is_cross: bool, place_in_unet: str):
112
+ raise NotImplementedError
113
+
114
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
115
+ attn = self.forward(attn, is_cross, place_in_unet)
116
+ return attn
117
+
118
+ def reset(self):
119
+ self.cur_step = 0
120
+ self.cur_att_layer = 0
121
+
122
+ def __init__(self):
123
+ self.cur_step = 0
124
+ self.num_att_layers = -1
125
+ self.cur_att_layer = 0
126
+
127
+
128
+ class AttentionStore(AttentionControl):
129
+ @staticmethod
130
+ def get_empty_store():
131
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
132
+ "down_self": [], "mid_self": [], "up_self": []}
133
+
134
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
135
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
136
+ if attn.shape[1] <= (self.max_size) ** 2: # avoid memory overhead
137
+ self.step_store[key].append(attn)
138
+ return attn
139
+
140
+ def between_steps(self):
141
+ if len(self.attention_store) == 0:
142
+ self.attention_store = self.step_store
143
+ else:
144
+ for key in self.attention_store:
145
+ for i in range(len(self.attention_store[key])):
146
+ self.attention_store[key][i] += self.step_store[key][i]
147
+ self.step_store = self.get_empty_store()
148
+
149
+ def get_average_attention(self):
150
+ average_attention = {key: [item for item in self.step_store[key]] for key in self.step_store}
151
+ return average_attention
152
+
153
+ def reset(self):
154
+ super(AttentionStore, self).reset()
155
+ self.step_store = self.get_empty_store()
156
+ self.attention_store = {}
157
+
158
+ def __init__(self, base_size=64, max_size=None):
159
+ super(AttentionStore, self).__init__()
160
+ self.step_store = self.get_empty_store()
161
+ self.attention_store = {}
162
+ self.base_size = base_size
163
+ if max_size is None:
164
+ self.max_size = self.base_size // 2
165
+ else:
166
+ self.max_size = max_size
167
+
168
+ def register_hier_output(model):
169
+ self = model.diffusion_model
170
+ from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
171
+ def forward(x, timesteps=None, context=None, y=None,**kwargs):
172
+ """
173
+ Apply the model to an input batch.
174
+ :param x: an [N x C x ...] Tensor of inputs.
175
+ :param timesteps: a 1-D batch of timesteps.
176
+ :param context: conditioning plugged in via crossattn
177
+ :param y: an [N] Tensor of labels, if class-conditional.
178
+ :return: an [N x C x ...] Tensor of outputs.
179
+ """
180
+ assert (y is not None) == (
181
+ self.num_classes is not None
182
+ ), "must specify y if and only if the model is class-conditional"
183
+ hs = []
184
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
185
+ emb = self.time_embed(t_emb)
186
+
187
+ if self.num_classes is not None:
188
+ assert y.shape == (x.shape[0],)
189
+ emb = emb + self.label_emb(y)
190
+
191
+ h = x.type(self.dtype)
192
+ for module in self.input_blocks:
193
+ # import pdb; pdb.set_trace()
194
+ if context.shape[1]==2:
195
+ h = module(h, emb, context[:,0,:].unsqueeze(1))
196
+ else:
197
+ h = module(h, emb, context)
198
+ hs.append(h)
199
+ if context.shape[1]==2:
200
+ h = self.middle_block(h, emb, context[:,0,:].unsqueeze(1))
201
+ else:
202
+ h = self.middle_block(h, emb, context)
203
+ out_list = []
204
+
205
+ for i_out, module in enumerate(self.output_blocks):
206
+ h = th.cat([h, hs.pop()], dim=1)
207
+ if context.shape[1]==2:
208
+ h = module(h, emb, context[:,1,:].unsqueeze(1))
209
+ else:
210
+ h = module(h, emb, context)
211
+ if i_out in [1, 4, 7]:
212
+ out_list.append(h)
213
+ h = h.type(x.dtype)
214
+
215
+ out_list.append(h)
216
+ return out_list
217
+
218
+ self.forward = forward
219
+
220
+ class UNetWrapper(nn.Module):
221
+ def __init__(self, unet, use_attn=True, base_size=512, max_attn_size=None, attn_selector='up_cross+down_cross') -> None:
222
+ super().__init__()
223
+ self.unet = unet
224
+ self.attention_store = AttentionStore(base_size=base_size // 8, max_size=max_attn_size)
225
+ self.size16 = base_size // 32
226
+ self.size32 = base_size // 16
227
+ self.size64 = base_size // 8
228
+ self.use_attn = use_attn
229
+ if self.use_attn:
230
+ register_attention_control(unet, self.attention_store)
231
+ register_hier_output(unet)
232
+ self.attn_selector = attn_selector.split('+')
233
+
234
+ def forward(self, *args, **kwargs):
235
+ if self.use_attn:
236
+ self.attention_store.reset()
237
+ out_list = self.unet(*args, **kwargs)
238
+ if self.use_attn:
239
+ avg_attn = self.attention_store.get_average_attention()
240
+ attn16, attn32, attn64 = self.process_attn(avg_attn)
241
+ out_list[1] = torch.cat([out_list[1], attn16], dim=1)
242
+ out_list[2] = torch.cat([out_list[2], attn32], dim=1)
243
+ if attn64 is not None:
244
+ out_list[3] = torch.cat([out_list[3], attn64], dim=1)
245
+ return out_list[::-1]
246
+
247
+ def process_attn(self, avg_attn):
248
+ attns = {self.size16: [], self.size32: [], self.size64: []}
249
+ for k in self.attn_selector:
250
+ for up_attn in avg_attn[k]:
251
+ size = int(math.sqrt(up_attn.shape[1]))
252
+ attns[size].append(rearrange(up_attn, 'b (h w) c -> b c h w', h=size))
253
+ attn16 = torch.stack(attns[self.size16]).mean(0)
254
+ attn32 = torch.stack(attns[self.size32]).mean(0)
255
+ if len(attns[self.size64]) > 0:
256
+ attn64 = torch.stack(attns[self.size64]).mean(0)
257
+ else:
258
+ attn64 = None
259
+ return attn16, attn32, attn64
260
+
261
+ class TextAdapter(nn.Module):
262
+ def __init__(self, text_dim=768, hidden_dim=None):
263
+ super().__init__()
264
+ if hidden_dim is None:
265
+ hidden_dim = text_dim
266
+ self.fc = nn.Sequential(
267
+ nn.Linear(text_dim, hidden_dim),
268
+ nn.GELU(),
269
+ nn.Linear(hidden_dim, text_dim)
270
+ )
271
+
272
+ def forward(self, latents, texts, gamma):
273
+ n_class, channel = texts.shape
274
+ bs = latents.shape[0]
275
+
276
+ texts_after = self.fc(texts)
277
+ texts = texts + gamma * texts_after
278
+ texts = repeat(texts, 'n c -> b n c', b=bs)
279
+ return texts
280
+
281
+ class TextAdapterRefer(nn.Module):
282
+ def __init__(self, text_dim=768):
283
+ super().__init__()
284
+
285
+ self.fc = nn.Sequential(
286
+ nn.Linear(text_dim, text_dim),
287
+ nn.GELU(),
288
+ nn.Linear(text_dim, text_dim)
289
+ )
290
+
291
+ def forward(self, latents, texts, gamma):
292
+ texts_after = self.fc(texts)
293
+ texts = texts + gamma * texts_after
294
+ return texts
295
+
296
+
297
+ class TextAdapterDepth(nn.Module):
298
+ def __init__(self, text_dim=768):
299
+ super().__init__()
300
+
301
+ self.fc = nn.Sequential(
302
+ nn.Linear(text_dim, text_dim),
303
+ nn.GELU(),
304
+ nn.Linear(text_dim, text_dim)
305
+ )
306
+
307
+ def forward(self, latents, texts, gamma):
308
+ # use the gamma to blend
309
+ n_sen, channel = texts.shape
310
+ bs = latents.shape[0]
311
+
312
+ texts_after = self.fc(texts)
313
+ texts = texts + gamma * texts_after
314
+ texts = repeat(texts, 'n c -> n b c', b=1)
315
+ return texts
316
+
317
+
318
+ class FrozenCLIPEmbedder(nn.Module):
319
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
320
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, pool=True):
321
+ super().__init__()
322
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
323
+ self.transformer = CLIPTextModel.from_pretrained(version)
324
+ self.device = device
325
+ self.max_length = max_length
326
+ self.freeze()
327
+
328
+ self.pool = pool
329
+
330
+ def freeze(self):
331
+ self.transformer = self.transformer.eval()
332
+ for param in self.parameters():
333
+ param.requires_grad = False
334
+
335
+ def forward(self, text):
336
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
337
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
338
+ tokens = batch_encoding["input_ids"].to(self.device)
339
+ outputs = self.transformer(input_ids=tokens)
340
+
341
+ if self.pool:
342
+ z = outputs.pooler_output
343
+ else:
344
+ z = outputs.last_hidden_state
345
+ return z
346
+
347
+ def encode(self, text):
348
+ return self(text)
349
+
refer/README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Referring Image Segmentation
2
+ ## Getting Started
3
+
4
+ 1. Install the required packages.
5
+
6
+ ```
7
+ pip install -r requirements.txt
8
+ ```
9
+
10
+ 2. Prepare RefCOCO datasets following [LAVT](https://github.com/yz93/LAVT-RIS).
11
+
12
+ * Download COCO 2014 Train Images [83K/13GB] from [COCO](https://cocodataset.org/#download), and extract `train2014.zip` to `./refer/data/images/mscoco/images`
13
+
14
+ * Follow the instructions in `./refer` to download and extract `refclef.zip, refcoco.zip, refcoco+.zip, refcocog.zip` to `./refer/data`
15
+
16
+ Your dataset directory should be:
17
+
18
+ ```
19
+ refer/
20
+ ├──data/
21
+ │ ├── images/mscoco/images/
22
+ │ ├── refclef
23
+ │ ├── refcoco
24
+ │ ├── refcoco+
25
+ │ ├── refcocog
26
+ ├──evaluation/
27
+ ├──...
28
+ ```
29
+
30
+ ## Results and Fine-tuned Models of EVP
31
+ EVP achieves 76.35 overall IoU and 77.61 mean IoU on the validation set of RefCOCO.
32
+
33
+ ## Training
34
+
35
+ We count the max length of referring sentences and set the token length of lenguage model accrodingly. The checkpoint of the best epoch would be saved at `./checkpoints/`.
36
+
37
+ * Train on RefCOCO
38
+
39
+ ```
40
+ bash train.sh refcoco /path/to/logdir <NUM_GPUS> --token_length 40
41
+ ```
42
+
43
+ * Train on RefCOCO+
44
+
45
+ ```
46
+ bash train.sh refcoco+ /path/to/logdir <NUM_GPUS> --token_length 40
47
+ ```
48
+
49
+ * Train on RefCOCOg
50
+
51
+ ```
52
+ bash train.sh refcocog /path/to/logdir <NUM_GPUS> --token_length 77 --splitBy umd
53
+ ```
54
+
55
+ ## Evaluation
56
+
57
+ * Evaluate on RefCOCO
58
+
59
+ ```
60
+ bash test.sh refcoco /path/to/evp_ris_refcoco.pth --token_length 40
61
+ ```
62
+
63
+ * Evaluate on RefCOCO+
64
+
65
+ ```
66
+ bash test.sh refcoco+ /path/to/evp_ris_refcoco+.pth --token_length 40
67
+ ```
68
+
69
+ * Evaluate on RefCOCOg
70
+
71
+ ```
72
+ bash test.sh refcocog /path/to/evp_ris_gref.pth --token_length 77 --splitBy umd
73
+ ```
74
+
75
+ ## Custom inference
76
+ ```
77
+ PYTHONPATH="../":$PYTHONPATH python inference.py --img_path test_img.jpg --resume refcoco.pth --token_length 40 --prompt 'green plant'
78
+ ```
refer/args.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_parser():
5
+ parser = argparse.ArgumentParser(description='EVP training and testing')
6
+ parser.add_argument('--amsgrad', action='store_true',
7
+ help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
8
+ parser.add_argument('-b', '--batch-size', default=8, type=int)
9
+ parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
10
+ parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
11
+ parser.add_argument('--ddp_trained_weights', action='store_true',
12
+ help='Only needs specified when testing,'
13
+ 'whether the weights to be loaded are from a DDP-trained model')
14
+ parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
15
+ parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
16
+ parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
17
+ parser.add_argument('--img_size', default=480, type=int, help='input image size')
18
+ parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel')
19
+ parser.add_argument("--local-rank", type=int, default=0, help='local rank for DistributedDataParallel')
20
+ parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
21
+ parser.add_argument('--model_id', default='evp', help='name to identify the model')
22
+ parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
23
+ parser.add_argument('--pin_mem', action='store_true',
24
+ help='If true, pin memory when using the data loader.')
25
+ parser.add_argument('--pretrained_swin_weights', default='',
26
+ help='path to pre-trained Swin backbone weights')
27
+ parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
28
+ parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
29
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
30
+ parser.add_argument('--split', default='val')
31
+ parser.add_argument('--splitBy', default='unc')
32
+ parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
33
+ dest='weight_decay')
34
+ parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers')
35
+ parser.add_argument('--token_length', default=77, type=int)
36
+
37
+ return parser
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = get_parser()
42
+ args_dict = parser.parse_args()
refer/inference.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.backends.cudnn as cudnn
6
+ from models_refer.model import EVPRefer
7
+ from args import get_parser
8
+ import glob
9
+ import utils
10
+ import torchvision.transforms as transforms
11
+ from PIL import Image
12
+ import torch.nn.functional as F
13
+ from transformers import CLIPTokenizer
14
+
15
+
16
+ def main():
17
+ parser = get_parser()
18
+ parser.add_argument('--img_path', type=str)
19
+ parser.add_argument('--prompt', type=str)
20
+ args = parser.parse_args()
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
24
+ model = EVPRefer(sd_path='../checkpoints/v1-5-pruned-emaonly.ckpt')
25
+ cudnn.benchmark = True
26
+ model.to(device)
27
+ model_weight = torch.load(args.resume)['model']
28
+ if 'module' in next(iter(model_weight.items()))[0]:
29
+ model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
30
+ model.load_state_dict(model_weight, strict=False)
31
+ model.eval()
32
+
33
+ img_path = args.img_path
34
+
35
+ image = cv2.imread(img_path)
36
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
37
+ image_t = transforms.ToTensor()(image).unsqueeze(0).to(device)
38
+ image_t = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(image_t)
39
+ shape = image_t.shape
40
+ image_t = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
41
+ input_ids = tokenizer(text=args.prompt, truncation=True, max_length=args.token_length, return_length=True,
42
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")['input_ids'].to(device)
43
+
44
+ with torch.no_grad():
45
+ pred = model(image_t, input_ids)
46
+
47
+ pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
48
+ output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
49
+
50
+ alpha = 0.65
51
+ image[output_mask == 0] = (image[output_mask == 0]*alpha).astype(np.uint8)
52
+ contours, _ = cv2.findContours(output_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
53
+ cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
54
+
55
+ Image.fromarray(image.astype(np.uint8)).save('res.png')
56
+
57
+ return 0
58
+
59
+ if __name__ == '__main__':
60
+ main()
refer/models_refer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import EVPRefer
refer/models_refer/model.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import sys
6
+ from ldm.util import instantiate_from_config
7
+ from transformers.models.clip.modeling_clip import CLIPTextModel
8
+ from omegaconf import OmegaConf
9
+ from lib.mask_predictor import SimpleDecoding
10
+
11
+ from evp.models import UNetWrapper, TextAdapterRefer
12
+
13
+
14
+ def icnr(x, scale=2, init=nn.init.kaiming_normal_):
15
+ """
16
+ Checkerboard artifact free sub-pixel convolution
17
+ https://arxiv.org/abs/1707.02937
18
+ """
19
+ ni,nf,h,w = x.shape
20
+ ni2 = int(ni/(scale**2))
21
+ k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
22
+ k = k.contiguous().view(ni2, nf, -1)
23
+ k = k.repeat(1, 1, scale**2)
24
+ k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
25
+ x.data.copy_(k)
26
+
27
+
28
+ class PixelShuffle(nn.Module):
29
+ """
30
+ Real-Time Single Image and Video Super-Resolution
31
+ https://arxiv.org/abs/1609.05158
32
+ """
33
+ def __init__(self, n_channels, scale):
34
+ super(PixelShuffle, self).__init__()
35
+ self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
36
+ icnr(self.conv.weight)
37
+ self.shuf = nn.PixelShuffle(scale)
38
+ self.relu = nn.ReLU()
39
+
40
+ def forward(self,x):
41
+ x = self.shuf(self.relu(self.conv(x)))
42
+ return x
43
+
44
+
45
+ class AttentionModule(nn.Module):
46
+ def __init__(self, in_channels, out_channels):
47
+ super(AttentionModule, self).__init__()
48
+
49
+ # Convolutional Layers
50
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
51
+
52
+ # Group Normalization
53
+ self.group_norm = nn.GroupNorm(20, out_channels)
54
+
55
+ # ReLU Activation
56
+ self.relu = nn.ReLU()
57
+
58
+ # Spatial Attention
59
+ self.spatial_attention = nn.Sequential(
60
+ nn.Conv2d(in_channels, 1, kernel_size=1),
61
+ nn.Sigmoid()
62
+ )
63
+
64
+ def forward(self, x):
65
+ # Apply spatial attention
66
+ spatial_attention = self.spatial_attention(x)
67
+ x = x * spatial_attention
68
+
69
+ # Apply convolutional layer
70
+ x = self.conv1(x)
71
+ x = self.group_norm(x)
72
+ x = self.relu(x)
73
+
74
+ return x
75
+
76
+
77
+ class AttentionDownsamplingModule(nn.Module):
78
+ def __init__(self, in_channels, out_channels, scale_factor=2):
79
+ super(AttentionDownsamplingModule, self).__init__()
80
+
81
+ # Spatial Attention
82
+ self.spatial_attention = nn.Sequential(
83
+ nn.Conv2d(in_channels, 1, kernel_size=1),
84
+ nn.Sigmoid()
85
+ )
86
+
87
+ # Channel Attention
88
+ self.channel_attention = nn.Sequential(
89
+ nn.AdaptiveAvgPool2d(1),
90
+ nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
91
+ nn.ReLU(inplace=True),
92
+ nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
93
+ nn.Sigmoid()
94
+ )
95
+
96
+ # Convolutional Layers
97
+ if scale_factor == 2:
98
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
99
+ elif scale_factor == 4:
100
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
101
+
102
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
103
+
104
+ # Group Normalization
105
+ self.group_norm = nn.GroupNorm(20, out_channels)
106
+
107
+ # ReLU Activation
108
+ self.relu = nn.ReLU(inplace=True)
109
+
110
+ def forward(self, x):
111
+ # Apply spatial attention
112
+ spatial_attention = self.spatial_attention(x)
113
+ x = x * spatial_attention
114
+
115
+ # Apply channel attention
116
+ channel_attention = self.channel_attention(x)
117
+ x = x * channel_attention
118
+
119
+ # Apply convolutional layers
120
+ x = self.conv1(x)
121
+ x = self.group_norm(x)
122
+ x = self.relu(x)
123
+ x = self.conv2(x)
124
+ x = self.group_norm(x)
125
+ x = self.relu(x)
126
+
127
+ return x
128
+
129
+
130
+ class AttentionUpsamplingModule(nn.Module):
131
+ def __init__(self, in_channels, out_channels):
132
+ super(AttentionUpsamplingModule, self).__init__()
133
+
134
+ # Spatial Attention for outs[2]
135
+ self.spatial_attention = nn.Sequential(
136
+ nn.Conv2d(in_channels, 1, kernel_size=1),
137
+ nn.Sigmoid()
138
+ )
139
+
140
+ # Channel Attention for outs[2]
141
+ self.channel_attention = nn.Sequential(
142
+ nn.AdaptiveAvgPool2d(1),
143
+ nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
144
+ nn.ReLU(),
145
+ nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
146
+ nn.Sigmoid()
147
+ )
148
+
149
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
150
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
151
+
152
+ # Group Normalization
153
+ self.group_norm = nn.GroupNorm(20, out_channels)
154
+
155
+ # ReLU Activation
156
+ self.relu = nn.ReLU()
157
+ self.upscale = PixelShuffle(in_channels, 2)
158
+
159
+ def forward(self, x):
160
+ # Apply spatial attention
161
+ spatial_attention = self.spatial_attention(x)
162
+ x = x * spatial_attention
163
+
164
+ # Apply channel attention
165
+ channel_attention = self.channel_attention(x)
166
+ x = x * channel_attention
167
+
168
+ # Apply convolutional layers
169
+ x = self.conv1(x)
170
+ x = self.group_norm(x)
171
+ x = self.relu(x)
172
+ x = self.conv2(x)
173
+ x = self.group_norm(x)
174
+ x = self.relu(x)
175
+
176
+ # Upsample
177
+ x = self.upscale(x)
178
+
179
+ return x
180
+
181
+
182
+ class ConvLayer(nn.Module):
183
+ def __init__(self, in_channels, out_channels):
184
+ super(ConvLayer, self).__init__()
185
+
186
+ self.conv1 = nn.Sequential(
187
+ nn.Conv2d(in_channels, out_channels, 1),
188
+ nn.GroupNorm(20, out_channels),
189
+ nn.ReLU(),
190
+ )
191
+
192
+ def forward(self, x):
193
+ x = self.conv1(x)
194
+
195
+ return x
196
+
197
+
198
+ class InverseMultiAttentiveFeatureRefinement(nn.Module):
199
+ def __init__(self, in_channels_list):
200
+ super(InverseMultiAttentiveFeatureRefinement, self).__init__()
201
+
202
+ self.layer1 = AttentionModule(in_channels_list[0], in_channels_list[0])
203
+ self.layer2 = AttentionDownsamplingModule(in_channels_list[0], in_channels_list[0]//2, scale_factor = 2)
204
+ self.layer3 = ConvLayer(in_channels_list[0]//2 + in_channels_list[1], in_channels_list[1])
205
+ self.layer4 = AttentionDownsamplingModule(in_channels_list[1], in_channels_list[1]//2, scale_factor = 2)
206
+ self.layer5 = ConvLayer(in_channels_list[1]//2 + in_channels_list[2], in_channels_list[2])
207
+ self.layer6 = AttentionDownsamplingModule(in_channels_list[2], in_channels_list[2]//2, scale_factor = 2)
208
+ self.layer7 = ConvLayer(in_channels_list[2]//2 + in_channels_list[3], in_channels_list[3])
209
+
210
+ '''
211
+ self.layer8 = AttentionUpsamplingModule(in_channels_list[3], in_channels_list[3])
212
+ self.layer9 = ConvLayer(in_channels_list[2] + in_channels_list[3], in_channels_list[2])
213
+ self.layer10 = AttentionUpsamplingModule(in_channels_list[2], in_channels_list[2])
214
+ self.layer11 = ConvLayer(in_channels_list[1] + in_channels_list[2], in_channels_list[1])
215
+ self.layer12 = AttentionUpsamplingModule(in_channels_list[1], in_channels_list[1])
216
+ self.layer13 = ConvLayer(in_channels_list[0] + in_channels_list[1], in_channels_list[0])
217
+ '''
218
+ def forward(self, inputs):
219
+ x_c4, x_c3, x_c2, x_c1 = inputs
220
+ x_c4 = self.layer1(x_c4)
221
+ x_c4_3 = self.layer2(x_c4)
222
+ x_c3 = torch.cat([x_c4_3, x_c3], dim=1)
223
+ x_c3 = self.layer3(x_c3)
224
+ x_c3_2 = self.layer4(x_c3)
225
+ x_c2 = torch.cat([x_c3_2, x_c2], dim=1)
226
+ x_c2 = self.layer5(x_c2)
227
+ x_c2_1 = self.layer6(x_c2)
228
+ x_c1 = torch.cat([x_c2_1, x_c1], dim=1)
229
+ x_c1 = self.layer7(x_c1)
230
+ '''
231
+ x_c1_2 = self.layer8(x_c1)
232
+ x_c2 = torch.cat([x_c1_2, x_c2], dim=1)
233
+ x_c2 = self.layer9(x_c2)
234
+ x_c2_3 = self.layer10(x_c2)
235
+ x_c3 = torch.cat([x_c2_3, x_c3], dim=1)
236
+ x_c3 = self.layer11(x_c3)
237
+ x_c3_4 = self.layer12(x_c3)
238
+ x_c4 = torch.cat([x_c3_4, x_c4], dim=1)
239
+ x_c4 = self.layer13(x_c4)
240
+ '''
241
+ return [x_c4, x_c3, x_c2, x_c1]
242
+
243
+
244
+
245
+ class EVPRefer(nn.Module):
246
+ """Encoder Decoder segmentors.
247
+
248
+ EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
249
+ Note that auxiliary_head is only used for deep supervision during training,
250
+ which could be dumped during inference.
251
+ """
252
+
253
+ def __init__(self,
254
+ sd_path=None,
255
+ base_size=512,
256
+ token_embed_dim=768,
257
+ neck_dim=[320,680,1320,1280],
258
+ **args):
259
+ super().__init__()
260
+ config = OmegaConf.load('./v1-inference.yaml')
261
+ config.model.params.ckpt_path = f'{sd_path}'
262
+ sd_model = instantiate_from_config(config.model)
263
+ self.encoder_vq = sd_model.first_stage_model
264
+ self.unet = UNetWrapper(sd_model.model, base_size=base_size)
265
+ del sd_model.cond_stage_model
266
+ del self.encoder_vq.decoder
267
+ for param in self.encoder_vq.parameters():
268
+ param.requires_grad = True
269
+
270
+ self.text_adapter = TextAdapterRefer(text_dim=token_embed_dim)
271
+
272
+ self.classifier = SimpleDecoding(dims=neck_dim)
273
+
274
+ self.gamma = nn.Parameter(torch.ones(token_embed_dim) * 1e-4)
275
+ self.aggregation = InverseMultiAttentiveFeatureRefinement([320,680,1320,1280])
276
+ self.clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
277
+ for param in self.clip_model.parameters():
278
+ param.requires_grad = True
279
+
280
+
281
+ def forward(self, img, sentences):
282
+ input_shape = img.shape[-2:]
283
+
284
+ latents = self.encoder_vq.encode(img).mode()
285
+ latents = latents / 4.7164
286
+
287
+ l_feats = self.clip_model(input_ids=sentences).last_hidden_state
288
+ c_crossattn = self.text_adapter(latents, l_feats, self.gamma) # NOTE: here the c_crossattn should be expand_dim as latents
289
+ t = torch.ones((img.shape[0],), device=img.device).long()
290
+ outs = self.unet(latents, t, c_crossattn=[c_crossattn])
291
+
292
+ outs = self.aggregation(outs)
293
+
294
+ x_c1, x_c2, x_c3, x_c4 = outs
295
+ x = self.classifier(x_c4, x_c3, x_c2, x_c1)
296
+ x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
297
+
298
+ return x
299
+
300
+ def get_latent(self, x):
301
+ return self.encoder_vq.encode(x).mode()
refer/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ requests
2
+ filelock
3
+ tqdm
4
+ timm
5
+ ftfy
6
+ regex
7
+ scipy
8
+ scikit-image
9
+ pycocotools==2.0.2
10
+ opencv-python==4.5.3.56
11
+ tokenizers
12
+ h5py
refer/transforms.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import random
4
+
5
+ import torch
6
+ from torchvision import transforms as T
7
+ from torchvision.transforms import functional as F
8
+
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+
12
+ def pad_if_smaller(img, size, fill=0):
13
+ min_size = min(img.size)
14
+ if min_size < size:
15
+ ow, oh = img.size
16
+ padh = size - oh if oh < size else 0
17
+ padw = size - ow if ow < size else 0
18
+ img = F.pad(img, (0, 0, padw, padh), fill=fill)
19
+ return img
20
+
21
+
22
+ class Compose(object):
23
+ def __init__(self, transforms):
24
+ self.transforms = transforms
25
+
26
+ def __call__(self, image, target):
27
+ for t in self.transforms:
28
+ image, target = t(image, target)
29
+ return image, target
30
+
31
+
32
+ class Resize(object):
33
+ def __init__(self, h, w):
34
+ self.h = h
35
+ self.w = w
36
+
37
+ def __call__(self, image, target):
38
+ image = F.resize(image, (self.h, self.w))
39
+ # If size is a sequence like (h, w), the output size will be matched to this.
40
+ # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio
41
+ target = F.resize(target, (self.h, self.w))
42
+ return image, target
43
+
44
+
45
+ class RandomResize(object):
46
+ def __init__(self, min_size, max_size=None):
47
+ self.min_size = min_size
48
+ if max_size is None:
49
+ max_size = min_size
50
+ self.max_size = max_size
51
+
52
+ def __call__(self, image, target):
53
+ size = random.randint(self.min_size, self.max_size) # Return a random integer N such that a <= N <= b. Alias for randrange(a, b+1)
54
+ image = F.resize(image, size)
55
+ # If size is a sequence like (h, w), the output size will be matched to this.
56
+ # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio
57
+ target = F.resize(target, size)
58
+ return image, target
59
+
60
+
61
+ class RandomHorizontalFlip(object):
62
+ def __init__(self, flip_prob):
63
+ self.flip_prob = flip_prob
64
+
65
+ def __call__(self, image, target):
66
+ if random.random() < self.flip_prob:
67
+ image = F.hflip(image)
68
+ target = F.hflip(target)
69
+ return image, target
70
+
71
+
72
+ class RandomCrop(object):
73
+ def __init__(self, size):
74
+ self.size = size
75
+
76
+ def __call__(self, image, target):
77
+ image = pad_if_smaller(image, self.size)
78
+ target = pad_if_smaller(target, self.size, fill=255)
79
+ crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
80
+ image = F.crop(image, *crop_params)
81
+ target = F.crop(target, *crop_params)
82
+ return image, target
83
+
84
+
85
+ class CenterCrop(object):
86
+ def __init__(self, size):
87
+ self.size = size
88
+
89
+ def __call__(self, image, target):
90
+ image = F.center_crop(image, self.size)
91
+ target = F.center_crop(target, self.size)
92
+ return image, target
93
+
94
+
95
+ class ToTensor(object):
96
+ def __call__(self, image, target):
97
+ image = F.to_tensor(image)
98
+ target = torch.as_tensor(np.asarray(target).copy(), dtype=torch.int64)
99
+ return image, target
100
+
101
+
102
+ class RandomAffine(object):
103
+ def __init__(self, angle, translate, scale, shear, resample=0, fillcolor=None):
104
+ self.angle = angle
105
+ self.translate = translate
106
+ self.scale = scale
107
+ self.shear = shear
108
+ self.resample = resample
109
+ self.fillcolor = fillcolor
110
+
111
+ def __call__(self, image, target):
112
+ affine_params = T.RandomAffine.get_params(self.angle, self.translate, self.scale, self.shear, image.size)
113
+ image = F.affine(image, *affine_params)
114
+ target = F.affine(target, *affine_params)
115
+ return image, target
116
+
117
+
118
+ class Normalize(object):
119
+ def __init__(self, mean, std):
120
+ self.mean = mean
121
+ self.std = std
122
+
123
+ def __call__(self, image, target):
124
+ image = F.normalize(image, mean=self.mean, std=self.std)
125
+ return image, target
126
+
refer/utils.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ from collections import defaultdict, deque
3
+ import datetime
4
+ import math
5
+ import time
6
+ import torch
7
+ import torch.distributed as dist
8
+ import torch.backends.cudnn as cudnn
9
+
10
+ import errno
11
+ import os
12
+
13
+ import sys
14
+
15
+
16
+ class SmoothedValue(object):
17
+ """Track a series of values and provide access to smoothed values over a
18
+ window or the global series average.
19
+ """
20
+
21
+ def __init__(self, window_size=20, fmt=None):
22
+ if fmt is None:
23
+ fmt = "{median:.4f} ({global_avg:.4f})"
24
+ self.deque = deque(maxlen=window_size)
25
+ self.total = 0.0
26
+ self.count = 0
27
+ self.fmt = fmt
28
+
29
+ def update(self, value, n=1):
30
+ self.deque.append(value)
31
+ self.count += n
32
+ self.total += value * n
33
+
34
+ def synchronize_between_processes(self):
35
+ """
36
+ Warning: does not synchronize the deque!
37
+ """
38
+ if not is_dist_avail_and_initialized():
39
+ return
40
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
41
+ dist.barrier()
42
+ dist.all_reduce(t)
43
+ t = t.tolist()
44
+ self.count = int(t[0])
45
+ self.total = t[1]
46
+
47
+ @property
48
+ def median(self):
49
+ d = torch.tensor(list(self.deque))
50
+ return d.median().item()
51
+
52
+ @property
53
+ def avg(self):
54
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
55
+ return d.mean().item()
56
+
57
+ @property
58
+ def global_avg(self):
59
+ return self.total / self.count
60
+
61
+ @property
62
+ def max(self):
63
+ return max(self.deque)
64
+
65
+ @property
66
+ def value(self):
67
+ return self.deque[-1]
68
+
69
+ def __str__(self):
70
+ return self.fmt.format(
71
+ median=self.median,
72
+ avg=self.avg,
73
+ global_avg=self.global_avg,
74
+ max=self.max,
75
+ value=self.value)
76
+
77
+
78
+ class MetricLogger(object):
79
+ def __init__(self, delimiter="\t"):
80
+ self.meters = defaultdict(SmoothedValue)
81
+ self.delimiter = delimiter
82
+
83
+ def update(self, **kwargs):
84
+ for k, v in kwargs.items():
85
+ if isinstance(v, torch.Tensor):
86
+ v = v.item()
87
+ assert isinstance(v, (float, int))
88
+ self.meters[k].update(v)
89
+
90
+ def __getattr__(self, attr):
91
+ if attr in self.meters:
92
+ return self.meters[attr]
93
+ if attr in self.__dict__:
94
+ return self.__dict__[attr]
95
+ raise AttributeError("'{}' object has no attribute '{}'".format(
96
+ type(self).__name__, attr))
97
+
98
+ def __str__(self):
99
+ loss_str = []
100
+ for name, meter in self.meters.items():
101
+ loss_str.append(
102
+ "{}: {}".format(name, str(meter))
103
+ )
104
+ return self.delimiter.join(loss_str)
105
+
106
+ def synchronize_between_processes(self):
107
+ for meter in self.meters.values():
108
+ meter.synchronize_between_processes()
109
+
110
+ def add_meter(self, name, meter):
111
+ self.meters[name] = meter
112
+
113
+ def log_every(self, iterable, print_freq, header=None):
114
+ i = 0
115
+ if not header:
116
+ header = ''
117
+ start_time = time.time()
118
+ end = time.time()
119
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
120
+ data_time = SmoothedValue(fmt='{avg:.4f}')
121
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
122
+ log_msg = self.delimiter.join([
123
+ header,
124
+ '[{0' + space_fmt + '}/{1}]',
125
+ 'eta: {eta}',
126
+ '{meters}',
127
+ 'time: {time}',
128
+ 'data: {data}',
129
+ 'max mem: {memory:.0f}'
130
+ ])
131
+ MB = 1024.0 * 1024.0
132
+ for obj in iterable:
133
+ data_time.update(time.time() - end)
134
+ yield obj
135
+ iter_time.update(time.time() - end)
136
+ if i % print_freq == 0:
137
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
138
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
139
+ print(log_msg.format(
140
+ i, len(iterable), eta=eta_string,
141
+ meters=str(self),
142
+ time=str(iter_time), data=str(data_time),
143
+ memory=torch.cuda.max_memory_allocated() / MB))
144
+ sys.stdout.flush()
145
+
146
+ i += 1
147
+ end = time.time()
148
+ total_time = time.time() - start_time
149
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
150
+ print('{} Total time: {}'.format(header, total_time_str))
151
+
152
+
153
+ def mkdir(path):
154
+ try:
155
+ os.makedirs(path)
156
+ except OSError as e:
157
+ if e.errno != errno.EEXIST:
158
+ raise
159
+
160
+
161
+ def setup_for_distributed(is_master):
162
+ """
163
+ This function disables printing when not in master process
164
+ """
165
+ import builtins as __builtin__
166
+ builtin_print = __builtin__.print
167
+
168
+ def print(*args, **kwargs):
169
+ force = kwargs.pop('force', False)
170
+ if is_master or force:
171
+ builtin_print(*args, **kwargs)
172
+
173
+ __builtin__.print = print
174
+
175
+
176
+ def is_dist_avail_and_initialized():
177
+ if not dist.is_available():
178
+ return False
179
+ if not dist.is_initialized():
180
+ return False
181
+ return True
182
+
183
+
184
+ def get_world_size():
185
+ if not is_dist_avail_and_initialized():
186
+ return 1
187
+ return dist.get_world_size()
188
+
189
+
190
+ def get_rank():
191
+ if not is_dist_avail_and_initialized():
192
+ return 0
193
+ return dist.get_rank()
194
+
195
+
196
+ def is_main_process():
197
+ return get_rank() == 0
198
+
199
+
200
+ def save_on_master(*args, **kwargs):
201
+ if is_main_process():
202
+ torch.save(*args, **kwargs)
203
+
204
+
205
+ def init_distributed_mode(args):
206
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
207
+ rank = int(os.environ["RANK"])
208
+ world_size = int(os.environ['WORLD_SIZE'])
209
+ print(f"RANK and WORLD_SIZE in environment: {rank}/{world_size}")
210
+ else:
211
+ rank = -1
212
+ world_size = -1
213
+
214
+ torch.cuda.set_device(args.local_rank)
215
+ torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
216
+ torch.distributed.barrier()
217
+ setup_for_distributed(is_main_process())
218
+
219
+ if args.output_dir:
220
+ mkdir(args.output_dir)
221
+ if args.model_id:
222
+ mkdir(os.path.join('./models/', args.model_id))
refer/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder