Spaces:
Running
Running
Upload 267 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- lama/.gitignore +137 -0
- lama/LICENSE +201 -0
- lama/LaMa_inpainting.ipynb +0 -0
- lama/README.md +464 -0
- lama/bin/analyze_errors.py +316 -0
- lama/bin/blur_predicts.py +57 -0
- lama/bin/calc_dataset_stats.py +88 -0
- lama/bin/debug/analyze_overlapping_masks.sh +31 -0
- lama/bin/evaluate_predicts.py +79 -0
- lama/bin/evaluator_example.py +76 -0
- lama/bin/extract_masks.py +63 -0
- lama/bin/filter_sharded_dataset.py +69 -0
- lama/bin/gen_debug_mask_dataset.py +61 -0
- lama/bin/gen_mask_dataset.py +130 -0
- lama/bin/gen_mask_dataset_hydra.py +124 -0
- lama/bin/gen_outpainting_dataset.py +88 -0
- lama/bin/make_checkpoint.py +79 -0
- lama/bin/mask_example.py +14 -0
- lama/bin/paper_runfiles/blur_tests.sh +37 -0
- lama/bin/paper_runfiles/env.sh +8 -0
- lama/bin/paper_runfiles/find_best_checkpoint.py +54 -0
- lama/bin/paper_runfiles/generate_test_celeba-hq.sh +17 -0
- lama/bin/paper_runfiles/generate_test_ffhq.sh +17 -0
- lama/bin/paper_runfiles/generate_test_paris.sh +17 -0
- lama/bin/paper_runfiles/generate_test_paris_256.sh +17 -0
- lama/bin/paper_runfiles/generate_val_test.sh +28 -0
- lama/bin/paper_runfiles/predict_inner_features.sh +20 -0
- lama/bin/paper_runfiles/update_test_data_stats.sh +30 -0
- lama/bin/predict.py +104 -0
- lama/bin/predict_inner_features.py +120 -0
- lama/bin/report_from_tb.py +83 -0
- lama/bin/sample_from_dataset.py +87 -0
- lama/bin/side_by_side.py +76 -0
- lama/bin/split_tar.py +22 -0
- lama/bin/to_jit.py +76 -0
- lama/bin/train.py +73 -0
- lama/conda_env.yml +165 -0
- lama/configs/analyze_mask_errors.yaml +7 -0
- lama/configs/data_gen/random_medium_256.yaml +33 -0
- lama/configs/data_gen/random_medium_512.yaml +33 -0
- lama/configs/data_gen/random_thick_256.yaml +33 -0
- lama/configs/data_gen/random_thick_512.yaml +33 -0
- lama/configs/data_gen/random_thin_256.yaml +25 -0
- lama/configs/data_gen/random_thin_512.yaml +25 -0
- lama/configs/debug_mask_gen.yaml +5 -0
- lama/configs/eval1.yaml +6 -0
- lama/configs/eval2.yaml +7 -0
- lama/configs/eval2_cpu.yaml +7 -0
- lama/configs/eval2_gpu.yaml +6 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
lama/saicinpainting/evaluation/masks/countless/images/gcim.jpg filter=lfs diff=lfs merge=lfs -text
|
lama/.gitignore
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# temporary files
|
132 |
+
## IDEA
|
133 |
+
.idea/
|
134 |
+
## vscode
|
135 |
+
.vscode/
|
136 |
+
## vim
|
137 |
+
*.sw?
|
lama/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2021] Samsung Research
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
lama/LaMa_inpainting.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lama/README.md
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🦙 LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions
|
2 |
+
|
3 |
+
by Roman Suvorov, Elizaveta Logacheva, Anton Mashikhin,
|
4 |
+
Anastasia Remizova, Arsenii Ashukha, Aleksei Silvestrov, Naejin Kong, Harshith Goka, Kiwoong Park, Victor Lempitsky.
|
5 |
+
|
6 |
+
<p align="center" "font-size:30px;">
|
7 |
+
🔥🔥🔥
|
8 |
+
<br>
|
9 |
+
<b>
|
10 |
+
LaMa generalizes surprisingly well to much higher resolutions (~2k❗️) than it saw during training (256x256), and achieves the excellent performance even in challenging scenarios, e.g. completion of periodic structures.</b>
|
11 |
+
</p>
|
12 |
+
|
13 |
+
[[Project page](https://advimman.github.io/lama-project/)] [[arXiv](https://arxiv.org/abs/2109.07161)] [[Supplementary](https://ashukha.com/projects/lama_21/lama_supmat_2021.pdf)] [[BibTeX](https://senya-ashukha.github.io/projects/lama_21/paper.txt)] [[Casual GAN Papers Summary](https://www.casualganpapers.com/large-masks-fourier-convolutions-inpainting/LaMa-explained.html)]
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
<a href="https://colab.research.google.com/drive/15KTEIScUbVZtUP6w2tCDMVpE-b1r9pkZ?usp=drive_link">
|
17 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg"/>
|
18 |
+
</a>
|
19 |
+
<br>
|
20 |
+
Try out in Google Colab
|
21 |
+
</p>
|
22 |
+
|
23 |
+
<p align="center">
|
24 |
+
<img src="https://raw.githubusercontent.com/senya-ashukha/senya-ashukha.github.io/master/projects/lama_21/ezgif-4-0db51df695a8.gif" />
|
25 |
+
</p>
|
26 |
+
|
27 |
+
|
28 |
+
<p align="center">
|
29 |
+
<img src="https://raw.githubusercontent.com/senya-ashukha/senya-ashukha.github.io/master/projects/lama_21/gif_for_lightning_v1_white.gif" />
|
30 |
+
</p>
|
31 |
+
|
32 |
+
# LaMa development
|
33 |
+
(Feel free to share your paper by creating an issue)
|
34 |
+
- https://github.com/geekyutao/Inpaint-Anything --- Inpaint Anything: Segment Anything Meets Image Inpainting
|
35 |
+
<p align="center">
|
36 |
+
<img src="https://raw.githubusercontent.com/geekyutao/Inpaint-Anything/main/example/MainFramework.png" />
|
37 |
+
</p>
|
38 |
+
|
39 |
+
- [Feature Refinement to Improve High Resolution Image Inpainting](https://arxiv.org/abs/2206.13644) / [video](https://www.youtube.com/watch?v=gEukhOheWgE) / code https://github.com/advimman/lama/pull/112 / by Geomagical Labs ([geomagical.com](geomagical.com))
|
40 |
+
<p align="center">
|
41 |
+
<img src="https://raw.githubusercontent.com/senya-ashukha/senya-ashukha.github.io/master/images/FeatureRefinement.png" />
|
42 |
+
</p>
|
43 |
+
|
44 |
+
# Non-official 3rd party apps:
|
45 |
+
(Feel free to share your app/implementation/demo by creating an issue)
|
46 |
+
|
47 |
+
- https://github.com/enesmsahin/simple-lama-inpainting - a simple pip package for LaMa inpainting.
|
48 |
+
- https://github.com/mallman/CoreMLaMa - Apple's Core ML model format
|
49 |
+
- [https://cleanup.pictures](https://cleanup.pictures/) - a simple interactive object removal tool by [@cyrildiagne](https://twitter.com/cyrildiagne)
|
50 |
+
- [lama-cleaner](https://github.com/Sanster/lama-cleaner) by [@Sanster](https://github.com/Sanster/lama-cleaner) is a self-host version of [https://cleanup.pictures](https://cleanup.pictures/)
|
51 |
+
- Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/lama) by [@AK391](https://github.com/AK391)
|
52 |
+
- Telegram bot [@MagicEraserBot](https://t.me/MagicEraserBot) by [@Moldoteck](https://github.com/Moldoteck), [code](https://github.com/Moldoteck/MagicEraser)
|
53 |
+
- [Auto-LaMa](https://github.com/andy971022/auto-lama) = DE:TR object detection + LaMa inpainting by [@andy971022](https://github.com/andy971022)
|
54 |
+
- [LAMA-Magic-Eraser-Local](https://github.com/zhaoyun0071/LAMA-Magic-Eraser-Local) = a standalone inpainting application built with PyQt5 by [@zhaoyun0071](https://github.com/zhaoyun0071)
|
55 |
+
- [Hama](https://www.hama.app/) - object removal with a smart brush which simplifies mask drawing.
|
56 |
+
- [ModelScope](https://www.modelscope.cn/models/damo/cv_fft_inpainting_lama/summary) = the largest Model Community in Chinese by [@chenbinghui1](https://github.com/chenbinghui1).
|
57 |
+
- [LaMa with MaskDINO](https://github.com/qwopqwop200/lama-with-maskdino) = MaskDINO object detection + LaMa inpainting with refinement by [@qwopqwop200](https://github.com/qwopqwop200).
|
58 |
+
- [CoreMLaMa](https://github.com/mallman/CoreMLaMa) - a script to convert Lama Cleaner's port of LaMa to Apple's Core ML model format.
|
59 |
+
|
60 |
+
# Environment setup
|
61 |
+
|
62 |
+
Clone the repo:
|
63 |
+
`git clone https://github.com/advimman/lama.git`
|
64 |
+
|
65 |
+
There are three options of an environment:
|
66 |
+
|
67 |
+
1. Python virtualenv:
|
68 |
+
|
69 |
+
```
|
70 |
+
virtualenv inpenv --python=/usr/bin/python3
|
71 |
+
source inpenv/bin/activate
|
72 |
+
pip install torch==1.8.0 torchvision==0.9.0
|
73 |
+
|
74 |
+
cd lama
|
75 |
+
pip install -r requirements.txt
|
76 |
+
```
|
77 |
+
|
78 |
+
2. Conda
|
79 |
+
|
80 |
+
```
|
81 |
+
% Install conda for Linux, for other OS download miniconda at https://docs.conda.io/en/latest/miniconda.html
|
82 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
83 |
+
bash Miniconda3-latest-Linux-x86_64.sh -b -p $HOME/miniconda
|
84 |
+
$HOME/miniconda/bin/conda init bash
|
85 |
+
|
86 |
+
cd lama
|
87 |
+
conda env create -f conda_env.yml
|
88 |
+
conda activate lama
|
89 |
+
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -y
|
90 |
+
pip install pytorch-lightning==1.2.9
|
91 |
+
```
|
92 |
+
|
93 |
+
3. Docker: No actions are needed 🎉.
|
94 |
+
|
95 |
+
# Inference <a name="prediction"></a>
|
96 |
+
|
97 |
+
Run
|
98 |
+
```
|
99 |
+
cd lama
|
100 |
+
export TORCH_HOME=$(pwd) && export PYTHONPATH=$(pwd)
|
101 |
+
```
|
102 |
+
|
103 |
+
**1. Download pre-trained models**
|
104 |
+
|
105 |
+
The best model (Places2, Places Challenge):
|
106 |
+
|
107 |
+
```
|
108 |
+
curl -LJO https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip
|
109 |
+
unzip big-lama.zip
|
110 |
+
```
|
111 |
+
|
112 |
+
All models (Places & CelebA-HQ):
|
113 |
+
|
114 |
+
```
|
115 |
+
download [https://drive.google.com/drive/folders/1B2x7eQDgecTL0oh3LSIBDGj0fTxs6Ips?usp=drive_link]
|
116 |
+
unzip lama-models.zip
|
117 |
+
```
|
118 |
+
|
119 |
+
**2. Prepare images and masks**
|
120 |
+
|
121 |
+
Download test images:
|
122 |
+
|
123 |
+
```
|
124 |
+
unzip LaMa_test_images.zip
|
125 |
+
```
|
126 |
+
<details>
|
127 |
+
<summary>OR prepare your data:</summary>
|
128 |
+
1) Create masks named as `[images_name]_maskXXX[image_suffix]`, put images and masks in the same folder.
|
129 |
+
|
130 |
+
- You can use the [script](https://github.com/advimman/lama/blob/main/bin/gen_mask_dataset.py) for random masks generation.
|
131 |
+
- Check the format of the files:
|
132 |
+
```
|
133 |
+
image1_mask001.png
|
134 |
+
image1.png
|
135 |
+
image2_mask001.png
|
136 |
+
image2.png
|
137 |
+
```
|
138 |
+
|
139 |
+
2) Specify `image_suffix`, e.g. `.png` or `.jpg` or `_input.jpg` in `configs/prediction/default.yaml`.
|
140 |
+
|
141 |
+
</details>
|
142 |
+
|
143 |
+
|
144 |
+
**3. Predict**
|
145 |
+
|
146 |
+
On the host machine:
|
147 |
+
|
148 |
+
python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output
|
149 |
+
|
150 |
+
**OR** in the docker
|
151 |
+
|
152 |
+
The following command will pull the docker image from Docker Hub and execute the prediction script
|
153 |
+
```
|
154 |
+
bash docker/2_predict.sh $(pwd)/big-lama $(pwd)/LaMa_test_images $(pwd)/output device=cpu
|
155 |
+
```
|
156 |
+
Docker cuda:
|
157 |
+
```
|
158 |
+
bash docker/2_predict_with_gpu.sh $(pwd)/big-lama $(pwd)/LaMa_test_images $(pwd)/output
|
159 |
+
```
|
160 |
+
|
161 |
+
**4. Predict with Refinement**
|
162 |
+
|
163 |
+
On the host machine:
|
164 |
+
|
165 |
+
python3 bin/predict.py refine=True model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output
|
166 |
+
|
167 |
+
# Train and Eval
|
168 |
+
|
169 |
+
Make sure you run:
|
170 |
+
|
171 |
+
```
|
172 |
+
cd lama
|
173 |
+
export TORCH_HOME=$(pwd) && export PYTHONPATH=$(pwd)
|
174 |
+
```
|
175 |
+
|
176 |
+
Then download models for _perceptual loss_:
|
177 |
+
|
178 |
+
mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/
|
179 |
+
wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
|
180 |
+
|
181 |
+
|
182 |
+
## Places
|
183 |
+
|
184 |
+
⚠️ NB: FID/SSIM/LPIPS metric values for Places that we see in LaMa paper are computed on 30000 images that we produce in evaluation section below.
|
185 |
+
For more details on evaluation data check [[Section 3. Dataset splits in Supplementary](https://ashukha.com/projects/lama_21/lama_supmat_2021.pdf#subsection.3.1)] ⚠️
|
186 |
+
|
187 |
+
On the host machine:
|
188 |
+
|
189 |
+
# Download data from http://places2.csail.mit.edu/download.html
|
190 |
+
# Places365-Standard: Train(105GB)/Test(19GB)/Val(2.1GB) from High-resolution images section
|
191 |
+
wget http://data.csail.mit.edu/places/places365/train_large_places365standard.tar
|
192 |
+
wget http://data.csail.mit.edu/places/places365/val_large.tar
|
193 |
+
wget http://data.csail.mit.edu/places/places365/test_large.tar
|
194 |
+
|
195 |
+
# Unpack train/test/val data and create .yaml config for it
|
196 |
+
bash fetch_data/places_standard_train_prepare.sh
|
197 |
+
bash fetch_data/places_standard_test_val_prepare.sh
|
198 |
+
|
199 |
+
# Sample images for test and viz at the end of epoch
|
200 |
+
bash fetch_data/places_standard_test_val_sample.sh
|
201 |
+
bash fetch_data/places_standard_test_val_gen_masks.sh
|
202 |
+
|
203 |
+
# Run training
|
204 |
+
python3 bin/train.py -cn lama-fourier location=places_standard
|
205 |
+
|
206 |
+
# To evaluate trained model and report metrics as in our paper
|
207 |
+
# we need to sample previously unseen 30k images and generate masks for them
|
208 |
+
bash fetch_data/places_standard_evaluation_prepare_data.sh
|
209 |
+
|
210 |
+
# Infer model on thick/thin/medium masks in 256 and 512 and run evaluation
|
211 |
+
# like this:
|
212 |
+
python3 bin/predict.py \
|
213 |
+
model.path=$(pwd)/experiments/<user>_<date:time>_lama-fourier_/ \
|
214 |
+
indir=$(pwd)/places_standard_dataset/evaluation/random_thick_512/ \
|
215 |
+
outdir=$(pwd)/inference/random_thick_512 model.checkpoint=last.ckpt
|
216 |
+
|
217 |
+
python3 bin/evaluate_predicts.py \
|
218 |
+
$(pwd)/configs/eval2_gpu.yaml \
|
219 |
+
$(pwd)/places_standard_dataset/evaluation/random_thick_512/ \
|
220 |
+
$(pwd)/inference/random_thick_512 \
|
221 |
+
$(pwd)/inference/random_thick_512_metrics.csv
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
Docker: TODO
|
226 |
+
|
227 |
+
## CelebA
|
228 |
+
On the host machine:
|
229 |
+
|
230 |
+
# Make shure you are in lama folder
|
231 |
+
cd lama
|
232 |
+
export TORCH_HOME=$(pwd) && export PYTHONPATH=$(pwd)
|
233 |
+
|
234 |
+
# Download CelebA-HQ dataset
|
235 |
+
# Download data256x256.zip from https://drive.google.com/drive/folders/11Vz0fqHS2rXDb5pprgTjpD7S2BAJhi1P
|
236 |
+
|
237 |
+
# unzip & split into train/test/visualization & create config for it
|
238 |
+
bash fetch_data/celebahq_dataset_prepare.sh
|
239 |
+
|
240 |
+
# generate masks for test and visual_test at the end of epoch
|
241 |
+
bash fetch_data/celebahq_gen_masks.sh
|
242 |
+
|
243 |
+
# Run training
|
244 |
+
python3 bin/train.py -cn lama-fourier-celeba data.batch_size=10
|
245 |
+
|
246 |
+
# Infer model on thick/thin/medium masks in 256 and run evaluation
|
247 |
+
# like this:
|
248 |
+
python3 bin/predict.py \
|
249 |
+
model.path=$(pwd)/experiments/<user>_<date:time>_lama-fourier-celeba_/ \
|
250 |
+
indir=$(pwd)/celeba-hq-dataset/visual_test_256/random_thick_256/ \
|
251 |
+
outdir=$(pwd)/inference/celeba_random_thick_256 model.checkpoint=last.ckpt
|
252 |
+
|
253 |
+
|
254 |
+
Docker: TODO
|
255 |
+
|
256 |
+
## Places Challenge
|
257 |
+
|
258 |
+
On the host machine:
|
259 |
+
|
260 |
+
# This script downloads multiple .tar files in parallel and unpacks them
|
261 |
+
# Places365-Challenge: Train(476GB) from High-resolution images (to train Big-Lama)
|
262 |
+
bash places_challenge_train_download.sh
|
263 |
+
|
264 |
+
TODO: prepare
|
265 |
+
TODO: train
|
266 |
+
TODO: eval
|
267 |
+
|
268 |
+
Docker: TODO
|
269 |
+
|
270 |
+
## Create your data
|
271 |
+
|
272 |
+
Please check bash scripts for data preparation and mask generation from CelebaHQ section,
|
273 |
+
if you stuck at one of the following steps.
|
274 |
+
|
275 |
+
|
276 |
+
On the host machine:
|
277 |
+
|
278 |
+
# Make shure you are in lama folder
|
279 |
+
cd lama
|
280 |
+
export TORCH_HOME=$(pwd) && export PYTHONPATH=$(pwd)
|
281 |
+
|
282 |
+
# You need to prepare following image folders:
|
283 |
+
$ ls my_dataset
|
284 |
+
train
|
285 |
+
val_source # 2000 or more images
|
286 |
+
visual_test_source # 100 or more images
|
287 |
+
eval_source # 2000 or more images
|
288 |
+
|
289 |
+
# LaMa generates random masks for the train data on the flight,
|
290 |
+
# but needs fixed masks for test and visual_test for consistency of evaluation.
|
291 |
+
|
292 |
+
# Suppose, we want to evaluate and pick best models
|
293 |
+
# on 512x512 val dataset with thick/thin/medium masks
|
294 |
+
# And your images have .jpg extention:
|
295 |
+
|
296 |
+
python3 bin/gen_mask_dataset.py \
|
297 |
+
$(pwd)/configs/data_gen/random_<size>_512.yaml \ # thick, thin, medium
|
298 |
+
my_dataset/val_source/ \
|
299 |
+
my_dataset/val/random_<size>_512.yaml \# thick, thin, medium
|
300 |
+
--ext jpg
|
301 |
+
|
302 |
+
# So the mask generator will:
|
303 |
+
# 1. resize and crop val images and save them as .png
|
304 |
+
# 2. generate masks
|
305 |
+
|
306 |
+
ls my_dataset/val/random_medium_512/
|
307 |
+
image1_crop000_mask000.png
|
308 |
+
image1_crop000.png
|
309 |
+
image2_crop000_mask000.png
|
310 |
+
image2_crop000.png
|
311 |
+
...
|
312 |
+
|
313 |
+
# Generate thick, thin, medium masks for visual_test folder:
|
314 |
+
|
315 |
+
python3 bin/gen_mask_dataset.py \
|
316 |
+
$(pwd)/configs/data_gen/random_<size>_512.yaml \ #thick, thin, medium
|
317 |
+
my_dataset/visual_test_source/ \
|
318 |
+
my_dataset/visual_test/random_<size>_512/ \ #thick, thin, medium
|
319 |
+
--ext jpg
|
320 |
+
|
321 |
+
|
322 |
+
ls my_dataset/visual_test/random_thick_512/
|
323 |
+
image1_crop000_mask000.png
|
324 |
+
image1_crop000.png
|
325 |
+
image2_crop000_mask000.png
|
326 |
+
image2_crop000.png
|
327 |
+
...
|
328 |
+
|
329 |
+
# Same process for eval_source image folder:
|
330 |
+
|
331 |
+
python3 bin/gen_mask_dataset.py \
|
332 |
+
$(pwd)/configs/data_gen/random_<size>_512.yaml \ #thick, thin, medium
|
333 |
+
my_dataset/eval_source/ \
|
334 |
+
my_dataset/eval/random_<size>_512/ \ #thick, thin, medium
|
335 |
+
--ext jpg
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
# Generate location config file which locate these folders:
|
340 |
+
|
341 |
+
touch my_dataset.yaml
|
342 |
+
echo "data_root_dir: $(pwd)/my_dataset/" >> my_dataset.yaml
|
343 |
+
echo "out_root_dir: $(pwd)/experiments/" >> my_dataset.yaml
|
344 |
+
echo "tb_dir: $(pwd)/tb_logs/" >> my_dataset.yaml
|
345 |
+
mv my_dataset.yaml ${PWD}/configs/training/location/
|
346 |
+
|
347 |
+
|
348 |
+
# Check data config for consistency with my_dataset folder structure:
|
349 |
+
$ cat ${PWD}/configs/training/data/abl-04-256-mh-dist
|
350 |
+
...
|
351 |
+
train:
|
352 |
+
indir: ${location.data_root_dir}/train
|
353 |
+
...
|
354 |
+
val:
|
355 |
+
indir: ${location.data_root_dir}/val
|
356 |
+
img_suffix: .png
|
357 |
+
visual_test:
|
358 |
+
indir: ${location.data_root_dir}/visual_test
|
359 |
+
img_suffix: .png
|
360 |
+
|
361 |
+
|
362 |
+
# Run training
|
363 |
+
python3 bin/train.py -cn lama-fourier location=my_dataset data.batch_size=10
|
364 |
+
|
365 |
+
# Evaluation: LaMa training procedure picks best few models according to
|
366 |
+
# scores on my_dataset/val/
|
367 |
+
|
368 |
+
# To evaluate one of your best models (i.e. at epoch=32)
|
369 |
+
# on previously unseen my_dataset/eval do the following
|
370 |
+
# for thin, thick and medium:
|
371 |
+
|
372 |
+
# infer:
|
373 |
+
python3 bin/predict.py \
|
374 |
+
model.path=$(pwd)/experiments/<user>_<date:time>_lama-fourier_/ \
|
375 |
+
indir=$(pwd)/my_dataset/eval/random_<size>_512/ \
|
376 |
+
outdir=$(pwd)/inference/my_dataset/random_<size>_512 \
|
377 |
+
model.checkpoint=epoch32.ckpt
|
378 |
+
|
379 |
+
# metrics calculation:
|
380 |
+
python3 bin/evaluate_predicts.py \
|
381 |
+
$(pwd)/configs/eval2_gpu.yaml \
|
382 |
+
$(pwd)/my_dataset/eval/random_<size>_512/ \
|
383 |
+
$(pwd)/inference/my_dataset/random_<size>_512 \
|
384 |
+
$(pwd)/inference/my_dataset/random_<size>_512_metrics.csv
|
385 |
+
|
386 |
+
|
387 |
+
**OR** in the docker:
|
388 |
+
|
389 |
+
TODO: train
|
390 |
+
TODO: eval
|
391 |
+
|
392 |
+
# Hints
|
393 |
+
|
394 |
+
### Generate different kinds of masks
|
395 |
+
The following command will execute a script that generates random masks.
|
396 |
+
|
397 |
+
bash docker/1_generate_masks_from_raw_images.sh \
|
398 |
+
configs/data_gen/random_medium_512.yaml \
|
399 |
+
/directory_with_input_images \
|
400 |
+
/directory_where_to_store_images_and_masks \
|
401 |
+
--ext png
|
402 |
+
|
403 |
+
The test data generation command stores images in the format,
|
404 |
+
which is suitable for [prediction](#prediction).
|
405 |
+
|
406 |
+
The table below describes which configs we used to generate different test sets from the paper.
|
407 |
+
Note that we *do not fix a random seed*, so the results will be slightly different each time.
|
408 |
+
|
409 |
+
| | Places 512x512 | CelebA 256x256 |
|
410 |
+
|--------|------------------------|------------------------|
|
411 |
+
| Narrow | random_thin_512.yaml | random_thin_256.yaml |
|
412 |
+
| Medium | random_medium_512.yaml | random_medium_256.yaml |
|
413 |
+
| Wide | random_thick_512.yaml | random_thick_256.yaml |
|
414 |
+
|
415 |
+
Feel free to change the config path (argument #1) to any other config in `configs/data_gen`
|
416 |
+
or adjust config files themselves.
|
417 |
+
|
418 |
+
### Override parameters in configs
|
419 |
+
Also you can override parameters in config like this:
|
420 |
+
|
421 |
+
python3 bin/train.py -cn <config> data.batch_size=10 run_title=my-title
|
422 |
+
|
423 |
+
Where .yaml file extension is omitted
|
424 |
+
|
425 |
+
### Models options
|
426 |
+
Config names for models from paper (substitude into the training command):
|
427 |
+
|
428 |
+
* big-lama
|
429 |
+
* big-lama-regular
|
430 |
+
* lama-fourier
|
431 |
+
* lama-regular
|
432 |
+
* lama_small_train_masks
|
433 |
+
|
434 |
+
Which are seated in configs/training/folder
|
435 |
+
|
436 |
+
### Links
|
437 |
+
- All the data (models, test images, etc.) https://disk.yandex.ru/d/AmdeG-bIjmvSug
|
438 |
+
- Test images from the paper https://disk.yandex.ru/d/xKQJZeVRk5vLlQ
|
439 |
+
- The pre-trained models https://disk.yandex.ru/d/EgqaSnLohjuzAg
|
440 |
+
- The models for perceptual loss https://disk.yandex.ru/d/ncVmQlmT_kTemQ
|
441 |
+
- Our training logs are available at https://disk.yandex.ru/d/9Bt1wNSDS4jDkQ
|
442 |
+
|
443 |
+
|
444 |
+
### Training time & resources
|
445 |
+
|
446 |
+
TODO
|
447 |
+
|
448 |
+
## Acknowledgments
|
449 |
+
|
450 |
+
* Segmentation code and models if form [CSAILVision](https://github.com/CSAILVision/semantic-segmentation-pytorch).
|
451 |
+
* LPIPS metric is from [richzhang](https://github.com/richzhang/PerceptualSimilarity)
|
452 |
+
* SSIM is from [Po-Hsun-Su](https://github.com/Po-Hsun-Su/pytorch-ssim)
|
453 |
+
* FID is from [mseitzer](https://github.com/mseitzer/pytorch-fid)
|
454 |
+
|
455 |
+
## Citation
|
456 |
+
If you found this code helpful, please consider citing:
|
457 |
+
```
|
458 |
+
@article{suvorov2021resolution,
|
459 |
+
title={Resolution-robust Large Mask Inpainting with Fourier Convolutions},
|
460 |
+
author={Suvorov, Roman and Logacheva, Elizaveta and Mashikhin, Anton and Remizova, Anastasia and Ashukha, Arsenii and Silvestrov, Aleksei and Kong, Naejin and Goka, Harshith and Park, Kiwoong and Lempitsky, Victor},
|
461 |
+
journal={arXiv preprint arXiv:2109.07161},
|
462 |
+
year={2021}
|
463 |
+
}
|
464 |
+
```
|
lama/bin/analyze_errors.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import sklearn
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
import pandas as pd
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
|
12 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset, load_image
|
13 |
+
from saicinpainting.evaluation.losses.fid.inception import InceptionV3
|
14 |
+
from saicinpainting.evaluation.utils import load_yaml
|
15 |
+
from saicinpainting.training.visualizers.base import visualize_mask_and_images
|
16 |
+
|
17 |
+
|
18 |
+
def draw_score(img, score):
|
19 |
+
img = np.transpose(img, (1, 2, 0))
|
20 |
+
cv2.putText(img, f'{score:.2f}',
|
21 |
+
(40, 40),
|
22 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
23 |
+
1,
|
24 |
+
(0, 1, 0),
|
25 |
+
thickness=3)
|
26 |
+
img = np.transpose(img, (2, 0, 1))
|
27 |
+
return img
|
28 |
+
|
29 |
+
|
30 |
+
def save_global_samples(global_mask_fnames, mask2real_fname, mask2fake_fname, out_dir, real_scores_by_fname, fake_scores_by_fname):
|
31 |
+
for cur_mask_fname in global_mask_fnames:
|
32 |
+
cur_real_fname = mask2real_fname[cur_mask_fname]
|
33 |
+
orig_img = load_image(cur_real_fname, mode='RGB')
|
34 |
+
fake_img = load_image(mask2fake_fname[cur_mask_fname], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
|
35 |
+
mask = load_image(cur_mask_fname, mode='L')[None, ...]
|
36 |
+
|
37 |
+
draw_score(orig_img, real_scores_by_fname.loc[cur_real_fname, 'real_score'])
|
38 |
+
draw_score(fake_img, fake_scores_by_fname.loc[cur_mask_fname, 'fake_score'])
|
39 |
+
|
40 |
+
cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=mask, fake=fake_img),
|
41 |
+
keys=['image', 'fake'],
|
42 |
+
last_without_mask=True)
|
43 |
+
cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
|
44 |
+
cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
|
45 |
+
cv2.imwrite(os.path.join(out_dir, os.path.splitext(os.path.basename(cur_mask_fname))[0] + '.jpg'),
|
46 |
+
cur_grid)
|
47 |
+
|
48 |
+
|
49 |
+
def save_samples_by_real(worst_best_by_real, mask2fake_fname, fake_info, out_dir):
|
50 |
+
for real_fname in worst_best_by_real.index:
|
51 |
+
worst_mask_path = worst_best_by_real.loc[real_fname, 'worst']
|
52 |
+
best_mask_path = worst_best_by_real.loc[real_fname, 'best']
|
53 |
+
orig_img = load_image(real_fname, mode='RGB')
|
54 |
+
worst_mask_img = load_image(worst_mask_path, mode='L')[None, ...]
|
55 |
+
worst_fake_img = load_image(mask2fake_fname[worst_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
|
56 |
+
best_mask_img = load_image(best_mask_path, mode='L')[None, ...]
|
57 |
+
best_fake_img = load_image(mask2fake_fname[best_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
|
58 |
+
|
59 |
+
draw_score(orig_img, worst_best_by_real.loc[real_fname, 'real_score'])
|
60 |
+
draw_score(worst_fake_img, worst_best_by_real.loc[real_fname, 'worst_score'])
|
61 |
+
draw_score(best_fake_img, worst_best_by_real.loc[real_fname, 'best_score'])
|
62 |
+
|
63 |
+
cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=np.zeros_like(worst_mask_img),
|
64 |
+
worst_mask=worst_mask_img, worst_img=worst_fake_img,
|
65 |
+
best_mask=best_mask_img, best_img=best_fake_img),
|
66 |
+
keys=['image', 'worst_mask', 'worst_img', 'best_mask', 'best_img'],
|
67 |
+
rescale_keys=['worst_mask', 'best_mask'],
|
68 |
+
last_without_mask=True)
|
69 |
+
cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
|
70 |
+
cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
|
71 |
+
cv2.imwrite(os.path.join(out_dir,
|
72 |
+
os.path.splitext(os.path.basename(real_fname))[0] + '.jpg'),
|
73 |
+
cur_grid)
|
74 |
+
|
75 |
+
fig, (ax1, ax2) = plt.subplots(1, 2)
|
76 |
+
cur_stat = fake_info[fake_info['real_fname'] == real_fname]
|
77 |
+
cur_stat['fake_score'].hist(ax=ax1)
|
78 |
+
cur_stat['real_score'].hist(ax=ax2)
|
79 |
+
fig.tight_layout()
|
80 |
+
fig.savefig(os.path.join(out_dir,
|
81 |
+
os.path.splitext(os.path.basename(real_fname))[0] + '_scores.png'))
|
82 |
+
plt.close(fig)
|
83 |
+
|
84 |
+
|
85 |
+
def extract_overlapping_masks(mask_fnames, cur_i, fake_scores_table, max_overlaps_n=2):
|
86 |
+
result_pairs = []
|
87 |
+
result_scores = []
|
88 |
+
mask_fname_a = mask_fnames[cur_i]
|
89 |
+
mask_a = load_image(mask_fname_a, mode='L')[None, ...] > 0.5
|
90 |
+
cur_score_a = fake_scores_table.loc[mask_fname_a, 'fake_score']
|
91 |
+
for mask_fname_b in mask_fnames[cur_i + 1:]:
|
92 |
+
mask_b = load_image(mask_fname_b, mode='L')[None, ...] > 0.5
|
93 |
+
if not np.any(mask_a & mask_b):
|
94 |
+
continue
|
95 |
+
cur_score_b = fake_scores_table.loc[mask_fname_b, 'fake_score']
|
96 |
+
result_pairs.append((mask_fname_a, mask_fname_b))
|
97 |
+
result_scores.append(cur_score_b - cur_score_a)
|
98 |
+
if len(result_pairs) >= max_overlaps_n:
|
99 |
+
break
|
100 |
+
return result_pairs, result_scores
|
101 |
+
|
102 |
+
|
103 |
+
def main(args):
|
104 |
+
config = load_yaml(args.config)
|
105 |
+
|
106 |
+
latents_dir = os.path.join(args.outpath, 'latents')
|
107 |
+
os.makedirs(latents_dir, exist_ok=True)
|
108 |
+
global_worst_dir = os.path.join(args.outpath, 'global_worst')
|
109 |
+
os.makedirs(global_worst_dir, exist_ok=True)
|
110 |
+
global_best_dir = os.path.join(args.outpath, 'global_best')
|
111 |
+
os.makedirs(global_best_dir, exist_ok=True)
|
112 |
+
worst_best_by_best_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_max')
|
113 |
+
os.makedirs(worst_best_by_best_worst_score_diff_max_dir, exist_ok=True)
|
114 |
+
worst_best_by_best_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_min')
|
115 |
+
os.makedirs(worst_best_by_best_worst_score_diff_min_dir, exist_ok=True)
|
116 |
+
worst_best_by_real_best_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_max')
|
117 |
+
os.makedirs(worst_best_by_real_best_score_diff_max_dir, exist_ok=True)
|
118 |
+
worst_best_by_real_best_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_min')
|
119 |
+
os.makedirs(worst_best_by_real_best_score_diff_min_dir, exist_ok=True)
|
120 |
+
worst_best_by_real_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_max')
|
121 |
+
os.makedirs(worst_best_by_real_worst_score_diff_max_dir, exist_ok=True)
|
122 |
+
worst_best_by_real_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_min')
|
123 |
+
os.makedirs(worst_best_by_real_worst_score_diff_min_dir, exist_ok=True)
|
124 |
+
|
125 |
+
if not args.only_report:
|
126 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
|
127 |
+
inception_model = InceptionV3([block_idx]).eval().cuda()
|
128 |
+
|
129 |
+
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
|
130 |
+
|
131 |
+
real2vector_cache = {}
|
132 |
+
|
133 |
+
real_features = []
|
134 |
+
fake_features = []
|
135 |
+
|
136 |
+
orig_fnames = []
|
137 |
+
mask_fnames = []
|
138 |
+
mask2real_fname = {}
|
139 |
+
mask2fake_fname = {}
|
140 |
+
|
141 |
+
for batch_i, batch in enumerate(dataset):
|
142 |
+
orig_img_fname = dataset.img_filenames[batch_i]
|
143 |
+
mask_fname = dataset.mask_filenames[batch_i]
|
144 |
+
fake_fname = dataset.pred_filenames[batch_i]
|
145 |
+
mask2real_fname[mask_fname] = orig_img_fname
|
146 |
+
mask2fake_fname[mask_fname] = fake_fname
|
147 |
+
|
148 |
+
cur_real_vector = real2vector_cache.get(orig_img_fname, None)
|
149 |
+
if cur_real_vector is None:
|
150 |
+
with torch.no_grad():
|
151 |
+
in_img = torch.from_numpy(batch['image'][None, ...]).cuda()
|
152 |
+
cur_real_vector = inception_model(in_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
|
153 |
+
real2vector_cache[orig_img_fname] = cur_real_vector
|
154 |
+
|
155 |
+
pred_img = torch.from_numpy(batch['inpainted'][None, ...]).cuda()
|
156 |
+
cur_fake_vector = inception_model(pred_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
|
157 |
+
|
158 |
+
real_features.append(cur_real_vector)
|
159 |
+
fake_features.append(cur_fake_vector)
|
160 |
+
|
161 |
+
orig_fnames.append(orig_img_fname)
|
162 |
+
mask_fnames.append(mask_fname)
|
163 |
+
|
164 |
+
ids_features = np.concatenate(real_features + fake_features, axis=0)
|
165 |
+
ids_labels = np.array(([1] * len(real_features)) + ([0] * len(fake_features)))
|
166 |
+
|
167 |
+
with open(os.path.join(latents_dir, 'featues.pkl'), 'wb') as f:
|
168 |
+
pickle.dump(ids_features, f, protocol=3)
|
169 |
+
with open(os.path.join(latents_dir, 'labels.pkl'), 'wb') as f:
|
170 |
+
pickle.dump(ids_labels, f, protocol=3)
|
171 |
+
with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'wb') as f:
|
172 |
+
pickle.dump(orig_fnames, f, protocol=3)
|
173 |
+
with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'wb') as f:
|
174 |
+
pickle.dump(mask_fnames, f, protocol=3)
|
175 |
+
with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'wb') as f:
|
176 |
+
pickle.dump(mask2real_fname, f, protocol=3)
|
177 |
+
with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'wb') as f:
|
178 |
+
pickle.dump(mask2fake_fname, f, protocol=3)
|
179 |
+
|
180 |
+
svm = sklearn.svm.LinearSVC(dual=False)
|
181 |
+
svm.fit(ids_features, ids_labels)
|
182 |
+
|
183 |
+
pred_scores = svm.decision_function(ids_features)
|
184 |
+
real_scores = pred_scores[:len(real_features)]
|
185 |
+
fake_scores = pred_scores[len(real_features):]
|
186 |
+
|
187 |
+
with open(os.path.join(latents_dir, 'pred_scores.pkl'), 'wb') as f:
|
188 |
+
pickle.dump(pred_scores, f, protocol=3)
|
189 |
+
with open(os.path.join(latents_dir, 'real_scores.pkl'), 'wb') as f:
|
190 |
+
pickle.dump(real_scores, f, protocol=3)
|
191 |
+
with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'wb') as f:
|
192 |
+
pickle.dump(fake_scores, f, protocol=3)
|
193 |
+
else:
|
194 |
+
with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'rb') as f:
|
195 |
+
orig_fnames = pickle.load(f)
|
196 |
+
with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'rb') as f:
|
197 |
+
mask_fnames = pickle.load(f)
|
198 |
+
with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'rb') as f:
|
199 |
+
mask2real_fname = pickle.load(f)
|
200 |
+
with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'rb') as f:
|
201 |
+
mask2fake_fname = pickle.load(f)
|
202 |
+
with open(os.path.join(latents_dir, 'real_scores.pkl'), 'rb') as f:
|
203 |
+
real_scores = pickle.load(f)
|
204 |
+
with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'rb') as f:
|
205 |
+
fake_scores = pickle.load(f)
|
206 |
+
|
207 |
+
real_info = pd.DataFrame(data=[dict(real_fname=fname,
|
208 |
+
real_score=score)
|
209 |
+
for fname, score
|
210 |
+
in zip(orig_fnames, real_scores)])
|
211 |
+
real_info.set_index('real_fname', drop=True, inplace=True)
|
212 |
+
|
213 |
+
fake_info = pd.DataFrame(data=[dict(mask_fname=fname,
|
214 |
+
fake_fname=mask2fake_fname[fname],
|
215 |
+
real_fname=mask2real_fname[fname],
|
216 |
+
fake_score=score)
|
217 |
+
for fname, score
|
218 |
+
in zip(mask_fnames, fake_scores)])
|
219 |
+
fake_info = fake_info.join(real_info, on='real_fname', how='left')
|
220 |
+
fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
|
221 |
+
|
222 |
+
fake_stats_by_real = fake_info.groupby('real_fname')['fake_score'].describe()[['mean', 'std']].rename(
|
223 |
+
{'mean': 'mean_fake_by_real', 'std': 'std_fake_by_real'}, axis=1)
|
224 |
+
fake_info = fake_info.join(fake_stats_by_real, on='real_fname', rsuffix='stat_by_real')
|
225 |
+
fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
|
226 |
+
fake_info.to_csv(os.path.join(latents_dir, 'join_scores_table.csv'), sep='\t', index=False)
|
227 |
+
|
228 |
+
fake_scores_table = fake_info.set_index('mask_fname')['fake_score'].to_frame()
|
229 |
+
real_scores_table = fake_info.set_index('real_fname')['real_score'].drop_duplicates().to_frame()
|
230 |
+
|
231 |
+
fig, (ax1, ax2) = plt.subplots(1, 2)
|
232 |
+
ax1.hist(fake_scores)
|
233 |
+
ax2.hist(real_scores)
|
234 |
+
fig.tight_layout()
|
235 |
+
fig.savefig(os.path.join(args.outpath, 'global_scores_hist.png'))
|
236 |
+
plt.close(fig)
|
237 |
+
|
238 |
+
global_worst_masks = fake_info.sort_values('fake_score', ascending=True)['mask_fname'].iloc[:config.take_global_top].to_list()
|
239 |
+
global_best_masks = fake_info.sort_values('fake_score', ascending=False)['mask_fname'].iloc[:config.take_global_top].to_list()
|
240 |
+
save_global_samples(global_worst_masks, mask2real_fname, mask2fake_fname, global_worst_dir, real_scores_table, fake_scores_table)
|
241 |
+
save_global_samples(global_best_masks, mask2real_fname, mask2fake_fname, global_best_dir, real_scores_table, fake_scores_table)
|
242 |
+
|
243 |
+
# grouped by real
|
244 |
+
worst_samples_by_real = fake_info.groupby('real_fname').apply(
|
245 |
+
lambda d: d.set_index('mask_fname')['fake_score'].idxmin()).to_frame().rename({0: 'worst'}, axis=1)
|
246 |
+
best_samples_by_real = fake_info.groupby('real_fname').apply(
|
247 |
+
lambda d: d.set_index('mask_fname')['fake_score'].idxmax()).to_frame().rename({0: 'best'}, axis=1)
|
248 |
+
worst_best_by_real = pd.concat([worst_samples_by_real, best_samples_by_real], axis=1)
|
249 |
+
|
250 |
+
worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'worst_score'}, axis=1),
|
251 |
+
on='worst')
|
252 |
+
worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'best_score'}, axis=1),
|
253 |
+
on='best')
|
254 |
+
worst_best_by_real = worst_best_by_real.join(real_scores_table)
|
255 |
+
|
256 |
+
worst_best_by_real['best_worst_score_diff'] = worst_best_by_real['best_score'] - worst_best_by_real['worst_score']
|
257 |
+
worst_best_by_real['real_best_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['best_score']
|
258 |
+
worst_best_by_real['real_worst_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['worst_score']
|
259 |
+
|
260 |
+
worst_best_by_best_worst_score_diff_min = worst_best_by_real.sort_values('best_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
|
261 |
+
worst_best_by_best_worst_score_diff_max = worst_best_by_real.sort_values('best_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
|
262 |
+
save_samples_by_real(worst_best_by_best_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_min_dir)
|
263 |
+
save_samples_by_real(worst_best_by_best_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_max_dir)
|
264 |
+
|
265 |
+
worst_best_by_real_best_score_diff_min = worst_best_by_real.sort_values('real_best_score_diff', ascending=True).iloc[:config.take_worst_best_top]
|
266 |
+
worst_best_by_real_best_score_diff_max = worst_best_by_real.sort_values('real_best_score_diff', ascending=False).iloc[:config.take_worst_best_top]
|
267 |
+
save_samples_by_real(worst_best_by_real_best_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_min_dir)
|
268 |
+
save_samples_by_real(worst_best_by_real_best_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_max_dir)
|
269 |
+
|
270 |
+
worst_best_by_real_worst_score_diff_min = worst_best_by_real.sort_values('real_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
|
271 |
+
worst_best_by_real_worst_score_diff_max = worst_best_by_real.sort_values('real_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
|
272 |
+
save_samples_by_real(worst_best_by_real_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_min_dir)
|
273 |
+
save_samples_by_real(worst_best_by_real_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_max_dir)
|
274 |
+
|
275 |
+
# analyze what change of mask causes bigger change of score
|
276 |
+
overlapping_mask_fname_pairs = []
|
277 |
+
overlapping_mask_fname_score_diffs = []
|
278 |
+
for cur_real_fname in orig_fnames:
|
279 |
+
cur_fakes_info = fake_info[fake_info['real_fname'] == cur_real_fname]
|
280 |
+
cur_mask_fnames = sorted(cur_fakes_info['mask_fname'].unique())
|
281 |
+
|
282 |
+
cur_mask_pairs_and_scores = Parallel(args.n_jobs)(
|
283 |
+
delayed(extract_overlapping_masks)(cur_mask_fnames, i, fake_scores_table)
|
284 |
+
for i in range(len(cur_mask_fnames) - 1)
|
285 |
+
)
|
286 |
+
for cur_pairs, cur_scores in cur_mask_pairs_and_scores:
|
287 |
+
overlapping_mask_fname_pairs.extend(cur_pairs)
|
288 |
+
overlapping_mask_fname_score_diffs.extend(cur_scores)
|
289 |
+
|
290 |
+
overlapping_mask_fname_pairs = np.asarray(overlapping_mask_fname_pairs)
|
291 |
+
overlapping_mask_fname_score_diffs = np.asarray(overlapping_mask_fname_score_diffs)
|
292 |
+
overlapping_sort_idx = np.argsort(overlapping_mask_fname_score_diffs)
|
293 |
+
overlapping_mask_fname_pairs = overlapping_mask_fname_pairs[overlapping_sort_idx]
|
294 |
+
overlapping_mask_fname_score_diffs = overlapping_mask_fname_score_diffs[overlapping_sort_idx]
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
if __name__ == '__main__':
|
302 |
+
import argparse
|
303 |
+
|
304 |
+
aparser = argparse.ArgumentParser()
|
305 |
+
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
|
306 |
+
aparser.add_argument('datadir', type=str,
|
307 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
308 |
+
aparser.add_argument('predictdir', type=str,
|
309 |
+
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
|
310 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
311 |
+
aparser.add_argument('--only-report', action='store_true',
|
312 |
+
help='Whether to skip prediction and feature extraction, '
|
313 |
+
'load all the possible latents and proceed with report only')
|
314 |
+
aparser.add_argument('--n-jobs', type=int, default=8, help='how many processes to use for pair mask mining')
|
315 |
+
|
316 |
+
main(aparser.parse_args())
|
lama/bin/blur_predicts.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
|
10 |
+
from saicinpainting.evaluation.utils import load_yaml
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
config = load_yaml(args.config)
|
15 |
+
|
16 |
+
if not args.predictdir.endswith('/'):
|
17 |
+
args.predictdir += '/'
|
18 |
+
|
19 |
+
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
|
20 |
+
|
21 |
+
os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
|
22 |
+
|
23 |
+
for img_i in tqdm.trange(len(dataset)):
|
24 |
+
pred_fname = dataset.pred_filenames[img_i]
|
25 |
+
cur_out_fname = os.path.join(args.outpath, pred_fname[len(args.predictdir):])
|
26 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
27 |
+
|
28 |
+
sample = dataset[img_i]
|
29 |
+
img = sample['image']
|
30 |
+
mask = sample['mask']
|
31 |
+
inpainted = sample['inpainted']
|
32 |
+
|
33 |
+
inpainted_blurred = cv2.GaussianBlur(np.transpose(inpainted, (1, 2, 0)),
|
34 |
+
ksize=(args.k, args.k),
|
35 |
+
sigmaX=args.s, sigmaY=args.s,
|
36 |
+
borderType=cv2.BORDER_REFLECT)
|
37 |
+
|
38 |
+
cur_res = (1 - mask) * np.transpose(img, (1, 2, 0)) + mask * inpainted_blurred
|
39 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
40 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
41 |
+
cv2.imwrite(cur_out_fname, cur_res)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
import argparse
|
46 |
+
|
47 |
+
aparser = argparse.ArgumentParser()
|
48 |
+
aparser.add_argument('config', type=str, help='Path to evaluation config')
|
49 |
+
aparser.add_argument('datadir', type=str,
|
50 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
51 |
+
aparser.add_argument('predictdir', type=str,
|
52 |
+
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
|
53 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
54 |
+
aparser.add_argument('-s', type=float, default=0.1, help='Gaussian blur sigma')
|
55 |
+
aparser.add_argument('-k', type=int, default=5, help='Kernel size in gaussian blur')
|
56 |
+
|
57 |
+
main(aparser.parse_args())
|
lama/bin/calc_dataset_stats.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
from scipy.ndimage.morphology import distance_transform_edt
|
8 |
+
|
9 |
+
from saicinpainting.evaluation.data import InpaintingDataset
|
10 |
+
from saicinpainting.evaluation.vis import save_item_for_vis
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
dataset = InpaintingDataset(args.datadir, img_suffix='.png')
|
15 |
+
|
16 |
+
area_bins = np.linspace(0, 1, args.area_bins + 1)
|
17 |
+
|
18 |
+
heights = []
|
19 |
+
widths = []
|
20 |
+
image_areas = []
|
21 |
+
hole_areas = []
|
22 |
+
hole_area_percents = []
|
23 |
+
known_pixel_distances = []
|
24 |
+
|
25 |
+
area_bins_count = np.zeros(args.area_bins)
|
26 |
+
area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
|
27 |
+
|
28 |
+
bin2i = [[] for _ in range(args.area_bins)]
|
29 |
+
|
30 |
+
for i, item in enumerate(tqdm.tqdm(dataset)):
|
31 |
+
h, w = item['image'].shape[1:]
|
32 |
+
heights.append(h)
|
33 |
+
widths.append(w)
|
34 |
+
full_area = h * w
|
35 |
+
image_areas.append(full_area)
|
36 |
+
bin_mask = item['mask'] > 0.5
|
37 |
+
hole_area = bin_mask.sum()
|
38 |
+
hole_areas.append(hole_area)
|
39 |
+
hole_percent = hole_area / full_area
|
40 |
+
hole_area_percents.append(hole_percent)
|
41 |
+
bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
|
42 |
+
area_bins_count[bin_i] += 1
|
43 |
+
bin2i[bin_i].append(i)
|
44 |
+
|
45 |
+
cur_dist = distance_transform_edt(bin_mask)
|
46 |
+
cur_dist_inside_mask = cur_dist[bin_mask]
|
47 |
+
known_pixel_distances.append(cur_dist_inside_mask.mean())
|
48 |
+
|
49 |
+
os.makedirs(args.outdir, exist_ok=True)
|
50 |
+
with open(os.path.join(args.outdir, 'summary.txt'), 'w') as f:
|
51 |
+
f.write(f'''Location: {args.datadir}
|
52 |
+
|
53 |
+
Number of samples: {len(dataset)}
|
54 |
+
|
55 |
+
Image height: min {min(heights):5d} max {max(heights):5d} mean {np.mean(heights):.2f}
|
56 |
+
Image width: min {min(widths):5d} max {max(widths):5d} mean {np.mean(widths):.2f}
|
57 |
+
Image area: min {min(image_areas):7d} max {max(image_areas):7d} mean {np.mean(image_areas):.2f}
|
58 |
+
Hole area: min {min(hole_areas):7d} max {max(hole_areas):7d} mean {np.mean(hole_areas):.2f}
|
59 |
+
Hole area %: min {min(hole_area_percents) * 100:2.2f} max {max(hole_area_percents) * 100:2.2f} mean {np.mean(hole_area_percents) * 100:2.2f}
|
60 |
+
Dist 2known: min {min(known_pixel_distances):2.2f} max {max(known_pixel_distances):2.2f} mean {np.mean(known_pixel_distances):2.2f} median {np.median(known_pixel_distances):2.2f}
|
61 |
+
|
62 |
+
Stats by hole area %:
|
63 |
+
''')
|
64 |
+
for bin_i in range(args.area_bins):
|
65 |
+
f.write(f'{area_bin_titles[bin_i]}%: '
|
66 |
+
f'samples number {area_bins_count[bin_i]}, '
|
67 |
+
f'{area_bins_count[bin_i] / len(dataset) * 100:.1f}%\n')
|
68 |
+
|
69 |
+
for bin_i in range(args.area_bins):
|
70 |
+
bindir = os.path.join(args.outdir, 'samples', area_bin_titles[bin_i])
|
71 |
+
os.makedirs(bindir, exist_ok=True)
|
72 |
+
bin_idx = bin2i[bin_i]
|
73 |
+
for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
|
74 |
+
save_item_for_vis(dataset[sample_i], os.path.join(bindir, f'{sample_i}.png'))
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
import argparse
|
79 |
+
|
80 |
+
aparser = argparse.ArgumentParser()
|
81 |
+
aparser.add_argument('datadir', type=str,
|
82 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
83 |
+
aparser.add_argument('outdir', type=str, help='Where to put results')
|
84 |
+
aparser.add_argument('--samples-n', type=int, default=10,
|
85 |
+
help='Number of sample images with masks to copy for visualization for each area bin')
|
86 |
+
aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
|
87 |
+
|
88 |
+
main(aparser.parse_args())
|
lama/bin/debug/analyze_overlapping_masks.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
BASEDIR="$(dirname $0)"
|
4 |
+
|
5 |
+
# paths are valid for mml7
|
6 |
+
|
7 |
+
# select images
|
8 |
+
#ls /data/inpainting/work/data/train | shuf | head -2000 | xargs -n1 -I{} cp {} /data/inpainting/mask_analysis/src
|
9 |
+
|
10 |
+
# generate masks
|
11 |
+
#"$BASEDIR/../gen_debug_mask_dataset.py" \
|
12 |
+
# "$BASEDIR/../../configs/debug_mask_gen.yaml" \
|
13 |
+
# "/data/inpainting/mask_analysis/src" \
|
14 |
+
# "/data/inpainting/mask_analysis/generated"
|
15 |
+
|
16 |
+
# predict
|
17 |
+
#"$BASEDIR/../predict.py" \
|
18 |
+
# model.path="simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/saved_checkpoint/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15_epoch22-step-574999" \
|
19 |
+
# indir="/data/inpainting/mask_analysis/generated" \
|
20 |
+
# outdir="/data/inpainting/mask_analysis/predicted" \
|
21 |
+
# dataset.img_suffix=.jpg \
|
22 |
+
# +out_ext=.jpg
|
23 |
+
|
24 |
+
# analyze good and bad samples
|
25 |
+
"$BASEDIR/../analyze_errors.py" \
|
26 |
+
--only-report \
|
27 |
+
--n-jobs 8 \
|
28 |
+
"$BASEDIR/../../configs/analyze_mask_errors.yaml" \
|
29 |
+
"/data/inpainting/mask_analysis/small/generated" \
|
30 |
+
"/data/inpainting/mask_analysis/small/predicted" \
|
31 |
+
"/data/inpainting/mask_analysis/small/report"
|
lama/bin/evaluate_predicts.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
|
8 |
+
from saicinpainting.evaluation.evaluator import InpaintingEvaluator, lpips_fid100_f1
|
9 |
+
from saicinpainting.evaluation.losses.base_loss import SegmentationAwareSSIM, \
|
10 |
+
SegmentationClassStats, SSIMScore, LPIPSScore, FIDScore, SegmentationAwareLPIPS, SegmentationAwareFID
|
11 |
+
from saicinpainting.evaluation.utils import load_yaml
|
12 |
+
|
13 |
+
|
14 |
+
def main(args):
|
15 |
+
config = load_yaml(args.config)
|
16 |
+
|
17 |
+
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
|
18 |
+
|
19 |
+
metrics = {
|
20 |
+
'ssim': SSIMScore(),
|
21 |
+
'lpips': LPIPSScore(),
|
22 |
+
'fid': FIDScore()
|
23 |
+
}
|
24 |
+
enable_segm = config.get('segmentation', dict(enable=False)).get('enable', False)
|
25 |
+
if enable_segm:
|
26 |
+
weights_path = os.path.expandvars(config.segmentation.weights_path)
|
27 |
+
metrics.update(dict(
|
28 |
+
segm_stats=SegmentationClassStats(weights_path=weights_path),
|
29 |
+
segm_ssim=SegmentationAwareSSIM(weights_path=weights_path),
|
30 |
+
segm_lpips=SegmentationAwareLPIPS(weights_path=weights_path),
|
31 |
+
segm_fid=SegmentationAwareFID(weights_path=weights_path)
|
32 |
+
))
|
33 |
+
evaluator = InpaintingEvaluator(dataset, scores=metrics,
|
34 |
+
integral_title='lpips_fid100_f1', integral_func=lpips_fid100_f1,
|
35 |
+
**config.evaluator_kwargs)
|
36 |
+
|
37 |
+
os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
|
38 |
+
|
39 |
+
results = evaluator.evaluate()
|
40 |
+
|
41 |
+
results = pd.DataFrame(results).stack(1).unstack(0)
|
42 |
+
results.dropna(axis=1, how='all', inplace=True)
|
43 |
+
results.to_csv(args.outpath, sep='\t', float_format='%.4f')
|
44 |
+
|
45 |
+
if enable_segm:
|
46 |
+
only_short_results = results[[c for c in results.columns if not c[0].startswith('segm_')]].dropna(axis=1, how='all')
|
47 |
+
only_short_results.to_csv(args.outpath + '_short', sep='\t', float_format='%.4f')
|
48 |
+
|
49 |
+
print(only_short_results)
|
50 |
+
|
51 |
+
segm_metrics_results = results[['segm_ssim', 'segm_lpips', 'segm_fid']].dropna(axis=1, how='all').transpose().unstack(0).reorder_levels([1, 0], axis=1)
|
52 |
+
segm_metrics_results.drop(['mean', 'std'], axis=0, inplace=True)
|
53 |
+
|
54 |
+
segm_stats_results = results['segm_stats'].dropna(axis=1, how='all').transpose()
|
55 |
+
segm_stats_results.index = pd.MultiIndex.from_tuples(n.split('/') for n in segm_stats_results.index)
|
56 |
+
segm_stats_results = segm_stats_results.unstack(0).reorder_levels([1, 0], axis=1)
|
57 |
+
segm_stats_results.sort_index(axis=1, inplace=True)
|
58 |
+
segm_stats_results.dropna(axis=0, how='all', inplace=True)
|
59 |
+
|
60 |
+
segm_results = pd.concat([segm_metrics_results, segm_stats_results], axis=1, sort=True)
|
61 |
+
segm_results.sort_values(('mask_freq', 'total'), ascending=False, inplace=True)
|
62 |
+
|
63 |
+
segm_results.to_csv(args.outpath + '_segm', sep='\t', float_format='%.4f')
|
64 |
+
else:
|
65 |
+
print(results)
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
import argparse
|
70 |
+
|
71 |
+
aparser = argparse.ArgumentParser()
|
72 |
+
aparser.add_argument('config', type=str, help='Path to evaluation config')
|
73 |
+
aparser.add_argument('datadir', type=str,
|
74 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
75 |
+
aparser.add_argument('predictdir', type=str,
|
76 |
+
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
|
77 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
78 |
+
|
79 |
+
main(aparser.parse_args())
|
lama/bin/evaluator_example.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from skimage import io
|
7 |
+
from skimage.transform import resize
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
from saicinpainting.evaluation.evaluator import InpaintingEvaluator
|
11 |
+
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
|
12 |
+
|
13 |
+
|
14 |
+
class SimpleImageDataset(Dataset):
|
15 |
+
def __init__(self, root_dir, image_size=(400, 600)):
|
16 |
+
self.root_dir = root_dir
|
17 |
+
self.files = sorted(os.listdir(root_dir))
|
18 |
+
self.image_size = image_size
|
19 |
+
|
20 |
+
def __getitem__(self, index):
|
21 |
+
img_name = os.path.join(self.root_dir, self.files[index])
|
22 |
+
image = io.imread(img_name)
|
23 |
+
image = resize(image, self.image_size, anti_aliasing=True)
|
24 |
+
image = torch.FloatTensor(image).permute(2, 0, 1)
|
25 |
+
return image
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.files)
|
29 |
+
|
30 |
+
|
31 |
+
def create_rectangle_mask(height, width):
|
32 |
+
mask = np.ones((height, width))
|
33 |
+
up_left_corner = width // 4, height // 4
|
34 |
+
down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1)
|
35 |
+
cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED)
|
36 |
+
return mask
|
37 |
+
|
38 |
+
|
39 |
+
class Model():
|
40 |
+
def __call__(self, img_batch, mask_batch):
|
41 |
+
mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None]
|
42 |
+
inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :]
|
43 |
+
return inpainted
|
44 |
+
|
45 |
+
|
46 |
+
class SimpleImageSquareMaskDataset(Dataset):
|
47 |
+
def __init__(self, dataset):
|
48 |
+
self.dataset = dataset
|
49 |
+
self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size))
|
50 |
+
self.model = Model()
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
img = self.dataset[index]
|
54 |
+
mask = self.mask.clone()
|
55 |
+
inpainted = self.model(img[None, ...], mask[None, ...])
|
56 |
+
return dict(image=img, mask=mask, inpainted=inpainted)
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.dataset)
|
60 |
+
|
61 |
+
|
62 |
+
dataset = SimpleImageDataset('imgs')
|
63 |
+
mask_dataset = SimpleImageSquareMaskDataset(dataset)
|
64 |
+
model = Model()
|
65 |
+
metrics = {
|
66 |
+
'ssim': SSIMScore(),
|
67 |
+
'lpips': LPIPSScore(),
|
68 |
+
'fid': FIDScore()
|
69 |
+
}
|
70 |
+
|
71 |
+
evaluator = InpaintingEvaluator(
|
72 |
+
mask_dataset, scores=metrics, batch_size=3, area_grouping=True
|
73 |
+
)
|
74 |
+
|
75 |
+
results = evaluator.evaluate(model)
|
76 |
+
print(results)
|
lama/bin/extract_masks.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL.Image as Image
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def main(args):
|
7 |
+
if not args.indir.endswith('/'):
|
8 |
+
args.indir += '/'
|
9 |
+
os.makedirs(args.outdir, exist_ok=True)
|
10 |
+
|
11 |
+
src_images = [
|
12 |
+
args.indir+fname for fname in os.listdir(args.indir)]
|
13 |
+
|
14 |
+
tgt_masks = [
|
15 |
+
args.outdir+fname[:-4] + f'_mask000.png'
|
16 |
+
for fname in os.listdir(args.indir)]
|
17 |
+
|
18 |
+
for img_name, msk_name in zip(src_images, tgt_masks):
|
19 |
+
#print(img)
|
20 |
+
#print(msk)
|
21 |
+
|
22 |
+
image = Image.open(img_name).convert('RGB')
|
23 |
+
image = np.transpose(np.array(image), (2, 0, 1))
|
24 |
+
|
25 |
+
mask = (image == 255).astype(int)
|
26 |
+
|
27 |
+
print(mask.dtype, mask.shape)
|
28 |
+
|
29 |
+
|
30 |
+
Image.fromarray(
|
31 |
+
np.clip(mask[0,:,:] * 255, 0, 255).astype('uint8'),mode='L'
|
32 |
+
).save(msk_name)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
'''
|
38 |
+
for infile in src_images:
|
39 |
+
try:
|
40 |
+
file_relpath = infile[len(indir):]
|
41 |
+
img_outpath = os.path.join(outdir, file_relpath)
|
42 |
+
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
43 |
+
|
44 |
+
image = Image.open(infile).convert('RGB')
|
45 |
+
|
46 |
+
mask =
|
47 |
+
|
48 |
+
Image.fromarray(
|
49 |
+
np.clip(
|
50 |
+
cur_mask * 255, 0, 255).astype('uint8'),
|
51 |
+
mode='L'
|
52 |
+
).save(cur_basename + f'_mask{i:03d}.png')
|
53 |
+
'''
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
import argparse
|
59 |
+
aparser = argparse.ArgumentParser()
|
60 |
+
aparser.add_argument('--indir', type=str, help='Path to folder with images')
|
61 |
+
aparser.add_argument('--outdir', type=str, help='Path to folder to store aligned images and masks to')
|
62 |
+
|
63 |
+
main(aparser.parse_args())
|
lama/bin/filter_sharded_dataset.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
|
8 |
+
import braceexpand
|
9 |
+
import webdataset as wds
|
10 |
+
|
11 |
+
DEFAULT_CATS_FILE = os.path.join(os.path.dirname(__file__), '..', 'configs', 'places2-categories_157.txt')
|
12 |
+
|
13 |
+
def is_good_key(key, cats):
|
14 |
+
return any(c in key for c in cats)
|
15 |
+
|
16 |
+
|
17 |
+
def main(args):
|
18 |
+
if args.categories == 'nofilter':
|
19 |
+
good_categories = None
|
20 |
+
else:
|
21 |
+
with open(args.categories, 'r') as f:
|
22 |
+
good_categories = set(line.strip().split(' ')[0] for line in f if line.strip())
|
23 |
+
|
24 |
+
all_input_files = list(braceexpand.braceexpand(args.infile))
|
25 |
+
chunk_size = int(math.ceil(len(all_input_files) / args.n_read_streams))
|
26 |
+
|
27 |
+
input_iterators = [iter(wds.Dataset(all_input_files[start : start + chunk_size]).shuffle(args.shuffle_buffer))
|
28 |
+
for start in range(0, len(all_input_files), chunk_size)]
|
29 |
+
output_datasets = [wds.ShardWriter(args.outpattern.format(i)) for i in range(args.n_write_streams)]
|
30 |
+
|
31 |
+
good_readers = list(range(len(input_iterators)))
|
32 |
+
step_i = 0
|
33 |
+
good_samples = 0
|
34 |
+
bad_samples = 0
|
35 |
+
while len(good_readers) > 0:
|
36 |
+
if step_i % args.print_freq == 0:
|
37 |
+
print(f'Iterations done {step_i}; readers alive {good_readers}; good samples {good_samples}; bad samples {bad_samples}')
|
38 |
+
|
39 |
+
step_i += 1
|
40 |
+
|
41 |
+
ri = random.choice(good_readers)
|
42 |
+
try:
|
43 |
+
sample = next(input_iterators[ri])
|
44 |
+
except StopIteration:
|
45 |
+
good_readers = list(set(good_readers) - {ri})
|
46 |
+
continue
|
47 |
+
|
48 |
+
if good_categories is not None and not is_good_key(sample['__key__'], good_categories):
|
49 |
+
bad_samples += 1
|
50 |
+
continue
|
51 |
+
|
52 |
+
wi = random.randint(0, args.n_write_streams - 1)
|
53 |
+
output_datasets[wi].write(sample)
|
54 |
+
good_samples += 1
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
import argparse
|
59 |
+
|
60 |
+
aparser = argparse.ArgumentParser()
|
61 |
+
aparser.add_argument('--categories', type=str, default=DEFAULT_CATS_FILE)
|
62 |
+
aparser.add_argument('--shuffle-buffer', type=int, default=10000)
|
63 |
+
aparser.add_argument('--n-read-streams', type=int, default=10)
|
64 |
+
aparser.add_argument('--n-write-streams', type=int, default=10)
|
65 |
+
aparser.add_argument('--print-freq', type=int, default=1000)
|
66 |
+
aparser.add_argument('infile', type=str)
|
67 |
+
aparser.add_argument('outpattern', type=str)
|
68 |
+
|
69 |
+
main(aparser.parse_args())
|
lama/bin/gen_debug_mask_dataset.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
|
6 |
+
import PIL.Image as Image
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import tqdm
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
|
13 |
+
from saicinpainting.evaluation.utils import load_yaml
|
14 |
+
|
15 |
+
|
16 |
+
def generate_masks_for_img(infile, outmask_pattern, mask_size=200, step=0.5):
|
17 |
+
inimg = Image.open(infile)
|
18 |
+
width, height = inimg.size
|
19 |
+
step_abs = int(mask_size * step)
|
20 |
+
|
21 |
+
mask = np.zeros((height, width), dtype='uint8')
|
22 |
+
mask_i = 0
|
23 |
+
|
24 |
+
for start_vertical in range(0, height - step_abs, step_abs):
|
25 |
+
for start_horizontal in range(0, width - step_abs, step_abs):
|
26 |
+
mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 255
|
27 |
+
|
28 |
+
cv2.imwrite(outmask_pattern.format(mask_i), mask)
|
29 |
+
|
30 |
+
mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 0
|
31 |
+
mask_i += 1
|
32 |
+
|
33 |
+
|
34 |
+
def main(args):
|
35 |
+
if not args.indir.endswith('/'):
|
36 |
+
args.indir += '/'
|
37 |
+
if not args.outdir.endswith('/'):
|
38 |
+
args.outdir += '/'
|
39 |
+
|
40 |
+
config = load_yaml(args.config)
|
41 |
+
|
42 |
+
in_files = list(glob.glob(os.path.join(args.indir, '**', f'*{config.img_ext}'), recursive=True))
|
43 |
+
for infile in tqdm.tqdm(in_files):
|
44 |
+
outimg = args.outdir + infile[len(args.indir):]
|
45 |
+
outmask_pattern = outimg[:-len(config.img_ext)] + '_mask{:04d}.png'
|
46 |
+
|
47 |
+
os.makedirs(os.path.dirname(outimg), exist_ok=True)
|
48 |
+
shutil.copy2(infile, outimg)
|
49 |
+
|
50 |
+
generate_masks_for_img(infile, outmask_pattern, **config.gen_kwargs)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
import argparse
|
55 |
+
|
56 |
+
aparser = argparse.ArgumentParser()
|
57 |
+
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
|
58 |
+
aparser.add_argument('indir', type=str, help='Path to folder with images')
|
59 |
+
aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
|
60 |
+
|
61 |
+
main(aparser.parse_args())
|
lama/bin/gen_mask_dataset.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import traceback
|
7 |
+
|
8 |
+
import PIL.Image as Image
|
9 |
+
import numpy as np
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
|
12 |
+
from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
|
13 |
+
from saicinpainting.evaluation.utils import load_yaml, SmallMode
|
14 |
+
from saicinpainting.training.data.masks import MixedMaskGenerator
|
15 |
+
|
16 |
+
|
17 |
+
class MakeManyMasksWrapper:
|
18 |
+
def __init__(self, impl, variants_n=2):
|
19 |
+
self.impl = impl
|
20 |
+
self.variants_n = variants_n
|
21 |
+
|
22 |
+
def get_masks(self, img):
|
23 |
+
img = np.transpose(np.array(img), (2, 0, 1))
|
24 |
+
return [self.impl(img)[0] for _ in range(self.variants_n)]
|
25 |
+
|
26 |
+
|
27 |
+
def process_images(src_images, indir, outdir, config):
|
28 |
+
if config.generator_kind == 'segmentation':
|
29 |
+
mask_generator = SegmentationMask(**config.mask_generator_kwargs)
|
30 |
+
elif config.generator_kind == 'random':
|
31 |
+
variants_n = config.mask_generator_kwargs.pop('variants_n', 2)
|
32 |
+
mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs),
|
33 |
+
variants_n=variants_n)
|
34 |
+
else:
|
35 |
+
raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
|
36 |
+
|
37 |
+
max_tamper_area = config.get('max_tamper_area', 1)
|
38 |
+
|
39 |
+
for infile in src_images:
|
40 |
+
try:
|
41 |
+
file_relpath = infile[len(indir):]
|
42 |
+
img_outpath = os.path.join(outdir, file_relpath)
|
43 |
+
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
44 |
+
|
45 |
+
image = Image.open(infile).convert('RGB')
|
46 |
+
|
47 |
+
# scale input image to output resolution and filter smaller images
|
48 |
+
if min(image.size) < config.cropping.out_min_size:
|
49 |
+
handle_small_mode = SmallMode(config.cropping.handle_small_mode)
|
50 |
+
if handle_small_mode == SmallMode.DROP:
|
51 |
+
continue
|
52 |
+
elif handle_small_mode == SmallMode.UPSCALE:
|
53 |
+
factor = config.cropping.out_min_size / min(image.size)
|
54 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
55 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
56 |
+
else:
|
57 |
+
factor = config.cropping.out_min_size / min(image.size)
|
58 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
59 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
60 |
+
|
61 |
+
# generate and select masks
|
62 |
+
src_masks = mask_generator.get_masks(image)
|
63 |
+
|
64 |
+
filtered_image_mask_pairs = []
|
65 |
+
for cur_mask in src_masks:
|
66 |
+
if config.cropping.out_square_crop:
|
67 |
+
(crop_left,
|
68 |
+
crop_top,
|
69 |
+
crop_right,
|
70 |
+
crop_bottom) = propose_random_square_crop(cur_mask,
|
71 |
+
min_overlap=config.cropping.crop_min_overlap)
|
72 |
+
cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
|
73 |
+
cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
|
74 |
+
else:
|
75 |
+
cur_image = image
|
76 |
+
|
77 |
+
if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
|
78 |
+
continue
|
79 |
+
|
80 |
+
filtered_image_mask_pairs.append((cur_image, cur_mask))
|
81 |
+
|
82 |
+
mask_indices = np.random.choice(len(filtered_image_mask_pairs),
|
83 |
+
size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
|
84 |
+
replace=False)
|
85 |
+
|
86 |
+
# crop masks; save masks together with input image
|
87 |
+
mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
|
88 |
+
for i, idx in enumerate(mask_indices):
|
89 |
+
cur_image, cur_mask = filtered_image_mask_pairs[idx]
|
90 |
+
cur_basename = mask_basename + f'_crop{i:03d}'
|
91 |
+
Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
|
92 |
+
mode='L').save(cur_basename + f'_mask{i:03d}.png')
|
93 |
+
cur_image.save(cur_basename + '.png')
|
94 |
+
except KeyboardInterrupt:
|
95 |
+
return
|
96 |
+
except Exception as ex:
|
97 |
+
print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
|
98 |
+
|
99 |
+
|
100 |
+
def main(args):
|
101 |
+
if not args.indir.endswith('/'):
|
102 |
+
args.indir += '/'
|
103 |
+
|
104 |
+
os.makedirs(args.outdir, exist_ok=True)
|
105 |
+
|
106 |
+
config = load_yaml(args.config)
|
107 |
+
|
108 |
+
in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True))
|
109 |
+
if args.n_jobs == 0:
|
110 |
+
process_images(in_files, args.indir, args.outdir, config)
|
111 |
+
else:
|
112 |
+
in_files_n = len(in_files)
|
113 |
+
chunk_size = in_files_n // args.n_jobs + (1 if in_files_n % args.n_jobs > 0 else 0)
|
114 |
+
Parallel(n_jobs=args.n_jobs)(
|
115 |
+
delayed(process_images)(in_files[start:start+chunk_size], args.indir, args.outdir, config)
|
116 |
+
for start in range(0, len(in_files), chunk_size)
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
import argparse
|
122 |
+
|
123 |
+
aparser = argparse.ArgumentParser()
|
124 |
+
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
|
125 |
+
aparser.add_argument('indir', type=str, help='Path to folder with images')
|
126 |
+
aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
|
127 |
+
aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use')
|
128 |
+
aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension')
|
129 |
+
|
130 |
+
main(aparser.parse_args())
|
lama/bin/gen_mask_dataset_hydra.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import traceback
|
7 |
+
import hydra
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import PIL.Image as Image
|
11 |
+
import numpy as np
|
12 |
+
from joblib import Parallel, delayed
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
|
15 |
+
from saicinpainting.evaluation.utils import load_yaml, SmallMode
|
16 |
+
from saicinpainting.training.data.masks import MixedMaskGenerator
|
17 |
+
|
18 |
+
|
19 |
+
class MakeManyMasksWrapper:
|
20 |
+
def __init__(self, impl, variants_n=2):
|
21 |
+
self.impl = impl
|
22 |
+
self.variants_n = variants_n
|
23 |
+
|
24 |
+
def get_masks(self, img):
|
25 |
+
img = np.transpose(np.array(img), (2, 0, 1))
|
26 |
+
return [self.impl(img)[0] for _ in range(self.variants_n)]
|
27 |
+
|
28 |
+
|
29 |
+
def process_images(src_images, indir, outdir, config):
|
30 |
+
if config.generator_kind == 'segmentation':
|
31 |
+
mask_generator = SegmentationMask(**config.mask_generator_kwargs)
|
32 |
+
elif config.generator_kind == 'random':
|
33 |
+
mask_generator_kwargs = OmegaConf.to_container(config.mask_generator_kwargs, resolve=True)
|
34 |
+
variants_n = mask_generator_kwargs.pop('variants_n', 2)
|
35 |
+
mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**mask_generator_kwargs),
|
36 |
+
variants_n=variants_n)
|
37 |
+
else:
|
38 |
+
raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
|
39 |
+
|
40 |
+
max_tamper_area = config.get('max_tamper_area', 1)
|
41 |
+
|
42 |
+
for infile in src_images:
|
43 |
+
try:
|
44 |
+
file_relpath = infile[len(indir):]
|
45 |
+
img_outpath = os.path.join(outdir, file_relpath)
|
46 |
+
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
47 |
+
|
48 |
+
image = Image.open(infile).convert('RGB')
|
49 |
+
|
50 |
+
# scale input image to output resolution and filter smaller images
|
51 |
+
if min(image.size) < config.cropping.out_min_size:
|
52 |
+
handle_small_mode = SmallMode(config.cropping.handle_small_mode)
|
53 |
+
if handle_small_mode == SmallMode.DROP:
|
54 |
+
continue
|
55 |
+
elif handle_small_mode == SmallMode.UPSCALE:
|
56 |
+
factor = config.cropping.out_min_size / min(image.size)
|
57 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
58 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
59 |
+
else:
|
60 |
+
factor = config.cropping.out_min_size / min(image.size)
|
61 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
62 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
63 |
+
|
64 |
+
# generate and select masks
|
65 |
+
src_masks = mask_generator.get_masks(image)
|
66 |
+
|
67 |
+
filtered_image_mask_pairs = []
|
68 |
+
for cur_mask in src_masks:
|
69 |
+
if config.cropping.out_square_crop:
|
70 |
+
(crop_left,
|
71 |
+
crop_top,
|
72 |
+
crop_right,
|
73 |
+
crop_bottom) = propose_random_square_crop(cur_mask,
|
74 |
+
min_overlap=config.cropping.crop_min_overlap)
|
75 |
+
cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
|
76 |
+
cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
|
77 |
+
else:
|
78 |
+
cur_image = image
|
79 |
+
|
80 |
+
if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
|
81 |
+
continue
|
82 |
+
|
83 |
+
filtered_image_mask_pairs.append((cur_image, cur_mask))
|
84 |
+
|
85 |
+
mask_indices = np.random.choice(len(filtered_image_mask_pairs),
|
86 |
+
size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
|
87 |
+
replace=False)
|
88 |
+
|
89 |
+
# crop masks; save masks together with input image
|
90 |
+
mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
|
91 |
+
for i, idx in enumerate(mask_indices):
|
92 |
+
cur_image, cur_mask = filtered_image_mask_pairs[idx]
|
93 |
+
cur_basename = mask_basename + f'_crop{i:03d}'
|
94 |
+
Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
|
95 |
+
mode='L').save(cur_basename + f'_mask{i:03d}.png')
|
96 |
+
cur_image.save(cur_basename + '.png')
|
97 |
+
except KeyboardInterrupt:
|
98 |
+
return
|
99 |
+
except Exception as ex:
|
100 |
+
print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
|
101 |
+
|
102 |
+
|
103 |
+
@hydra.main(config_path='../configs/data_gen/whydra', config_name='random_medium_256.yaml')
|
104 |
+
def main(config: OmegaConf):
|
105 |
+
if not config.indir.endswith('/'):
|
106 |
+
config.indir += '/'
|
107 |
+
|
108 |
+
os.makedirs(config.outdir, exist_ok=True)
|
109 |
+
|
110 |
+
in_files = list(glob.glob(os.path.join(config.indir, '**', f'*.{config.location.extension}'),
|
111 |
+
recursive=True))
|
112 |
+
if config.n_jobs == 0:
|
113 |
+
process_images(in_files, config.indir, config.outdir, config)
|
114 |
+
else:
|
115 |
+
in_files_n = len(in_files)
|
116 |
+
chunk_size = in_files_n // config.n_jobs + (1 if in_files_n % config.n_jobs > 0 else 0)
|
117 |
+
Parallel(n_jobs=config.n_jobs)(
|
118 |
+
delayed(process_images)(in_files[start:start+chunk_size], config.indir, config.outdir, config)
|
119 |
+
for start in range(0, len(in_files), chunk_size)
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
main()
|
lama/bin/gen_outpainting_dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import sys
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
from saicinpainting.evaluation.data import load_image
|
10 |
+
from saicinpainting.evaluation.utils import move_to_device
|
11 |
+
|
12 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
13 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
14 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
15 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
16 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import hydra
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import tqdm
|
23 |
+
import yaml
|
24 |
+
from omegaconf import OmegaConf
|
25 |
+
from torch.utils.data._utils.collate import default_collate
|
26 |
+
|
27 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
28 |
+
from saicinpainting.training.trainers import load_checkpoint
|
29 |
+
from saicinpainting.utils import register_debug_signal_handlers
|
30 |
+
|
31 |
+
LOGGER = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
def main(args):
|
35 |
+
try:
|
36 |
+
if not args.indir.endswith('/'):
|
37 |
+
args.indir += '/'
|
38 |
+
|
39 |
+
for in_img in glob.glob(os.path.join(args.indir, '**', '*' + args.img_suffix), recursive=True):
|
40 |
+
if 'mask' in os.path.basename(in_img):
|
41 |
+
continue
|
42 |
+
|
43 |
+
out_img_path = os.path.join(args.outdir, os.path.splitext(in_img[len(args.indir):])[0] + '.png')
|
44 |
+
out_mask_path = f'{os.path.splitext(out_img_path)[0]}_mask.png'
|
45 |
+
|
46 |
+
os.makedirs(os.path.dirname(out_img_path), exist_ok=True)
|
47 |
+
|
48 |
+
img = load_image(in_img)
|
49 |
+
height, width = img.shape[1:]
|
50 |
+
pad_h, pad_w = int(height * args.coef / 2), int(width * args.coef / 2)
|
51 |
+
|
52 |
+
mask = np.zeros((height, width), dtype='uint8')
|
53 |
+
|
54 |
+
if args.expand:
|
55 |
+
img = np.pad(img, ((0, 0), (pad_h, pad_h), (pad_w, pad_w)))
|
56 |
+
mask = np.pad(mask, ((pad_h, pad_h), (pad_w, pad_w)), mode='constant', constant_values=255)
|
57 |
+
else:
|
58 |
+
mask[:pad_h] = 255
|
59 |
+
mask[-pad_h:] = 255
|
60 |
+
mask[:, :pad_w] = 255
|
61 |
+
mask[:, -pad_w:] = 255
|
62 |
+
|
63 |
+
# img = np.pad(img, ((0, 0), (pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode='symmetric')
|
64 |
+
# mask = np.pad(mask, ((pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode = 'symmetric')
|
65 |
+
|
66 |
+
img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype('uint8')
|
67 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
68 |
+
cv2.imwrite(out_img_path, img)
|
69 |
+
|
70 |
+
cv2.imwrite(out_mask_path, mask)
|
71 |
+
except KeyboardInterrupt:
|
72 |
+
LOGGER.warning('Interrupted by user')
|
73 |
+
except Exception as ex:
|
74 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
75 |
+
sys.exit(1)
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
import argparse
|
80 |
+
|
81 |
+
aparser = argparse.ArgumentParser()
|
82 |
+
aparser.add_argument('indir', type=str, help='Root directory with images')
|
83 |
+
aparser.add_argument('outdir', type=str, help='Where to store results')
|
84 |
+
aparser.add_argument('--img-suffix', type=str, default='.png', help='Input image extension')
|
85 |
+
aparser.add_argument('--expand', action='store_true', help='Generate mask by padding (true) or by cropping (false)')
|
86 |
+
aparser.add_argument('--coef', type=float, default=0.2, help='How much to crop/expand in order to get masks')
|
87 |
+
|
88 |
+
main(aparser.parse_args())
|
lama/bin/make_checkpoint.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def get_checkpoint_files(s):
|
10 |
+
s = s.strip()
|
11 |
+
if ',' in s:
|
12 |
+
return [get_checkpoint_files(chunk) for chunk in s.split(',')]
|
13 |
+
return 'last.ckpt' if s == 'last' else f'{s}.ckpt'
|
14 |
+
|
15 |
+
|
16 |
+
def main(args):
|
17 |
+
checkpoint_fnames = get_checkpoint_files(args.epochs)
|
18 |
+
if isinstance(checkpoint_fnames, str):
|
19 |
+
checkpoint_fnames = [checkpoint_fnames]
|
20 |
+
assert len(checkpoint_fnames) >= 1
|
21 |
+
|
22 |
+
checkpoint_path = os.path.join(args.indir, 'models', checkpoint_fnames[0])
|
23 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
24 |
+
del checkpoint['optimizer_states']
|
25 |
+
|
26 |
+
if len(checkpoint_fnames) > 1:
|
27 |
+
for fname in checkpoint_fnames[1:]:
|
28 |
+
print('sum', fname)
|
29 |
+
sum_tensors_cnt = 0
|
30 |
+
other_cp = torch.load(os.path.join(args.indir, 'models', fname), map_location='cpu')
|
31 |
+
for k in checkpoint['state_dict'].keys():
|
32 |
+
if checkpoint['state_dict'][k].dtype is torch.float:
|
33 |
+
checkpoint['state_dict'][k].data.add_(other_cp['state_dict'][k].data)
|
34 |
+
sum_tensors_cnt += 1
|
35 |
+
print('summed', sum_tensors_cnt, 'tensors')
|
36 |
+
|
37 |
+
for k in checkpoint['state_dict'].keys():
|
38 |
+
if checkpoint['state_dict'][k].dtype is torch.float:
|
39 |
+
checkpoint['state_dict'][k].data.mul_(1 / float(len(checkpoint_fnames)))
|
40 |
+
|
41 |
+
state_dict = checkpoint['state_dict']
|
42 |
+
|
43 |
+
if not args.leave_discriminators:
|
44 |
+
for k in list(state_dict.keys()):
|
45 |
+
if k.startswith('discriminator.'):
|
46 |
+
del state_dict[k]
|
47 |
+
|
48 |
+
if not args.leave_losses:
|
49 |
+
for k in list(state_dict.keys()):
|
50 |
+
if k.startswith('loss_'):
|
51 |
+
del state_dict[k]
|
52 |
+
|
53 |
+
out_checkpoint_path = os.path.join(args.outdir, 'models', 'best.ckpt')
|
54 |
+
os.makedirs(os.path.dirname(out_checkpoint_path), exist_ok=True)
|
55 |
+
|
56 |
+
torch.save(checkpoint, out_checkpoint_path)
|
57 |
+
|
58 |
+
shutil.copy2(os.path.join(args.indir, 'config.yaml'),
|
59 |
+
os.path.join(args.outdir, 'config.yaml'))
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
import argparse
|
64 |
+
|
65 |
+
aparser = argparse.ArgumentParser()
|
66 |
+
aparser.add_argument('indir',
|
67 |
+
help='Path to directory with output of training '
|
68 |
+
'(i.e. directory, which has samples, modules, config.yaml and train.log')
|
69 |
+
aparser.add_argument('outdir',
|
70 |
+
help='Where to put minimal checkpoint, which can be consumed by "bin/predict.py"')
|
71 |
+
aparser.add_argument('--epochs', type=str, default='last',
|
72 |
+
help='Which checkpoint to take. '
|
73 |
+
'Can be "last" or integer - number of epoch')
|
74 |
+
aparser.add_argument('--leave-discriminators', action='store_true',
|
75 |
+
help='If enabled, the state of discriminators will not be removed from the checkpoint')
|
76 |
+
aparser.add_argument('--leave-losses', action='store_true',
|
77 |
+
help='If enabled, weights of nn-based losses (e.g. perceptual) will not be removed')
|
78 |
+
|
79 |
+
main(aparser.parse_args())
|
lama/bin/mask_example.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from skimage import io
|
3 |
+
from skimage.transform import resize
|
4 |
+
|
5 |
+
from saicinpainting.evaluation.masks.mask import SegmentationMask
|
6 |
+
|
7 |
+
im = io.imread('imgs/ex4.jpg')
|
8 |
+
im = resize(im, (512, 1024), anti_aliasing=True)
|
9 |
+
mask_seg = SegmentationMask(num_variants_per_mask=10)
|
10 |
+
mask_examples = mask_seg.get_masks(im)
|
11 |
+
for i, example in enumerate(mask_examples):
|
12 |
+
plt.imshow(example)
|
13 |
+
plt.show()
|
14 |
+
plt.imsave(f'tmp/img_masks/{i}.png', example)
|
lama/bin/paper_runfiles/blur_tests.sh
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/env bash
|
2 |
+
#
|
3 |
+
## !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
|
4 |
+
#
|
5 |
+
## paths to data are valid for mml7
|
6 |
+
#PLACES_ROOT="/data/inpainting/Places365"
|
7 |
+
#OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
|
8 |
+
#
|
9 |
+
#source "$(dirname $0)/env.sh"
|
10 |
+
#
|
11 |
+
#for datadir in test_large_30k # val_large
|
12 |
+
#do
|
13 |
+
# for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
14 |
+
# do
|
15 |
+
# "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
16 |
+
# "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
|
17 |
+
#
|
18 |
+
# "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
19 |
+
# done
|
20 |
+
#
|
21 |
+
# for conf in segm_256 segm_512
|
22 |
+
# do
|
23 |
+
# "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
24 |
+
# "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
|
25 |
+
#
|
26 |
+
# "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
27 |
+
# done
|
28 |
+
#done
|
29 |
+
#
|
30 |
+
#IN_DIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k/random_medium_512"
|
31 |
+
#PRED_DIR="/data/inpainting/predictions/final/images/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37/random_medium_512"
|
32 |
+
#BLUR_OUT_DIR="/data/inpainting/predictions/final/blur/images"
|
33 |
+
#
|
34 |
+
#for b in 0.1
|
35 |
+
#
|
36 |
+
#"$BINDIR/blur_predicts.py" "$BASEDIR/../../configs/eval2.yaml" "$CUR_IN_DIR" "$CUR_OUT_DIR" "$CUR_EVAL_DIR"
|
37 |
+
#
|
lama/bin/paper_runfiles/env.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DIRNAME="$(dirname $0)"
|
2 |
+
DIRNAME="$(realpath ""$DIRNAME"")"
|
3 |
+
|
4 |
+
BINDIR="$DIRNAME/.."
|
5 |
+
SRCDIR="$BINDIR/.."
|
6 |
+
CONFIGDIR="$SRCDIR/configs"
|
7 |
+
|
8 |
+
export PYTHONPATH="$SRCDIR:$PYTHONPATH"
|
lama/bin/paper_runfiles/find_best_checkpoint.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import os
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
|
7 |
+
|
8 |
+
def ssim_fid100_f1(metrics, fid_scale=100):
|
9 |
+
ssim = metrics.loc['total', 'ssim']['mean']
|
10 |
+
fid = metrics.loc['total', 'fid']['mean']
|
11 |
+
fid_rel = max(0, fid_scale - fid) / fid_scale
|
12 |
+
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
|
13 |
+
return f1
|
14 |
+
|
15 |
+
|
16 |
+
def find_best_checkpoint(model_list, models_dir):
|
17 |
+
with open(model_list) as f:
|
18 |
+
models = [m.strip() for m in f.readlines()]
|
19 |
+
with open(f'{model_list}_best', 'w') as f:
|
20 |
+
for model in models:
|
21 |
+
print(model)
|
22 |
+
best_f1 = 0
|
23 |
+
best_epoch = 0
|
24 |
+
best_step = 0
|
25 |
+
with open(os.path.join(models_dir, model, 'train.log')) as fm:
|
26 |
+
lines = fm.readlines()
|
27 |
+
for line_index in range(len(lines)):
|
28 |
+
line = lines[line_index]
|
29 |
+
if 'Validation metrics after epoch' in line:
|
30 |
+
sharp_index = line.index('#')
|
31 |
+
cur_ep = line[sharp_index + 1:]
|
32 |
+
comma_index = cur_ep.index(',')
|
33 |
+
cur_ep = int(cur_ep[:comma_index])
|
34 |
+
total_index = line.index('total ')
|
35 |
+
step = int(line[total_index:].split()[1].strip())
|
36 |
+
total_line = lines[line_index + 5]
|
37 |
+
if not total_line.startswith('total'):
|
38 |
+
continue
|
39 |
+
words = total_line.strip().split()
|
40 |
+
f1 = float(words[-1])
|
41 |
+
print(f'\tEpoch: {cur_ep}, f1={f1}')
|
42 |
+
if f1 > best_f1:
|
43 |
+
best_f1 = f1
|
44 |
+
best_epoch = cur_ep
|
45 |
+
best_step = step
|
46 |
+
f.write(f'{model}\t{best_epoch}\t{best_step}\t{best_f1}\n')
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == '__main__':
|
50 |
+
parser = ArgumentParser()
|
51 |
+
parser.add_argument('model_list')
|
52 |
+
parser.add_argument('models_dir')
|
53 |
+
args = parser.parse_args()
|
54 |
+
find_best_checkpoint(args.model_list, args.models_dir)
|
lama/bin/paper_runfiles/generate_test_celeba-hq.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/CelebA-HQ_val_test"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in "val" "test"
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-celeba-hq \
|
13 |
+
location.out_dir=$OUT_DIR cropping.out_square_crop=False
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
lama/bin/paper_runfiles/generate_test_ffhq.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/FFHQ_val"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in test
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-ffhq \
|
13 |
+
location.out_dir=$OUT_DIR cropping.out_square_crop=False
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
lama/bin/paper_runfiles/generate_test_paris.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in paris_eval_gt
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
|
13 |
+
location.out_dir=OUT_DIR cropping.out_square_crop=False cropping.out_min_size=227
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
lama/bin/paper_runfiles/generate_test_paris_256.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val_256"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in paris_eval_gt
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
|
13 |
+
location.out_dir=$OUT_DIR cropping.out_square_crop=False cropping.out_min_size=256
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
lama/bin/paper_runfiles/generate_val_test.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
|
4 |
+
|
5 |
+
# paths to data are valid for mml7
|
6 |
+
PLACES_ROOT="/data/inpainting/Places365"
|
7 |
+
OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
|
8 |
+
|
9 |
+
source "$(dirname $0)/env.sh"
|
10 |
+
|
11 |
+
for datadir in test_large_30k # val_large
|
12 |
+
do
|
13 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
14 |
+
do
|
15 |
+
"$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
16 |
+
"$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
|
17 |
+
|
18 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
19 |
+
done
|
20 |
+
|
21 |
+
for conf in segm_256 segm_512
|
22 |
+
do
|
23 |
+
"$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
24 |
+
"$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
|
25 |
+
|
26 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
27 |
+
done
|
28 |
+
done
|
lama/bin/paper_runfiles/predict_inner_features.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml7
|
4 |
+
|
5 |
+
source "$(dirname $0)/env.sh"
|
6 |
+
|
7 |
+
"$BINDIR/predict_inner_features.py" \
|
8 |
+
-cn default_inner_features_ffc \
|
9 |
+
model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-34-05_train_ablv2_work_ffc075_resume_epoch39" \
|
10 |
+
indir="/data/inpainting/paper_data/inner_features_vis/input/" \
|
11 |
+
outdir="/data/inpainting/paper_data/inner_features_vis/output/ffc" \
|
12 |
+
dataset.img_suffix=.png
|
13 |
+
|
14 |
+
|
15 |
+
"$BINDIR/predict_inner_features.py" \
|
16 |
+
-cn default_inner_features_work \
|
17 |
+
model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37" \
|
18 |
+
indir="/data/inpainting/paper_data/inner_features_vis/input/" \
|
19 |
+
outdir="/data/inpainting/paper_data/inner_features_vis/output/work" \
|
20 |
+
dataset.img_suffix=.png
|
lama/bin/paper_runfiles/update_test_data_stats.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml7
|
4 |
+
|
5 |
+
source "$(dirname $0)/env.sh"
|
6 |
+
|
7 |
+
#INDIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k"
|
8 |
+
#
|
9 |
+
#for dataset in random_medium_256 random_medium_512 random_thick_256 random_thick_512 random_thin_256 random_thin_512
|
10 |
+
#do
|
11 |
+
# "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
|
12 |
+
#done
|
13 |
+
#
|
14 |
+
#"$BINDIR/calc_dataset_stats.py" "/data/inpainting/evalset2" "/data/inpainting/evalset2_stats2"
|
15 |
+
|
16 |
+
|
17 |
+
INDIR="/data/inpainting/paper_data/CelebA-HQ_val_test/test"
|
18 |
+
|
19 |
+
for dataset in random_medium_256 random_thick_256 random_thin_256
|
20 |
+
do
|
21 |
+
"$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
|
22 |
+
done
|
23 |
+
|
24 |
+
|
25 |
+
INDIR="/data/inpainting/paper_data/Paris_StreetView_Dataset_val_256/paris_eval_gt"
|
26 |
+
|
27 |
+
for dataset in random_medium_256 random_thick_256 random_thin_256
|
28 |
+
do
|
29 |
+
"$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
|
30 |
+
done
|
lama/bin/predict.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Example command:
|
4 |
+
# ./bin/predict.py \
|
5 |
+
# model.path=<path to checkpoint, prepared by make_checkpoint.py> \
|
6 |
+
# indir=<path to input data> \
|
7 |
+
# outdir=<where to store predicts>
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.utils import move_to_device
|
15 |
+
from saicinpainting.evaluation.refinement import refine_predict
|
16 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
17 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
18 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
19 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
20 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
21 |
+
|
22 |
+
import cv2
|
23 |
+
import hydra
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import tqdm
|
27 |
+
import yaml
|
28 |
+
from omegaconf import OmegaConf
|
29 |
+
from torch.utils.data._utils.collate import default_collate
|
30 |
+
|
31 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
32 |
+
from saicinpainting.training.trainers import load_checkpoint
|
33 |
+
from saicinpainting.utils import register_debug_signal_handlers
|
34 |
+
|
35 |
+
LOGGER = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
@hydra.main(config_path='../configs/prediction', config_name='default.yaml')
|
39 |
+
def main(predict_config: OmegaConf):
|
40 |
+
try:
|
41 |
+
if sys.platform != 'win32':
|
42 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
43 |
+
|
44 |
+
device = torch.device("cpu")
|
45 |
+
|
46 |
+
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
|
47 |
+
with open(train_config_path, 'r') as f:
|
48 |
+
train_config = OmegaConf.create(yaml.safe_load(f))
|
49 |
+
|
50 |
+
train_config.training_model.predict_only = True
|
51 |
+
train_config.visualizer.kind = 'noop'
|
52 |
+
|
53 |
+
out_ext = predict_config.get('out_ext', '.png')
|
54 |
+
|
55 |
+
checkpoint_path = os.path.join(predict_config.model.path,
|
56 |
+
'models',
|
57 |
+
predict_config.model.checkpoint)
|
58 |
+
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
|
59 |
+
model.freeze()
|
60 |
+
if not predict_config.get('refine', False):
|
61 |
+
model.to(device)
|
62 |
+
|
63 |
+
if not predict_config.indir.endswith('/'):
|
64 |
+
predict_config.indir += '/'
|
65 |
+
|
66 |
+
dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
|
67 |
+
for img_i in tqdm.trange(len(dataset)):
|
68 |
+
mask_fname = dataset.mask_filenames[img_i]
|
69 |
+
cur_out_fname = os.path.join(
|
70 |
+
predict_config.outdir,
|
71 |
+
os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
|
72 |
+
)
|
73 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
74 |
+
batch = default_collate([dataset[img_i]])
|
75 |
+
if predict_config.get('refine', False):
|
76 |
+
assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement"
|
77 |
+
# image unpadding is taken care of in the refiner, so that output image
|
78 |
+
# is same size as the input image
|
79 |
+
cur_res = refine_predict(batch, model, **predict_config.refiner)
|
80 |
+
cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
|
81 |
+
else:
|
82 |
+
with torch.no_grad():
|
83 |
+
batch = move_to_device(batch, device)
|
84 |
+
batch['mask'] = (batch['mask'] > 0) * 1
|
85 |
+
batch = model(batch)
|
86 |
+
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
|
87 |
+
unpad_to_size = batch.get('unpad_to_size', None)
|
88 |
+
if unpad_to_size is not None:
|
89 |
+
orig_height, orig_width = unpad_to_size
|
90 |
+
cur_res = cur_res[:orig_height, :orig_width]
|
91 |
+
|
92 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
93 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
94 |
+
cv2.imwrite(cur_out_fname, cur_res)
|
95 |
+
|
96 |
+
except KeyboardInterrupt:
|
97 |
+
LOGGER.warning('Interrupted by user')
|
98 |
+
except Exception as ex:
|
99 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
100 |
+
sys.exit(1)
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == '__main__':
|
104 |
+
main()
|
lama/bin/predict_inner_features.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Example command:
|
4 |
+
# ./bin/predict.py \
|
5 |
+
# model.path=<path to checkpoint, prepared by make_checkpoint.py> \
|
6 |
+
# indir=<path to input data> \
|
7 |
+
# outdir=<where to store predicts>
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.utils import move_to_device
|
15 |
+
|
16 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
17 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
18 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
19 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
20 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
21 |
+
|
22 |
+
import cv2
|
23 |
+
import hydra
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import tqdm
|
27 |
+
import yaml
|
28 |
+
from omegaconf import OmegaConf
|
29 |
+
from torch.utils.data._utils.collate import default_collate
|
30 |
+
|
31 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
32 |
+
from saicinpainting.training.trainers import load_checkpoint, DefaultInpaintingTrainingModule
|
33 |
+
from saicinpainting.utils import register_debug_signal_handlers, get_shape
|
34 |
+
|
35 |
+
LOGGER = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
@hydra.main(config_path='../configs/prediction', config_name='default_inner_features.yaml')
|
39 |
+
def main(predict_config: OmegaConf):
|
40 |
+
try:
|
41 |
+
if sys.platform != 'win32':
|
42 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
43 |
+
|
44 |
+
device = torch.device(predict_config.device)
|
45 |
+
|
46 |
+
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
|
47 |
+
with open(train_config_path, 'r') as f:
|
48 |
+
train_config = OmegaConf.create(yaml.safe_load(f))
|
49 |
+
|
50 |
+
checkpoint_path = os.path.join(predict_config.model.path, 'models', predict_config.model.checkpoint)
|
51 |
+
model = load_checkpoint(train_config, checkpoint_path, strict=False)
|
52 |
+
model.freeze()
|
53 |
+
model.to(device)
|
54 |
+
|
55 |
+
assert isinstance(model, DefaultInpaintingTrainingModule), 'Only DefaultInpaintingTrainingModule is supported'
|
56 |
+
assert isinstance(getattr(model.generator, 'model', None), torch.nn.Sequential)
|
57 |
+
|
58 |
+
if not predict_config.indir.endswith('/'):
|
59 |
+
predict_config.indir += '/'
|
60 |
+
|
61 |
+
dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
|
62 |
+
|
63 |
+
max_level = max(predict_config.levels)
|
64 |
+
|
65 |
+
with torch.no_grad():
|
66 |
+
for img_i in tqdm.trange(len(dataset)):
|
67 |
+
mask_fname = dataset.mask_filenames[img_i]
|
68 |
+
cur_out_fname = os.path.join(predict_config.outdir, os.path.splitext(mask_fname[len(predict_config.indir):])[0])
|
69 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
70 |
+
|
71 |
+
batch = move_to_device(default_collate([dataset[img_i]]), device)
|
72 |
+
|
73 |
+
img = batch['image']
|
74 |
+
mask = batch['mask']
|
75 |
+
mask[:] = 0
|
76 |
+
mask_h, mask_w = mask.shape[-2:]
|
77 |
+
mask[:, :,
|
78 |
+
mask_h // 2 - predict_config.hole_radius : mask_h // 2 + predict_config.hole_radius,
|
79 |
+
mask_w // 2 - predict_config.hole_radius : mask_w // 2 + predict_config.hole_radius] = 1
|
80 |
+
|
81 |
+
masked_img = torch.cat([img * (1 - mask), mask], dim=1)
|
82 |
+
|
83 |
+
feats = masked_img
|
84 |
+
for level_i, level in enumerate(model.generator.model):
|
85 |
+
feats = level(feats)
|
86 |
+
if level_i in predict_config.levels:
|
87 |
+
cur_feats = torch.cat([f for f in feats if torch.is_tensor(f)], dim=1) \
|
88 |
+
if isinstance(feats, tuple) else feats
|
89 |
+
|
90 |
+
if predict_config.slice_channels:
|
91 |
+
cur_feats = cur_feats[:, slice(*predict_config.slice_channels)]
|
92 |
+
|
93 |
+
cur_feat = cur_feats.pow(2).mean(1).pow(0.5).clone()
|
94 |
+
cur_feat -= cur_feat.min()
|
95 |
+
cur_feat /= cur_feat.std()
|
96 |
+
cur_feat = cur_feat.clamp(0, 1) / 1
|
97 |
+
cur_feat = cur_feat.cpu().numpy()[0]
|
98 |
+
cur_feat *= 255
|
99 |
+
cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
|
100 |
+
cv2.imwrite(cur_out_fname + f'_lev{level_i:02d}_norm.png', cur_feat)
|
101 |
+
|
102 |
+
# for channel_i in predict_config.channels:
|
103 |
+
#
|
104 |
+
# cur_feat = cur_feats[0, channel_i].clone().detach().cpu().numpy()
|
105 |
+
# cur_feat -= cur_feat.min()
|
106 |
+
# cur_feat /= cur_feat.max()
|
107 |
+
# cur_feat *= 255
|
108 |
+
# cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
|
109 |
+
# cv2.imwrite(cur_out_fname + f'_lev{level_i}_ch{channel_i}.png', cur_feat)
|
110 |
+
elif level_i >= max_level:
|
111 |
+
break
|
112 |
+
except KeyboardInterrupt:
|
113 |
+
LOGGER.warning('Interrupted by user')
|
114 |
+
except Exception as ex:
|
115 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
116 |
+
sys.exit(1)
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == '__main__':
|
120 |
+
main()
|
lama/bin/report_from_tb.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
|
7 |
+
import tensorflow as tf
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
|
10 |
+
|
11 |
+
GROUPING_RULES = [
|
12 |
+
re.compile(r'^(?P<group>train|test|val|extra_val_.*?(256|512))_(?P<title>.*)', re.I)
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
DROP_RULES = [
|
17 |
+
re.compile(r'_std$', re.I)
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def need_drop(tag):
|
22 |
+
for rule in DROP_RULES:
|
23 |
+
if rule.search(tag):
|
24 |
+
return True
|
25 |
+
return False
|
26 |
+
|
27 |
+
|
28 |
+
def get_group_and_title(tag):
|
29 |
+
for rule in GROUPING_RULES:
|
30 |
+
match = rule.search(tag)
|
31 |
+
if match is None:
|
32 |
+
continue
|
33 |
+
return match.group('group'), match.group('title')
|
34 |
+
return None, None
|
35 |
+
|
36 |
+
|
37 |
+
def main(args):
|
38 |
+
os.makedirs(args.outdir, exist_ok=True)
|
39 |
+
|
40 |
+
ignored_events = set()
|
41 |
+
|
42 |
+
for orig_fname in glob.glob(args.inglob):
|
43 |
+
cur_dirpath = os.path.dirname(orig_fname) # remove filename, this should point to "version_0" directory
|
44 |
+
subdirname = os.path.basename(cur_dirpath) # == "version_0" most of time
|
45 |
+
exp_root_path = os.path.dirname(cur_dirpath) # remove "version_0"
|
46 |
+
exp_name = os.path.basename(exp_root_path)
|
47 |
+
|
48 |
+
writers_by_group = {}
|
49 |
+
|
50 |
+
for e in tf.compat.v1.train.summary_iterator(orig_fname):
|
51 |
+
for v in e.summary.value:
|
52 |
+
if need_drop(v.tag):
|
53 |
+
continue
|
54 |
+
|
55 |
+
cur_group, cur_title = get_group_and_title(v.tag)
|
56 |
+
if cur_group is None:
|
57 |
+
if v.tag not in ignored_events:
|
58 |
+
print(f'WARNING: Could not detect group for {v.tag}, ignoring it')
|
59 |
+
ignored_events.add(v.tag)
|
60 |
+
continue
|
61 |
+
|
62 |
+
cur_writer = writers_by_group.get(cur_group, None)
|
63 |
+
if cur_writer is None:
|
64 |
+
if args.include_version:
|
65 |
+
cur_outdir = os.path.join(args.outdir, exp_name, f'{subdirname}_{cur_group}')
|
66 |
+
else:
|
67 |
+
cur_outdir = os.path.join(args.outdir, exp_name, cur_group)
|
68 |
+
cur_writer = SummaryWriter(cur_outdir)
|
69 |
+
writers_by_group[cur_group] = cur_writer
|
70 |
+
|
71 |
+
cur_writer.add_scalar(cur_title, v.simple_value, global_step=e.step, walltime=e.wall_time)
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
import argparse
|
76 |
+
|
77 |
+
aparser = argparse.ArgumentParser()
|
78 |
+
aparser.add_argument('inglob', type=str)
|
79 |
+
aparser.add_argument('outdir', type=str)
|
80 |
+
aparser.add_argument('--include-version', action='store_true',
|
81 |
+
help='Include subdirectory name e.g. "version_0" into output path')
|
82 |
+
|
83 |
+
main(aparser.parse_args())
|
lama/bin/sample_from_dataset.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
from skimage import io
|
8 |
+
from skimage.segmentation import mark_boundaries
|
9 |
+
|
10 |
+
from saicinpainting.evaluation.data import InpaintingDataset
|
11 |
+
from saicinpainting.evaluation.vis import save_item_for_vis
|
12 |
+
|
13 |
+
def save_mask_for_sidebyside(item, out_file):
|
14 |
+
mask = item['mask']# > 0.5
|
15 |
+
if mask.ndim == 3:
|
16 |
+
mask = mask[0]
|
17 |
+
mask = np.clip(mask * 255, 0, 255).astype('uint8')
|
18 |
+
io.imsave(out_file, mask)
|
19 |
+
|
20 |
+
def save_img_for_sidebyside(item, out_file):
|
21 |
+
img = np.transpose(item['image'], (1, 2, 0))
|
22 |
+
img = np.clip(img * 255, 0, 255).astype('uint8')
|
23 |
+
io.imsave(out_file, img)
|
24 |
+
|
25 |
+
def save_masked_img_for_sidebyside(item, out_file):
|
26 |
+
mask = item['mask']
|
27 |
+
img = item['image']
|
28 |
+
|
29 |
+
img = (1-mask) * img + mask
|
30 |
+
img = np.transpose(img, (1, 2, 0))
|
31 |
+
|
32 |
+
img = np.clip(img * 255, 0, 255).astype('uint8')
|
33 |
+
io.imsave(out_file, img)
|
34 |
+
|
35 |
+
def main(args):
|
36 |
+
dataset = InpaintingDataset(args.datadir, img_suffix='.png')
|
37 |
+
|
38 |
+
area_bins = np.linspace(0, 1, args.area_bins + 1)
|
39 |
+
|
40 |
+
heights = []
|
41 |
+
widths = []
|
42 |
+
image_areas = []
|
43 |
+
hole_areas = []
|
44 |
+
hole_area_percents = []
|
45 |
+
area_bins_count = np.zeros(args.area_bins)
|
46 |
+
area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
|
47 |
+
|
48 |
+
bin2i = [[] for _ in range(args.area_bins)]
|
49 |
+
|
50 |
+
for i, item in enumerate(tqdm.tqdm(dataset)):
|
51 |
+
h, w = item['image'].shape[1:]
|
52 |
+
heights.append(h)
|
53 |
+
widths.append(w)
|
54 |
+
full_area = h * w
|
55 |
+
image_areas.append(full_area)
|
56 |
+
hole_area = (item['mask'] == 1).sum()
|
57 |
+
hole_areas.append(hole_area)
|
58 |
+
hole_percent = hole_area / full_area
|
59 |
+
hole_area_percents.append(hole_percent)
|
60 |
+
bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
|
61 |
+
area_bins_count[bin_i] += 1
|
62 |
+
bin2i[bin_i].append(i)
|
63 |
+
|
64 |
+
os.makedirs(args.outdir, exist_ok=True)
|
65 |
+
|
66 |
+
for bin_i in range(args.area_bins):
|
67 |
+
bindir = os.path.join(args.outdir, area_bin_titles[bin_i])
|
68 |
+
os.makedirs(bindir, exist_ok=True)
|
69 |
+
bin_idx = bin2i[bin_i]
|
70 |
+
for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
|
71 |
+
item = dataset[sample_i]
|
72 |
+
path = os.path.join(bindir, dataset.img_filenames[sample_i].split('/')[-1])
|
73 |
+
save_masked_img_for_sidebyside(item, path)
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
import argparse
|
78 |
+
|
79 |
+
aparser = argparse.ArgumentParser()
|
80 |
+
aparser.add_argument('--datadir', type=str,
|
81 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
82 |
+
aparser.add_argument('--outdir', type=str, help='Where to put results')
|
83 |
+
aparser.add_argument('--samples-n', type=int, default=10,
|
84 |
+
help='Number of sample images with masks to copy for visualization for each area bin')
|
85 |
+
aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
|
86 |
+
|
87 |
+
main(aparser.parse_args())
|
lama/bin/side_by_side.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
|
9 |
+
from saicinpainting.evaluation.utils import load_yaml
|
10 |
+
from saicinpainting.training.visualizers.base import visualize_mask_and_images
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
config = load_yaml(args.config)
|
15 |
+
|
16 |
+
datasets = [PrecomputedInpaintingResultsDataset(args.datadir, cur_predictdir, **config.dataset_kwargs)
|
17 |
+
for cur_predictdir in args.predictdirs]
|
18 |
+
assert len({len(ds) for ds in datasets}) == 1
|
19 |
+
len_first = len(datasets[0])
|
20 |
+
|
21 |
+
indices = list(range(len_first))
|
22 |
+
if len_first > args.max_n:
|
23 |
+
indices = sorted(random.sample(indices, args.max_n))
|
24 |
+
|
25 |
+
os.makedirs(args.outpath, exist_ok=True)
|
26 |
+
|
27 |
+
filename2i = {}
|
28 |
+
|
29 |
+
keys = ['image'] + [i for i in range(len(datasets))]
|
30 |
+
for img_i in indices:
|
31 |
+
try:
|
32 |
+
mask_fname = os.path.basename(datasets[0].mask_filenames[img_i])
|
33 |
+
if mask_fname in filename2i:
|
34 |
+
filename2i[mask_fname] += 1
|
35 |
+
idx = filename2i[mask_fname]
|
36 |
+
mask_fname_only, ext = os.path.split(mask_fname)
|
37 |
+
mask_fname = f'{mask_fname_only}_{idx}{ext}'
|
38 |
+
else:
|
39 |
+
filename2i[mask_fname] = 1
|
40 |
+
|
41 |
+
cur_vis_dict = datasets[0][img_i]
|
42 |
+
for ds_i, ds in enumerate(datasets):
|
43 |
+
cur_vis_dict[ds_i] = ds[img_i]['inpainted']
|
44 |
+
|
45 |
+
vis_img = visualize_mask_and_images(cur_vis_dict, keys,
|
46 |
+
last_without_mask=False,
|
47 |
+
mask_only_first=True,
|
48 |
+
black_mask=args.black)
|
49 |
+
vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
|
50 |
+
|
51 |
+
out_fname = os.path.join(args.outpath, mask_fname)
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
|
56 |
+
cv2.imwrite(out_fname, vis_img)
|
57 |
+
except Exception as ex:
|
58 |
+
print(f'Could not process {img_i} due to {ex}')
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
import argparse
|
63 |
+
|
64 |
+
aparser = argparse.ArgumentParser()
|
65 |
+
aparser.add_argument('--max-n', type=int, default=100, help='Maximum number of images to print')
|
66 |
+
aparser.add_argument('--black', action='store_true', help='Whether to fill mask on GT with black')
|
67 |
+
aparser.add_argument('config', type=str, help='Path to evaluation config (e.g. configs/eval1.yaml)')
|
68 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
69 |
+
aparser.add_argument('datadir', type=str,
|
70 |
+
help='Path to folder with images and masks')
|
71 |
+
aparser.add_argument('predictdirs', type=str,
|
72 |
+
nargs='+',
|
73 |
+
help='Path to folders with predicts')
|
74 |
+
|
75 |
+
|
76 |
+
main(aparser.parse_args())
|
lama/bin/split_tar.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import tqdm
|
5 |
+
import webdataset as wds
|
6 |
+
|
7 |
+
|
8 |
+
def main(args):
|
9 |
+
input_dataset = wds.Dataset(args.infile)
|
10 |
+
output_dataset = wds.ShardWriter(args.outpattern)
|
11 |
+
for rec in tqdm.tqdm(input_dataset):
|
12 |
+
output_dataset.write(rec)
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
aparser = argparse.ArgumentParser()
|
19 |
+
aparser.add_argument('infile', type=str)
|
20 |
+
aparser.add_argument('outpattern', type=str)
|
21 |
+
|
22 |
+
main(aparser.parse_args())
|
lama/bin/to_jit.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import hydra
|
5 |
+
import torch
|
6 |
+
import yaml
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from saicinpainting.training.trainers import load_checkpoint
|
11 |
+
from saicinpainting.utils import register_debug_signal_handlers
|
12 |
+
|
13 |
+
|
14 |
+
class JITWrapper(nn.Module):
|
15 |
+
def __init__(self, model):
|
16 |
+
super().__init__()
|
17 |
+
self.model = model
|
18 |
+
|
19 |
+
def forward(self, image, mask):
|
20 |
+
batch = {
|
21 |
+
"image": image,
|
22 |
+
"mask": mask
|
23 |
+
}
|
24 |
+
out = self.model(batch)
|
25 |
+
return out["inpainted"]
|
26 |
+
|
27 |
+
|
28 |
+
@hydra.main(config_path="../configs/prediction", config_name="default.yaml")
|
29 |
+
def main(predict_config: OmegaConf):
|
30 |
+
if sys.platform != 'win32':
|
31 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
32 |
+
|
33 |
+
train_config_path = os.path.join(predict_config.model.path, "config.yaml")
|
34 |
+
with open(train_config_path, "r") as f:
|
35 |
+
train_config = OmegaConf.create(yaml.safe_load(f))
|
36 |
+
|
37 |
+
train_config.training_model.predict_only = True
|
38 |
+
train_config.visualizer.kind = "noop"
|
39 |
+
|
40 |
+
checkpoint_path = os.path.join(
|
41 |
+
predict_config.model.path, "models", predict_config.model.checkpoint
|
42 |
+
)
|
43 |
+
model = load_checkpoint(
|
44 |
+
train_config, checkpoint_path, strict=False, map_location="cpu"
|
45 |
+
)
|
46 |
+
model.eval()
|
47 |
+
jit_model_wrapper = JITWrapper(model)
|
48 |
+
|
49 |
+
image = torch.rand(1, 3, 120, 120)
|
50 |
+
mask = torch.rand(1, 1, 120, 120)
|
51 |
+
output = jit_model_wrapper(image, mask)
|
52 |
+
|
53 |
+
if torch.cuda.is_available():
|
54 |
+
device = torch.device("cuda")
|
55 |
+
else:
|
56 |
+
device = torch.device("cpu")
|
57 |
+
|
58 |
+
image = image.to(device)
|
59 |
+
mask = mask.to(device)
|
60 |
+
traced_model = torch.jit.trace(jit_model_wrapper, (image, mask), strict=False).to(device)
|
61 |
+
|
62 |
+
save_path = Path(predict_config.save_path)
|
63 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
64 |
+
|
65 |
+
print(f"Saving big-lama.pt model to {save_path}")
|
66 |
+
traced_model.save(save_path)
|
67 |
+
|
68 |
+
print(f"Checking jit model output...")
|
69 |
+
jit_model = torch.jit.load(str(save_path))
|
70 |
+
jit_output = jit_model(image, mask)
|
71 |
+
diff = (output - jit_output).abs().sum()
|
72 |
+
print(f"diff: {diff}")
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
main()
|
lama/bin/train.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import traceback
|
7 |
+
|
8 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
9 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
10 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
11 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
12 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
13 |
+
|
14 |
+
import hydra
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from pytorch_lightning import Trainer
|
17 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
18 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
19 |
+
from pytorch_lightning.plugins import DDPPlugin
|
20 |
+
|
21 |
+
from saicinpainting.training.trainers import make_training_model
|
22 |
+
from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \
|
23 |
+
handle_deterministic_config
|
24 |
+
|
25 |
+
LOGGER = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
@handle_ddp_subprocess()
|
29 |
+
@hydra.main(config_path='../configs/training', config_name='tiny_test.yaml')
|
30 |
+
def main(config: OmegaConf):
|
31 |
+
try:
|
32 |
+
need_set_deterministic = handle_deterministic_config(config)
|
33 |
+
|
34 |
+
if sys.platform != 'win32':
|
35 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
36 |
+
|
37 |
+
is_in_ddp_subprocess = handle_ddp_parent_process()
|
38 |
+
|
39 |
+
config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir)
|
40 |
+
if not is_in_ddp_subprocess:
|
41 |
+
LOGGER.info(OmegaConf.to_yaml(config))
|
42 |
+
OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml'))
|
43 |
+
|
44 |
+
checkpoints_dir = os.path.join(os.getcwd(), 'models')
|
45 |
+
os.makedirs(checkpoints_dir, exist_ok=True)
|
46 |
+
|
47 |
+
# there is no need to suppress this logger in ddp, because it handles rank on its own
|
48 |
+
metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd()))
|
49 |
+
metrics_logger.log_hyperparams(config)
|
50 |
+
|
51 |
+
training_model = make_training_model(config)
|
52 |
+
|
53 |
+
trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True)
|
54 |
+
if need_set_deterministic:
|
55 |
+
trainer_kwargs['deterministic'] = True
|
56 |
+
|
57 |
+
trainer = Trainer(
|
58 |
+
# there is no need to suppress checkpointing in ddp, because it handles rank on its own
|
59 |
+
callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs),
|
60 |
+
logger=metrics_logger,
|
61 |
+
default_root_dir=os.getcwd(),
|
62 |
+
**trainer_kwargs
|
63 |
+
)
|
64 |
+
trainer.fit(training_model)
|
65 |
+
except KeyboardInterrupt:
|
66 |
+
LOGGER.warning('Interrupted by user')
|
67 |
+
except Exception as ex:
|
68 |
+
LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}')
|
69 |
+
sys.exit(1)
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == '__main__':
|
73 |
+
main()
|
lama/conda_env.yml
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: lama
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
- conda-forge
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- _openmp_mutex=4.5=1_gnu
|
8 |
+
- absl-py=0.13.0=py36h06a4308_0
|
9 |
+
- aiohttp=3.7.4.post0=py36h7f8727e_2
|
10 |
+
- antlr-python-runtime=4.8=py36h9f0ad1d_2
|
11 |
+
- async-timeout=3.0.1=py36h06a4308_0
|
12 |
+
- attrs=21.2.0=pyhd3eb1b0_0
|
13 |
+
- blas=1.0=mkl
|
14 |
+
- blinker=1.4=py36h06a4308_0
|
15 |
+
- brotlipy=0.7.0=py36h27cfd23_1003
|
16 |
+
- bzip2=1.0.8=h7b6447c_0
|
17 |
+
- c-ares=1.17.1=h27cfd23_0
|
18 |
+
- ca-certificates=2021.7.5=h06a4308_1
|
19 |
+
- cachetools=4.2.2=pyhd3eb1b0_0
|
20 |
+
- certifi=2021.5.30=py36h06a4308_0
|
21 |
+
- cffi=1.14.6=py36h400218f_0
|
22 |
+
- chardet=4.0.0=py36h06a4308_1003
|
23 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
24 |
+
- click=8.0.1=pyhd3eb1b0_0
|
25 |
+
- cloudpickle=2.0.0=pyhd3eb1b0_0
|
26 |
+
- coverage=5.5=py36h27cfd23_2
|
27 |
+
- cryptography=3.4.7=py36hd23ed53_0
|
28 |
+
- cudatoolkit=10.2.89=hfd86e86_1
|
29 |
+
- cycler=0.10.0=py36_0
|
30 |
+
- cython=0.29.24=py36h295c915_0
|
31 |
+
- cytoolz=0.11.0=py36h7b6447c_0
|
32 |
+
- dask-core=1.1.4=py36_1
|
33 |
+
- dataclasses=0.8=pyh4f3eec9_6
|
34 |
+
- dbus=1.13.18=hb2f20db_0
|
35 |
+
- decorator=5.0.9=pyhd3eb1b0_0
|
36 |
+
- easydict=1.9=py_0
|
37 |
+
- expat=2.4.1=h2531618_2
|
38 |
+
- ffmpeg=4.2.2=h20bf706_0
|
39 |
+
- fontconfig=2.13.1=h6c09931_0
|
40 |
+
- freetype=2.10.4=h5ab3b9f_0
|
41 |
+
- fsspec=2021.8.1=pyhd3eb1b0_0
|
42 |
+
- future=0.18.2=py36_1
|
43 |
+
- glib=2.69.1=h5202010_0
|
44 |
+
- gmp=6.2.1=h2531618_2
|
45 |
+
- gnutls=3.6.15=he1e5248_0
|
46 |
+
- google-auth=1.33.0=pyhd3eb1b0_0
|
47 |
+
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
48 |
+
- grpcio=1.36.1=py36h2157cd5_1
|
49 |
+
- gst-plugins-base=1.14.0=h8213a91_2
|
50 |
+
- gstreamer=1.14.0=h28cd5cc_2
|
51 |
+
- hydra-core=1.1.0=pyhd8ed1ab_0
|
52 |
+
- icu=58.2=he6710b0_3
|
53 |
+
- idna=3.2=pyhd3eb1b0_0
|
54 |
+
- idna_ssl=1.1.0=py36h06a4308_0
|
55 |
+
- imageio=2.9.0=pyhd3eb1b0_0
|
56 |
+
- importlib-metadata=4.8.1=py36h06a4308_0
|
57 |
+
- importlib_resources=5.2.0=pyhd3eb1b0_1
|
58 |
+
- intel-openmp=2021.3.0=h06a4308_3350
|
59 |
+
- joblib=1.0.1=pyhd3eb1b0_0
|
60 |
+
- jpeg=9b=h024ee3a_2
|
61 |
+
- kiwisolver=1.3.1=py36h2531618_0
|
62 |
+
- lame=3.100=h7b6447c_0
|
63 |
+
- lcms2=2.12=h3be6417_0
|
64 |
+
- ld_impl_linux-64=2.35.1=h7274673_9
|
65 |
+
- libblas=3.9.0=11_linux64_mkl
|
66 |
+
- libcblas=3.9.0=11_linux64_mkl
|
67 |
+
- libffi=3.3=he6710b0_2
|
68 |
+
- libgcc-ng=9.3.0=h5101ec6_17
|
69 |
+
- libgfortran-ng=9.3.0=ha5ec8a7_17
|
70 |
+
- libgfortran5=9.3.0=ha5ec8a7_17
|
71 |
+
- libgomp=9.3.0=h5101ec6_17
|
72 |
+
- libidn2=2.3.2=h7f8727e_0
|
73 |
+
- liblapack=3.9.0=11_linux64_mkl
|
74 |
+
- libopus=1.3.1=h7b6447c_0
|
75 |
+
- libpng=1.6.37=hbc83047_0
|
76 |
+
- libprotobuf=3.17.2=h4ff587b_1
|
77 |
+
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
78 |
+
- libtasn1=4.16.0=h27cfd23_0
|
79 |
+
- libtiff=4.2.0=h85742a9_0
|
80 |
+
- libunistring=0.9.10=h27cfd23_0
|
81 |
+
- libuuid=1.0.3=h1bed415_2
|
82 |
+
- libuv=1.40.0=h7b6447c_0
|
83 |
+
- libvpx=1.7.0=h439df22_0
|
84 |
+
- libwebp-base=1.2.0=h27cfd23_0
|
85 |
+
- libxcb=1.14=h7b6447c_0
|
86 |
+
- libxml2=2.9.12=h03d6c58_0
|
87 |
+
- lz4-c=1.9.3=h295c915_1
|
88 |
+
- markdown=3.3.4=py36h06a4308_0
|
89 |
+
- matplotlib=3.3.4=py36h06a4308_0
|
90 |
+
- matplotlib-base=3.3.4=py36h62a2d02_0
|
91 |
+
- mkl=2021.3.0=h06a4308_520
|
92 |
+
- multidict=5.1.0=py36h27cfd23_2
|
93 |
+
- ncurses=6.2=he6710b0_1
|
94 |
+
- nettle=3.7.3=hbbd107a_1
|
95 |
+
- networkx=2.2=py36_1
|
96 |
+
- ninja=1.10.2=hff7bd54_1
|
97 |
+
- numpy=1.19.5=py36hfc0c790_2
|
98 |
+
- oauthlib=3.1.1=pyhd3eb1b0_0
|
99 |
+
- olefile=0.46=py36_0
|
100 |
+
- omegaconf=2.1.1=py36h5fab9bb_0
|
101 |
+
- openh264=2.1.0=hd408876_0
|
102 |
+
- openjpeg=2.4.0=h3ad879b_0
|
103 |
+
- openssl=1.1.1l=h7f8727e_0
|
104 |
+
- packaging=21.0=pyhd3eb1b0_0
|
105 |
+
- pandas=1.1.5=py36h284efc9_0
|
106 |
+
- pcre=8.45=h295c915_0
|
107 |
+
- pillow=8.3.1=py36h2c7a002_0
|
108 |
+
- pip=21.0.1=py36h06a4308_0
|
109 |
+
- protobuf=3.17.2=py36h295c915_0
|
110 |
+
- pyasn1=0.4.8=pyhd3eb1b0_0
|
111 |
+
- pyasn1-modules=0.2.8=py_0
|
112 |
+
- pycparser=2.20=py_2
|
113 |
+
- pyjwt=2.1.0=py36h06a4308_0
|
114 |
+
- pyopenssl=20.0.1=pyhd3eb1b0_1
|
115 |
+
- pyparsing=2.4.7=pyhd3eb1b0_0
|
116 |
+
- pyqt=5.9.2=py36h05f1152_2
|
117 |
+
- pysocks=1.7.1=py36h06a4308_0
|
118 |
+
- python=3.6.13=h12debd9_1
|
119 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
120 |
+
- python_abi=3.6=2_cp36m
|
121 |
+
- pytz=2021.1=pyhd3eb1b0_0
|
122 |
+
- pywavelets=1.1.1=py36h7b6447c_2
|
123 |
+
- pyyaml=5.4.1=py36h27cfd23_1
|
124 |
+
- qt=5.9.7=h5867ecd_1
|
125 |
+
- readline=8.1=h27cfd23_0
|
126 |
+
- requests=2.26.0=pyhd3eb1b0_0
|
127 |
+
- requests-oauthlib=1.3.0=py_0
|
128 |
+
- rsa=4.7.2=pyhd3eb1b0_1
|
129 |
+
- scikit-image=0.17.2=py36h284efc9_4
|
130 |
+
- scikit-learn=0.24.2=py36ha9443f7_0
|
131 |
+
- scipy=1.5.3=py36h9e8f40b_0
|
132 |
+
- setuptools=58.0.4=py36h06a4308_0
|
133 |
+
- sip=4.19.8=py36hf484d3e_0
|
134 |
+
- six=1.16.0=pyhd3eb1b0_0
|
135 |
+
- sqlite=3.36.0=hc218d9a_0
|
136 |
+
- tabulate=0.8.9=py36h06a4308_0
|
137 |
+
- tensorboard=2.4.0=pyhc547734_0
|
138 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
139 |
+
- threadpoolctl=2.2.0=pyh0d69192_0
|
140 |
+
- tifffile=2020.10.1=py36hdd07704_2
|
141 |
+
- tk=8.6.11=h1ccaba5_0
|
142 |
+
- toolz=0.11.1=pyhd3eb1b0_0
|
143 |
+
- tqdm=4.62.2=pyhd3eb1b0_1
|
144 |
+
- typing-extensions=3.10.0.2=hd3eb1b0_0
|
145 |
+
- typing_extensions=3.10.0.2=pyh06a4308_0
|
146 |
+
- urllib3=1.26.6=pyhd3eb1b0_1
|
147 |
+
- werkzeug=2.0.1=pyhd3eb1b0_0
|
148 |
+
- wheel=0.37.0=pyhd3eb1b0_1
|
149 |
+
- x264=1!157.20191217=h7b6447c_0
|
150 |
+
- xz=5.2.5=h7b6447c_0
|
151 |
+
- yaml=0.2.5=h7b6447c_0
|
152 |
+
- yarl=1.6.3=py36h27cfd23_0
|
153 |
+
- zipp=3.5.0=pyhd3eb1b0_0
|
154 |
+
- zlib=1.2.11=h7b6447c_3
|
155 |
+
- zstd=1.4.9=haebb681_0
|
156 |
+
- pip:
|
157 |
+
- albumentations==0.5.2
|
158 |
+
- braceexpand==0.1.7
|
159 |
+
- imgaug==0.4.0
|
160 |
+
- kornia==0.5.0
|
161 |
+
- opencv-python==4.5.3.56
|
162 |
+
- opencv-python-headless==4.5.3.56
|
163 |
+
- shapely==1.7.1
|
164 |
+
- webdataset==0.1.76
|
165 |
+
- wldhx-yadisk-direct==0.0.6
|
lama/configs/analyze_mask_errors.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_kwargs:
|
2 |
+
img_suffix: .jpg
|
3 |
+
inpainted_suffix: .jpg
|
4 |
+
|
5 |
+
take_global_top: 30
|
6 |
+
take_worst_best_top: 30
|
7 |
+
take_overlapping_top: 30
|
lama/configs/data_gen/random_medium_256.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 5
|
8 |
+
max_width: 50
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 100
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 0
|
15 |
+
bbox_min_size: 10
|
16 |
+
bbox_max_size: 50
|
17 |
+
max_times: 5
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 256
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
lama/configs/data_gen/random_medium_512.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 10
|
8 |
+
max_width: 100
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 200
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 0
|
15 |
+
bbox_min_size: 30
|
16 |
+
bbox_max_size: 150
|
17 |
+
max_times: 5
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 512
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
lama/configs/data_gen/random_thick_256.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 1
|
7 |
+
max_times: 5
|
8 |
+
max_width: 100
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 200
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 10
|
15 |
+
bbox_min_size: 30
|
16 |
+
bbox_max_size: 150
|
17 |
+
max_times: 3
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 256
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
lama/configs/data_gen/random_thick_512.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 1
|
7 |
+
max_times: 5
|
8 |
+
max_width: 250
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 450
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 10
|
15 |
+
bbox_min_size: 30
|
16 |
+
bbox_max_size: 300
|
17 |
+
max_times: 4
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 512
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
lama/configs/data_gen/random_thin_256.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 50
|
8 |
+
max_width: 10
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 40
|
11 |
+
box_proba: 0
|
12 |
+
segm_proba: 0
|
13 |
+
squares_proba: 0
|
14 |
+
|
15 |
+
variants_n: 5
|
16 |
+
|
17 |
+
max_masks_per_image: 1
|
18 |
+
|
19 |
+
cropping:
|
20 |
+
out_min_size: 256
|
21 |
+
handle_small_mode: upscale
|
22 |
+
out_square_crop: True
|
23 |
+
crop_min_overlap: 1
|
24 |
+
|
25 |
+
max_tamper_area: 0.5
|
lama/configs/data_gen/random_thin_512.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 70
|
8 |
+
max_width: 20
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 100
|
11 |
+
box_proba: 0
|
12 |
+
segm_proba: 0
|
13 |
+
squares_proba: 0
|
14 |
+
|
15 |
+
variants_n: 5
|
16 |
+
|
17 |
+
max_masks_per_image: 1
|
18 |
+
|
19 |
+
cropping:
|
20 |
+
out_min_size: 512
|
21 |
+
handle_small_mode: upscale
|
22 |
+
out_square_crop: True
|
23 |
+
crop_min_overlap: 1
|
24 |
+
|
25 |
+
max_tamper_area: 0.5
|
lama/configs/debug_mask_gen.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
img_ext: .jpg
|
2 |
+
|
3 |
+
gen_kwargs:
|
4 |
+
mask_size: 200
|
5 |
+
step: 0.5
|
lama/configs/eval1.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
evaluator_kwargs:
|
2 |
+
batch_size: 8
|
3 |
+
|
4 |
+
dataset_kwargs:
|
5 |
+
img_suffix: .png
|
6 |
+
inpainted_suffix: .jpg
|
lama/configs/eval2.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
evaluator_kwargs:
|
2 |
+
batch_size: 8
|
3 |
+
device: cuda
|
4 |
+
|
5 |
+
dataset_kwargs:
|
6 |
+
img_suffix: .png
|
7 |
+
inpainted_suffix: .png
|
lama/configs/eval2_cpu.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
evaluator_kwargs:
|
2 |
+
batch_size: 8
|
3 |
+
device: cpu
|
4 |
+
|
5 |
+
dataset_kwargs:
|
6 |
+
img_suffix: .png
|
7 |
+
inpainted_suffix: .png
|
lama/configs/eval2_gpu.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
evaluator_kwargs:
|
2 |
+
batch_size: 8
|
3 |
+
|
4 |
+
dataset_kwargs:
|
5 |
+
img_suffix: .png
|
6 |
+
inpainted_suffix: .png
|