Spaces:
Runtime error
Runtime error
vitaliykinakh
commited on
Commit
Β·
8d6cd57
1
Parent(s):
931138c
Initial
Browse files- .gitignore +351 -0
- app.py +20 -0
- requirements.txt +7 -0
- src/app/__init__.py +1 -0
- src/app/compare_models.py +292 -0
- src/app/explore_biggan.py +256 -0
- src/app/explore_cvae.py +255 -0
- src/app/explore_infoscc_gan.py +263 -0
- src/app/multipage.py +41 -0
- src/app/params.py +25 -0
- src/app/questions.py +36 -0
- src/data/__init__.py +2 -0
- src/data/data.py +10 -0
- src/data/labels.py +63 -0
- src/models/__init__.py +2 -0
- src/models/big/BigGAN2.py +469 -0
- src/models/big/LICENSE +21 -0
- src/models/big/README.md +144 -0
- src/models/big/__init__.py +0 -0
- src/models/big/animal_hash.py +439 -0
- src/models/big/cheat sheet +30 -0
- src/models/big/datasets.py +362 -0
- src/models/big/layers.py +456 -0
- src/models/big/sync_batchnorm/__init__.py +12 -0
- src/models/big/sync_batchnorm/batchnorm.py +349 -0
- src/models/big/sync_batchnorm/batchnorm_reimpl.py +74 -0
- src/models/big/sync_batchnorm/comm.py +137 -0
- src/models/big/sync_batchnorm/replicate.py +94 -0
- src/models/big/sync_batchnorm/unittest.py +29 -0
- src/models/big/utils.py +1193 -0
- src/models/cvae.py +75 -0
- src/models/infoscc_gan.py +272 -0
- src/models/neuralnetwork.py +46 -0
- src/models/parameter.py +53 -0
- src/utils/__init__.py +2 -0
- src/utils/utils.py +14 -0
.gitignore
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python,pycharm,jupyternotebooks,images
|
3 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm,jupyternotebooks,images
|
4 |
+
|
5 |
+
.idea/
|
6 |
+
models/InfoSCC-GAN/
|
7 |
+
models/BigGAN/
|
8 |
+
models/CVAE
|
9 |
+
*.csv
|
10 |
+
|
11 |
+
|
12 |
+
### Images ###
|
13 |
+
# JPEG
|
14 |
+
*.jpg
|
15 |
+
*.jpeg
|
16 |
+
*.jpe
|
17 |
+
*.jif
|
18 |
+
*.jfif
|
19 |
+
*.jfi
|
20 |
+
|
21 |
+
# JPEG 2000
|
22 |
+
*.jp2
|
23 |
+
*.j2k
|
24 |
+
*.jpf
|
25 |
+
*.jpx
|
26 |
+
*.jpm
|
27 |
+
*.mj2
|
28 |
+
|
29 |
+
# JPEG XR
|
30 |
+
*.jxr
|
31 |
+
*.hdp
|
32 |
+
*.wdp
|
33 |
+
|
34 |
+
# Graphics Interchange Format
|
35 |
+
*.gif
|
36 |
+
|
37 |
+
# RAW
|
38 |
+
*.raw
|
39 |
+
|
40 |
+
# Web P
|
41 |
+
*.webp
|
42 |
+
|
43 |
+
# Portable Network Graphics
|
44 |
+
*.png
|
45 |
+
|
46 |
+
# Animated Portable Network Graphics
|
47 |
+
*.apng
|
48 |
+
|
49 |
+
# Multiple-image Network Graphics
|
50 |
+
*.mng
|
51 |
+
|
52 |
+
# Tagged Image File Format
|
53 |
+
*.tiff
|
54 |
+
*.tif
|
55 |
+
|
56 |
+
# Scalable Vector Graphics
|
57 |
+
*.svg
|
58 |
+
*.svgz
|
59 |
+
|
60 |
+
# Portable Document Format
|
61 |
+
*.pdf
|
62 |
+
|
63 |
+
# X BitMap
|
64 |
+
*.xbm
|
65 |
+
|
66 |
+
# BMP
|
67 |
+
*.bmp
|
68 |
+
*.dib
|
69 |
+
|
70 |
+
# ICO
|
71 |
+
*.ico
|
72 |
+
|
73 |
+
# 3D Images
|
74 |
+
*.3dm
|
75 |
+
*.max
|
76 |
+
|
77 |
+
### JupyterNotebooks ###
|
78 |
+
# gitignore template for Jupyter Notebooks
|
79 |
+
# website: http://jupyter.org/
|
80 |
+
|
81 |
+
.ipynb_checkpoints
|
82 |
+
*/.ipynb_checkpoints/*
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# Remove previous ipynb_checkpoints
|
89 |
+
# git rm -r .ipynb_checkpoints/
|
90 |
+
|
91 |
+
### PyCharm ###
|
92 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
93 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
94 |
+
|
95 |
+
# User-specific stuff
|
96 |
+
.idea/**/workspace.xml
|
97 |
+
.idea/**/tasks.xml
|
98 |
+
.idea/**/usage.statistics.xml
|
99 |
+
.idea/**/dictionaries
|
100 |
+
.idea/**/shelf
|
101 |
+
|
102 |
+
# AWS User-specific
|
103 |
+
.idea/**/aws.xml
|
104 |
+
|
105 |
+
# Generated files
|
106 |
+
.idea/**/contentModel.xml
|
107 |
+
|
108 |
+
# Sensitive or high-churn files
|
109 |
+
.idea/**/dataSources/
|
110 |
+
.idea/**/dataSources.ids
|
111 |
+
.idea/**/dataSources.local.xml
|
112 |
+
.idea/**/sqlDataSources.xml
|
113 |
+
.idea/**/dynamic.xml
|
114 |
+
.idea/**/uiDesigner.xml
|
115 |
+
.idea/**/dbnavigator.xml
|
116 |
+
|
117 |
+
# Gradle
|
118 |
+
.idea/**/gradle.xml
|
119 |
+
.idea/**/libraries
|
120 |
+
|
121 |
+
# Gradle and Maven with auto-import
|
122 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
123 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
124 |
+
# auto-import.
|
125 |
+
# .idea/artifacts
|
126 |
+
# .idea/compiler.xml
|
127 |
+
# .idea/jarRepositories.xml
|
128 |
+
# .idea/modules.xml
|
129 |
+
# .idea/*.iml
|
130 |
+
# .idea/modules
|
131 |
+
# *.iml
|
132 |
+
# *.ipr
|
133 |
+
|
134 |
+
# CMake
|
135 |
+
cmake-build-*/
|
136 |
+
|
137 |
+
# Mongo Explorer plugin
|
138 |
+
.idea/**/mongoSettings.xml
|
139 |
+
|
140 |
+
# File-based project format
|
141 |
+
*.iws
|
142 |
+
|
143 |
+
# IntelliJ
|
144 |
+
out/
|
145 |
+
|
146 |
+
# mpeltonen/sbt-idea plugin
|
147 |
+
.idea_modules/
|
148 |
+
|
149 |
+
# JIRA plugin
|
150 |
+
atlassian-ide-plugin.xml
|
151 |
+
|
152 |
+
# Cursive Clojure plugin
|
153 |
+
.idea/replstate.xml
|
154 |
+
|
155 |
+
# SonarLint plugin
|
156 |
+
.idea/sonarlint/
|
157 |
+
|
158 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
159 |
+
com_crashlytics_export_strings.xml
|
160 |
+
crashlytics.properties
|
161 |
+
crashlytics-build.properties
|
162 |
+
fabric.properties
|
163 |
+
|
164 |
+
# Editor-based Rest Client
|
165 |
+
.idea/httpRequests
|
166 |
+
|
167 |
+
# Android studio 3.1+ serialized cache file
|
168 |
+
.idea/caches/build_file_checksums.ser
|
169 |
+
|
170 |
+
### PyCharm Patch ###
|
171 |
+
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
|
172 |
+
|
173 |
+
# *.iml
|
174 |
+
# modules.xml
|
175 |
+
# .idea/misc.xml
|
176 |
+
# *.ipr
|
177 |
+
|
178 |
+
# Sonarlint plugin
|
179 |
+
# https://plugins.jetbrains.com/plugin/7973-sonarlint
|
180 |
+
.idea/**/sonarlint/
|
181 |
+
|
182 |
+
# SonarQube Plugin
|
183 |
+
# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
|
184 |
+
.idea/**/sonarIssues.xml
|
185 |
+
|
186 |
+
# Markdown Navigator plugin
|
187 |
+
# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
|
188 |
+
.idea/**/markdown-navigator.xml
|
189 |
+
.idea/**/markdown-navigator-enh.xml
|
190 |
+
.idea/**/markdown-navigator/
|
191 |
+
|
192 |
+
# Cache file creation bug
|
193 |
+
# See https://youtrack.jetbrains.com/issue/JBR-2257
|
194 |
+
.idea/$CACHE_FILE$
|
195 |
+
|
196 |
+
# CodeStream plugin
|
197 |
+
# https://plugins.jetbrains.com/plugin/12206-codestream
|
198 |
+
.idea/codestream.xml
|
199 |
+
|
200 |
+
### Python ###
|
201 |
+
# Byte-compiled / optimized / DLL files
|
202 |
+
__pycache__/
|
203 |
+
*.py[cod]
|
204 |
+
*$py.class
|
205 |
+
|
206 |
+
# C extensions
|
207 |
+
*.so
|
208 |
+
|
209 |
+
# Distribution / packaging
|
210 |
+
.Python
|
211 |
+
build/
|
212 |
+
develop-eggs/
|
213 |
+
dist/
|
214 |
+
downloads/
|
215 |
+
eggs/
|
216 |
+
.eggs/
|
217 |
+
lib/
|
218 |
+
lib64/
|
219 |
+
parts/
|
220 |
+
sdist/
|
221 |
+
var/
|
222 |
+
wheels/
|
223 |
+
share/python-wheels/
|
224 |
+
*.egg-info/
|
225 |
+
.installed.cfg
|
226 |
+
*.egg
|
227 |
+
MANIFEST
|
228 |
+
|
229 |
+
# PyInstaller
|
230 |
+
# Usually these files are written by a python script from a template
|
231 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
232 |
+
*.manifest
|
233 |
+
*.spec
|
234 |
+
|
235 |
+
# Installer logs
|
236 |
+
pip-log.txt
|
237 |
+
pip-delete-this-directory.txt
|
238 |
+
|
239 |
+
# Unit test / coverage reports
|
240 |
+
htmlcov/
|
241 |
+
.tox/
|
242 |
+
.nox/
|
243 |
+
.coverage
|
244 |
+
.coverage.*
|
245 |
+
.cache
|
246 |
+
nosetests.xml
|
247 |
+
coverage.xml
|
248 |
+
*.cover
|
249 |
+
*.py,cover
|
250 |
+
.hypothesis/
|
251 |
+
.pytest_cache/
|
252 |
+
cover/
|
253 |
+
|
254 |
+
# Translations
|
255 |
+
*.mo
|
256 |
+
*.pot
|
257 |
+
|
258 |
+
# Django stuff:
|
259 |
+
*.log
|
260 |
+
local_settings.py
|
261 |
+
db.sqlite3
|
262 |
+
db.sqlite3-journal
|
263 |
+
|
264 |
+
# Flask stuff:
|
265 |
+
instance/
|
266 |
+
.webassets-cache
|
267 |
+
|
268 |
+
# Scrapy stuff:
|
269 |
+
.scrapy
|
270 |
+
|
271 |
+
# Sphinx documentation
|
272 |
+
docs/_build/
|
273 |
+
|
274 |
+
# PyBuilder
|
275 |
+
.pybuilder/
|
276 |
+
target/
|
277 |
+
|
278 |
+
# Jupyter Notebook
|
279 |
+
|
280 |
+
# IPython
|
281 |
+
|
282 |
+
# pyenv
|
283 |
+
# For a library or package, you might want to ignore these files since the code is
|
284 |
+
# intended to run in multiple environments; otherwise, check them in:
|
285 |
+
# .python-version
|
286 |
+
|
287 |
+
# pipenv
|
288 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
289 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
290 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
291 |
+
# install all needed dependencies.
|
292 |
+
#Pipfile.lock
|
293 |
+
|
294 |
+
# poetry
|
295 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
296 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
297 |
+
# commonly ignored for libraries.
|
298 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
299 |
+
#poetry.lock
|
300 |
+
|
301 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
302 |
+
__pypackages__/
|
303 |
+
|
304 |
+
# Celery stuff
|
305 |
+
celerybeat-schedule
|
306 |
+
celerybeat.pid
|
307 |
+
|
308 |
+
# SageMath parsed files
|
309 |
+
*.sage.py
|
310 |
+
|
311 |
+
# Environments
|
312 |
+
.env
|
313 |
+
.venv
|
314 |
+
env/
|
315 |
+
venv/
|
316 |
+
ENV/
|
317 |
+
env.bak/
|
318 |
+
venv.bak/
|
319 |
+
|
320 |
+
# Spyder project settings
|
321 |
+
.spyderproject
|
322 |
+
.spyproject
|
323 |
+
|
324 |
+
# Rope project settings
|
325 |
+
.ropeproject
|
326 |
+
|
327 |
+
# mkdocs documentation
|
328 |
+
/site
|
329 |
+
|
330 |
+
# mypy
|
331 |
+
.mypy_cache/
|
332 |
+
.dmypy.json
|
333 |
+
dmypy.json
|
334 |
+
|
335 |
+
# Pyre type checker
|
336 |
+
.pyre/
|
337 |
+
|
338 |
+
# pytype static type analyzer
|
339 |
+
.pytype/
|
340 |
+
|
341 |
+
# Cython debug symbols
|
342 |
+
cython_debug/
|
343 |
+
|
344 |
+
# PyCharm
|
345 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
346 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
347 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
348 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
349 |
+
#.idea/
|
350 |
+
|
351 |
+
# End of https://www.toptal.com/developers/gitignore/api/python,pycharm,jupyternotebooks,images
|
app.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
# Custom imports
|
4 |
+
from src.app import MultiPage
|
5 |
+
from src.app import explore_infoscc_gan, explore_biggan, explore_cvae, compare_models
|
6 |
+
|
7 |
+
# Create an instance of the app
|
8 |
+
app = MultiPage()
|
9 |
+
|
10 |
+
# Title of the main page
|
11 |
+
st.title('Galaxy Zoo generation')
|
12 |
+
|
13 |
+
# Add all your applications (pages) here
|
14 |
+
app.add_page('Compare models', compare_models.app)
|
15 |
+
app.add_page('Explore BigGAN', explore_biggan.app)
|
16 |
+
app.add_page('Explore cVAE', explore_cvae.app)
|
17 |
+
app.add_page('Explore InfoSCC-GAN', explore_infoscc_gan.app)
|
18 |
+
|
19 |
+
# The main app
|
20 |
+
app.run()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
gdown
|
3 |
+
googledrivedownloader==0.4
|
4 |
+
pandas==1.4.1
|
5 |
+
streamlit==1.7.0
|
6 |
+
torch==1.9.1+cpu
|
7 |
+
torchvision==0.10.1+cpu
|
src/app/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .multipage import MultiPage
|
src/app/compare_models.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import math
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import src.app.params as params
|
11 |
+
from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4, q4_options, q5, q5_options, \
|
12 |
+
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
13 |
+
from src.models import ConditionalGenerator as InfoSCC_GAN
|
14 |
+
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
15 |
+
from src.models import ConditionalDecoder as cVAE
|
16 |
+
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
17 |
+
from src.utils import download_file, sample_labels
|
18 |
+
|
19 |
+
|
20 |
+
device = params.device
|
21 |
+
bs = 10 # number of images to generate each model
|
22 |
+
n_cols = int(math.sqrt(bs))
|
23 |
+
size = params.size
|
24 |
+
n_layers = int(math.log2(size) - 2)
|
25 |
+
|
26 |
+
# manual labels
|
27 |
+
q1_out = [0] * len(q1_options)
|
28 |
+
q2_out = [0] * len(q2_options)
|
29 |
+
q3_out = [0] * len(q3_options)
|
30 |
+
q4_out = [0] * len(q4_options)
|
31 |
+
q5_out = [0] * len(q5_options)
|
32 |
+
q6_out = [0] * len(q6_options)
|
33 |
+
q7_out = [0] * len(q7_options)
|
34 |
+
q8_out = [0] * len(q8_options)
|
35 |
+
q9_out = [0] * len(q9_options)
|
36 |
+
q10_out = [0] * len(q10_options)
|
37 |
+
q11_out = [0] * len(q11_options)
|
38 |
+
|
39 |
+
|
40 |
+
def clear_out(elems=None):
|
41 |
+
global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
|
42 |
+
|
43 |
+
if elems is None:
|
44 |
+
elems = list(range(1, 12))
|
45 |
+
|
46 |
+
if 1 in elems:
|
47 |
+
q1_out = [0] * len(q1_options)
|
48 |
+
if 2 in elems:
|
49 |
+
q2_out = [0] * len(q2_options)
|
50 |
+
if 3 in elems:
|
51 |
+
q3_out = [0] * len(q3_options)
|
52 |
+
if 4 in elems:
|
53 |
+
q4_out = [0] * len(q4_options)
|
54 |
+
if 5 in elems:
|
55 |
+
q5_out = [0] * len(q5_options)
|
56 |
+
if 6 in elems:
|
57 |
+
q6_out = [0] * len(q6_options)
|
58 |
+
if 7 in elems:
|
59 |
+
q7_out = [0] * len(q7_options)
|
60 |
+
if 8 in elems:
|
61 |
+
q8_out = [0] * len(q8_options)
|
62 |
+
if 9 in elems:
|
63 |
+
q9_out = [0] * len(q9_options)
|
64 |
+
if 10 in elems:
|
65 |
+
q10_out = [0] * len(q10_options)
|
66 |
+
if 11 in elems:
|
67 |
+
q11_out = [0] * len(q11_options)
|
68 |
+
|
69 |
+
|
70 |
+
@st.cache(allow_output_mutation=True)
|
71 |
+
def load_model(model_type: str):
|
72 |
+
|
73 |
+
print(f'Loading model: {model_type}')
|
74 |
+
if model_type == 'InfoSCC-GAN':
|
75 |
+
g = InfoSCC_GAN(size=params.size,
|
76 |
+
y_size=params.shape_label,
|
77 |
+
z_size=params.noise_dim)
|
78 |
+
|
79 |
+
if not Path(params.path_infoscc_gan).exists():
|
80 |
+
download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
|
81 |
+
|
82 |
+
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
83 |
+
g.load_state_dict(ckpt['g_ema'])
|
84 |
+
elif model_type == 'BigGAN':
|
85 |
+
g = BigGAN2Generator()
|
86 |
+
|
87 |
+
if not Path(params.path_biggan).exists():
|
88 |
+
download_file(params.drive_id_biggan, params.path_biggan)
|
89 |
+
|
90 |
+
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
91 |
+
g.load_state_dict(ckpt)
|
92 |
+
elif model_type == 'cVAE':
|
93 |
+
g = cVAE()
|
94 |
+
|
95 |
+
if not Path(params.path_cvae).exists():
|
96 |
+
download_file(params.drive_id_cvae, params.path_cvae)
|
97 |
+
|
98 |
+
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
99 |
+
g.load_state_dict(ckpt)
|
100 |
+
else:
|
101 |
+
raise ValueError('Unsupported model')
|
102 |
+
g = g.eval().to(device=params.device)
|
103 |
+
return g
|
104 |
+
|
105 |
+
|
106 |
+
@st.cache
|
107 |
+
def get_labels() -> torch.Tensor:
|
108 |
+
path_labels = params.path_labels
|
109 |
+
|
110 |
+
if not Path(path_labels).exists():
|
111 |
+
download_file(params.drive_id_labels, path_labels)
|
112 |
+
|
113 |
+
labels_train = get_labels_train(path_labels)
|
114 |
+
return labels_train
|
115 |
+
|
116 |
+
|
117 |
+
def get_eps(n: int) -> torch.Tensor:
|
118 |
+
eps = torch.randn((n, params.dim_z), device=device)
|
119 |
+
return eps
|
120 |
+
|
121 |
+
|
122 |
+
def app():
|
123 |
+
global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
|
124 |
+
|
125 |
+
st.title('Compare models')
|
126 |
+
st.markdown('This demo allows to compare BigGAN, InfoSCC-GAN and cVAE models for conditional galaxy generation.')
|
127 |
+
st.markdown('In each there there are images generated with the same labels by each of the models')
|
128 |
+
|
129 |
+
biggan = load_model('BigGAN')
|
130 |
+
infoscc_gan = load_model('InfoSCC-GAN')
|
131 |
+
cvae = load_model('cVAE')
|
132 |
+
labels_train = get_labels()
|
133 |
+
|
134 |
+
eps = get_eps(bs) # for BigGAN and cVAE
|
135 |
+
eps_infoscc = infoscc_gan.sample_eps(bs)
|
136 |
+
|
137 |
+
zs = np.array([[0.0] * params.n_basis] * n_layers, dtype=np.float32)
|
138 |
+
zs_torch = torch.from_numpy(zs).unsqueeze(0).repeat(bs, 1, 1).to(device)
|
139 |
+
|
140 |
+
# ========================== Labels ================================
|
141 |
+
st.subheader('Label')
|
142 |
+
st.markdown(r'There are two types of selecting labels: __Random__ - sample random samples from the dataset;'
|
143 |
+
r' __Manual__ - select labels manually (advanced use). When using __Manual__ all of the images will be'
|
144 |
+
r' generated with tha same labels')
|
145 |
+
label_type = st.radio('Label type', options=['Random', 'Manual (Advanced)'])
|
146 |
+
if label_type == 'Random':
|
147 |
+
labels = sample_labels(labels_train, bs).to(device)
|
148 |
+
|
149 |
+
st.markdown(r'Click on __Sample labels__ button to sample random input labels')
|
150 |
+
change_label = st.button('Sample label')
|
151 |
+
|
152 |
+
if change_label:
|
153 |
+
labels = sample_labels(labels_train, bs).to(device)
|
154 |
+
elif label_type == 'Manual (Advanced)':
|
155 |
+
st.markdown('Answer the questions below')
|
156 |
+
|
157 |
+
q1_select_box = st.selectbox(q1, options=q1_options)
|
158 |
+
clear_out()
|
159 |
+
q1_out[q1_options.index(q1_select_box)] = 1
|
160 |
+
# 1
|
161 |
+
|
162 |
+
if q1_select_box == 'Smooth':
|
163 |
+
q7_select_box = st.selectbox(q7, options=q7_options)
|
164 |
+
clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
165 |
+
q7_out[q7_options.index(q7_select_box)] = 1
|
166 |
+
# 1 - 7
|
167 |
+
|
168 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
169 |
+
clear_out([2, 3, 4, 5, 6, 8, 9, 10, 11])
|
170 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
171 |
+
# 1 - 7 - 6
|
172 |
+
|
173 |
+
if q6_select_box == 'Yes':
|
174 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
175 |
+
clear_out([2, 3, 4, 5, 8, 9, 10, 11])
|
176 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
177 |
+
# 1 - 7 - 6 - 8 - end
|
178 |
+
|
179 |
+
elif q1_select_box == 'Features or disk':
|
180 |
+
q2_select_box = st.selectbox(q2, options=q2_options)
|
181 |
+
clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
182 |
+
q2_out[q2_options.index(q2_select_box)] = 1
|
183 |
+
# 1 - 2
|
184 |
+
|
185 |
+
if q2_select_box == 'Yes':
|
186 |
+
q9_select_box = st.selectbox(q9, options=q9_options)
|
187 |
+
clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
|
188 |
+
q9_out[q9_options.index(q9_select_box)] = 1
|
189 |
+
# 1 - 2 - 9
|
190 |
+
|
191 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
192 |
+
clear_out([3, 4, 5, 6, 7, 8, 10, 11])
|
193 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
194 |
+
# 1 - 2 - 9 - 6
|
195 |
+
|
196 |
+
if q6_select_box == 'Yes':
|
197 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
198 |
+
clear_out([3, 4, 5, 7, 8, 10, 11])
|
199 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
200 |
+
# 1 - 2 - 9 - 6 - 8
|
201 |
+
else:
|
202 |
+
q3_select_box = st.selectbox(q3, options=q3_options)
|
203 |
+
clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
|
204 |
+
q3_out[q3_options.index(q3_select_box)] = 1
|
205 |
+
# 1 - 2 - 3
|
206 |
+
|
207 |
+
q4_select_box = st.selectbox(q4, options=q4_options)
|
208 |
+
clear_out([4, 5, 6, 7, 8, 9, 10, 11])
|
209 |
+
q4_out[q4_options.index(q4_select_box)] = 1
|
210 |
+
# 1 - 2 - 3 - 4
|
211 |
+
|
212 |
+
if q4_select_box == 'Yes':
|
213 |
+
q10_select_box = st.selectbox(q10, options=q10_options)
|
214 |
+
clear_out([5, 6, 7, 8, 9, 10, 11])
|
215 |
+
q10_out[q10_options.index(q10_select_box)] = 1
|
216 |
+
# 1 - 2 - 3 - 4 - 10
|
217 |
+
|
218 |
+
q11_select_box = st.selectbox(q11, options=q11_options)
|
219 |
+
clear_out([5, 6, 7, 8, 9, 11])
|
220 |
+
q11_out[q11_options.index(q11_select_box)] = 1
|
221 |
+
# 1 - 2 - 3 - 4 - 10 - 11
|
222 |
+
|
223 |
+
q5_select_box = st.selectbox(q5, options=q5_options)
|
224 |
+
clear_out([5, 6, 7, 8, 9])
|
225 |
+
q5_out[q5_options.index(q5_select_box)] = 1
|
226 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5
|
227 |
+
|
228 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
229 |
+
clear_out([6, 7, 8, 9])
|
230 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
231 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6
|
232 |
+
|
233 |
+
if q6_select_box == 'Yes':
|
234 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
235 |
+
clear_out([7, 8, 9])
|
236 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
237 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6 - 8 - End
|
238 |
+
else:
|
239 |
+
q5_select_box = st.selectbox(q5, options=q5_options)
|
240 |
+
clear_out([5, 6, 7, 8, 9, 10, 11])
|
241 |
+
q5_out[q5_options.index(q5_select_box)] = 1
|
242 |
+
# 1 - 2 - 3 - 4 - 5
|
243 |
+
|
244 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
245 |
+
clear_out([6, 7, 8, 9, 10, 11])
|
246 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
247 |
+
# 1 - 2 - 3 - 4 - 5 - 6
|
248 |
+
|
249 |
+
if q6_select_box == 'Yes':
|
250 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
251 |
+
clear_out([7, 8, 9, 10, 11])
|
252 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
253 |
+
# 1 - 2 - 3 - 4 - 5 - 6 - 8 - End
|
254 |
+
|
255 |
+
labels = [*q1_out, *q2_out, *q3_out, *q4_out, *q5_out, *q6_out, *q7_out, *q8_out, *q9_out, *q10_out, *q11_out]
|
256 |
+
labels = torch.Tensor(labels).to(device)
|
257 |
+
labels = labels.unsqueeze(0).repeat(bs, 1)
|
258 |
+
labels = make_galaxy_labels_hierarchical(labels)
|
259 |
+
clear_out()
|
260 |
+
# ========================== Labels ================================
|
261 |
+
|
262 |
+
st.subheader('Noise')
|
263 |
+
st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
|
264 |
+
change_eps = st.button('Change eps')
|
265 |
+
if change_eps:
|
266 |
+
eps = get_eps(bs) # for BigGAN and cVAE
|
267 |
+
eps_infoscc = infoscc_gan.sample_eps(bs)
|
268 |
+
|
269 |
+
with torch.no_grad():
|
270 |
+
imgs_biggan = biggan(eps, labels).squeeze(0).cpu()
|
271 |
+
imgs_infoscc = infoscc_gan(labels, eps_infoscc, zs_torch).squeeze(0).cpu()
|
272 |
+
imgs_cvae = cvae(eps, labels).squeeze(0).cpu()
|
273 |
+
|
274 |
+
if params.upsample:
|
275 |
+
imgs_biggan = F.interpolate(imgs_biggan, (size * 4, size * 4), mode='bicubic')
|
276 |
+
imgs_infoscc = F.interpolate(imgs_infoscc, (size * 4, size * 4), mode='bicubic')
|
277 |
+
imgs_cvae = F.interpolate(imgs_cvae, (size * 4, size * 4), mode='bicubic')
|
278 |
+
|
279 |
+
imgs_biggan = torch.clip(imgs_biggan, 0, 1)
|
280 |
+
imgs_biggan = [(imgs_biggan[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) for i in range(bs)]
|
281 |
+
imgs_infoscc = [(imgs_infoscc[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
|
282 |
+
imgs_cvae = [(imgs_cvae[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
|
283 |
+
|
284 |
+
c1, c2, c3 = st.columns(3)
|
285 |
+
c1.header('BigGAN')
|
286 |
+
c1.image(imgs_biggan, use_column_width=True)
|
287 |
+
|
288 |
+
c2.header('InfoSCC-GAN')
|
289 |
+
c2.image(imgs_infoscc, use_column_width=True)
|
290 |
+
|
291 |
+
c3.header('cVAE')
|
292 |
+
c3.image(imgs_cvae, use_column_width=True)
|
src/app/explore_biggan.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import src.app.params as params
|
11 |
+
from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4, q4_options, q5, q5_options, \
|
12 |
+
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
13 |
+
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
14 |
+
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
15 |
+
from src.utils import download_file, sample_labels
|
16 |
+
|
17 |
+
|
18 |
+
# global parameters
|
19 |
+
device = params.device
|
20 |
+
size = params.size
|
21 |
+
y_size = shape_label = params.shape_label
|
22 |
+
n_channels = params.n_channels
|
23 |
+
upsample = params.upsample
|
24 |
+
dim_z = params.dim_z
|
25 |
+
bs = 16 # number of samples to generate
|
26 |
+
n_cols = int(math.sqrt(bs))
|
27 |
+
model_path = params.path_biggan
|
28 |
+
drive_id = params.drive_id_biggan
|
29 |
+
path_labels = params.path_labels
|
30 |
+
|
31 |
+
# manual labels
|
32 |
+
q1_out = [0] * len(q1_options)
|
33 |
+
q2_out = [0] * len(q2_options)
|
34 |
+
q3_out = [0] * len(q3_options)
|
35 |
+
q4_out = [0] * len(q4_options)
|
36 |
+
q5_out = [0] * len(q5_options)
|
37 |
+
q6_out = [0] * len(q6_options)
|
38 |
+
q7_out = [0] * len(q7_options)
|
39 |
+
q8_out = [0] * len(q8_options)
|
40 |
+
q9_out = [0] * len(q9_options)
|
41 |
+
q10_out = [0] * len(q10_options)
|
42 |
+
q11_out = [0] * len(q11_options)
|
43 |
+
|
44 |
+
|
45 |
+
def clear_out(elems=None):
|
46 |
+
global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
|
47 |
+
|
48 |
+
if elems is None:
|
49 |
+
elems = list(range(1, 12))
|
50 |
+
|
51 |
+
if 1 in elems:
|
52 |
+
q1_out = [0] * len(q1_options)
|
53 |
+
if 2 in elems:
|
54 |
+
q2_out = [0] * len(q2_options)
|
55 |
+
if 3 in elems:
|
56 |
+
q3_out = [0] * len(q3_options)
|
57 |
+
if 4 in elems:
|
58 |
+
q4_out = [0] * len(q4_options)
|
59 |
+
if 5 in elems:
|
60 |
+
q5_out = [0] * len(q5_options)
|
61 |
+
if 6 in elems:
|
62 |
+
q6_out = [0] * len(q6_options)
|
63 |
+
if 7 in elems:
|
64 |
+
q7_out = [0] * len(q7_options)
|
65 |
+
if 8 in elems:
|
66 |
+
q8_out = [0] * len(q8_options)
|
67 |
+
if 9 in elems:
|
68 |
+
q9_out = [0] * len(q9_options)
|
69 |
+
if 10 in elems:
|
70 |
+
q10_out = [0] * len(q10_options)
|
71 |
+
if 11 in elems:
|
72 |
+
q11_out = [0] * len(q11_options)
|
73 |
+
|
74 |
+
|
75 |
+
@st.cache(allow_output_mutation=True)
|
76 |
+
def load_model(model_path: str) -> BigGAN2Generator:
|
77 |
+
|
78 |
+
print(f'Loading model: {model_path}')
|
79 |
+
g = BigGAN2Generator()
|
80 |
+
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
81 |
+
g.load_state_dict(ckpt)
|
82 |
+
g.eval().to(device)
|
83 |
+
return g
|
84 |
+
|
85 |
+
|
86 |
+
def get_eps(n: int) -> torch.Tensor:
|
87 |
+
eps = torch.randn((n, dim_z), device=device)
|
88 |
+
return eps
|
89 |
+
|
90 |
+
|
91 |
+
@st.cache
|
92 |
+
def get_labels() -> torch.Tensor:
|
93 |
+
if not Path(path_labels).exists():
|
94 |
+
download_file(params.drive_id_labels, path_labels)
|
95 |
+
|
96 |
+
labels_train = get_labels_train(path_labels)
|
97 |
+
return labels_train
|
98 |
+
|
99 |
+
|
100 |
+
def app():
|
101 |
+
global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
|
102 |
+
|
103 |
+
st.title('Explore BigGAN')
|
104 |
+
st.markdown('This demo shows BigGAN for conditional galaxy generation')
|
105 |
+
|
106 |
+
if not Path(model_path).exists():
|
107 |
+
download_file(drive_id, model_path)
|
108 |
+
|
109 |
+
model = load_model(model_path)
|
110 |
+
eps = get_eps(bs)
|
111 |
+
labels_train = get_labels()
|
112 |
+
|
113 |
+
# ========================== Labels ================================
|
114 |
+
st.subheader('Label')
|
115 |
+
st.markdown(r'There are two types of selecting labels: __Random__ - sample random samples from the dataset;'
|
116 |
+
r' __Manual__ - select labels manually (advanced use). When using __Manual__ all of the images will be'
|
117 |
+
r' generated with tha same labels')
|
118 |
+
label_type = st.radio('Label type', options=['Random', 'Manual (Advanced)'])
|
119 |
+
if label_type == 'Random':
|
120 |
+
labels = sample_labels(labels_train, bs).to(device)
|
121 |
+
|
122 |
+
st.markdown(r'Click on __Sample labels__ button to sample random input labels')
|
123 |
+
change_label = st.button('Sample label')
|
124 |
+
|
125 |
+
if change_label:
|
126 |
+
labels = sample_labels(labels_train, bs).to(device)
|
127 |
+
elif label_type == 'Manual (Advanced)':
|
128 |
+
st.markdown('Answer the questions below')
|
129 |
+
|
130 |
+
q1_select_box = st.selectbox(q1, options=q1_options)
|
131 |
+
clear_out()
|
132 |
+
q1_out[q1_options.index(q1_select_box)] = 1
|
133 |
+
# 1
|
134 |
+
|
135 |
+
if q1_select_box == 'Smooth':
|
136 |
+
q7_select_box = st.selectbox(q7, options=q7_options)
|
137 |
+
clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
138 |
+
q7_out[q7_options.index(q7_select_box)] = 1
|
139 |
+
# 1 - 7
|
140 |
+
|
141 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
142 |
+
clear_out([2, 3, 4, 5, 6, 8, 9, 10, 11])
|
143 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
144 |
+
# 1 - 7 - 6
|
145 |
+
|
146 |
+
if q6_select_box == 'Yes':
|
147 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
148 |
+
clear_out([2, 3, 4, 5, 8, 9, 10, 11])
|
149 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
150 |
+
# 1 - 7 - 6 - 8 - end
|
151 |
+
|
152 |
+
elif q1_select_box == 'Features or disk':
|
153 |
+
q2_select_box = st.selectbox(q2, options=q2_options)
|
154 |
+
clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
155 |
+
q2_out[q2_options.index(q2_select_box)] = 1
|
156 |
+
# 1 - 2
|
157 |
+
|
158 |
+
if q2_select_box == 'Yes':
|
159 |
+
q9_select_box = st.selectbox(q9, options=q9_options)
|
160 |
+
clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
|
161 |
+
q9_out[q9_options.index(q9_select_box)] = 1
|
162 |
+
# 1 - 2 - 9
|
163 |
+
|
164 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
165 |
+
clear_out([3, 4, 5, 6, 7, 8, 10, 11])
|
166 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
167 |
+
# 1 - 2 - 9 - 6
|
168 |
+
|
169 |
+
if q6_select_box == 'Yes':
|
170 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
171 |
+
clear_out([3, 4, 5, 7, 8, 10, 11])
|
172 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
173 |
+
# 1 - 2 - 9 - 6 - 8
|
174 |
+
else:
|
175 |
+
q3_select_box = st.selectbox(q3, options=q3_options)
|
176 |
+
clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
|
177 |
+
q3_out[q3_options.index(q3_select_box)] = 1
|
178 |
+
# 1 - 2 - 3
|
179 |
+
|
180 |
+
q4_select_box = st.selectbox(q4, options=q4_options)
|
181 |
+
clear_out([4, 5, 6, 7, 8, 9, 10, 11])
|
182 |
+
q4_out[q4_options.index(q4_select_box)] = 1
|
183 |
+
# 1 - 2 - 3 - 4
|
184 |
+
|
185 |
+
if q4_select_box == 'Yes':
|
186 |
+
q10_select_box = st.selectbox(q10, options=q10_options)
|
187 |
+
clear_out([5, 6, 7, 8, 9, 10, 11])
|
188 |
+
q10_out[q10_options.index(q10_select_box)] = 1
|
189 |
+
# 1 - 2 - 3 - 4 - 10
|
190 |
+
|
191 |
+
q11_select_box = st.selectbox(q11, options=q11_options)
|
192 |
+
clear_out([5, 6, 7, 8, 9, 11])
|
193 |
+
q11_out[q11_options.index(q11_select_box)] = 1
|
194 |
+
# 1 - 2 - 3 - 4 - 10 - 11
|
195 |
+
|
196 |
+
q5_select_box = st.selectbox(q5, options=q5_options)
|
197 |
+
clear_out([5, 6, 7, 8, 9])
|
198 |
+
q5_out[q5_options.index(q5_select_box)] = 1
|
199 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5
|
200 |
+
|
201 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
202 |
+
clear_out([6, 7, 8, 9])
|
203 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
204 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6
|
205 |
+
|
206 |
+
if q6_select_box == 'Yes':
|
207 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
208 |
+
clear_out([7, 8, 9])
|
209 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
210 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6 - 8 - End
|
211 |
+
else:
|
212 |
+
q5_select_box = st.selectbox(q5, options=q5_options)
|
213 |
+
clear_out([5, 6, 7, 8, 9, 10, 11])
|
214 |
+
q5_out[q5_options.index(q5_select_box)] = 1
|
215 |
+
# 1 - 2 - 3 - 4 - 5
|
216 |
+
|
217 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
218 |
+
clear_out([6, 7, 8, 9, 10, 11])
|
219 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
220 |
+
# 1 - 2 - 3 - 4 - 5 - 6
|
221 |
+
|
222 |
+
if q6_select_box == 'Yes':
|
223 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
224 |
+
clear_out([7, 8, 9, 10, 11])
|
225 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
226 |
+
# 1 - 2 - 3 - 4 - 5 - 6 - 8 - End
|
227 |
+
|
228 |
+
labels = [*q1_out, *q2_out, *q3_out, *q4_out, *q5_out, *q6_out, *q7_out, *q8_out, *q9_out, *q10_out, *q11_out]
|
229 |
+
labels = torch.Tensor(labels).to(device)
|
230 |
+
labels = labels.unsqueeze(0).repeat(bs, 1)
|
231 |
+
labels = make_galaxy_labels_hierarchical(labels)
|
232 |
+
clear_out()
|
233 |
+
# ========================== Labels ================================
|
234 |
+
|
235 |
+
st.subheader('Noise')
|
236 |
+
st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
|
237 |
+
change_eps = st.button('Change eps')
|
238 |
+
if change_eps:
|
239 |
+
eps = get_eps(bs)
|
240 |
+
|
241 |
+
with torch.no_grad():
|
242 |
+
imgs = model(eps, labels)
|
243 |
+
|
244 |
+
if upsample:
|
245 |
+
imgs = F.interpolate(imgs, (size * 4, size * 4), mode='bicubic')
|
246 |
+
|
247 |
+
imgs = torch.clip(imgs, 0, 1)
|
248 |
+
imgs = [(imgs[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) for i in range(bs)]
|
249 |
+
|
250 |
+
counter = 0
|
251 |
+
for r in range(bs // n_cols):
|
252 |
+
cols = st.columns(n_cols)
|
253 |
+
|
254 |
+
for c in range(n_cols):
|
255 |
+
cols[c].image(imgs[counter])
|
256 |
+
counter += 1
|
src/app/explore_cvae.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import src.app.params as params
|
11 |
+
from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4, q4_options, q5, q5_options, \
|
12 |
+
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
13 |
+
from src.models import ConditionalDecoder
|
14 |
+
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
15 |
+
from src.utils import download_file, sample_labels
|
16 |
+
|
17 |
+
|
18 |
+
# global parameters
|
19 |
+
device = params.device
|
20 |
+
size = params.size
|
21 |
+
y_size = shape_label = params.shape_label
|
22 |
+
n_channels = params.n_channels
|
23 |
+
upsample = params.upsample
|
24 |
+
dim_z = params.dim_z
|
25 |
+
bs = 16 # number of samples to generate
|
26 |
+
n_cols = int(math.sqrt(bs))
|
27 |
+
model_path = params.path_cvae
|
28 |
+
drive_id = params.drive_id_cvae
|
29 |
+
path_labels = params.path_labels
|
30 |
+
|
31 |
+
# manual labels
|
32 |
+
q1_out = [0] * len(q1_options)
|
33 |
+
q2_out = [0] * len(q2_options)
|
34 |
+
q3_out = [0] * len(q3_options)
|
35 |
+
q4_out = [0] * len(q4_options)
|
36 |
+
q5_out = [0] * len(q5_options)
|
37 |
+
q6_out = [0] * len(q6_options)
|
38 |
+
q7_out = [0] * len(q7_options)
|
39 |
+
q8_out = [0] * len(q8_options)
|
40 |
+
q9_out = [0] * len(q9_options)
|
41 |
+
q10_out = [0] * len(q10_options)
|
42 |
+
q11_out = [0] * len(q11_options)
|
43 |
+
|
44 |
+
|
45 |
+
def clear_out(elems=None):
|
46 |
+
global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
|
47 |
+
|
48 |
+
if elems is None:
|
49 |
+
elems = list(range(1, 12))
|
50 |
+
|
51 |
+
if 1 in elems:
|
52 |
+
q1_out = [0] * len(q1_options)
|
53 |
+
if 2 in elems:
|
54 |
+
q2_out = [0] * len(q2_options)
|
55 |
+
if 3 in elems:
|
56 |
+
q3_out = [0] * len(q3_options)
|
57 |
+
if 4 in elems:
|
58 |
+
q4_out = [0] * len(q4_options)
|
59 |
+
if 5 in elems:
|
60 |
+
q5_out = [0] * len(q5_options)
|
61 |
+
if 6 in elems:
|
62 |
+
q6_out = [0] * len(q6_options)
|
63 |
+
if 7 in elems:
|
64 |
+
q7_out = [0] * len(q7_options)
|
65 |
+
if 8 in elems:
|
66 |
+
q8_out = [0] * len(q8_options)
|
67 |
+
if 9 in elems:
|
68 |
+
q9_out = [0] * len(q9_options)
|
69 |
+
if 10 in elems:
|
70 |
+
q10_out = [0] * len(q10_options)
|
71 |
+
if 11 in elems:
|
72 |
+
q11_out = [0] * len(q11_options)
|
73 |
+
|
74 |
+
|
75 |
+
@st.cache(allow_output_mutation=True)
|
76 |
+
def load_model(model_path: str) -> ConditionalDecoder:
|
77 |
+
|
78 |
+
print(f'Loading model: {model_path}')
|
79 |
+
g = ConditionalDecoder()
|
80 |
+
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
81 |
+
g.load_state_dict(ckpt)
|
82 |
+
g.eval().to(device)
|
83 |
+
return g
|
84 |
+
|
85 |
+
|
86 |
+
def get_eps(n: int) -> torch.Tensor:
|
87 |
+
eps = torch.randn((n, dim_z), device=device)
|
88 |
+
return eps
|
89 |
+
|
90 |
+
|
91 |
+
@st.cache
|
92 |
+
def get_labels() -> torch.Tensor:
|
93 |
+
if not Path(path_labels).exists():
|
94 |
+
download_file(params.drive_id_labels, path_labels)
|
95 |
+
|
96 |
+
labels_train = get_labels_train(path_labels)
|
97 |
+
return labels_train
|
98 |
+
|
99 |
+
|
100 |
+
def app():
|
101 |
+
global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
|
102 |
+
|
103 |
+
st.title('Explore cVAE')
|
104 |
+
st.markdown('This demo shows cVAE for conditional galaxy generation')
|
105 |
+
|
106 |
+
if not Path(model_path).exists():
|
107 |
+
download_file(drive_id, model_path)
|
108 |
+
|
109 |
+
model = load_model(model_path)
|
110 |
+
eps = get_eps(bs)
|
111 |
+
labels_train = get_labels()
|
112 |
+
|
113 |
+
# ========================== Labels ================================
|
114 |
+
st.subheader('Label')
|
115 |
+
st.markdown(r'There are two types of selecting labels: __Random__ - sample random samples from the dataset;'
|
116 |
+
r' __Manual__ - select labels manually (advanced use). When using __Manual__ all of the images will be'
|
117 |
+
r' generated with tha same labels')
|
118 |
+
label_type = st.radio('Label type', options=['Random', 'Manual (Advanced)'])
|
119 |
+
if label_type == 'Random':
|
120 |
+
labels = sample_labels(labels_train, bs).to(device)
|
121 |
+
|
122 |
+
st.markdown(r'Click on __Sample labels__ button to sample random input labels')
|
123 |
+
change_label = st.button('Sample label')
|
124 |
+
|
125 |
+
if change_label:
|
126 |
+
labels = sample_labels(labels_train, bs).to(device)
|
127 |
+
elif label_type == 'Manual (Advanced)':
|
128 |
+
st.markdown('Answer the questions below')
|
129 |
+
|
130 |
+
q1_select_box = st.selectbox(q1, options=q1_options)
|
131 |
+
clear_out()
|
132 |
+
q1_out[q1_options.index(q1_select_box)] = 1
|
133 |
+
# 1
|
134 |
+
|
135 |
+
if q1_select_box == 'Smooth':
|
136 |
+
q7_select_box = st.selectbox(q7, options=q7_options)
|
137 |
+
clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
138 |
+
q7_out[q7_options.index(q7_select_box)] = 1
|
139 |
+
# 1 - 7
|
140 |
+
|
141 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
142 |
+
clear_out([2, 3, 4, 5, 6, 8, 9, 10, 11])
|
143 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
144 |
+
# 1 - 7 - 6
|
145 |
+
|
146 |
+
if q6_select_box == 'Yes':
|
147 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
148 |
+
clear_out([2, 3, 4, 5, 8, 9, 10, 11])
|
149 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
150 |
+
# 1 - 7 - 6 - 8 - end
|
151 |
+
|
152 |
+
elif q1_select_box == 'Features or disk':
|
153 |
+
q2_select_box = st.selectbox(q2, options=q2_options)
|
154 |
+
clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
155 |
+
q2_out[q2_options.index(q2_select_box)] = 1
|
156 |
+
# 1 - 2
|
157 |
+
|
158 |
+
if q2_select_box == 'Yes':
|
159 |
+
q9_select_box = st.selectbox(q9, options=q9_options)
|
160 |
+
clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
|
161 |
+
q9_out[q9_options.index(q9_select_box)] = 1
|
162 |
+
# 1 - 2 - 9
|
163 |
+
|
164 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
165 |
+
clear_out([3, 4, 5, 6, 7, 8, 10, 11])
|
166 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
167 |
+
# 1 - 2 - 9 - 6
|
168 |
+
|
169 |
+
if q6_select_box == 'Yes':
|
170 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
171 |
+
clear_out([3, 4, 5, 7, 8, 10, 11])
|
172 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
173 |
+
# 1 - 2 - 9 - 6 - 8
|
174 |
+
else:
|
175 |
+
q3_select_box = st.selectbox(q3, options=q3_options)
|
176 |
+
clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
|
177 |
+
q3_out[q3_options.index(q3_select_box)] = 1
|
178 |
+
# 1 - 2 - 3
|
179 |
+
|
180 |
+
q4_select_box = st.selectbox(q4, options=q4_options)
|
181 |
+
clear_out([4, 5, 6, 7, 8, 9, 10, 11])
|
182 |
+
q4_out[q4_options.index(q4_select_box)] = 1
|
183 |
+
# 1 - 2 - 3 - 4
|
184 |
+
|
185 |
+
if q4_select_box == 'Yes':
|
186 |
+
q10_select_box = st.selectbox(q10, options=q10_options)
|
187 |
+
clear_out([5, 6, 7, 8, 9, 10, 11])
|
188 |
+
q10_out[q10_options.index(q10_select_box)] = 1
|
189 |
+
# 1 - 2 - 3 - 4 - 10
|
190 |
+
|
191 |
+
q11_select_box = st.selectbox(q11, options=q11_options)
|
192 |
+
clear_out([5, 6, 7, 8, 9, 11])
|
193 |
+
q11_out[q11_options.index(q11_select_box)] = 1
|
194 |
+
# 1 - 2 - 3 - 4 - 10 - 11
|
195 |
+
|
196 |
+
q5_select_box = st.selectbox(q5, options=q5_options)
|
197 |
+
clear_out([5, 6, 7, 8, 9])
|
198 |
+
q5_out[q5_options.index(q5_select_box)] = 1
|
199 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5
|
200 |
+
|
201 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
202 |
+
clear_out([6, 7, 8, 9])
|
203 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
204 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6
|
205 |
+
|
206 |
+
if q6_select_box == 'Yes':
|
207 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
208 |
+
clear_out([7, 8, 9])
|
209 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
210 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6 - 8 - End
|
211 |
+
else:
|
212 |
+
q5_select_box = st.selectbox(q5, options=q5_options)
|
213 |
+
clear_out([5, 6, 7, 8, 9, 10, 11])
|
214 |
+
q5_out[q5_options.index(q5_select_box)] = 1
|
215 |
+
# 1 - 2 - 3 - 4 - 5
|
216 |
+
|
217 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
218 |
+
clear_out([6, 7, 8, 9, 10, 11])
|
219 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
220 |
+
# 1 - 2 - 3 - 4 - 5 - 6
|
221 |
+
|
222 |
+
if q6_select_box == 'Yes':
|
223 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
224 |
+
clear_out([7, 8, 9, 10, 11])
|
225 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
226 |
+
# 1 - 2 - 3 - 4 - 5 - 6 - 8 - End
|
227 |
+
|
228 |
+
labels = [*q1_out, *q2_out, *q3_out, *q4_out, *q5_out, *q6_out, *q7_out, *q8_out, *q9_out, *q10_out, *q11_out]
|
229 |
+
labels = torch.Tensor(labels).to(device)
|
230 |
+
labels = labels.unsqueeze(0).repeat(bs, 1)
|
231 |
+
labels = make_galaxy_labels_hierarchical(labels)
|
232 |
+
clear_out()
|
233 |
+
# ========================== Labels ================================
|
234 |
+
|
235 |
+
st.subheader('Noise')
|
236 |
+
st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
|
237 |
+
change_eps = st.button('Change eps')
|
238 |
+
if change_eps:
|
239 |
+
eps = get_eps(bs)
|
240 |
+
|
241 |
+
with torch.no_grad():
|
242 |
+
imgs = model(eps, labels)
|
243 |
+
|
244 |
+
if upsample:
|
245 |
+
imgs = F.interpolate(imgs, (size * 4, size * 4), mode='bicubic')
|
246 |
+
|
247 |
+
imgs = [(imgs[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
|
248 |
+
|
249 |
+
counter = 0
|
250 |
+
for r in range(bs // n_cols):
|
251 |
+
cols = st.columns(n_cols)
|
252 |
+
|
253 |
+
for c in range(n_cols):
|
254 |
+
cols[c].image(imgs[counter])
|
255 |
+
counter += 1
|
src/app/explore_infoscc_gan.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import src.app.params as params
|
11 |
+
from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4, q4_options, q5, q5_options, \
|
12 |
+
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
13 |
+
from src.models import ConditionalGenerator
|
14 |
+
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
15 |
+
from src.utils import download_file, sample_labels
|
16 |
+
|
17 |
+
# global parameters
|
18 |
+
device = params.device
|
19 |
+
size = params.size
|
20 |
+
y_size = params.shape_label
|
21 |
+
n_channels = params.n_channels
|
22 |
+
upsample = params.upsample
|
23 |
+
z_size = noise_dim = params.noise_dim
|
24 |
+
n_layers = int(math.log2(size) - 2)
|
25 |
+
n_basis = params.n_basis
|
26 |
+
y_type = params.y_type
|
27 |
+
bs = 16 # number of samples to generate
|
28 |
+
n_cols = int(math.sqrt(bs))
|
29 |
+
model_path = params.path_infoscc_gan # path to the model
|
30 |
+
drive_id = params.drive_id_infoscc_gan # google drive id of the model
|
31 |
+
path_labels = params.path_labels
|
32 |
+
|
33 |
+
# manual labels
|
34 |
+
q1_out = [0] * len(q1_options)
|
35 |
+
q2_out = [0] * len(q2_options)
|
36 |
+
q3_out = [0] * len(q3_options)
|
37 |
+
q4_out = [0] * len(q4_options)
|
38 |
+
q5_out = [0] * len(q5_options)
|
39 |
+
q6_out = [0] * len(q6_options)
|
40 |
+
q7_out = [0] * len(q7_options)
|
41 |
+
q8_out = [0] * len(q8_options)
|
42 |
+
q9_out = [0] * len(q9_options)
|
43 |
+
q10_out = [0] * len(q10_options)
|
44 |
+
q11_out = [0] * len(q11_options)
|
45 |
+
|
46 |
+
|
47 |
+
def clear_out(elems=None):
|
48 |
+
global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
|
49 |
+
|
50 |
+
if elems is None:
|
51 |
+
elems = list(range(1, 12))
|
52 |
+
|
53 |
+
if 1 in elems:
|
54 |
+
q1_out = [0] * len(q1_options)
|
55 |
+
if 2 in elems:
|
56 |
+
q2_out = [0] * len(q2_options)
|
57 |
+
if 3 in elems:
|
58 |
+
q3_out = [0] * len(q3_options)
|
59 |
+
if 4 in elems:
|
60 |
+
q4_out = [0] * len(q4_options)
|
61 |
+
if 5 in elems:
|
62 |
+
q5_out = [0] * len(q5_options)
|
63 |
+
if 6 in elems:
|
64 |
+
q6_out = [0] * len(q6_options)
|
65 |
+
if 7 in elems:
|
66 |
+
q7_out = [0] * len(q7_options)
|
67 |
+
if 8 in elems:
|
68 |
+
q8_out = [0] * len(q8_options)
|
69 |
+
if 9 in elems:
|
70 |
+
q9_out = [0] * len(q9_options)
|
71 |
+
if 10 in elems:
|
72 |
+
q10_out = [0] * len(q10_options)
|
73 |
+
if 11 in elems:
|
74 |
+
q11_out = [0] * len(q11_options)
|
75 |
+
|
76 |
+
|
77 |
+
@st.cache(allow_output_mutation=True)
|
78 |
+
def load_model(model_path: str) -> ConditionalGenerator:
|
79 |
+
|
80 |
+
print(f'Loading model: {model_path}')
|
81 |
+
g_ema = ConditionalGenerator(size, y_size, z_size, n_channels, n_basis, noise_dim)
|
82 |
+
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
83 |
+
g_ema.load_state_dict(ckpt['g_ema'])
|
84 |
+
g_ema.eval().to(device)
|
85 |
+
return g_ema
|
86 |
+
|
87 |
+
|
88 |
+
@st.cache
|
89 |
+
def get_labels() -> torch.Tensor:
|
90 |
+
if not Path(path_labels).exists():
|
91 |
+
download_file(params.drive_id_labels, path_labels)
|
92 |
+
labels_train = get_labels_train(path_labels)
|
93 |
+
return labels_train
|
94 |
+
|
95 |
+
|
96 |
+
def app():
|
97 |
+
global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
|
98 |
+
|
99 |
+
st.title('Explore InfoSCC-GAN')
|
100 |
+
st.markdown('This demo shows InfoSCC-GAN for conditional galaxy generation')
|
101 |
+
st.subheader(r'<- Use sidebar to explore $z_1, ..., z_k$ latent variables')
|
102 |
+
|
103 |
+
if not Path(model_path).exists():
|
104 |
+
download_file(drive_id, model_path)
|
105 |
+
|
106 |
+
model = load_model(model_path)
|
107 |
+
eps = model.sample_eps(bs).to(device)
|
108 |
+
labels_train = get_labels()
|
109 |
+
|
110 |
+
# get zs
|
111 |
+
zs = np.array([[0.0] * n_basis] * n_layers, dtype=np.float32)
|
112 |
+
|
113 |
+
for l in range(n_layers):
|
114 |
+
st.sidebar.markdown(f'## Layer: {l}')
|
115 |
+
for d in range(n_basis):
|
116 |
+
zs[l][d] = st.sidebar.slider(f'Dimension: {d}', key=f'{l}{d}',
|
117 |
+
min_value=-5., max_value=5., value=0., step=0.1)
|
118 |
+
|
119 |
+
# ========================== Labels ================================
|
120 |
+
st.subheader('Label')
|
121 |
+
st.markdown(r'There are two types of selecting labels: __Random__ - sample random samples from the dataset;'
|
122 |
+
r' __Manual__ - select labels manually (advanced use). When using __Manual__ all of the images will be'
|
123 |
+
r' generated with tha same labels')
|
124 |
+
label_type = st.radio('Label type', options=['Random', 'Manual (Advanced)'])
|
125 |
+
if label_type == 'Random':
|
126 |
+
labels = sample_labels(labels_train, bs).to(device)
|
127 |
+
|
128 |
+
st.markdown(r'Click on __Sample labels__ button to sample random input labels')
|
129 |
+
change_label = st.button('Sample label')
|
130 |
+
|
131 |
+
if change_label:
|
132 |
+
labels = sample_labels(labels_train, bs).to(device)
|
133 |
+
elif label_type == 'Manual (Advanced)':
|
134 |
+
st.markdown('Answer the questions below')
|
135 |
+
|
136 |
+
q1_select_box = st.selectbox(q1, options=q1_options)
|
137 |
+
clear_out()
|
138 |
+
q1_out[q1_options.index(q1_select_box)] = 1
|
139 |
+
# 1
|
140 |
+
|
141 |
+
if q1_select_box == 'Smooth':
|
142 |
+
q7_select_box = st.selectbox(q7, options=q7_options)
|
143 |
+
clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
144 |
+
q7_out[q7_options.index(q7_select_box)] = 1
|
145 |
+
# 1 - 7
|
146 |
+
|
147 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
148 |
+
clear_out([2, 3, 4, 5, 6, 8, 9, 10, 11])
|
149 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
150 |
+
# 1 - 7 - 6
|
151 |
+
|
152 |
+
if q6_select_box == 'Yes':
|
153 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
154 |
+
clear_out([2, 3, 4, 5, 8, 9, 10, 11])
|
155 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
156 |
+
# 1 - 7 - 6 - 8 - end
|
157 |
+
|
158 |
+
elif q1_select_box == 'Features or disk':
|
159 |
+
q2_select_box = st.selectbox(q2, options=q2_options)
|
160 |
+
clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
161 |
+
q2_out[q2_options.index(q2_select_box)] = 1
|
162 |
+
# 1 - 2
|
163 |
+
|
164 |
+
if q2_select_box == 'Yes':
|
165 |
+
q9_select_box = st.selectbox(q9, options=q9_options)
|
166 |
+
clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
|
167 |
+
q9_out[q9_options.index(q9_select_box)] = 1
|
168 |
+
# 1 - 2 - 9
|
169 |
+
|
170 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
171 |
+
clear_out([3, 4, 5, 6, 7, 8, 10, 11])
|
172 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
173 |
+
# 1 - 2 - 9 - 6
|
174 |
+
|
175 |
+
if q6_select_box == 'Yes':
|
176 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
177 |
+
clear_out([3, 4, 5, 7, 8, 10, 11])
|
178 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
179 |
+
# 1 - 2 - 9 - 6 - 8
|
180 |
+
else:
|
181 |
+
q3_select_box = st.selectbox(q3, options=q3_options)
|
182 |
+
clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
|
183 |
+
q3_out[q3_options.index(q3_select_box)] = 1
|
184 |
+
# 1 - 2 - 3
|
185 |
+
|
186 |
+
q4_select_box = st.selectbox(q4, options=q4_options)
|
187 |
+
clear_out([4, 5, 6, 7, 8, 9, 10, 11])
|
188 |
+
q4_out[q4_options.index(q4_select_box)] = 1
|
189 |
+
# 1 - 2 - 3 - 4
|
190 |
+
|
191 |
+
if q4_select_box == 'Yes':
|
192 |
+
q10_select_box = st.selectbox(q10, options=q10_options)
|
193 |
+
clear_out([5, 6, 7, 8, 9, 10, 11])
|
194 |
+
q10_out[q10_options.index(q10_select_box)] = 1
|
195 |
+
# 1 - 2 - 3 - 4 - 10
|
196 |
+
|
197 |
+
q11_select_box = st.selectbox(q11, options=q11_options)
|
198 |
+
clear_out([5, 6, 7, 8, 9, 11])
|
199 |
+
q11_out[q11_options.index(q11_select_box)] = 1
|
200 |
+
# 1 - 2 - 3 - 4 - 10 - 11
|
201 |
+
|
202 |
+
q5_select_box = st.selectbox(q5, options=q5_options)
|
203 |
+
clear_out([5, 6, 7, 8, 9])
|
204 |
+
q5_out[q5_options.index(q5_select_box)] = 1
|
205 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5
|
206 |
+
|
207 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
208 |
+
clear_out([6, 7, 8, 9])
|
209 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
210 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6
|
211 |
+
|
212 |
+
if q6_select_box == 'Yes':
|
213 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
214 |
+
clear_out([7, 8, 9])
|
215 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
216 |
+
# 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6 - 8 - End
|
217 |
+
else:
|
218 |
+
q5_select_box = st.selectbox(q5, options=q5_options)
|
219 |
+
clear_out([5, 6, 7, 8, 9, 10, 11])
|
220 |
+
q5_out[q5_options.index(q5_select_box)] = 1
|
221 |
+
# 1 - 2 - 3 - 4 - 5
|
222 |
+
|
223 |
+
q6_select_box = st.selectbox(q6, options=q6_options)
|
224 |
+
clear_out([6, 7, 8, 9, 10, 11])
|
225 |
+
q6_out[q6_options.index(q6_select_box)] = 1
|
226 |
+
# 1 - 2 - 3 - 4 - 5 - 6
|
227 |
+
|
228 |
+
if q6_select_box == 'Yes':
|
229 |
+
q8_select_box = st.selectbox(q8, options=q8_options)
|
230 |
+
clear_out([7, 8, 9, 10, 11])
|
231 |
+
q8_out[q8_options.index(q8_select_box)] = 1
|
232 |
+
# 1 - 2 - 3 - 4 - 5 - 6 - 8 - End
|
233 |
+
|
234 |
+
labels = [*q1_out, *q2_out, *q3_out, *q4_out, *q5_out, *q6_out, *q7_out, *q8_out, *q9_out, *q10_out, *q11_out]
|
235 |
+
labels = torch.Tensor(labels).to(device)
|
236 |
+
labels = labels.unsqueeze(0).repeat(bs, 1)
|
237 |
+
labels = make_galaxy_labels_hierarchical(labels)
|
238 |
+
clear_out()
|
239 |
+
# ========================== Labels ================================
|
240 |
+
|
241 |
+
st.subheader('Noise')
|
242 |
+
st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
|
243 |
+
change_eps = st.button('Change eps')
|
244 |
+
if change_eps:
|
245 |
+
eps = model.sample_eps(bs).to(device)
|
246 |
+
|
247 |
+
zs_torch = torch.from_numpy(zs).unsqueeze(0).repeat(bs, 1, 1).to(device)
|
248 |
+
|
249 |
+
with torch.no_grad():
|
250 |
+
imgs = model(labels, eps, zs_torch).squeeze(0).cpu()
|
251 |
+
|
252 |
+
if upsample:
|
253 |
+
imgs = F.interpolate(imgs, (size * 4, size * 4), mode='bicubic')
|
254 |
+
|
255 |
+
imgs = [(imgs[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
|
256 |
+
|
257 |
+
counter = 0
|
258 |
+
for r in range(bs // n_cols):
|
259 |
+
cols = st.columns(n_cols)
|
260 |
+
|
261 |
+
for c in range(n_cols):
|
262 |
+
cols[c].image(imgs[counter])
|
263 |
+
counter += 1
|
src/app/multipage.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is the framework for generating multiple Streamlit applications
|
3 |
+
through an object oriented framework.
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Import necessary libraries
|
7 |
+
import streamlit as st
|
8 |
+
|
9 |
+
|
10 |
+
# Define the multipage class to manage the multiple apps in our program
|
11 |
+
class MultiPage:
|
12 |
+
"""Framework for combining multiple streamlit applications."""
|
13 |
+
|
14 |
+
def __init__(self) -> None:
|
15 |
+
"""Constructor class to generate a list which will store all our applications as an instance variable."""
|
16 |
+
self.pages = []
|
17 |
+
|
18 |
+
def add_page(self, title, func) -> None:
|
19 |
+
"""Class Method to Add pages to the project
|
20 |
+
Args:
|
21 |
+
title ([str]): The title of page which we are adding to the list of apps
|
22 |
+
|
23 |
+
func: Python function to render this page in Streamlit
|
24 |
+
"""
|
25 |
+
|
26 |
+
self.pages.append({
|
27 |
+
|
28 |
+
"title": title,
|
29 |
+
"function": func
|
30 |
+
})
|
31 |
+
|
32 |
+
def run(self):
|
33 |
+
# Drodown to select the page to run
|
34 |
+
page = st.sidebar.selectbox(
|
35 |
+
'App Navigation',
|
36 |
+
self.pages,
|
37 |
+
format_func=lambda page: page['title']
|
38 |
+
)
|
39 |
+
|
40 |
+
# run the app function
|
41 |
+
page['function']()
|
src/app/params.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains list of global parameters for the Galaxy Zoo generation app
|
3 |
+
"""
|
4 |
+
|
5 |
+
device = 'cpu'
|
6 |
+
size = 64 # generated image size
|
7 |
+
shape_label = 37 # shape of the input label
|
8 |
+
n_channels = 3 # number of color channels in image
|
9 |
+
upsample = True # if true, generated images will be upsampled
|
10 |
+
noise_dim = 512 # noise size in InfoSCC-GAN
|
11 |
+
n_basis = 6 # size of additional z vectors in InfoSCC-GAN
|
12 |
+
y_type = 'real' # type of labels in InfoSCC-GAN
|
13 |
+
dim_z = 128 # z vector size in BigGAN and cVAE
|
14 |
+
|
15 |
+
path_infoscc_gan = './models/InfoSCC-GAN/generator.pt'
|
16 |
+
drive_id_infoscc_gan = '1_kIujc497OH0ZJ7PNPwS5_otNlS7jMLI'
|
17 |
+
|
18 |
+
path_biggan = './models/BigGAN/generator.pth'
|
19 |
+
drive_id_biggan = '1sMSDdnQ5GjHcno5knHTDSKAKhhoHh_4z'
|
20 |
+
|
21 |
+
path_cvae = './models/CVAE/generator.pth'
|
22 |
+
drive_id_cvae = '17FmLvhwXq8PQMrD1CtjqyoAy5BobYMTE'
|
23 |
+
|
24 |
+
path_labels = './data/training_solutions_rev1.csv'
|
25 |
+
drive_id_labels = '1dzsB_HdGtmSHE4pCppamISpFaJBfPF7E'
|
src/app/questions.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains question and options for the manual labeling
|
3 |
+
"""
|
4 |
+
|
5 |
+
q1 = 'Is the object a smooth galaxy, a galaxy with features/disk or a star?'
|
6 |
+
q1_options = ['Smooth', 'Features or disk', 'Star or artifact']
|
7 |
+
|
8 |
+
q2 = 'Is it edge-on? '
|
9 |
+
q2_options = ['Yes', 'No']
|
10 |
+
|
11 |
+
q3 = 'Is there a bar?'
|
12 |
+
q3_options = ['Yes', 'No']
|
13 |
+
|
14 |
+
q4 = 'Is there a spiral pattern?'
|
15 |
+
q4_options = ['Yes', 'No']
|
16 |
+
|
17 |
+
q5 = 'How prominent is the central bulge?'
|
18 |
+
q5_options = ['No bulge', 'Just noticeable', 'Obvious', 'Dominant']
|
19 |
+
|
20 |
+
q6 = 'Is there anything "odd" about the galaxy?'
|
21 |
+
q6_options = ['Yes', 'No']
|
22 |
+
|
23 |
+
q7 = 'How round is the smooth galaxy?'
|
24 |
+
q7_options = ['Completely round', 'In between', 'Cigar-shaped']
|
25 |
+
|
26 |
+
q8 = 'What is the odd feature?'
|
27 |
+
q8_options = ['Ring', 'Lens or are', 'Disturbed', 'Irregular', 'Other', 'Merger', 'Dust lane']
|
28 |
+
|
29 |
+
q9 = 'What shape is the bulge in the edge-on galaxy?'
|
30 |
+
q9_options = ['Rounded', 'Boxy', 'No bulge']
|
31 |
+
|
32 |
+
q10 = 'How tightly wound are the spiral arms?'
|
33 |
+
q10_options = ['Tight', 'Medium', 'Loose']
|
34 |
+
|
35 |
+
q11 = 'How many spiral arms are there?'
|
36 |
+
q11_options = ['1', '2', '3', '4', 'more than four', 'can`t tell']
|
src/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .data import get_labels_train
|
2 |
+
from .labels import make_galaxy_labels_hierarchical
|
src/data/data.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pandas import read_csv
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def get_labels_train(file_galaxy_labels) -> torch.Tensor:
|
7 |
+
df_galaxy_labels = read_csv(file_galaxy_labels)
|
8 |
+
labels_train = df_galaxy_labels[df_galaxy_labels.columns[1:]].values
|
9 |
+
labels_train = torch.from_numpy(labels_train).float()
|
10 |
+
return labels_train
|
src/data/labels.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class_groups = {
|
7 |
+
# group : indices (assuming 0th position is id)
|
8 |
+
0: (),
|
9 |
+
1: (1, 2, 3),
|
10 |
+
2: (4, 5),
|
11 |
+
3: (6, 7),
|
12 |
+
4: (8, 9),
|
13 |
+
5: (10, 11, 12, 13),
|
14 |
+
6: (14, 15),
|
15 |
+
7: (16, 17, 18),
|
16 |
+
8: (19, 20, 21, 22, 23, 24, 25),
|
17 |
+
9: (26, 27, 28),
|
18 |
+
10: (29, 30, 31),
|
19 |
+
11: (32, 33, 34, 35, 36, 37),
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class_groups_indices = {g: np.array(ixs)-1 for g, ixs in class_groups.items()}
|
24 |
+
|
25 |
+
|
26 |
+
hierarchy = {
|
27 |
+
# group : parent (group, label)
|
28 |
+
2: (1, 1),
|
29 |
+
3: (2, 1),
|
30 |
+
4: (2, 1),
|
31 |
+
5: (2, 1),
|
32 |
+
7: (1, 0),
|
33 |
+
8: (6, 0),
|
34 |
+
9: (2, 0),
|
35 |
+
10: (4, 0),
|
36 |
+
11: (4, 0),
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def make_galaxy_labels_hierarchical(labels: torch.Tensor) -> torch.Tensor:
|
41 |
+
""" transform groups of galaxy label probabilities to follow the hierarchical order defined in galaxy zoo
|
42 |
+
more info here: https://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge/overview/the-galaxy-zoo-decision-tree
|
43 |
+
labels is a NxL torch tensor, where N is the batch size and L is the number of labels,
|
44 |
+
all labels should be > 1
|
45 |
+
the indices of label groups are listed in class_groups_indices
|
46 |
+
|
47 |
+
Return
|
48 |
+
------
|
49 |
+
hierarchical_labels : NxL torch tensor, where L is the total number of labels
|
50 |
+
"""
|
51 |
+
shift = labels.shape[1] > 37 ## in case the id is included at 0th position, shift indices accordingly
|
52 |
+
index = lambda i: class_groups_indices[i] + shift
|
53 |
+
|
54 |
+
for i in range(1, 12):
|
55 |
+
## normalize probabilities to 1
|
56 |
+
norm = torch.sum(labels[:, index(i)], dim=1, keepdims=True)
|
57 |
+
norm[norm == 0] += 1e-4 ## add small number to prevent NaNs dividing by zero, yet keep track of gradient
|
58 |
+
labels[:, index(i)] /= norm
|
59 |
+
## renormalize according to hierarchical structure
|
60 |
+
if i not in [1, 6]:
|
61 |
+
parent_group_label = labels[:, index(hierarchy[i][0])]
|
62 |
+
labels[:, index(i)] *= parent_group_label[:, hierarchy[i][1]].unsqueeze(-1)
|
63 |
+
return labels
|
src/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .infoscc_gan import ConditionalGenerator
|
2 |
+
from .cvae import ConditionalDecoder
|
src/models/big/BigGAN2.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import init
|
6 |
+
import torch.optim as optim
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
import src.models.big.layers as layers
|
10 |
+
from src.models.parameter import labels_dim, parameter
|
11 |
+
from src.models.neuralnetwork import NeuralNetwork
|
12 |
+
|
13 |
+
|
14 |
+
# Architectures for G
|
15 |
+
# Attention is passed in in the format '32_64' to mean applying an attention
|
16 |
+
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
|
17 |
+
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
|
18 |
+
arch = {}
|
19 |
+
arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
|
20 |
+
'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
|
21 |
+
'upsample' : [True] * 7,
|
22 |
+
'resolution' : [8, 16, 32, 64, 128, 256, 512],
|
23 |
+
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
|
24 |
+
for i in range(3,10)}}
|
25 |
+
arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]],
|
26 |
+
'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]],
|
27 |
+
'upsample' : [True] * 6,
|
28 |
+
'resolution' : [8, 16, 32, 64, 128, 256],
|
29 |
+
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
|
30 |
+
for i in range(3,9)}}
|
31 |
+
arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]],
|
32 |
+
'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],
|
33 |
+
'upsample' : [True] * 5,
|
34 |
+
'resolution' : [8, 16, 32, 64, 128],
|
35 |
+
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
|
36 |
+
for i in range(3,8)}}
|
37 |
+
arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]],
|
38 |
+
'out_channels' : [ch * item for item in [16, 8, 4, 2]],
|
39 |
+
'upsample' : [True] * 4,
|
40 |
+
'resolution' : [8, 16, 32, 64],
|
41 |
+
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
|
42 |
+
for i in range(3,7)}}
|
43 |
+
arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]],
|
44 |
+
'out_channels' : [ch * item for item in [4, 4, 4]],
|
45 |
+
'upsample' : [True] * 3,
|
46 |
+
'resolution' : [8, 16, 32],
|
47 |
+
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
|
48 |
+
for i in range(3,6)}}
|
49 |
+
|
50 |
+
return arch
|
51 |
+
|
52 |
+
class Generator(NeuralNetwork):
|
53 |
+
def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=64, labels_dim=labels_dim,
|
54 |
+
G_kernel_size=3, G_attn='64', n_classes=1,
|
55 |
+
num_G_SVs=1, num_G_SV_itrs=1,
|
56 |
+
G_shared=True, shared_dim=0, hier=False,
|
57 |
+
cross_replica=False, mybn=False,
|
58 |
+
G_activation=nn.ReLU(inplace=False),
|
59 |
+
G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
|
60 |
+
BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
|
61 |
+
G_init='ortho', skip_init=False, no_optim=False,
|
62 |
+
G_param='SN', norm_style='bn',
|
63 |
+
**kwargs):
|
64 |
+
super(Generator, self).__init__()
|
65 |
+
# Channel width mulitplier
|
66 |
+
self.ch = G_ch
|
67 |
+
# Dimensionality of the latent space
|
68 |
+
self.dim_z = dim_z
|
69 |
+
# The initial spatial dimensions
|
70 |
+
self.bottom_width = bottom_width
|
71 |
+
# Resolution of the output
|
72 |
+
self.resolution = resolution
|
73 |
+
# Kernel size?
|
74 |
+
self.kernel_size = G_kernel_size
|
75 |
+
# Attention?
|
76 |
+
self.attention = G_attn
|
77 |
+
# number of classes, for use in categorical conditional generation
|
78 |
+
self.n_classes = n_classes
|
79 |
+
# Use shared embeddings?
|
80 |
+
self.G_shared = G_shared
|
81 |
+
# Dimensionality of the shared embedding? Unused if not using G_shared
|
82 |
+
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
|
83 |
+
# Hierarchical latent space?
|
84 |
+
self.hier = hier
|
85 |
+
# Cross replica batchnorm?
|
86 |
+
self.cross_replica = cross_replica
|
87 |
+
# Use my batchnorm?
|
88 |
+
self.mybn = mybn
|
89 |
+
# nonlinearity for residual blocks
|
90 |
+
self.activation = G_activation
|
91 |
+
# Initialization style
|
92 |
+
self.init = G_init
|
93 |
+
# Parameterization style
|
94 |
+
self.G_param = G_param
|
95 |
+
# Normalization style
|
96 |
+
self.norm_style = norm_style
|
97 |
+
# Epsilon for BatchNorm?
|
98 |
+
self.BN_eps = BN_eps
|
99 |
+
# Epsilon for Spectral Norm?
|
100 |
+
self.SN_eps = SN_eps
|
101 |
+
# fp16?
|
102 |
+
self.fp16 = G_fp16
|
103 |
+
# Architecture dict
|
104 |
+
self.arch = G_arch(self.ch, self.attention)[resolution]
|
105 |
+
|
106 |
+
# If using hierarchical latents, adjust z
|
107 |
+
if self.hier:
|
108 |
+
# Number of places z slots into
|
109 |
+
self.num_slots = len(self.arch['in_channels']) + 1
|
110 |
+
self.z_chunk_size = (self.dim_z // self.num_slots)
|
111 |
+
# Recalculate latent dimensionality for even splitting into chunks
|
112 |
+
self.dim_z = self.z_chunk_size * self.num_slots
|
113 |
+
else:
|
114 |
+
self.num_slots = 1
|
115 |
+
self.z_chunk_size = 0
|
116 |
+
|
117 |
+
# Which convs, batchnorms, and linear layers to use
|
118 |
+
if self.G_param == 'SN':
|
119 |
+
self.which_conv = functools.partial(layers.SNConv2d,
|
120 |
+
kernel_size=3, padding=1,
|
121 |
+
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
122 |
+
eps=self.SN_eps)
|
123 |
+
self.which_linear = functools.partial(layers.SNLinear,
|
124 |
+
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
125 |
+
eps=self.SN_eps)
|
126 |
+
else:
|
127 |
+
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
128 |
+
self.which_linear = nn.Linear
|
129 |
+
|
130 |
+
# We use a non-spectral-normed embedding here regardless;
|
131 |
+
# For some reason applying SN to G's embedding seems to randomly cripple G
|
132 |
+
self.which_embedding = nn.Embedding
|
133 |
+
bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
|
134 |
+
else self.which_embedding)
|
135 |
+
self.which_bn = functools.partial(layers.ccbn,
|
136 |
+
which_linear=bn_linear,
|
137 |
+
cross_replica=self.cross_replica,
|
138 |
+
mybn=self.mybn,
|
139 |
+
input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
|
140 |
+
else self.n_classes),
|
141 |
+
norm_style=self.norm_style,
|
142 |
+
eps=self.BN_eps)
|
143 |
+
|
144 |
+
|
145 |
+
# Prepare model
|
146 |
+
# prepare label input
|
147 |
+
self.transform_label_layer = torch.nn.Linear(labels_dim, 128)
|
148 |
+
# If not using shared embeddings, self.shared is just a passthrough
|
149 |
+
self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
|
150 |
+
else layers.identity())
|
151 |
+
# First linear layer
|
152 |
+
self.linear = self.which_linear(self.dim_z // self.num_slots,
|
153 |
+
self.arch['in_channels'][0] * (self.bottom_width **2))
|
154 |
+
|
155 |
+
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
156 |
+
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
157 |
+
# while the inner loop is over a given block
|
158 |
+
self.blocks = []
|
159 |
+
for index in range(len(self.arch['out_channels'])):
|
160 |
+
self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
|
161 |
+
out_channels=self.arch['out_channels'][index],
|
162 |
+
which_conv=self.which_conv,
|
163 |
+
which_bn=self.which_bn,
|
164 |
+
activation=self.activation,
|
165 |
+
upsample=(functools.partial(F.interpolate, scale_factor=2)
|
166 |
+
if self.arch['upsample'][index] else None))]]
|
167 |
+
|
168 |
+
# If attention on this block, attach it to the end
|
169 |
+
if self.arch['attention'][self.arch['resolution'][index]]:
|
170 |
+
print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
|
171 |
+
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
|
172 |
+
|
173 |
+
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
174 |
+
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
175 |
+
|
176 |
+
# output layer: batchnorm-relu-conv.
|
177 |
+
# Consider using a non-spectral conv here
|
178 |
+
self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
|
179 |
+
cross_replica=self.cross_replica,
|
180 |
+
mybn=self.mybn),
|
181 |
+
self.activation,
|
182 |
+
self.which_conv(self.arch['out_channels'][-1], 3))
|
183 |
+
|
184 |
+
# Initialize weights. Optionally skip init for testing.
|
185 |
+
if not skip_init:
|
186 |
+
self.init_weights()
|
187 |
+
|
188 |
+
# Set up optimizer
|
189 |
+
# If this is an EMA copy, no need for an optim, so just return now
|
190 |
+
if no_optim:
|
191 |
+
return
|
192 |
+
self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
|
193 |
+
if G_mixed_precision:
|
194 |
+
print('Using fp16 adam in G...')
|
195 |
+
import utils
|
196 |
+
self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
|
197 |
+
betas=(self.B1, self.B2), weight_decay=0,
|
198 |
+
eps=self.adam_eps)
|
199 |
+
else:
|
200 |
+
self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
|
201 |
+
betas=(self.B1, self.B2), weight_decay=0,
|
202 |
+
eps=self.adam_eps)
|
203 |
+
|
204 |
+
# LR scheduling, left here for forward compatibility
|
205 |
+
# self.lr_sched = {'itr' : 0}# if self.progressive else {}
|
206 |
+
# self.j = 0
|
207 |
+
self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)
|
208 |
+
|
209 |
+
# Initialize
|
210 |
+
def init_weights(self):
|
211 |
+
self.param_count = 0
|
212 |
+
for module in self.modules():
|
213 |
+
if (isinstance(module, nn.Conv2d)
|
214 |
+
or isinstance(module, nn.Linear)
|
215 |
+
or isinstance(module, nn.Embedding)):
|
216 |
+
if self.init == 'ortho':
|
217 |
+
init.orthogonal_(module.weight)
|
218 |
+
elif self.init == 'N02':
|
219 |
+
init.normal_(module.weight, 0, 0.02)
|
220 |
+
elif self.init in ['glorot', 'xavier']:
|
221 |
+
init.xavier_uniform_(module.weight)
|
222 |
+
else:
|
223 |
+
print('Init style not recognized...')
|
224 |
+
self.param_count += sum([p.data.nelement() for p in module.parameters()])
|
225 |
+
print('Param count for G''s initialized parameters: %d' % self.param_count)
|
226 |
+
|
227 |
+
|
228 |
+
def transform_labels(self, labels):
|
229 |
+
""" prepore labels for input to generator """
|
230 |
+
return self.transform_label_layer(labels)
|
231 |
+
|
232 |
+
|
233 |
+
# Note on this forward function: we pass in a y vector which has
|
234 |
+
# already been passed through G.shared to enable easy class-wise
|
235 |
+
# interpolation later. If we passed in the one-hot and then ran it through
|
236 |
+
# G.shared in this forward function, it would be harder to handle.
|
237 |
+
def forward(self, z, y):
|
238 |
+
# If hierarchical, concatenate zs and ys
|
239 |
+
y = self.transform_labels(y)
|
240 |
+
if self.hier:
|
241 |
+
zs = torch.split(z, self.z_chunk_size, 1)
|
242 |
+
z = zs[0]
|
243 |
+
ys = [torch.cat([y, item], 1) for item in zs[1:]]
|
244 |
+
else:
|
245 |
+
ys = [y] * len(self.blocks)
|
246 |
+
|
247 |
+
# First linear layer
|
248 |
+
h = self.linear(z)
|
249 |
+
# Reshape
|
250 |
+
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
|
251 |
+
|
252 |
+
# Loop over blocks
|
253 |
+
for index, blocklist in enumerate(self.blocks):
|
254 |
+
# Second inner loop in case block has multiple layers
|
255 |
+
for block in blocklist:
|
256 |
+
h = block(h, ys[index])
|
257 |
+
|
258 |
+
# Apply batchnorm-relu-conv-tanh at output
|
259 |
+
return torch.sigmoid(self.output_layer(h))
|
260 |
+
# return torch.tanh(self.output_layer(h))
|
261 |
+
|
262 |
+
|
263 |
+
# Discriminator architecture, same paradigm as G's above
|
264 |
+
def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
|
265 |
+
arch = {}
|
266 |
+
arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
|
267 |
+
'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
|
268 |
+
'downsample' : [True] * 6 + [False],
|
269 |
+
'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
|
270 |
+
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
271 |
+
for i in range(2,8)}}
|
272 |
+
arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]],
|
273 |
+
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
274 |
+
'downsample' : [True] * 5 + [False],
|
275 |
+
'resolution' : [64, 32, 16, 8, 4, 4],
|
276 |
+
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
277 |
+
for i in range(2,8)}}
|
278 |
+
arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]],
|
279 |
+
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
|
280 |
+
'downsample' : [True] * 4 + [False],
|
281 |
+
'resolution' : [32, 16, 8, 4, 4],
|
282 |
+
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
283 |
+
for i in range(2,7)}}
|
284 |
+
arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]],
|
285 |
+
'out_channels' : [item * ch for item in [4, 4, 4, 4]],
|
286 |
+
'downsample' : [True, True, False, False],
|
287 |
+
'resolution' : [16, 16, 16, 16],
|
288 |
+
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
|
289 |
+
for i in range(2,6)}}
|
290 |
+
return arch
|
291 |
+
|
292 |
+
class Discriminator(NeuralNetwork):
|
293 |
+
|
294 |
+
def __init__(self, D_ch=64, D_wide=True, resolution=64, labels_dim=labels_dim,
|
295 |
+
D_kernel_size=3, D_attn='64', n_classes=1,
|
296 |
+
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
|
297 |
+
D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
|
298 |
+
SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
|
299 |
+
D_init='ortho', skip_init=False, D_param='SN', **kwargs):
|
300 |
+
super(Discriminator, self).__init__()
|
301 |
+
# Width multiplier
|
302 |
+
self.ch = D_ch
|
303 |
+
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
|
304 |
+
self.D_wide = D_wide
|
305 |
+
# Resolution
|
306 |
+
self.resolution = resolution
|
307 |
+
# Kernel size
|
308 |
+
self.kernel_size = D_kernel_size
|
309 |
+
# Attention?
|
310 |
+
self.attention = D_attn
|
311 |
+
# Number of classes
|
312 |
+
self.n_classes = n_classes
|
313 |
+
# Activation
|
314 |
+
self.activation = D_activation
|
315 |
+
# Initialization style
|
316 |
+
self.init = D_init
|
317 |
+
# Parameterization style
|
318 |
+
self.D_param = D_param
|
319 |
+
# Epsilon for Spectral Norm?
|
320 |
+
self.SN_eps = SN_eps
|
321 |
+
# Fp16?
|
322 |
+
self.fp16 = D_fp16
|
323 |
+
# Architecture
|
324 |
+
self.arch = D_arch(self.ch, self.attention)[resolution]
|
325 |
+
|
326 |
+
# Which convs, batchnorms, and linear layers to use
|
327 |
+
# No option to turn off SN in D right now
|
328 |
+
if self.D_param == 'SN':
|
329 |
+
self.which_conv = functools.partial(layers.SNConv2d,
|
330 |
+
kernel_size=3, padding=1,
|
331 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
332 |
+
eps=self.SN_eps)
|
333 |
+
self.which_linear = functools.partial(layers.SNLinear,
|
334 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
335 |
+
eps=self.SN_eps)
|
336 |
+
self.which_embedding = functools.partial(layers.SNEmbedding,
|
337 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
338 |
+
eps=self.SN_eps)
|
339 |
+
# Prepare model
|
340 |
+
# prepare label input
|
341 |
+
self.transform_label_layer = torch.nn.Linear(labels_dim, 1024)
|
342 |
+
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
343 |
+
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
344 |
+
self.blocks = []
|
345 |
+
for index in range(len(self.arch['out_channels'])):
|
346 |
+
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
|
347 |
+
out_channels=self.arch['out_channels'][index],
|
348 |
+
which_conv=self.which_conv,
|
349 |
+
wide=self.D_wide,
|
350 |
+
activation=self.activation,
|
351 |
+
preactivation=(index > 0),
|
352 |
+
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
|
353 |
+
# If attention on this block, attach it to the end
|
354 |
+
if self.arch['attention'][self.arch['resolution'][index]]:
|
355 |
+
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
|
356 |
+
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
|
357 |
+
self.which_conv)]
|
358 |
+
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
359 |
+
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
360 |
+
# Linear output layer. The output dimension is typically 1, but may be
|
361 |
+
# larger if we're e.g. turning this into a VAE with an inference output
|
362 |
+
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
|
363 |
+
# Embedding for projection discrimination
|
364 |
+
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
|
365 |
+
|
366 |
+
# Initialize weights
|
367 |
+
if not skip_init:
|
368 |
+
self.init_weights()
|
369 |
+
|
370 |
+
# Set up optimizer
|
371 |
+
self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
|
372 |
+
if D_mixed_precision:
|
373 |
+
print('Using fp16 adam in D...')
|
374 |
+
import utils
|
375 |
+
self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
|
376 |
+
betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
|
377 |
+
else:
|
378 |
+
self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
|
379 |
+
betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
|
380 |
+
# LR scheduling, left here for forward compatibility
|
381 |
+
# self.lr_sched = {'itr' : 0}# if self.progressive else {}
|
382 |
+
# self.j = 0
|
383 |
+
self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate*3, betas=parameter.betas)
|
384 |
+
|
385 |
+
# Initialize
|
386 |
+
def init_weights(self):
|
387 |
+
self.param_count = 0
|
388 |
+
for module in self.modules():
|
389 |
+
if (isinstance(module, nn.Conv2d)
|
390 |
+
or isinstance(module, nn.Linear)
|
391 |
+
or isinstance(module, nn.Embedding)):
|
392 |
+
if self.init == 'ortho':
|
393 |
+
init.orthogonal_(module.weight)
|
394 |
+
elif self.init == 'N02':
|
395 |
+
init.normal_(module.weight, 0, 0.02)
|
396 |
+
elif self.init in ['glorot', 'xavier']:
|
397 |
+
init.xavier_uniform_(module.weight)
|
398 |
+
else:
|
399 |
+
print('Init style not recognized...')
|
400 |
+
self.param_count += sum([p.data.nelement() for p in module.parameters()])
|
401 |
+
print('Param count for D''s initialized parameters: %d' % self.param_count)
|
402 |
+
|
403 |
+
def transform_labels(self, labels):
|
404 |
+
""" prepore labels for input to discriminator """
|
405 |
+
return self.transform_label_layer(labels)
|
406 |
+
|
407 |
+
|
408 |
+
def forward(self, x, y=None):
|
409 |
+
# Stick x into h for cleaner for loops without flow control
|
410 |
+
h = x
|
411 |
+
# Loop over blocks
|
412 |
+
for index, blocklist in enumerate(self.blocks):
|
413 |
+
for block in blocklist:
|
414 |
+
h = block(h)
|
415 |
+
# Apply global sum pooling as in SN-GAN
|
416 |
+
h = torch.sum(self.activation(h), [2, 3])
|
417 |
+
# Get initial class-unconditional output
|
418 |
+
out = self.linear(h)
|
419 |
+
# Get projection of final featureset onto class vectors and add to evidence
|
420 |
+
y = self.transform_labels(y)
|
421 |
+
out = out + torch.sum(y * h, 1, keepdim=True)
|
422 |
+
# out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) ## use y = torch.tensor(0)
|
423 |
+
return out
|
424 |
+
|
425 |
+
# Parallelized G_D to minimize cross-gpu communication
|
426 |
+
# Without this, Generator outputs would get all-gathered and then rebroadcast.
|
427 |
+
class G_D(nn.Module):
|
428 |
+
def __init__(self, G, D):
|
429 |
+
super(G_D, self).__init__()
|
430 |
+
self.G = G
|
431 |
+
self.D = D
|
432 |
+
|
433 |
+
def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
|
434 |
+
split_D=False):
|
435 |
+
# If training G, enable grad tape
|
436 |
+
with torch.set_grad_enabled(train_G):
|
437 |
+
# Get Generator output given noise
|
438 |
+
G_z = self.G(z, self.G.shared(gy))
|
439 |
+
# Cast as necessary
|
440 |
+
if self.G.fp16 and not self.D.fp16:
|
441 |
+
G_z = G_z.float()
|
442 |
+
if self.D.fp16 and not self.G.fp16:
|
443 |
+
G_z = G_z.half()
|
444 |
+
# Split_D means to run D once with real data and once with fake,
|
445 |
+
# rather than concatenating along the batch dimension.
|
446 |
+
if split_D:
|
447 |
+
D_fake = self.D(G_z, gy)
|
448 |
+
if x is not None:
|
449 |
+
D_real = self.D(x, dy)
|
450 |
+
return D_fake, D_real
|
451 |
+
else:
|
452 |
+
if return_G_z:
|
453 |
+
return D_fake, G_z
|
454 |
+
else:
|
455 |
+
return D_fake
|
456 |
+
# If real data is provided, concatenate it with the Generator's output
|
457 |
+
# along the batch dimension for improved efficiency.
|
458 |
+
else:
|
459 |
+
D_input = torch.cat([G_z, x], 0) if x is not None else G_z
|
460 |
+
D_class = torch.cat([gy, dy], 0) if dy is not None else gy
|
461 |
+
# Get Discriminator output
|
462 |
+
D_out = self.D(D_input, D_class)
|
463 |
+
if x is not None:
|
464 |
+
return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
|
465 |
+
else:
|
466 |
+
if return_G_z:
|
467 |
+
return D_out, G_z
|
468 |
+
else:
|
469 |
+
return D_out
|
src/models/big/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Andy Brock
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
src/models/big/README.md
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BigGAN-PyTorch
|
2 |
+
The author's officially unofficial PyTorch BigGAN implementation.
|
3 |
+
|
4 |
+
![Dogball? Dogball!](imgs/header_image.jpg?raw=true "Dogball? Dogball!")
|
5 |
+
|
6 |
+
|
7 |
+
This repo contains code for 4-8 GPU training of BigGANs from [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://arxiv.org/abs/1809.11096) by Andrew Brock, Jeff Donahue, and Karen Simonyan.
|
8 |
+
|
9 |
+
This code is by Andy Brock and Alex Andonian.
|
10 |
+
|
11 |
+
## How To Use This Code
|
12 |
+
You will need:
|
13 |
+
|
14 |
+
- [PyTorch](https://PyTorch.org/), version 1.0.1
|
15 |
+
- tqdm, numpy, scipy, and h5py
|
16 |
+
- The ImageNet training set
|
17 |
+
|
18 |
+
First, you may optionally prepare a pre-processed HDF5 version of your target dataset for faster I/O. Following this (or not), you'll need the Inception moments needed to calculate FID. These can both be done by modifying and running
|
19 |
+
|
20 |
+
```sh
|
21 |
+
sh scripts/utils/prepare_data.sh
|
22 |
+
```
|
23 |
+
|
24 |
+
Which by default assumes your ImageNet training set is downloaded into the root folder `data` in this directory, and will prepare the cached HDF5 at 128x128 pixel resolution.
|
25 |
+
|
26 |
+
In the scripts folder, there are multiple bash scripts which will train BigGANs with different batch sizes. This code assumes you do not have access to a full TPU pod, and accordingly
|
27 |
+
spoofs mega-batches by using gradient accumulation (averaging grads over multiple minibatches, and only taking an optimizer step after N accumulations). By default, the `launch_BigGAN_bs256x8.sh` script trains a
|
28 |
+
full-sized BigGAN model with a batch size of 256 and 8 gradient accumulations, for a total batch size of 2048. On 8xV100 with full-precision training (no Tensor cores), this script takes 15 days to train to 150k iterations.
|
29 |
+
|
30 |
+
You will first need to figure out the maximum batch size your setup can support. The pre-trained models provided here were trained on 8xV100 (16GB VRAM each) which can support slightly more than the BS256 used by default.
|
31 |
+
Once you've determined this, you should modify the script so that the batch size times the number of gradient accumulations is equal to your desired total batch size (BigGAN defaults to 2048).
|
32 |
+
|
33 |
+
Note also that this script uses the `--load_in_mem` arg, which loads the entire (~64GB) I128.hdf5 file into RAM for faster data loading. If you don't have enough RAM to support this (probably 96GB+), remove this argument.
|
34 |
+
|
35 |
+
|
36 |
+
## Metrics and Sampling
|
37 |
+
![I believe I can fly!](imgs/interp_sample.jpg?raw=true "I believe I can fly!")
|
38 |
+
|
39 |
+
During training, this script will output logs with training metrics and test metrics, will save multiple copies (2 most recent and 5 highest-scoring) of the model weights/optimizer params, and will produce samples and interpolations every time it saves weights.
|
40 |
+
The logs folder contains scripts to process these logs and plot the results using MATLAB (sorry not sorry).
|
41 |
+
|
42 |
+
After training, one can use `sample.py` to produce additional samples and interpolations, test with different truncation values, batch sizes, number of standing stat accumulations, etc. See the `sample_BigGAN_bs256x8.sh` script for an example.
|
43 |
+
|
44 |
+
By default, everything is saved to weights/samples/logs/data folders which are assumed to be in the same folder as this repo.
|
45 |
+
You can point all of these to a different base folder using the `--base_root` argument, or pick specific locations for each of these with their respective arguments (e.g. `--logs_root`).
|
46 |
+
|
47 |
+
We include scripts to run BigGAN-deep, but we have not fully trained a model using them, so consider them untested. Additionally, we include scripts to run a model on CIFAR, and to run SA-GAN (with EMA) and SN-GAN on ImageNet. The SA-GAN code assumes you have 4xTitanX (or equivalent in terms of GPU RAM) and will run with a batch size of 128 and 2 gradient accumulations.
|
48 |
+
|
49 |
+
## An Important Note on Inception Metrics
|
50 |
+
This repo uses the PyTorch in-built inception network to calculate IS and FID.
|
51 |
+
These scores are different from the scores you would get using the official TF inception code, and are only for monitoring purposes!
|
52 |
+
Run sample.py on your model, with the `--sample_npz` argument, then run inception_tf13 to calculate the actual TensorFlow IS. Note that you will need to have TensorFlow 1.3 or earlier installed, as TF1.4+ breaks the original IS code.
|
53 |
+
|
54 |
+
## Pretrained models
|
55 |
+
![PyTorch Inception Score and FID](imgs/IS_FID.png)
|
56 |
+
We include two pretrained model checkpoints (with G, D, the EMA copy of G, the optimizers, and the state dict):
|
57 |
+
- The main checkpoint is for a BigGAN trained on ImageNet at 128x128, using BS256 and 8 gradient accumulations, taken just before collapse, with a TF Inception Score of 97.35 +/- 1.79: [LINK](https://drive.google.com/open?id=1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW)
|
58 |
+
- An earlier checkpoint of the first model (100k G iters), at high performance but well before collapse, which may be easier to fine-tune: [LINK](https://drive.google.com/open?id=1dmZrcVJUAWkPBGza_XgswSuT-UODXZcO)
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
Pretrained models for Places-365 coming soon.
|
63 |
+
|
64 |
+
This repo also contains scripts for porting the original TFHub BigGAN Generator weights to PyTorch. See the scripts in the TFHub folder for more details.
|
65 |
+
|
66 |
+
## Fine-tuning, Using Your Own Dataset, or Making New Training Functions
|
67 |
+
![That's deep, man](imgs/DeepSamples.png?raw=true "Deep Samples")
|
68 |
+
|
69 |
+
If you wish to resume interrupted training or fine-tune a pre-trained model, run the same launch script but with the `--resume` argument added.
|
70 |
+
Experiment names are automatically generated from the configuration, but can be overridden using the `--experiment_name` arg (for example, if you wish to fine-tune a model using modified optimizer settings).
|
71 |
+
|
72 |
+
To prep your own dataset, you will need to add it to datasets.py and modify the convenience dicts in utils.py (dset_dict, imsize_dict, root_dict, nclass_dict, classes_per_sheet_dict) to have the appropriate metadata for your dataset.
|
73 |
+
Repeat the process in prepare_data.sh (optionally produce an HDF5 preprocessed copy, and calculate the Inception Moments for FID).
|
74 |
+
|
75 |
+
By default, the training script will save the top 5 best checkpoints as measured by Inception Score.
|
76 |
+
For datasets other than ImageNet, Inception Score can be a very poor measure of quality, so you will likely want to use `--which_best FID` instead.
|
77 |
+
|
78 |
+
To use your own training function (e.g. train a BigVAE): either modify train_fns.GAN_training_function or add a new train fn and add it after the `if config['which_train_fn'] == 'GAN':` line in `train.py`.
|
79 |
+
|
80 |
+
|
81 |
+
## Neat Stuff
|
82 |
+
- We include the full training and metrics logs [here](https://drive.google.com/open?id=1ZhY9Mg2b_S4QwxNmt57aXJ9FOC3ZN1qb) for reference. I've found that one of the hardest things about re-implementing a paper can be checking if the logs line up early in training,
|
83 |
+
especially if training takes multiple weeks. Hopefully these will be helpful for future work.
|
84 |
+
- We include an accelerated FID calculation--the original scipy version can require upwards of 10 minutes to calculate the matrix sqrt, this version uses an accelerated PyTorch version to calculate it in under a second.
|
85 |
+
- We include an accelerated, low-memory consumption ortho reg implementation.
|
86 |
+
- By default, we only compute the top singular value (the spectral norm), but this code supports computing more SVs through the `--num_G_SVs` argument.
|
87 |
+
|
88 |
+
## Key Differences Between This Code And The Original BigGAN
|
89 |
+
- We use the optimizer settings from SA-GAN (G_lr=1e-4, D_lr=4e-4, num_D_steps=1, as opposed to BigGAN's G_lr=5e-5, D_lr=2e-4, num_D_steps=2).
|
90 |
+
While slightly less performant, this was the first corner we cut to bring training times down.
|
91 |
+
- By default, we do not use Cross-Replica BatchNorm (AKA Synced BatchNorm).
|
92 |
+
The two variants we tried (a custom, naive one and the one included in this repo) have slightly different gradients (albeit identical forward passes) from the built-in BatchNorm, which appear to be sufficient to cripple training.
|
93 |
+
- Gradient accumulation means that we update the SV estimates and the BN statistics 8 times more frequently. This means that the BN stats are much closer to standing stats, and that the singular value estimates tend to be more accurate.
|
94 |
+
Because of this, we measure metrics by default with G in test mode (using the BatchNorm running stat estimates instead of computing standing stats as in the paper). We do still support standing stats (see the sample.sh scripts).
|
95 |
+
This could also conceivably result in gradients from the earlier accumulations being stale, but in practice this does not appear to be a problem.
|
96 |
+
- The currently provided pretrained models were not trained with orthogonal regularization. Training without ortho reg seems to increase the probability that models will not be amenable to truncation,
|
97 |
+
but it looks like this particular model got a winning ticket. Regardless, we provide two highly optimized (fast and minimal memory consumption) ortho reg implementations which directly compute the ortho reg. gradients.
|
98 |
+
|
99 |
+
## A Note On The Design Of This Repo
|
100 |
+
This code is designed from the ground up to serve as an extensible, hackable base for further research code.
|
101 |
+
We've put a lot of thought into making sure the abstractions are the *right* thickness for research--not so thick as to be impenetrable, but not so thin as to be useless.
|
102 |
+
The key idea is that if you want to experiment with a SOTA setup and make some modification (try out your own new loss function, architecture, self-attention block, etc) you should be able to easily do so just by dropping your code in one or two places, without having to worry about the rest of the codebase.
|
103 |
+
Things like the use of self.which_conv and functools.partial in the BigGAN.py model definition were put together with this in mind, as was the design of the Spectral Norm class inheritance.
|
104 |
+
|
105 |
+
With that said, this is a somewhat large codebase for a single project. While we tried to be thorough with the comments, if there's something you think could be more clear, better written, or better refactored, please feel free to raise an issue or a pull request.
|
106 |
+
|
107 |
+
## Feature Requests
|
108 |
+
Want to work on or improve this code? There are a couple things this repo would benefit from, but which don't yet work.
|
109 |
+
|
110 |
+
- Synchronized BatchNorm (AKA Cross-Replica BatchNorm). We tried out two variants of this, but for some unknown reason it crippled training each time.
|
111 |
+
We have not tried the [apex](https://github.com/NVIDIA/apex) SyncBN as my school's servers are on ancient NVIDIA drivers that don't support it--apex would probably be a good place to start.
|
112 |
+
- Mixed precision training and making use of Tensor cores. This repo includes a naive mixed-precision Adam implementation which works early in training but leads to early collapse, and doesn't do anything to activate Tensor cores (it just reduces memory consumption).
|
113 |
+
As above, integrating [apex](https://github.com/NVIDIA/apex) into this code and employing its mixed-precision training techniques to take advantage of Tensor cores and reduce memory consumption could yield substantial speed gains.
|
114 |
+
|
115 |
+
## Misc Notes
|
116 |
+
See [This directory](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for ImageNet labels.
|
117 |
+
|
118 |
+
If you use this code, please cite
|
119 |
+
```text
|
120 |
+
@inproceedings{
|
121 |
+
brock2018large,
|
122 |
+
title={Large Scale {GAN} Training for High Fidelity Natural Image Synthesis},
|
123 |
+
author={Andrew Brock and Jeff Donahue and Karen Simonyan},
|
124 |
+
booktitle={International Conference on Learning Representations},
|
125 |
+
year={2019},
|
126 |
+
url={https://openreview.net/forum?id=B1xsqj09Fm},
|
127 |
+
}
|
128 |
+
```
|
129 |
+
|
130 |
+
## Acknowledgments
|
131 |
+
Thanks to Google for the generous cloud credit donations.
|
132 |
+
|
133 |
+
[SyncBN](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) by Jiayuan Mao and Tete Xiao.
|
134 |
+
|
135 |
+
[Progress bar](https://github.com/Lasagne/Recipes/tree/master/papers/densenet) originally from Jan SchlΓΌter.
|
136 |
+
|
137 |
+
Test metrics logger from [VoxNet.](https://github.com/dimatura/voxnet)
|
138 |
+
|
139 |
+
PyTorch [implementation of cov](https://discuss.PyTorch.org/t/covariance-and-gradient-support/16217/2) from Modar M. Alfadly.
|
140 |
+
|
141 |
+
PyTorch [fast Matrix Sqrt](https://github.com/msubhransu/matrix-sqrt) for FID from Tsung-Yu Lin and Subhransu Maji.
|
142 |
+
|
143 |
+
TensorFlow Inception Score code from [OpenAI's Improved-GAN.](https://github.com/openai/improved-gan)
|
144 |
+
|
src/models/big/__init__.py
ADDED
File without changes
|
src/models/big/animal_hash.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
c = ['Aardvark', 'Abyssinian', 'Affenpinscher', 'Akbash', 'Akita', 'Albatross',
|
2 |
+
'Alligator', 'Alpaca', 'Angelfish', 'Ant', 'Anteater', 'Antelope', 'Ape',
|
3 |
+
'Armadillo', 'Ass', 'Avocet', 'Axolotl', 'Baboon', 'Badger', 'Balinese',
|
4 |
+
'Bandicoot', 'Barb', 'Barnacle', 'Barracuda', 'Bat', 'Beagle', 'Bear',
|
5 |
+
'Beaver', 'Bee', 'Beetle', 'Binturong', 'Bird', 'Birman', 'Bison',
|
6 |
+
'Bloodhound', 'Boar', 'Bobcat', 'Bombay', 'Bongo', 'Bonobo', 'Booby',
|
7 |
+
'Budgerigar', 'Buffalo', 'Bulldog', 'Bullfrog', 'Burmese', 'Butterfly',
|
8 |
+
'Caiman', 'Camel', 'Capybara', 'Caracal', 'Caribou', 'Cassowary', 'Cat',
|
9 |
+
'Caterpillar', 'Catfish', 'Cattle', 'Centipede', 'Chameleon', 'Chamois',
|
10 |
+
'Cheetah', 'Chicken', 'Chihuahua', 'Chimpanzee', 'Chinchilla', 'Chinook',
|
11 |
+
'Chipmunk', 'Chough', 'Cichlid', 'Clam', 'Coati', 'Cobra', 'Cockroach',
|
12 |
+
'Cod', 'Collie', 'Coral', 'Cormorant', 'Cougar', 'Cow', 'Coyote',
|
13 |
+
'Crab', 'Crane', 'Crocodile', 'Crow', 'Curlew', 'Cuscus', 'Cuttlefish',
|
14 |
+
'Dachshund', 'Dalmatian', 'Deer', 'Dhole', 'Dingo', 'Dinosaur', 'Discus',
|
15 |
+
'Dodo', 'Dog', 'Dogball', 'Dogfish', 'Dolphin', 'Donkey', 'Dormouse',
|
16 |
+
'Dove', 'Dragonfly', 'Drever', 'Duck', 'Dugong', 'Dunker', 'Dunlin',
|
17 |
+
'Eagle', 'Earwig', 'Echidna', 'Eel', 'Eland', 'Elephant', 'ElephantSeal',
|
18 |
+
'Elk', 'Emu', 'Falcon', 'Ferret', 'Finch', 'Fish', 'Flamingo', 'Flounder',
|
19 |
+
'Fly', 'Fossa', 'Fox', 'Frigatebird', 'Frog', 'Galago', 'Gar', 'Gaur',
|
20 |
+
'Gazelle', 'Gecko', 'Gerbil', 'Gharial', 'GiantPanda', 'Gibbon', 'Giraffe',
|
21 |
+
'Gnat', 'Gnu', 'Goat', 'Goldfinch', 'Goldfish', 'Goose', 'Gopher',
|
22 |
+
'Gorilla', 'Goshawk', 'Grasshopper', 'Greyhound', 'Grouse', 'Guanaco',
|
23 |
+
'GuineaFowl', 'GuineaPig', 'Gull', 'Guppy', 'Hamster', 'Hare', 'Harrier',
|
24 |
+
'Havanese', 'Hawk', 'Hedgehog', 'Heron', 'Herring', 'Himalayan',
|
25 |
+
'Hippopotamus', 'Hornet', 'Horse', 'Human', 'Hummingbird', 'Hyena',
|
26 |
+
'Ibis', 'Iguana', 'Impala', 'Indri', 'Insect', 'Jackal', 'Jaguar',
|
27 |
+
'Javanese', 'Jay', 'Jellyfish', 'Kakapo', 'Kangaroo', 'Kingfisher',
|
28 |
+
'Kiwi', 'Koala', 'KomodoDragon', 'Kouprey', 'Kudu', 'Labradoodle',
|
29 |
+
'Ladybird', 'Lapwing', 'Lark', 'Lemming', 'Lemur', 'Leopard', 'Liger',
|
30 |
+
'Lion', 'Lionfish', 'Lizard', 'Llama', 'Lobster', 'Locust', 'Loris',
|
31 |
+
'Louse', 'Lynx', 'Lyrebird', 'Macaw', 'Magpie', 'Mallard', 'Maltese',
|
32 |
+
'Manatee', 'Mandrill', 'Markhor', 'Marten', 'Mastiff', 'Mayfly', 'Meerkat',
|
33 |
+
'Millipede', 'Mink', 'Mole', 'Molly', 'Mongoose', 'Mongrel', 'Monkey',
|
34 |
+
'Moorhen', 'Moose', 'Mosquito', 'Moth', 'Mouse', 'Mule', 'Narwhal',
|
35 |
+
'Neanderthal', 'Newfoundland', 'Newt', 'Nightingale', 'Numbat', 'Ocelot',
|
36 |
+
'Octopus', 'Okapi', 'Olm', 'Opossum', 'Orang-utan', 'Oryx', 'Ostrich',
|
37 |
+
'Otter', 'Owl', 'Ox', 'Oyster', 'Pademelon', 'Panther', 'Parrot',
|
38 |
+
'Partridge', 'Peacock', 'Peafowl', 'Pekingese', 'Pelican', 'Penguin',
|
39 |
+
'Persian', 'Pheasant', 'Pig', 'Pigeon', 'Pika', 'Pike', 'Piranha',
|
40 |
+
'Platypus', 'Pointer', 'Pony', 'Poodle', 'Porcupine', 'Porpoise',
|
41 |
+
'Possum', 'PrairieDog', 'Prawn', 'Puffin', 'Pug', 'Puma', 'Quail',
|
42 |
+
'Quelea', 'Quetzal', 'Quokka', 'Quoll', 'Rabbit', 'Raccoon', 'Ragdoll',
|
43 |
+
'Rail', 'Ram', 'Rat', 'Rattlesnake', 'Raven', 'RedDeer', 'RedPanda',
|
44 |
+
'Reindeer', 'Rhinoceros', 'Robin', 'Rook', 'Rottweiler', 'Ruff',
|
45 |
+
'Salamander', 'Salmon', 'SandDollar', 'Sandpiper', 'Saola',
|
46 |
+
'Sardine', 'Scorpion', 'SeaLion', 'SeaUrchin', 'Seahorse',
|
47 |
+
'Seal', 'Serval', 'Shark', 'Sheep', 'Shrew', 'Shrimp', 'Siamese',
|
48 |
+
'Siberian', 'Skunk', 'Sloth', 'Snail', 'Snake', 'Snowshoe', 'Somali',
|
49 |
+
'Sparrow', 'Spider', 'Sponge', 'Squid', 'Squirrel', 'Starfish', 'Starling',
|
50 |
+
'Stingray', 'Stinkbug', 'Stoat', 'Stork', 'Swallow', 'Swan', 'Tang',
|
51 |
+
'Tapir', 'Tarsier', 'Termite', 'Tetra', 'Tiffany', 'Tiger', 'Toad',
|
52 |
+
'Tortoise', 'Toucan', 'Tropicbird', 'Trout', 'Tuatara', 'Turkey',
|
53 |
+
'Turtle', 'Uakari', 'Uguisu', 'Umbrellabird', 'Viper', 'Vulture',
|
54 |
+
'Wallaby', 'Walrus', 'Warthog', 'Wasp', 'WaterBuffalo', 'Weasel',
|
55 |
+
'Whale', 'Whippet', 'Wildebeest', 'Wolf', 'Wolverine', 'Wombat',
|
56 |
+
'Woodcock', 'Woodlouse', 'Woodpecker', 'Worm', 'Wrasse', 'Wren',
|
57 |
+
'Yak', 'Zebra', 'Zebu', 'Zonkey']
|
58 |
+
a = ['able', 'above', 'absent', 'absolute', 'abstract', 'abundant', 'academic',
|
59 |
+
'acceptable', 'accepted', 'accessible', 'accurate', 'accused', 'active',
|
60 |
+
'actual', 'acute', 'added', 'additional', 'adequate', 'adjacent',
|
61 |
+
'administrative', 'adorable', 'advanced', 'adverse', 'advisory',
|
62 |
+
'aesthetic', 'afraid', 'african', 'aggregate', 'aggressive', 'agreeable',
|
63 |
+
'agreed', 'agricultural', 'alert', 'alive', 'alleged', 'allied', 'alone',
|
64 |
+
'alright', 'alternative', 'amateur', 'amazing', 'ambitious', 'american',
|
65 |
+
'amused', 'ancient', 'angry', 'annoyed', 'annual', 'anonymous', 'anxious',
|
66 |
+
'appalling', 'apparent', 'applicable', 'appropriate', 'arab', 'arbitrary',
|
67 |
+
'architectural', 'armed', 'arrogant', 'artificial', 'artistic', 'ashamed',
|
68 |
+
'asian', 'asleep', 'assistant', 'associated', 'atomic', 'attractive',
|
69 |
+
'australian', 'automatic', 'autonomous', 'available', 'average',
|
70 |
+
'awake', 'aware', 'awful', 'awkward', 'back', 'bad', 'balanced', 'bare',
|
71 |
+
'basic', 'beautiful', 'beneficial', 'better', 'bewildered', 'big',
|
72 |
+
'binding', 'biological', 'bitter', 'bizarre', 'black', 'blank', 'blind',
|
73 |
+
'blonde', 'bloody', 'blue', 'blushing', 'boiling', 'bold', 'bored',
|
74 |
+
'boring', 'bottom', 'brainy', 'brave', 'breakable', 'breezy', 'brief',
|
75 |
+
'bright', 'brilliant', 'british', 'broad', 'broken', 'brown', 'bumpy',
|
76 |
+
'burning', 'busy', 'calm', 'canadian', 'capable', 'capitalist', 'careful',
|
77 |
+
'casual', 'catholic', 'causal', 'cautious', 'central', 'certain',
|
78 |
+
'changing', 'characteristic', 'charming', 'cheap', 'cheerful', 'chemical',
|
79 |
+
'chief', 'chilly', 'chinese', 'chosen', 'christian', 'chronic', 'chubby',
|
80 |
+
'circular', 'civic', 'civil', 'civilian', 'classic', 'classical', 'clean',
|
81 |
+
'clear', 'clever', 'clinical', 'close', 'closed', 'cloudy', 'clumsy',
|
82 |
+
'coastal', 'cognitive', 'coherent', 'cold', 'collective', 'colonial',
|
83 |
+
'colorful', 'colossal', 'coloured', 'colourful', 'combative', 'combined',
|
84 |
+
'comfortable', 'coming', 'commercial', 'common', 'communist', 'compact',
|
85 |
+
'comparable', 'comparative', 'compatible', 'competent', 'competitive',
|
86 |
+
'complete', 'complex', 'complicated', 'comprehensive', 'compulsory',
|
87 |
+
'conceptual', 'concerned', 'concrete', 'condemned', 'confident',
|
88 |
+
'confidential', 'confused', 'conscious', 'conservation', 'conservative',
|
89 |
+
'considerable', 'consistent', 'constant', 'constitutional',
|
90 |
+
'contemporary', 'content', 'continental', 'continued', 'continuing',
|
91 |
+
'continuous', 'controlled', 'controversial', 'convenient', 'conventional',
|
92 |
+
'convinced', 'convincing', 'cooing', 'cool', 'cooperative', 'corporate',
|
93 |
+
'correct', 'corresponding', 'costly', 'courageous', 'crazy', 'creative',
|
94 |
+
'creepy', 'criminal', 'critical', 'crooked', 'crowded', 'crucial',
|
95 |
+
'crude', 'cruel', 'cuddly', 'cultural', 'curious', 'curly', 'current',
|
96 |
+
'curved', 'cute', 'daily', 'damaged', 'damp', 'dangerous', 'dark', 'dead',
|
97 |
+
'deaf', 'deafening', 'dear', 'decent', 'decisive', 'deep', 'defeated',
|
98 |
+
'defensive', 'defiant', 'definite', 'deliberate', 'delicate', 'delicious',
|
99 |
+
'delighted', 'delightful', 'democratic', 'dependent', 'depressed',
|
100 |
+
'desirable', 'desperate', 'detailed', 'determined', 'developed',
|
101 |
+
'developing', 'devoted', 'different', 'difficult', 'digital', 'diplomatic',
|
102 |
+
'direct', 'dirty', 'disabled', 'disappointed', 'disastrous',
|
103 |
+
'disciplinary', 'disgusted', 'distant', 'distinct', 'distinctive',
|
104 |
+
'distinguished', 'disturbed', 'disturbing', 'diverse', 'divine', 'dizzy',
|
105 |
+
'domestic', 'dominant', 'double', 'doubtful', 'drab', 'dramatic',
|
106 |
+
'dreadful', 'driving', 'drunk', 'dry', 'dual', 'due', 'dull', 'dusty',
|
107 |
+
'dutch', 'dying', 'dynamic', 'eager', 'early', 'eastern', 'easy',
|
108 |
+
'economic', 'educational', 'eerie', 'effective', 'efficient',
|
109 |
+
'elaborate', 'elated', 'elderly', 'eldest', 'electoral', 'electric',
|
110 |
+
'electrical', 'electronic', 'elegant', 'eligible', 'embarrassed',
|
111 |
+
'embarrassing', 'emotional', 'empirical', 'empty', 'enchanting',
|
112 |
+
'encouraging', 'endless', 'energetic', 'english', 'enormous',
|
113 |
+
'enthusiastic', 'entire', 'entitled', 'envious', 'environmental', 'equal',
|
114 |
+
'equivalent', 'essential', 'established', 'estimated', 'ethical',
|
115 |
+
'ethnic', 'european', 'eventual', 'everyday', 'evident', 'evil',
|
116 |
+
'evolutionary', 'exact', 'excellent', 'exceptional', 'excess',
|
117 |
+
'excessive', 'excited', 'exciting', 'exclusive', 'existing', 'exotic',
|
118 |
+
'expected', 'expensive', 'experienced', 'experimental', 'explicit',
|
119 |
+
'extended', 'extensive', 'external', 'extra', 'extraordinary', 'extreme',
|
120 |
+
'exuberant', 'faint', 'fair', 'faithful', 'familiar', 'famous', 'fancy',
|
121 |
+
'fantastic', 'far', 'fascinating', 'fashionable', 'fast', 'fat', 'fatal',
|
122 |
+
'favourable', 'favourite', 'federal', 'fellow', 'female', 'feminist',
|
123 |
+
'few', 'fierce', 'filthy', 'final', 'financial', 'fine', 'firm', 'fiscal',
|
124 |
+
'fit', 'fixed', 'flaky', 'flat', 'flexible', 'fluffy', 'fluttering',
|
125 |
+
'flying', 'following', 'fond', 'foolish', 'foreign', 'formal',
|
126 |
+
'formidable', 'forthcoming', 'fortunate', 'forward', 'fragile',
|
127 |
+
'frail', 'frantic', 'free', 'french', 'frequent', 'fresh', 'friendly',
|
128 |
+
'frightened', 'front', 'frozen', 'fucking', 'full', 'full-time', 'fun',
|
129 |
+
'functional', 'fundamental', 'funny', 'furious', 'future', 'fuzzy',
|
130 |
+
'gastric', 'gay', 'general', 'generous', 'genetic', 'gentle', 'genuine',
|
131 |
+
'geographical', 'german', 'giant', 'gigantic', 'given', 'glad',
|
132 |
+
'glamorous', 'gleaming', 'global', 'glorious', 'golden', 'good',
|
133 |
+
'gorgeous', 'gothic', 'governing', 'graceful', 'gradual', 'grand',
|
134 |
+
'grateful', 'greasy', 'great', 'greek', 'green', 'grey', 'grieving',
|
135 |
+
'grim', 'gross', 'grotesque', 'growing', 'grubby', 'grumpy', 'guilty',
|
136 |
+
'handicapped', 'handsome', 'happy', 'hard', 'harsh', 'head', 'healthy',
|
137 |
+
'heavy', 'helpful', 'helpless', 'hidden', 'high', 'high-pitched',
|
138 |
+
'hilarious', 'hissing', 'historic', 'historical', 'hollow', 'holy',
|
139 |
+
'homeless', 'homely', 'hon', 'honest', 'horizontal', 'horrible',
|
140 |
+
'hostile', 'hot', 'huge', 'human', 'hungry', 'hurt', 'hushed', 'husky',
|
141 |
+
'icy', 'ideal', 'identical', 'ideological', 'ill', 'illegal',
|
142 |
+
'imaginative', 'immediate', 'immense', 'imperial', 'implicit',
|
143 |
+
'important', 'impossible', 'impressed', 'impressive', 'improved',
|
144 |
+
'inadequate', 'inappropriate', 'inc', 'inclined', 'increased',
|
145 |
+
'increasing', 'incredible', 'independent', 'indian', 'indirect',
|
146 |
+
'individual', 'industrial', 'inevitable', 'influential', 'informal',
|
147 |
+
'inherent', 'initial', 'injured', 'inland', 'inner', 'innocent',
|
148 |
+
'innovative', 'inquisitive', 'instant', 'institutional', 'insufficient',
|
149 |
+
'intact', 'integral', 'integrated', 'intellectual', 'intelligent',
|
150 |
+
'intense', 'intensive', 'interested', 'interesting', 'interim',
|
151 |
+
'interior', 'intermediate', 'internal', 'international', 'intimate',
|
152 |
+
'invisible', 'involved', 'iraqi', 'irish', 'irrelevant', 'islamic',
|
153 |
+
'isolated', 'israeli', 'italian', 'itchy', 'japanese', 'jealous',
|
154 |
+
'jewish', 'jittery', 'joint', 'jolly', 'joyous', 'judicial', 'juicy',
|
155 |
+
'junior', 'just', 'keen', 'key', 'kind', 'known', 'korean', 'labour',
|
156 |
+
'large', 'large-scale', 'late', 'latin', 'lazy', 'leading', 'left',
|
157 |
+
'legal', 'legislative', 'legitimate', 'lengthy', 'lesser', 'level',
|
158 |
+
'lexical', 'liable', 'liberal', 'light', 'like', 'likely', 'limited',
|
159 |
+
'linear', 'linguistic', 'liquid', 'literary', 'little', 'live', 'lively',
|
160 |
+
'living', 'local', 'logical', 'lonely', 'long', 'long-term', 'loose',
|
161 |
+
'lost', 'loud', 'lovely', 'low', 'loyal', 'ltd', 'lucky', 'mad',
|
162 |
+
'magenta', 'magic', 'magnetic', 'magnificent', 'main', 'major', 'male',
|
163 |
+
'mammoth', 'managerial', 'managing', 'manual', 'many', 'marginal',
|
164 |
+
'marine', 'marked', 'married', 'marvellous', 'marxist', 'mass', 'massive',
|
165 |
+
'mathematical', 'mature', 'maximum', 'mean', 'meaningful', 'mechanical',
|
166 |
+
'medical', 'medieval', 'melodic', 'melted', 'mental', 'mere',
|
167 |
+
'metropolitan', 'mid', 'middle', 'middle-class', 'mighty', 'mild',
|
168 |
+
'military', 'miniature', 'minimal', 'minimum', 'ministerial', 'minor',
|
169 |
+
'miserable', 'misleading', 'missing', 'misty', 'mixed', 'moaning',
|
170 |
+
'mobile', 'moderate', 'modern', 'modest', 'molecular', 'monetary',
|
171 |
+
'monthly', 'moral', 'motionless', 'muddy', 'multiple', 'mushy',
|
172 |
+
'musical', 'mute', 'mutual', 'mysterious', 'naked', 'narrow', 'nasty',
|
173 |
+
'national', 'native', 'natural', 'naughty', 'naval', 'near', 'nearby',
|
174 |
+
'neat', 'necessary', 'negative', 'neighbouring', 'nervous', 'net',
|
175 |
+
'neutral', 'new', 'nice', 'nineteenth-century', 'noble', 'noisy',
|
176 |
+
'normal', 'northern', 'nosy', 'notable', 'novel', 'nuclear', 'numerous',
|
177 |
+
'nursing', 'nutritious', 'nutty', 'obedient', 'objective', 'obliged',
|
178 |
+
'obnoxious', 'obvious', 'occasional', 'occupational', 'odd', 'official',
|
179 |
+
'ok', 'okay', 'old', 'old-fashioned', 'olympic', 'only', 'open',
|
180 |
+
'operational', 'opposite', 'optimistic', 'oral', 'orange', 'ordinary',
|
181 |
+
'organic', 'organisational', 'original', 'orthodox', 'other', 'outdoor',
|
182 |
+
'outer', 'outrageous', 'outside', 'outstanding', 'overall', 'overseas',
|
183 |
+
'overwhelming', 'painful', 'pale', 'palestinian', 'panicky', 'parallel',
|
184 |
+
'parental', 'parliamentary', 'part-time', 'partial', 'particular',
|
185 |
+
'passing', 'passive', 'past', 'patient', 'payable', 'peaceful',
|
186 |
+
'peculiar', 'perfect', 'permanent', 'persistent', 'personal', 'petite',
|
187 |
+
'philosophical', 'physical', 'pink', 'plain', 'planned', 'plastic',
|
188 |
+
'pleasant', 'pleased', 'poised', 'polish', 'polite', 'political', 'poor',
|
189 |
+
'popular', 'positive', 'possible', 'post-war', 'potential', 'powerful',
|
190 |
+
'practical', 'precious', 'precise', 'preferred', 'pregnant',
|
191 |
+
'preliminary', 'premier', 'prepared', 'present', 'presidential',
|
192 |
+
'pretty', 'previous', 'prickly', 'primary', 'prime', 'primitive',
|
193 |
+
'principal', 'printed', 'prior', 'private', 'probable', 'productive',
|
194 |
+
'professional', 'profitable', 'profound', 'progressive', 'prominent',
|
195 |
+
'promising', 'proper', 'proposed', 'prospective', 'protective',
|
196 |
+
'protestant', 'proud', 'provincial', 'psychiatric', 'psychological',
|
197 |
+
'public', 'puny', 'pure', 'purple', 'purring', 'puzzled', 'quaint',
|
198 |
+
'qualified', 'quick', 'quickest', 'quiet', 'racial', 'radical', 'rainy',
|
199 |
+
'random', 'rapid', 'rare', 'raspy', 'rational', 'ratty', 'raw', 'ready',
|
200 |
+
'real', 'realistic', 'rear', 'reasonable', 'recent', 'red', 'reduced',
|
201 |
+
'redundant', 'regional', 'registered', 'regular', 'regulatory', 'related',
|
202 |
+
'relative', 'relaxed', 'relevant', 'reliable', 'relieved', 'religious',
|
203 |
+
'reluctant', 'remaining', 'remarkable', 'remote', 'renewed',
|
204 |
+
'representative', 'repulsive', 'required', 'resident', 'residential',
|
205 |
+
'resonant', 'respectable', 'respective', 'responsible', 'resulting',
|
206 |
+
'retail', 'retired', 'revolutionary', 'rich', 'ridiculous', 'right',
|
207 |
+
'rigid', 'ripe', 'rising', 'rival', 'roasted', 'robust', 'rolling',
|
208 |
+
'roman', 'romantic', 'rotten', 'rough', 'round', 'royal', 'rubber',
|
209 |
+
'rude', 'ruling', 'running', 'rural', 'russian', 'sacred', 'sad', 'safe',
|
210 |
+
'salty', 'satisfactory', 'satisfied', 'scared', 'scary', 'scattered',
|
211 |
+
'scientific', 'scornful', 'scottish', 'scrawny', 'screeching',
|
212 |
+
'secondary', 'secret', 'secure', 'select', 'selected', 'selective',
|
213 |
+
'selfish', 'semantic', 'senior', 'sensible', 'sensitive', 'separate',
|
214 |
+
'serious', 'severe', 'sexual', 'shaggy', 'shaky', 'shallow', 'shared',
|
215 |
+
'sharp', 'sheer', 'shiny', 'shivering', 'shocked', 'short', 'short-term',
|
216 |
+
'shrill', 'shy', 'sick', 'significant', 'silent', 'silky', 'silly',
|
217 |
+
'similar', 'simple', 'single', 'skilled', 'skinny', 'sleepy', 'slight',
|
218 |
+
'slim', 'slimy', 'slippery', 'slow', 'small', 'smart', 'smiling',
|
219 |
+
'smoggy', 'smooth', 'so-called', 'social', 'socialist', 'soft', 'solar',
|
220 |
+
'sole', 'solid', 'sophisticated', 'sore', 'sorry', 'sound', 'sour',
|
221 |
+
'southern', 'soviet', 'spanish', 'spare', 'sparkling', 'spatial',
|
222 |
+
'special', 'specific', 'specified', 'spectacular', 'spicy', 'spiritual',
|
223 |
+
'splendid', 'spontaneous', 'sporting', 'spotless', 'spotty', 'square',
|
224 |
+
'squealing', 'stable', 'stale', 'standard', 'static', 'statistical',
|
225 |
+
'statutory', 'steady', 'steep', 'sticky', 'stiff', 'still', 'stingy',
|
226 |
+
'stormy', 'straight', 'straightforward', 'strange', 'strategic',
|
227 |
+
'strict', 'striking', 'striped', 'strong', 'structural', 'stuck',
|
228 |
+
'stupid', 'subjective', 'subsequent', 'substantial', 'subtle',
|
229 |
+
'successful', 'successive', 'sudden', 'sufficient', 'suitable',
|
230 |
+
'sunny', 'super', 'superb', 'superior', 'supporting', 'supposed',
|
231 |
+
'supreme', 'sure', 'surprised', 'surprising', 'surrounding',
|
232 |
+
'surviving', 'suspicious', 'sweet', 'swift', 'swiss', 'symbolic',
|
233 |
+
'sympathetic', 'systematic', 'tall', 'tame', 'tan', 'tart',
|
234 |
+
'tasteless', 'tasty', 'technical', 'technological', 'teenage',
|
235 |
+
'temporary', 'tender', 'tense', 'terrible', 'territorial', 'testy',
|
236 |
+
'then', 'theoretical', 'thick', 'thin', 'thirsty', 'thorough',
|
237 |
+
'thoughtful', 'thoughtless', 'thundering', 'tight', 'tiny', 'tired',
|
238 |
+
'top', 'tory', 'total', 'tough', 'toxic', 'traditional', 'tragic',
|
239 |
+
'tremendous', 'tricky', 'tropical', 'troubled', 'turkish', 'typical',
|
240 |
+
'ugliest', 'ugly', 'ultimate', 'unable', 'unacceptable', 'unaware',
|
241 |
+
'uncertain', 'unchanged', 'uncomfortable', 'unconscious', 'underground',
|
242 |
+
'underlying', 'unemployed', 'uneven', 'unexpected', 'unfair',
|
243 |
+
'unfortunate', 'unhappy', 'uniform', 'uninterested', 'unique', 'united',
|
244 |
+
'universal', 'unknown', 'unlikely', 'unnecessary', 'unpleasant',
|
245 |
+
'unsightly', 'unusual', 'unwilling', 'upper', 'upset', 'uptight',
|
246 |
+
'urban', 'urgent', 'used', 'useful', 'useless', 'usual', 'vague',
|
247 |
+
'valid', 'valuable', 'variable', 'varied', 'various', 'varying', 'vast',
|
248 |
+
'verbal', 'vertical', 'very', 'victorian', 'victorious', 'video-taped',
|
249 |
+
'violent', 'visible', 'visiting', 'visual', 'vital', 'vivacious',
|
250 |
+
'vivid', 'vocational', 'voiceless', 'voluntary', 'vulnerable',
|
251 |
+
'wandering', 'warm', 'wasteful', 'watery', 'weak', 'wealthy', 'weary',
|
252 |
+
'wee', 'weekly', 'weird', 'welcome', 'well', 'well-known', 'welsh',
|
253 |
+
'western', 'wet', 'whispering', 'white', 'whole', 'wicked', 'wide',
|
254 |
+
'wide-eyed', 'widespread', 'wild', 'willing', 'wise', 'witty',
|
255 |
+
'wonderful', 'wooden', 'working', 'working-class', 'worldwide',
|
256 |
+
'worried', 'worrying', 'worthwhile', 'worthy', 'written', 'wrong',
|
257 |
+
'yellow', 'young', 'yummy', 'zany', 'zealous']
|
258 |
+
b = ['abiding', 'accelerating', 'accepting', 'accomplishing', 'achieving',
|
259 |
+
'acquiring', 'acteding', 'activating', 'adapting', 'adding', 'addressing',
|
260 |
+
'administering', 'admiring', 'admiting', 'adopting', 'advising', 'affording',
|
261 |
+
'agreeing', 'alerting', 'alighting', 'allowing', 'altereding', 'amusing',
|
262 |
+
'analyzing', 'announcing', 'annoying', 'answering', 'anticipating',
|
263 |
+
'apologizing', 'appearing', 'applauding', 'applieding', 'appointing',
|
264 |
+
'appraising', 'appreciating', 'approving', 'arbitrating', 'arguing',
|
265 |
+
'arising', 'arranging', 'arresting', 'arriving', 'ascertaining', 'asking',
|
266 |
+
'assembling', 'assessing', 'assisting', 'assuring', 'attaching', 'attacking',
|
267 |
+
'attaining', 'attempting', 'attending', 'attracting', 'auditeding', 'avoiding',
|
268 |
+
'awaking', 'backing', 'baking', 'balancing', 'baning', 'banging', 'baring',
|
269 |
+
'bating', 'bathing', 'battling', 'bing', 'beaming', 'bearing', 'beating',
|
270 |
+
'becoming', 'beging', 'begining', 'behaving', 'beholding', 'belonging',
|
271 |
+
'bending', 'beseting', 'beting', 'biding', 'binding', 'biting', 'bleaching',
|
272 |
+
'bleeding', 'blessing', 'blinding', 'blinking', 'bloting', 'blowing',
|
273 |
+
'blushing', 'boasting', 'boiling', 'bolting', 'bombing', 'booking',
|
274 |
+
'boring', 'borrowing', 'bouncing', 'bowing', 'boxing', 'braking',
|
275 |
+
'branching', 'breaking', 'breathing', 'breeding', 'briefing', 'bringing',
|
276 |
+
'broadcasting', 'bruising', 'brushing', 'bubbling', 'budgeting', 'building',
|
277 |
+
'bumping', 'burning', 'bursting', 'burying', 'busting', 'buying', 'buzing',
|
278 |
+
'calculating', 'calling', 'camping', 'caring', 'carrying', 'carving',
|
279 |
+
'casting', 'cataloging', 'catching', 'causing', 'challenging', 'changing',
|
280 |
+
'charging', 'charting', 'chasing', 'cheating', 'checking', 'cheering',
|
281 |
+
'chewing', 'choking', 'choosing', 'choping', 'claiming', 'claping',
|
282 |
+
'clarifying', 'classifying', 'cleaning', 'clearing', 'clinging', 'cliping',
|
283 |
+
'closing', 'clothing', 'coaching', 'coiling', 'collecting', 'coloring',
|
284 |
+
'combing', 'coming', 'commanding', 'communicating', 'comparing', 'competing',
|
285 |
+
'compiling', 'complaining', 'completing', 'composing', 'computing',
|
286 |
+
'conceiving', 'concentrating', 'conceptualizing', 'concerning', 'concluding',
|
287 |
+
'conducting', 'confessing', 'confronting', 'confusing', 'connecting',
|
288 |
+
'conserving', 'considering', 'consisting', 'consolidating', 'constructing',
|
289 |
+
'consulting', 'containing', 'continuing', 'contracting', 'controling',
|
290 |
+
'converting', 'coordinating', 'copying', 'correcting', 'correlating',
|
291 |
+
'costing', 'coughing', 'counseling', 'counting', 'covering', 'cracking',
|
292 |
+
'crashing', 'crawling', 'creating', 'creeping', 'critiquing', 'crossing',
|
293 |
+
'crushing', 'crying', 'curing', 'curling', 'curving', 'cuting', 'cycling',
|
294 |
+
'daming', 'damaging', 'dancing', 'daring', 'dealing', 'decaying', 'deceiving',
|
295 |
+
'deciding', 'decorating', 'defining', 'delaying', 'delegating', 'delighting',
|
296 |
+
'delivering', 'demonstrating', 'depending', 'describing', 'deserting',
|
297 |
+
'deserving', 'designing', 'destroying', 'detailing', 'detecting',
|
298 |
+
'determining', 'developing', 'devising', 'diagnosing', 'diging',
|
299 |
+
'directing', 'disagreing', 'disappearing', 'disapproving', 'disarming',
|
300 |
+
'discovering', 'disliking', 'dispensing', 'displaying', 'disproving',
|
301 |
+
'dissecting', 'distributing', 'diving', 'diverting', 'dividing', 'doing',
|
302 |
+
'doubling', 'doubting', 'drafting', 'draging', 'draining', 'dramatizing',
|
303 |
+
'drawing', 'dreaming', 'dressing', 'drinking', 'driping', 'driving',
|
304 |
+
'dropping', 'drowning', 'druming', 'drying', 'dusting', 'dwelling',
|
305 |
+
'earning', 'eating', 'editeding', 'educating', 'eliminating',
|
306 |
+
'embarrassing', 'employing', 'emptying', 'enacteding', 'encouraging',
|
307 |
+
'ending', 'enduring', 'enforcing', 'engineering', 'enhancing',
|
308 |
+
'enjoying', 'enlisting', 'ensuring', 'entering', 'entertaining',
|
309 |
+
'escaping', 'establishing', 'estimating', 'evaluating', 'examining',
|
310 |
+
'exceeding', 'exciting', 'excusing', 'executing', 'exercising', 'exhibiting',
|
311 |
+
'existing', 'expanding', 'expecting', 'expediting', 'experimenting',
|
312 |
+
'explaining', 'exploding', 'expressing', 'extending', 'extracting',
|
313 |
+
'facing', 'facilitating', 'fading', 'failing', 'fancying', 'fastening',
|
314 |
+
'faxing', 'fearing', 'feeding', 'feeling', 'fencing', 'fetching', 'fighting',
|
315 |
+
'filing', 'filling', 'filming', 'finalizing', 'financing', 'finding',
|
316 |
+
'firing', 'fiting', 'fixing', 'flaping', 'flashing', 'fleing', 'flinging',
|
317 |
+
'floating', 'flooding', 'flowing', 'flowering', 'flying', 'folding',
|
318 |
+
'following', 'fooling', 'forbiding', 'forcing', 'forecasting', 'foregoing',
|
319 |
+
'foreseing', 'foretelling', 'forgeting', 'forgiving', 'forming',
|
320 |
+
'formulating', 'forsaking', 'framing', 'freezing', 'frightening', 'frying',
|
321 |
+
'gathering', 'gazing', 'generating', 'geting', 'giving', 'glowing', 'gluing',
|
322 |
+
'going', 'governing', 'grabing', 'graduating', 'grating', 'greasing', 'greeting',
|
323 |
+
'grinning', 'grinding', 'griping', 'groaning', 'growing', 'guaranteeing',
|
324 |
+
'guarding', 'guessing', 'guiding', 'hammering', 'handing', 'handling',
|
325 |
+
'handwriting', 'hanging', 'happening', 'harassing', 'harming', 'hating',
|
326 |
+
'haunting', 'heading', 'healing', 'heaping', 'hearing', 'heating', 'helping',
|
327 |
+
'hiding', 'hitting', 'holding', 'hooking', 'hoping', 'hopping', 'hovering',
|
328 |
+
'hugging', 'hmuming', 'hunting', 'hurrying', 'hurting', 'hypothesizing',
|
329 |
+
'identifying', 'ignoring', 'illustrating', 'imagining', 'implementing',
|
330 |
+
'impressing', 'improving', 'improvising', 'including', 'increasing',
|
331 |
+
'inducing', 'influencing', 'informing', 'initiating', 'injecting',
|
332 |
+
'injuring', 'inlaying', 'innovating', 'inputing', 'inspecting',
|
333 |
+
'inspiring', 'installing', 'instituting', 'instructing', 'insuring',
|
334 |
+
'integrating', 'intending', 'intensifying', 'interesting',
|
335 |
+
'interfering', 'interlaying', 'interpreting', 'interrupting',
|
336 |
+
'interviewing', 'introducing', 'inventing', 'inventorying',
|
337 |
+
'investigating', 'inviting', 'irritating', 'itching', 'jailing',
|
338 |
+
'jamming', 'jogging', 'joining', 'joking', 'judging', 'juggling', 'jumping',
|
339 |
+
'justifying', 'keeping', 'kepting', 'kicking', 'killing', 'kissing', 'kneeling',
|
340 |
+
'kniting', 'knocking', 'knotting', 'knowing', 'labeling', 'landing', 'lasting',
|
341 |
+
'laughing', 'launching', 'laying', 'leading', 'leaning', 'leaping', 'learning',
|
342 |
+
'leaving', 'lecturing', 'leding', 'lending', 'leting', 'leveling',
|
343 |
+
'licensing', 'licking', 'lying', 'lifteding', 'lighting', 'lightening',
|
344 |
+
'liking', 'listing', 'listening', 'living', 'loading', 'locating',
|
345 |
+
'locking', 'loging', 'longing', 'looking', 'losing', 'loving',
|
346 |
+
'maintaining', 'making', 'maning', 'managing', 'manipulating',
|
347 |
+
'manufacturing', 'mapping', 'marching', 'marking', 'marketing',
|
348 |
+
'marrying', 'matching', 'mating', 'mattering', 'meaning', 'measuring',
|
349 |
+
'meddling', 'mediating', 'meeting', 'melting', 'melting', 'memorizing',
|
350 |
+
'mending', 'mentoring', 'milking', 'mining', 'misleading', 'missing',
|
351 |
+
'misspelling', 'mistaking', 'misunderstanding', 'mixing', 'moaning',
|
352 |
+
'modeling', 'modifying', 'monitoring', 'mooring', 'motivating',
|
353 |
+
'mourning', 'moving', 'mowing', 'muddling', 'muging', 'multiplying',
|
354 |
+
'murdering', 'nailing', 'naming', 'navigating', 'needing', 'negotiating',
|
355 |
+
'nesting', 'noding', 'nominating', 'normalizing', 'noting', 'noticing',
|
356 |
+
'numbering', 'obeying', 'objecting', 'observing', 'obtaining', 'occuring',
|
357 |
+
'offending', 'offering', 'officiating', 'opening', 'operating', 'ordering',
|
358 |
+
'organizing', 'orienteding', 'originating', 'overcoming', 'overdoing',
|
359 |
+
'overdrawing', 'overflowing', 'overhearing', 'overtaking', 'overthrowing',
|
360 |
+
'owing', 'owning', 'packing', 'paddling', 'painting', 'parking', 'parting',
|
361 |
+
'participating', 'passing', 'pasting', 'pating', 'pausing', 'paying',
|
362 |
+
'pecking', 'pedaling', 'peeling', 'peeping', 'perceiving', 'perfecting',
|
363 |
+
'performing', 'permiting', 'persuading', 'phoning', 'photographing',
|
364 |
+
'picking', 'piloting', 'pinching', 'pining', 'pinpointing', 'pioneering',
|
365 |
+
'placing', 'planing', 'planting', 'playing', 'pleading', 'pleasing',
|
366 |
+
'plugging', 'pointing', 'poking', 'polishing', 'poping', 'possessing',
|
367 |
+
'posting', 'pouring', 'practicing', 'praiseding', 'praying', 'preaching',
|
368 |
+
'preceding', 'predicting', 'prefering', 'preparing', 'prescribing',
|
369 |
+
'presenting', 'preserving', 'preseting', 'presiding', 'pressing',
|
370 |
+
'pretending', 'preventing', 'pricking', 'printing', 'processing',
|
371 |
+
'procuring', 'producing', 'professing', 'programing', 'progressing',
|
372 |
+
'projecting', 'promising', 'promoting', 'proofreading', 'proposing',
|
373 |
+
'protecting', 'proving', 'providing', 'publicizing', 'pulling', 'pumping',
|
374 |
+
'punching', 'puncturing', 'punishing', 'purchasing', 'pushing', 'puting',
|
375 |
+
'qualifying', 'questioning', 'queuing', 'quiting', 'racing', 'radiating',
|
376 |
+
'raining', 'raising', 'ranking', 'rating', 'reaching', 'reading',
|
377 |
+
'realigning', 'realizing', 'reasoning', 'receiving', 'recognizing',
|
378 |
+
'recommending', 'reconciling', 'recording', 'recruiting', 'reducing',
|
379 |
+
'referring', 'reflecting', 'refusing', 'regreting', 'regulating',
|
380 |
+
'rehabilitating', 'reigning', 'reinforcing', 'rejecting', 'rejoicing',
|
381 |
+
'relating', 'relaxing', 'releasing', 'relying', 'remaining', 'remembering',
|
382 |
+
'reminding', 'removing', 'rendering', 'reorganizing', 'repairing',
|
383 |
+
'repeating', 'replacing', 'replying', 'reporting', 'representing',
|
384 |
+
'reproducing', 'requesting', 'rescuing', 'researching', 'resolving',
|
385 |
+
'responding', 'restoreding', 'restructuring', 'retiring', 'retrieving',
|
386 |
+
'returning', 'reviewing', 'revising', 'rhyming', 'riding', 'riding',
|
387 |
+
'ringing', 'rinsing', 'rising', 'risking', 'robing', 'rocking', 'rolling',
|
388 |
+
'roting', 'rubing', 'ruining', 'ruling', 'runing', 'rushing', 'sacking',
|
389 |
+
'sailing', 'satisfying', 'saving', 'sawing', 'saying', 'scaring',
|
390 |
+
'scattering', 'scheduling', 'scolding', 'scorching', 'scraping',
|
391 |
+
'scratching', 'screaming', 'screwing', 'scribbling', 'scrubing',
|
392 |
+
'sealing', 'searching', 'securing', 'seing', 'seeking', 'selecting',
|
393 |
+
'selling', 'sending', 'sensing', 'separating', 'serving', 'servicing',
|
394 |
+
'seting', 'settling', 'sewing', 'shading', 'shaking', 'shaping',
|
395 |
+
'sharing', 'shaving', 'shearing', 'sheding', 'sheltering', 'shining',
|
396 |
+
'shivering', 'shocking', 'shoing', 'shooting', 'shoping', 'showing',
|
397 |
+
'shrinking', 'shruging', 'shuting', 'sighing', 'signing', 'signaling',
|
398 |
+
'simplifying', 'sining', 'singing', 'sinking', 'siping', 'siting',
|
399 |
+
'sketching', 'skiing', 'skiping', 'slaping', 'slaying', 'sleeping',
|
400 |
+
'sliding', 'slinging', 'slinking', 'sliping', 'sliting', 'slowing',
|
401 |
+
'smashing', 'smelling', 'smiling', 'smiting', 'smoking', 'snatching',
|
402 |
+
'sneaking', 'sneezing', 'sniffing', 'snoring', 'snowing', 'soaking',
|
403 |
+
'solving', 'soothing', 'soothsaying', 'sorting', 'sounding', 'sowing',
|
404 |
+
'sparing', 'sparking', 'sparkling', 'speaking', 'specifying', 'speeding',
|
405 |
+
'spelling', 'spending', 'spilling', 'spining', 'spiting', 'spliting',
|
406 |
+
'spoiling', 'spoting', 'spraying', 'spreading', 'springing', 'sprouting',
|
407 |
+
'squashing', 'squeaking', 'squealing', 'squeezing', 'staining', 'stamping',
|
408 |
+
'standing', 'staring', 'starting', 'staying', 'stealing', 'steering',
|
409 |
+
'stepping', 'sticking', 'stimulating', 'stinging', 'stinking', 'stirring',
|
410 |
+
'stitching', 'stoping', 'storing', 'straping', 'streamlining',
|
411 |
+
'strengthening', 'stretching', 'striding', 'striking', 'stringing',
|
412 |
+
'stripping', 'striving', 'stroking', 'structuring', 'studying',
|
413 |
+
'stuffing', 'subleting', 'subtracting', 'succeeding', 'sucking',
|
414 |
+
'suffering', 'suggesting', 'suiting', 'summarizing', 'supervising',
|
415 |
+
'supplying', 'supporting', 'supposing', 'surprising', 'surrounding',
|
416 |
+
'suspecting', 'suspending', 'swearing', 'sweating', 'sweeping', 'swelling',
|
417 |
+
'swimming', 'swinging', 'switching', 'symbolizing', 'synthesizing',
|
418 |
+
'systemizing', 'tabulating', 'taking', 'talking', 'taming', 'taping',
|
419 |
+
'targeting', 'tasting', 'teaching', 'tearing', 'teasing', 'telephoning',
|
420 |
+
'telling', 'tempting', 'terrifying', 'testing', 'thanking', 'thawing',
|
421 |
+
'thinking', 'thriving', 'throwing', 'thrusting', 'ticking', 'tickling',
|
422 |
+
'tying', 'timing', 'tiping', 'tiring', 'touching', 'touring', 'towing',
|
423 |
+
'tracing', 'trading', 'training', 'transcribing', 'transfering',
|
424 |
+
'transforming', 'translating', 'transporting', 'traping', 'traveling',
|
425 |
+
'treading', 'treating', 'trembling', 'tricking', 'triping', 'troting',
|
426 |
+
'troubling', 'troubleshooting', 'trusting', 'trying', 'tuging', 'tumbling',
|
427 |
+
'turning', 'tutoring', 'twisting', 'typing', 'undergoing', 'understanding',
|
428 |
+
'undertaking', 'undressing', 'unfastening', 'unifying', 'uniting',
|
429 |
+
'unlocking', 'unpacking', 'untidying', 'updating', 'upgrading',
|
430 |
+
'upholding', 'upseting', 'using', 'utilizing', 'vanishing', 'verbalizing',
|
431 |
+
'verifying', 'vexing', 'visiting', 'wailing', 'waiting', 'waking',
|
432 |
+
'walking', 'wandering', 'wanting', 'warming', 'warning', 'washing',
|
433 |
+
'wasting', 'watching', 'watering', 'waving', 'wearing', 'weaving',
|
434 |
+
'wedding', 'weeping', 'weighing', 'welcoming', 'wending', 'weting',
|
435 |
+
'whining', 'whiping', 'whirling', 'whispering', 'whistling', 'wining',
|
436 |
+
'winding', 'winking', 'wiping', 'wishing', 'withdrawing', 'withholding',
|
437 |
+
'withstanding', 'wobbling', 'wondering', 'working', 'worrying', 'wrapping',
|
438 |
+
'wrecking', 'wrestling', 'wriggling', 'wringing', 'writing', 'x-raying',
|
439 |
+
'yawning', 'yelling', 'zipping', 'zooming']
|
src/models/big/cheat sheet
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from big.BigGAN2 import Generator,Discriminator
|
2 |
+
|
3 |
+
from big.losses import generator_loss, discriminator_loss
|
4 |
+
|
5 |
+
generator = Generator().cuda()
|
6 |
+
discriminator = Discriminator().cuda()
|
7 |
+
|
8 |
+
label_transformed_fake = label_fc_net(label_fake)
|
9 |
+
label_transformed_real = label_fc_net(label_real)
|
10 |
+
|
11 |
+
generated_images = generator(decoder_input,label_transformed_fake)
|
12 |
+
|
13 |
+
#disc training
|
14 |
+
|
15 |
+
prediction_fake = discriminator(generated_images.detach(),label_transformed_fake).view(-1)
|
16 |
+
prediction_real = discriminator(images,label_transformed_real).view(-1)
|
17 |
+
|
18 |
+
d_loss_real,d_loss_fake = discriminator_loss(prediction_fake,prediction_real)
|
19 |
+
|
20 |
+
discriminator.optim.step()
|
21 |
+
|
22 |
+
|
23 |
+
#gen training
|
24 |
+
|
25 |
+
prediction = discriminator(generated_images,label_transformed_fake).view(-1)
|
26 |
+
|
27 |
+
g_loss = generator_loss( prediction)
|
28 |
+
g_loss.backward()
|
29 |
+
|
30 |
+
generator.optim.step()
|
src/models/big/datasets.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
''' Datasets
|
2 |
+
This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets
|
3 |
+
'''
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import sys
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
|
11 |
+
import torchvision.datasets as dset
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
from torchvision.datasets.utils import download_url, check_integrity
|
14 |
+
import torch.utils.data as data
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
|
17 |
+
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
|
18 |
+
|
19 |
+
|
20 |
+
def is_image_file(filename):
|
21 |
+
"""Checks if a file is an image.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
filename (string): path to a file
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
bool: True if the filename ends with a known image extension
|
28 |
+
"""
|
29 |
+
filename_lower = filename.lower()
|
30 |
+
return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
|
31 |
+
|
32 |
+
|
33 |
+
def find_classes(dir):
|
34 |
+
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
|
35 |
+
classes.sort()
|
36 |
+
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
37 |
+
return classes, class_to_idx
|
38 |
+
|
39 |
+
|
40 |
+
def make_dataset(dir, class_to_idx):
|
41 |
+
images = []
|
42 |
+
dir = os.path.expanduser(dir)
|
43 |
+
for target in tqdm(sorted(os.listdir(dir))):
|
44 |
+
d = os.path.join(dir, target)
|
45 |
+
if not os.path.isdir(d):
|
46 |
+
continue
|
47 |
+
|
48 |
+
for root, _, fnames in sorted(os.walk(d)):
|
49 |
+
for fname in sorted(fnames):
|
50 |
+
if is_image_file(fname):
|
51 |
+
path = os.path.join(root, fname)
|
52 |
+
item = (path, class_to_idx[target])
|
53 |
+
images.append(item)
|
54 |
+
|
55 |
+
return images
|
56 |
+
|
57 |
+
|
58 |
+
def pil_loader(path):
|
59 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
60 |
+
with open(path, 'rb') as f:
|
61 |
+
img = Image.open(f)
|
62 |
+
return img.convert('RGB')
|
63 |
+
|
64 |
+
|
65 |
+
def accimage_loader(path):
|
66 |
+
import accimage
|
67 |
+
try:
|
68 |
+
return accimage.Image(path)
|
69 |
+
except IOError:
|
70 |
+
# Potentially a decoding problem, fall back to PIL.Image
|
71 |
+
return pil_loader(path)
|
72 |
+
|
73 |
+
|
74 |
+
def default_loader(path):
|
75 |
+
from torchvision import get_image_backend
|
76 |
+
if get_image_backend() == 'accimage':
|
77 |
+
return accimage_loader(path)
|
78 |
+
else:
|
79 |
+
return pil_loader(path)
|
80 |
+
|
81 |
+
|
82 |
+
class ImageFolder(data.Dataset):
|
83 |
+
"""A generic data loader where the images are arranged in this way: ::
|
84 |
+
|
85 |
+
root/dogball/xxx.png
|
86 |
+
root/dogball/xxy.png
|
87 |
+
root/dogball/xxz.png
|
88 |
+
|
89 |
+
root/cat/123.png
|
90 |
+
root/cat/nsdf3.png
|
91 |
+
root/cat/asd932_.png
|
92 |
+
|
93 |
+
Args:
|
94 |
+
root (string): Root directory path.
|
95 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
96 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
97 |
+
target_transform (callable, optional): A function/transform that takes in the
|
98 |
+
target and transforms it.
|
99 |
+
loader (callable, optional): A function to load an image given its path.
|
100 |
+
|
101 |
+
Attributes:
|
102 |
+
classes (list): List of the class names.
|
103 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
104 |
+
imgs (list): List of (image path, class_index) tuples
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, root, transform=None, target_transform=None,
|
108 |
+
loader=default_loader, load_in_mem=False,
|
109 |
+
index_filename='imagenet_imgs.npz', **kwargs):
|
110 |
+
classes, class_to_idx = find_classes(root)
|
111 |
+
# Load pre-computed image directory walk
|
112 |
+
if os.path.exists(index_filename):
|
113 |
+
print('Loading pre-saved Index file %s...' % index_filename)
|
114 |
+
imgs = np.load(index_filename)['imgs']
|
115 |
+
# If first time, walk the folder directory and save the
|
116 |
+
# results to a pre-computed file.
|
117 |
+
else:
|
118 |
+
print('Generating Index file %s...' % index_filename)
|
119 |
+
imgs = make_dataset(root, class_to_idx)
|
120 |
+
np.savez_compressed(index_filename, **{'imgs' : imgs})
|
121 |
+
if len(imgs) == 0:
|
122 |
+
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
|
123 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
124 |
+
|
125 |
+
self.root = root
|
126 |
+
self.imgs = imgs
|
127 |
+
self.classes = classes
|
128 |
+
self.class_to_idx = class_to_idx
|
129 |
+
self.transform = transform
|
130 |
+
self.target_transform = target_transform
|
131 |
+
self.loader = loader
|
132 |
+
self.load_in_mem = load_in_mem
|
133 |
+
|
134 |
+
if self.load_in_mem:
|
135 |
+
print('Loading all images into memory...')
|
136 |
+
self.data, self.labels = [], []
|
137 |
+
for index in tqdm(range(len(self.imgs))):
|
138 |
+
path, target = imgs[index][0], imgs[index][1]
|
139 |
+
self.data.append(self.transform(self.loader(path)))
|
140 |
+
self.labels.append(target)
|
141 |
+
|
142 |
+
|
143 |
+
def __getitem__(self, index):
|
144 |
+
"""
|
145 |
+
Args:
|
146 |
+
index (int): Index
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
tuple: (image, target) where target is class_index of the target class.
|
150 |
+
"""
|
151 |
+
if self.load_in_mem:
|
152 |
+
img = self.data[index]
|
153 |
+
target = self.labels[index]
|
154 |
+
else:
|
155 |
+
path, target = self.imgs[index]
|
156 |
+
img = self.loader(str(path))
|
157 |
+
if self.transform is not None:
|
158 |
+
img = self.transform(img)
|
159 |
+
|
160 |
+
if self.target_transform is not None:
|
161 |
+
target = self.target_transform(target)
|
162 |
+
|
163 |
+
# print(img.size(), target)
|
164 |
+
return img, int(target)
|
165 |
+
|
166 |
+
def __len__(self):
|
167 |
+
return len(self.imgs)
|
168 |
+
|
169 |
+
def __repr__(self):
|
170 |
+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
171 |
+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
172 |
+
fmt_str += ' Root Location: {}\n'.format(self.root)
|
173 |
+
tmp = ' Transforms (if any): '
|
174 |
+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
175 |
+
tmp = ' Target Transforms (if any): '
|
176 |
+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
177 |
+
return fmt_str
|
178 |
+
|
179 |
+
|
180 |
+
''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid
|
181 |
+
having to load individual images all the time. '''
|
182 |
+
import h5py as h5
|
183 |
+
import torch
|
184 |
+
class ILSVRC_HDF5(data.Dataset):
|
185 |
+
def __init__(self, root, transform=None, target_transform=None,
|
186 |
+
load_in_mem=False, train=True,download=False, validate_seed=0,
|
187 |
+
val_split=0, **kwargs): # last four are dummies
|
188 |
+
|
189 |
+
self.root = root
|
190 |
+
self.num_imgs = len(h5.File(root, 'r')['labels'])
|
191 |
+
|
192 |
+
# self.transform = transform
|
193 |
+
self.target_transform = target_transform
|
194 |
+
|
195 |
+
# Set the transform here
|
196 |
+
self.transform = transform
|
197 |
+
|
198 |
+
# load the entire dataset into memory?
|
199 |
+
self.load_in_mem = load_in_mem
|
200 |
+
|
201 |
+
# If loading into memory, do so now
|
202 |
+
if self.load_in_mem:
|
203 |
+
print('Loading %s into memory...' % root)
|
204 |
+
with h5.File(root,'r') as f:
|
205 |
+
self.data = f['imgs'][:]
|
206 |
+
self.labels = f['labels'][:]
|
207 |
+
|
208 |
+
def __getitem__(self, index):
|
209 |
+
"""
|
210 |
+
Args:
|
211 |
+
index (int): Index
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
tuple: (image, target) where target is class_index of the target class.
|
215 |
+
"""
|
216 |
+
# If loaded the entire dataset in RAM, get image from memory
|
217 |
+
if self.load_in_mem:
|
218 |
+
img = self.data[index]
|
219 |
+
target = self.labels[index]
|
220 |
+
|
221 |
+
# Else load it from disk
|
222 |
+
else:
|
223 |
+
with h5.File(self.root,'r') as f:
|
224 |
+
img = f['imgs'][index]
|
225 |
+
target = f['labels'][index]
|
226 |
+
|
227 |
+
|
228 |
+
# if self.transform is not None:
|
229 |
+
# img = self.transform(img)
|
230 |
+
# Apply my own transform
|
231 |
+
img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2
|
232 |
+
|
233 |
+
if self.target_transform is not None:
|
234 |
+
target = self.target_transform(target)
|
235 |
+
|
236 |
+
return img, int(target)
|
237 |
+
|
238 |
+
def __len__(self):
|
239 |
+
return self.num_imgs
|
240 |
+
# return len(self.f['imgs'])
|
241 |
+
|
242 |
+
import pickle
|
243 |
+
class CIFAR10(dset.CIFAR10):
|
244 |
+
|
245 |
+
def __init__(self, root, train=True,
|
246 |
+
transform=None, target_transform=None,
|
247 |
+
download=True, validate_seed=0,
|
248 |
+
val_split=0, load_in_mem=True, **kwargs):
|
249 |
+
self.root = os.path.expanduser(root)
|
250 |
+
self.transform = transform
|
251 |
+
self.target_transform = target_transform
|
252 |
+
self.train = train # training set or test set
|
253 |
+
self.val_split = val_split
|
254 |
+
|
255 |
+
if download:
|
256 |
+
self.download()
|
257 |
+
|
258 |
+
if not self._check_integrity():
|
259 |
+
raise RuntimeError('Dataset not found or corrupted.' +
|
260 |
+
' You can use download=True to download it')
|
261 |
+
|
262 |
+
# now load the picked numpy arrays
|
263 |
+
self.data = []
|
264 |
+
self.labels= []
|
265 |
+
for fentry in self.train_list:
|
266 |
+
f = fentry[0]
|
267 |
+
file = os.path.join(self.root, self.base_folder, f)
|
268 |
+
fo = open(file, 'rb')
|
269 |
+
if sys.version_info[0] == 2:
|
270 |
+
entry = pickle.load(fo)
|
271 |
+
else:
|
272 |
+
entry = pickle.load(fo, encoding='latin1')
|
273 |
+
self.data.append(entry['data'])
|
274 |
+
if 'labels' in entry:
|
275 |
+
self.labels += entry['labels']
|
276 |
+
else:
|
277 |
+
self.labels += entry['fine_labels']
|
278 |
+
fo.close()
|
279 |
+
|
280 |
+
self.data = np.concatenate(self.data)
|
281 |
+
# Randomly select indices for validation
|
282 |
+
if self.val_split > 0:
|
283 |
+
label_indices = [[] for _ in range(max(self.labels)+1)]
|
284 |
+
for i,l in enumerate(self.labels):
|
285 |
+
label_indices[l] += [i]
|
286 |
+
label_indices = np.asarray(label_indices)
|
287 |
+
|
288 |
+
# randomly grab 500 elements of each class
|
289 |
+
np.random.seed(validate_seed)
|
290 |
+
self.val_indices = []
|
291 |
+
for l_i in label_indices:
|
292 |
+
self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)])
|
293 |
+
|
294 |
+
if self.train=='validate':
|
295 |
+
self.data = self.data[self.val_indices]
|
296 |
+
self.labels = list(np.asarray(self.labels)[self.val_indices])
|
297 |
+
|
298 |
+
self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32))
|
299 |
+
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
300 |
+
|
301 |
+
elif self.train:
|
302 |
+
print(np.shape(self.data))
|
303 |
+
if self.val_split > 0:
|
304 |
+
self.data = np.delete(self.data,self.val_indices,axis=0)
|
305 |
+
self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0))
|
306 |
+
|
307 |
+
self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32))
|
308 |
+
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
309 |
+
else:
|
310 |
+
f = self.test_list[0][0]
|
311 |
+
file = os.path.join(self.root, self.base_folder, f)
|
312 |
+
fo = open(file, 'rb')
|
313 |
+
if sys.version_info[0] == 2:
|
314 |
+
entry = pickle.load(fo)
|
315 |
+
else:
|
316 |
+
entry = pickle.load(fo, encoding='latin1')
|
317 |
+
self.data = entry['data']
|
318 |
+
if 'labels' in entry:
|
319 |
+
self.labels = entry['labels']
|
320 |
+
else:
|
321 |
+
self.labels = entry['fine_labels']
|
322 |
+
fo.close()
|
323 |
+
self.data = self.data.reshape((10000, 3, 32, 32))
|
324 |
+
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
325 |
+
|
326 |
+
def __getitem__(self, index):
|
327 |
+
"""
|
328 |
+
Args:
|
329 |
+
index (int): Index
|
330 |
+
Returns:
|
331 |
+
tuple: (image, target) where target is index of the target class.
|
332 |
+
"""
|
333 |
+
img, target = self.data[index], self.labels[index]
|
334 |
+
|
335 |
+
# doing this so that it is consistent with all other datasets
|
336 |
+
# to return a PIL Image
|
337 |
+
img = Image.fromarray(img)
|
338 |
+
|
339 |
+
if self.transform is not None:
|
340 |
+
img = self.transform(img)
|
341 |
+
|
342 |
+
if self.target_transform is not None:
|
343 |
+
target = self.target_transform(target)
|
344 |
+
|
345 |
+
return img, target
|
346 |
+
|
347 |
+
def __len__(self):
|
348 |
+
return len(self.data)
|
349 |
+
|
350 |
+
|
351 |
+
class CIFAR100(CIFAR10):
|
352 |
+
base_folder = 'cifar-100-python'
|
353 |
+
url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
354 |
+
filename = "cifar-100-python.tar.gz"
|
355 |
+
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
356 |
+
train_list = [
|
357 |
+
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
358 |
+
]
|
359 |
+
|
360 |
+
test_list = [
|
361 |
+
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
|
362 |
+
]
|
src/models/big/layers.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
''' Layers
|
2 |
+
This file contains various layers for the BigGAN models.
|
3 |
+
'''
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import Parameter as P
|
8 |
+
|
9 |
+
from src.models.big.sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
|
10 |
+
|
11 |
+
|
12 |
+
# Projection of x onto y
|
13 |
+
def proj(x, y):
|
14 |
+
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
|
15 |
+
|
16 |
+
|
17 |
+
# Orthogonalize x wrt list of vectors ys
|
18 |
+
def gram_schmidt(x, ys):
|
19 |
+
for y in ys:
|
20 |
+
x = x - proj(x, y)
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
# Apply num_itrs steps of the power method to estimate top N singular values.
|
25 |
+
def power_iteration(W, u_, update=True, eps=1e-12):
|
26 |
+
# Lists holding singular vectors and values
|
27 |
+
us, vs, svs = [], [], []
|
28 |
+
for i, u in enumerate(u_):
|
29 |
+
# Run one step of the power iteration
|
30 |
+
with torch.no_grad():
|
31 |
+
v = torch.matmul(u, W)
|
32 |
+
# Run Gram-Schmidt to subtract components of all other singular vectors
|
33 |
+
v = F.normalize(gram_schmidt(v, vs), eps=eps)
|
34 |
+
# Add to the list
|
35 |
+
vs += [v]
|
36 |
+
# Update the other singular vector
|
37 |
+
u = torch.matmul(v, W.t())
|
38 |
+
# Run Gram-Schmidt to subtract components of all other singular vectors
|
39 |
+
u = F.normalize(gram_schmidt(u, us), eps=eps)
|
40 |
+
# Add to the list
|
41 |
+
us += [u]
|
42 |
+
if update:
|
43 |
+
u_[i][:] = u
|
44 |
+
# Compute this singular value and add it to the list
|
45 |
+
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
|
46 |
+
#svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
|
47 |
+
return svs, us, vs
|
48 |
+
|
49 |
+
|
50 |
+
# Convenience passthrough function
|
51 |
+
class identity(nn.Module):
|
52 |
+
def forward(self, input):
|
53 |
+
return input
|
54 |
+
|
55 |
+
|
56 |
+
# Spectral normalization base class
|
57 |
+
class SN(object):
|
58 |
+
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
|
59 |
+
# Number of power iterations per step
|
60 |
+
self.num_itrs = num_itrs
|
61 |
+
# Number of singular values
|
62 |
+
self.num_svs = num_svs
|
63 |
+
# Transposed?
|
64 |
+
self.transpose = transpose
|
65 |
+
# Epsilon value for avoiding divide-by-0
|
66 |
+
self.eps = eps
|
67 |
+
# Register a singular vector for each sv
|
68 |
+
for i in range(self.num_svs):
|
69 |
+
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
|
70 |
+
self.register_buffer('sv%d' % i, torch.ones(1))
|
71 |
+
|
72 |
+
# Singular vectors (u side)
|
73 |
+
@property
|
74 |
+
def u(self):
|
75 |
+
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
|
76 |
+
|
77 |
+
# Singular values;
|
78 |
+
# note that these buffers are just for logging and are not used in training.
|
79 |
+
@property
|
80 |
+
def sv(self):
|
81 |
+
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
|
82 |
+
|
83 |
+
# Compute the spectrally-normalized weight
|
84 |
+
def W_(self):
|
85 |
+
W_mat = self.weight.view(self.weight.size(0), -1)
|
86 |
+
if self.transpose:
|
87 |
+
W_mat = W_mat.t()
|
88 |
+
# Apply num_itrs power iterations
|
89 |
+
for _ in range(self.num_itrs):
|
90 |
+
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
|
91 |
+
# Update the svs
|
92 |
+
if self.training:
|
93 |
+
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
|
94 |
+
for i, sv in enumerate(svs):
|
95 |
+
self.sv[i][:] = sv
|
96 |
+
return self.weight / svs[0]
|
97 |
+
|
98 |
+
|
99 |
+
# 2D Conv layer with spectral norm
|
100 |
+
class SNConv2d(nn.Conv2d, SN):
|
101 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
102 |
+
padding=0, dilation=1, groups=1, bias=True,
|
103 |
+
num_svs=1, num_itrs=1, eps=1e-12):
|
104 |
+
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
|
105 |
+
padding, dilation, groups, bias)
|
106 |
+
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
|
107 |
+
def forward(self, x):
|
108 |
+
return F.conv2d(x, self.W_(), self.bias, self.stride,
|
109 |
+
self.padding, self.dilation, self.groups)
|
110 |
+
|
111 |
+
|
112 |
+
# Linear layer with spectral norm
|
113 |
+
class SNLinear(nn.Linear, SN):
|
114 |
+
def __init__(self, in_features, out_features, bias=True,
|
115 |
+
num_svs=1, num_itrs=1, eps=1e-12):
|
116 |
+
nn.Linear.__init__(self, in_features, out_features, bias)
|
117 |
+
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
|
118 |
+
def forward(self, x):
|
119 |
+
return F.linear(x, self.W_(), self.bias)
|
120 |
+
|
121 |
+
|
122 |
+
# Embedding layer with spectral norm
|
123 |
+
# We use num_embeddings as the dim instead of embedding_dim here
|
124 |
+
# for convenience sake
|
125 |
+
class SNEmbedding(nn.Embedding, SN):
|
126 |
+
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
127 |
+
max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
128 |
+
sparse=False, _weight=None,
|
129 |
+
num_svs=1, num_itrs=1, eps=1e-12):
|
130 |
+
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
|
131 |
+
max_norm, norm_type, scale_grad_by_freq,
|
132 |
+
sparse, _weight)
|
133 |
+
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
|
134 |
+
def forward(self, x):
|
135 |
+
return F.embedding(x, self.W_())
|
136 |
+
|
137 |
+
|
138 |
+
# A non-local block as used in SA-GAN
|
139 |
+
# Note that the implementation as described in the paper is largely incorrect;
|
140 |
+
# refer to the released code for the actual implementation.
|
141 |
+
class Attention(nn.Module):
|
142 |
+
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
|
143 |
+
super(Attention, self).__init__()
|
144 |
+
# Channel multiplier
|
145 |
+
self.ch = ch
|
146 |
+
self.which_conv = which_conv
|
147 |
+
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
148 |
+
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
149 |
+
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
|
150 |
+
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
|
151 |
+
# Learnable gain parameter
|
152 |
+
self.gamma = P(torch.tensor(0.), requires_grad=True)
|
153 |
+
def forward(self, x, y=None):
|
154 |
+
# Apply convs
|
155 |
+
theta = self.theta(x)
|
156 |
+
phi = F.max_pool2d(self.phi(x), [2,2])
|
157 |
+
g = F.max_pool2d(self.g(x), [2,2])
|
158 |
+
# Perform reshapes
|
159 |
+
theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3])
|
160 |
+
phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4)
|
161 |
+
g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4)
|
162 |
+
# Matmul and softmax to get attention maps
|
163 |
+
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
|
164 |
+
# Attention map times g path
|
165 |
+
o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
|
166 |
+
return self.gamma * o + x
|
167 |
+
|
168 |
+
|
169 |
+
# Fused batchnorm op
|
170 |
+
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
|
171 |
+
# Apply scale and shift--if gain and bias are provided, fuse them here
|
172 |
+
# Prepare scale
|
173 |
+
scale = torch.rsqrt(var + eps)
|
174 |
+
# If a gain is provided, use it
|
175 |
+
if gain is not None:
|
176 |
+
scale = scale * gain
|
177 |
+
# Prepare shift
|
178 |
+
shift = mean * scale
|
179 |
+
# If bias is provided, use it
|
180 |
+
if bias is not None:
|
181 |
+
shift = shift - bias
|
182 |
+
return x * scale - shift
|
183 |
+
#return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
|
184 |
+
|
185 |
+
|
186 |
+
# Manual BN
|
187 |
+
# Calculate means and variances using mean-of-squares minus mean-squared
|
188 |
+
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
|
189 |
+
# Cast x to float32 if necessary
|
190 |
+
float_x = x.float()
|
191 |
+
# Calculate expected value of x (m) and expected value of x**2 (m2)
|
192 |
+
# Mean of x
|
193 |
+
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
|
194 |
+
# Mean of x squared
|
195 |
+
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
|
196 |
+
# Calculate variance as mean of squared minus mean squared.
|
197 |
+
var = (m2 - m **2)
|
198 |
+
# Cast back to float 16 if necessary
|
199 |
+
var = var.type(x.type())
|
200 |
+
m = m.type(x.type())
|
201 |
+
# Return mean and variance for updating stored mean/var if requested
|
202 |
+
if return_mean_var:
|
203 |
+
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
|
204 |
+
else:
|
205 |
+
return fused_bn(x, m, var, gain, bias, eps)
|
206 |
+
|
207 |
+
|
208 |
+
# My batchnorm, supports standing stats
|
209 |
+
class myBN(nn.Module):
|
210 |
+
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
|
211 |
+
super(myBN, self).__init__()
|
212 |
+
# momentum for updating running stats
|
213 |
+
self.momentum = momentum
|
214 |
+
# epsilon to avoid dividing by 0
|
215 |
+
self.eps = eps
|
216 |
+
# Momentum
|
217 |
+
self.momentum = momentum
|
218 |
+
# Register buffers
|
219 |
+
self.register_buffer('stored_mean', torch.zeros(num_channels))
|
220 |
+
self.register_buffer('stored_var', torch.ones(num_channels))
|
221 |
+
self.register_buffer('accumulation_counter', torch.zeros(1))
|
222 |
+
# Accumulate running means and vars
|
223 |
+
self.accumulate_standing = False
|
224 |
+
|
225 |
+
# reset standing stats
|
226 |
+
def reset_stats(self):
|
227 |
+
self.stored_mean[:] = 0
|
228 |
+
self.stored_var[:] = 0
|
229 |
+
self.accumulation_counter[:] = 0
|
230 |
+
|
231 |
+
def forward(self, x, gain, bias):
|
232 |
+
if self.training:
|
233 |
+
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
|
234 |
+
# If accumulating standing stats, increment them
|
235 |
+
if self.accumulate_standing:
|
236 |
+
self.stored_mean[:] = self.stored_mean + mean.data
|
237 |
+
self.stored_var[:] = self.stored_var + var.data
|
238 |
+
self.accumulation_counter += 1.0
|
239 |
+
# If not accumulating standing stats, take running averages
|
240 |
+
else:
|
241 |
+
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
|
242 |
+
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
|
243 |
+
return out
|
244 |
+
# If not in training mode, use the stored statistics
|
245 |
+
else:
|
246 |
+
mean = self.stored_mean.view(1, -1, 1, 1)
|
247 |
+
var = self.stored_var.view(1, -1, 1, 1)
|
248 |
+
# If using standing stats, divide them by the accumulation counter
|
249 |
+
if self.accumulate_standing:
|
250 |
+
mean = mean / self.accumulation_counter
|
251 |
+
var = var / self.accumulation_counter
|
252 |
+
return fused_bn(x, mean, var, gain, bias, self.eps)
|
253 |
+
|
254 |
+
|
255 |
+
# Simple function to handle groupnorm norm stylization
|
256 |
+
def groupnorm(x, norm_style):
|
257 |
+
# If number of channels specified in norm_style:
|
258 |
+
if 'ch' in norm_style:
|
259 |
+
ch = int(norm_style.split('_')[-1])
|
260 |
+
groups = max(int(x.shape[1]) // ch, 1)
|
261 |
+
# If number of groups specified in norm style
|
262 |
+
elif 'grp' in norm_style:
|
263 |
+
groups = int(norm_style.split('_')[-1])
|
264 |
+
# If neither, default to groups = 16
|
265 |
+
else:
|
266 |
+
groups = 16
|
267 |
+
return F.group_norm(x, groups)
|
268 |
+
|
269 |
+
|
270 |
+
# Class-conditional bn
|
271 |
+
# output size is the number of channels, input size is for the linear layers
|
272 |
+
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
|
273 |
+
# Suggestions welcome! (By which I mean, refactor this and make a pull request
|
274 |
+
# if you want to make this more readable/usable).
|
275 |
+
class ccbn(nn.Module):
|
276 |
+
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
|
277 |
+
cross_replica=False, mybn=False, norm_style='bn',):
|
278 |
+
super(ccbn, self).__init__()
|
279 |
+
self.output_size, self.input_size = output_size, input_size
|
280 |
+
# Prepare gain and bias layers
|
281 |
+
self.gain = which_linear(input_size, output_size)
|
282 |
+
self.bias = which_linear(input_size, output_size)
|
283 |
+
# epsilon to avoid dividing by 0
|
284 |
+
self.eps = eps
|
285 |
+
# Momentum
|
286 |
+
self.momentum = momentum
|
287 |
+
# Use cross-replica batchnorm?
|
288 |
+
self.cross_replica = cross_replica
|
289 |
+
# Use my batchnorm?
|
290 |
+
self.mybn = mybn
|
291 |
+
# Norm style?
|
292 |
+
self.norm_style = norm_style
|
293 |
+
|
294 |
+
if self.cross_replica:
|
295 |
+
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
296 |
+
elif self.mybn:
|
297 |
+
self.bn = myBN(output_size, self.eps, self.momentum)
|
298 |
+
elif self.norm_style in ['bn', 'in']:
|
299 |
+
self.register_buffer('stored_mean', torch.zeros(output_size))
|
300 |
+
self.register_buffer('stored_var', torch.ones(output_size))
|
301 |
+
|
302 |
+
|
303 |
+
def forward(self, x, y):
|
304 |
+
# Calculate class-conditional gains and biases
|
305 |
+
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
306 |
+
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
307 |
+
# If using my batchnorm
|
308 |
+
if self.mybn or self.cross_replica:
|
309 |
+
return self.bn(x, gain=gain, bias=bias)
|
310 |
+
# else:
|
311 |
+
else:
|
312 |
+
if self.norm_style == 'bn':
|
313 |
+
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
314 |
+
self.training, 0.1, self.eps)
|
315 |
+
elif self.norm_style == 'in':
|
316 |
+
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
|
317 |
+
self.training, 0.1, self.eps)
|
318 |
+
elif self.norm_style == 'gn':
|
319 |
+
out = groupnorm(x, self.normstyle)
|
320 |
+
elif self.norm_style == 'nonorm':
|
321 |
+
out = x
|
322 |
+
return out * gain + bias
|
323 |
+
def extra_repr(self):
|
324 |
+
s = 'out: {output_size}, in: {input_size},'
|
325 |
+
s +=' cross_replica={cross_replica}'
|
326 |
+
return s.format(**self.__dict__)
|
327 |
+
|
328 |
+
|
329 |
+
# Normal, non-class-conditional BN
|
330 |
+
class bn(nn.Module):
|
331 |
+
def __init__(self, output_size, eps=1e-5, momentum=0.1,
|
332 |
+
cross_replica=False, mybn=False):
|
333 |
+
super(bn, self).__init__()
|
334 |
+
self.output_size= output_size
|
335 |
+
# Prepare gain and bias layers
|
336 |
+
self.gain = P(torch.ones(output_size), requires_grad=True)
|
337 |
+
self.bias = P(torch.zeros(output_size), requires_grad=True)
|
338 |
+
# epsilon to avoid dividing by 0
|
339 |
+
self.eps = eps
|
340 |
+
# Momentum
|
341 |
+
self.momentum = momentum
|
342 |
+
# Use cross-replica batchnorm?
|
343 |
+
self.cross_replica = cross_replica
|
344 |
+
# Use my batchnorm?
|
345 |
+
self.mybn = mybn
|
346 |
+
|
347 |
+
if self.cross_replica:
|
348 |
+
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
349 |
+
elif mybn:
|
350 |
+
self.bn = myBN(output_size, self.eps, self.momentum)
|
351 |
+
# Register buffers if neither of the above
|
352 |
+
else:
|
353 |
+
self.register_buffer('stored_mean', torch.zeros(output_size))
|
354 |
+
self.register_buffer('stored_var', torch.ones(output_size))
|
355 |
+
|
356 |
+
def forward(self, x, y=None):
|
357 |
+
if self.cross_replica or self.mybn:
|
358 |
+
gain = self.gain.view(1,-1,1,1)
|
359 |
+
bias = self.bias.view(1,-1,1,1)
|
360 |
+
return self.bn(x, gain=gain, bias=bias)
|
361 |
+
else:
|
362 |
+
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
|
363 |
+
self.bias, self.training, self.momentum, self.eps)
|
364 |
+
|
365 |
+
|
366 |
+
# Generator blocks
|
367 |
+
# Note that this class assumes the kernel size and padding (and any other
|
368 |
+
# settings) have been selected in the main generator module and passed in
|
369 |
+
# through the which_conv arg. Similar rules apply with which_bn (the input
|
370 |
+
# size [which is actually the number of channels of the conditional info] must
|
371 |
+
# be preselected)
|
372 |
+
class GBlock(nn.Module):
|
373 |
+
def __init__(self, in_channels, out_channels,
|
374 |
+
which_conv=nn.Conv2d, which_bn=bn, activation=None,
|
375 |
+
upsample=None):
|
376 |
+
super(GBlock, self).__init__()
|
377 |
+
|
378 |
+
self.in_channels, self.out_channels = in_channels, out_channels
|
379 |
+
self.which_conv, self.which_bn = which_conv, which_bn
|
380 |
+
self.activation = activation
|
381 |
+
self.upsample = upsample
|
382 |
+
# Conv layers
|
383 |
+
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
|
384 |
+
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
|
385 |
+
self.learnable_sc = in_channels != out_channels or upsample
|
386 |
+
if self.learnable_sc:
|
387 |
+
self.conv_sc = self.which_conv(in_channels, out_channels,
|
388 |
+
kernel_size=1, padding=0)
|
389 |
+
# Batchnorm layers
|
390 |
+
self.bn1 = self.which_bn(in_channels)
|
391 |
+
self.bn2 = self.which_bn(out_channels)
|
392 |
+
# upsample layers
|
393 |
+
self.upsample = upsample
|
394 |
+
|
395 |
+
def forward(self, x, y):
|
396 |
+
h = self.activation(self.bn1(x, y))
|
397 |
+
if self.upsample:
|
398 |
+
h = self.upsample(h)
|
399 |
+
x = self.upsample(x)
|
400 |
+
h = self.conv1(h)
|
401 |
+
h = self.activation(self.bn2(h, y))
|
402 |
+
h = self.conv2(h)
|
403 |
+
if self.learnable_sc:
|
404 |
+
x = self.conv_sc(x)
|
405 |
+
return h + x
|
406 |
+
|
407 |
+
|
408 |
+
# Residual block for the discriminator
|
409 |
+
class DBlock(nn.Module):
|
410 |
+
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
|
411 |
+
preactivation=False, activation=None, downsample=None,):
|
412 |
+
super(DBlock, self).__init__()
|
413 |
+
self.in_channels, self.out_channels = in_channels, out_channels
|
414 |
+
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
|
415 |
+
self.hidden_channels = self.out_channels if wide else self.in_channels
|
416 |
+
self.which_conv = which_conv
|
417 |
+
self.preactivation = preactivation
|
418 |
+
self.activation = activation
|
419 |
+
self.downsample = downsample
|
420 |
+
|
421 |
+
# Conv layers
|
422 |
+
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
|
423 |
+
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
|
424 |
+
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
|
425 |
+
if self.learnable_sc:
|
426 |
+
self.conv_sc = self.which_conv(in_channels, out_channels,
|
427 |
+
kernel_size=1, padding=0)
|
428 |
+
def shortcut(self, x):
|
429 |
+
if self.preactivation:
|
430 |
+
if self.learnable_sc:
|
431 |
+
x = self.conv_sc(x)
|
432 |
+
if self.downsample:
|
433 |
+
x = self.downsample(x)
|
434 |
+
else:
|
435 |
+
if self.downsample:
|
436 |
+
x = self.downsample(x)
|
437 |
+
if self.learnable_sc:
|
438 |
+
x = self.conv_sc(x)
|
439 |
+
return x
|
440 |
+
|
441 |
+
def forward(self, x):
|
442 |
+
if self.preactivation:
|
443 |
+
# h = self.activation(x) # NOT TODAY SATAN
|
444 |
+
# Andy's note: This line *must* be an out-of-place ReLU or it
|
445 |
+
# will negatively affect the shortcut connection.
|
446 |
+
h = F.relu(x)
|
447 |
+
else:
|
448 |
+
h = x
|
449 |
+
h = self.conv1(h)
|
450 |
+
h = self.conv2(self.activation(h))
|
451 |
+
if self.downsample:
|
452 |
+
h = self.downsample(h)
|
453 |
+
|
454 |
+
return h + self.shortcut(x)
|
455 |
+
|
456 |
+
# dogball
|
src/models/big/sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
src/models/big/sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
18 |
+
|
19 |
+
from .comm import SyncMaster
|
20 |
+
|
21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
22 |
+
|
23 |
+
|
24 |
+
def _sum_ft(tensor):
|
25 |
+
"""sum over the first and last dimention"""
|
26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
27 |
+
|
28 |
+
|
29 |
+
def _unsqueeze_ft(tensor):
|
30 |
+
"""add new dementions at the front and the tail"""
|
31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
32 |
+
|
33 |
+
|
34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
36 |
+
# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
|
37 |
+
|
38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
41 |
+
|
42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
43 |
+
|
44 |
+
self._is_parallel = False
|
45 |
+
self._parallel_id = None
|
46 |
+
self._slave_pipe = None
|
47 |
+
|
48 |
+
def forward(self, input, gain=None, bias=None):
|
49 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
50 |
+
if not (self._is_parallel and self.training):
|
51 |
+
out = F.batch_norm(
|
52 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
53 |
+
self.training, self.momentum, self.eps)
|
54 |
+
if gain is not None:
|
55 |
+
out = out + gain
|
56 |
+
if bias is not None:
|
57 |
+
out = out + bias
|
58 |
+
return out
|
59 |
+
|
60 |
+
# Resize the input to (B, C, -1).
|
61 |
+
input_shape = input.size()
|
62 |
+
# print(input_shape)
|
63 |
+
input = input.view(input.size(0), input.size(1), -1)
|
64 |
+
|
65 |
+
# Compute the sum and square-sum.
|
66 |
+
sum_size = input.size(0) * input.size(2)
|
67 |
+
input_sum = _sum_ft(input)
|
68 |
+
input_ssum = _sum_ft(input ** 2)
|
69 |
+
# Reduce-and-broadcast the statistics.
|
70 |
+
# print('it begins')
|
71 |
+
if self._parallel_id == 0:
|
72 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
73 |
+
else:
|
74 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
75 |
+
# if self._parallel_id == 0:
|
76 |
+
# # print('here')
|
77 |
+
# sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
78 |
+
# else:
|
79 |
+
# # print('there')
|
80 |
+
# sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
81 |
+
|
82 |
+
# print('how2')
|
83 |
+
# num = sum_size
|
84 |
+
# print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
|
85 |
+
# Fix the graph
|
86 |
+
# sum = (sum.detach() - input_sum.detach()) + input_sum
|
87 |
+
# ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
|
88 |
+
|
89 |
+
# mean = sum / num
|
90 |
+
# var = ssum / num - mean ** 2
|
91 |
+
# # var = (ssum - mean * sum) / num
|
92 |
+
# inv_std = torch.rsqrt(var + self.eps)
|
93 |
+
|
94 |
+
# Compute the output.
|
95 |
+
if gain is not None:
|
96 |
+
# print('gaining')
|
97 |
+
# scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
|
98 |
+
# shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
|
99 |
+
# output = input * scale - shift
|
100 |
+
output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
|
101 |
+
elif self.affine:
|
102 |
+
# MJY:: Fuse the multiplication for speed.
|
103 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
104 |
+
else:
|
105 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
106 |
+
|
107 |
+
# Reshape it.
|
108 |
+
return output.view(input_shape)
|
109 |
+
|
110 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
111 |
+
self._is_parallel = True
|
112 |
+
self._parallel_id = copy_id
|
113 |
+
|
114 |
+
# parallel_id == 0 means master device.
|
115 |
+
if self._parallel_id == 0:
|
116 |
+
ctx.sync_master = self._sync_master
|
117 |
+
else:
|
118 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
119 |
+
|
120 |
+
def _data_parallel_master(self, intermediates):
|
121 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
122 |
+
|
123 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
124 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
125 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
126 |
+
|
127 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
128 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
129 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
130 |
+
|
131 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
132 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
133 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
134 |
+
|
135 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
136 |
+
# print('a')
|
137 |
+
# print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
|
138 |
+
# broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
|
139 |
+
# print('b')
|
140 |
+
outputs = []
|
141 |
+
for i, rec in enumerate(intermediates):
|
142 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
143 |
+
# outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
|
144 |
+
|
145 |
+
return outputs
|
146 |
+
|
147 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
148 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
149 |
+
also maintains the moving average on the master device."""
|
150 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
151 |
+
mean = sum_ / size
|
152 |
+
sumvar = ssum - sum_ * mean
|
153 |
+
unbias_var = sumvar / (size - 1)
|
154 |
+
bias_var = sumvar / size
|
155 |
+
|
156 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
157 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
158 |
+
return mean, torch.rsqrt(bias_var + self.eps)
|
159 |
+
# return mean, bias_var.clamp(self.eps) ** -0.5
|
160 |
+
|
161 |
+
|
162 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
163 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
164 |
+
mini-batch.
|
165 |
+
|
166 |
+
.. math::
|
167 |
+
|
168 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
169 |
+
|
170 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
171 |
+
standard-deviation are reduced across all devices during training.
|
172 |
+
|
173 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
174 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
175 |
+
the statistics only on that device, which accelerated the computation and
|
176 |
+
is also easy to implement, but the statistics might be inaccurate.
|
177 |
+
Instead, in this synchronized version, the statistics will be computed
|
178 |
+
over all training samples distributed on multiple devices.
|
179 |
+
|
180 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
181 |
+
as the built-in PyTorch implementation.
|
182 |
+
|
183 |
+
The mean and standard-deviation are calculated per-dimension over
|
184 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
185 |
+
of size C (where C is the input size).
|
186 |
+
|
187 |
+
During training, this layer keeps a running estimate of its computed mean
|
188 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
189 |
+
|
190 |
+
During evaluation, this running mean/variance is used for normalization.
|
191 |
+
|
192 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
193 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
194 |
+
|
195 |
+
Args:
|
196 |
+
num_features: num_features from an expected input of size
|
197 |
+
`batch_size x num_features [x width]`
|
198 |
+
eps: a value added to the denominator for numerical stability.
|
199 |
+
Default: 1e-5
|
200 |
+
momentum: the value used for the running_mean and running_var
|
201 |
+
computation. Default: 0.1
|
202 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
203 |
+
affine parameters. Default: ``True``
|
204 |
+
|
205 |
+
Shape:
|
206 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
207 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
208 |
+
|
209 |
+
Examples:
|
210 |
+
>>> # With Learnable Parameters
|
211 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
212 |
+
>>> # Without Learnable Parameters
|
213 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
214 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
215 |
+
>>> output = m(input)
|
216 |
+
"""
|
217 |
+
|
218 |
+
def _check_input_dim(self, input):
|
219 |
+
if input.dim() != 2 and input.dim() != 3:
|
220 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
221 |
+
.format(input.dim()))
|
222 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
223 |
+
|
224 |
+
|
225 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
226 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
227 |
+
of 3d inputs
|
228 |
+
|
229 |
+
.. math::
|
230 |
+
|
231 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
232 |
+
|
233 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
234 |
+
standard-deviation are reduced across all devices during training.
|
235 |
+
|
236 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
237 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
238 |
+
the statistics only on that device, which accelerated the computation and
|
239 |
+
is also easy to implement, but the statistics might be inaccurate.
|
240 |
+
Instead, in this synchronized version, the statistics will be computed
|
241 |
+
over all training samples distributed on multiple devices.
|
242 |
+
|
243 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
244 |
+
as the built-in PyTorch implementation.
|
245 |
+
|
246 |
+
The mean and standard-deviation are calculated per-dimension over
|
247 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
248 |
+
of size C (where C is the input size).
|
249 |
+
|
250 |
+
During training, this layer keeps a running estimate of its computed mean
|
251 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
252 |
+
|
253 |
+
During evaluation, this running mean/variance is used for normalization.
|
254 |
+
|
255 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
256 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
257 |
+
|
258 |
+
Args:
|
259 |
+
num_features: num_features from an expected input of
|
260 |
+
size batch_size x num_features x height x width
|
261 |
+
eps: a value added to the denominator for numerical stability.
|
262 |
+
Default: 1e-5
|
263 |
+
momentum: the value used for the running_mean and running_var
|
264 |
+
computation. Default: 0.1
|
265 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
266 |
+
affine parameters. Default: ``True``
|
267 |
+
|
268 |
+
Shape:
|
269 |
+
- Input: :math:`(N, C, H, W)`
|
270 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
271 |
+
|
272 |
+
Examples:
|
273 |
+
>>> # With Learnable Parameters
|
274 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
275 |
+
>>> # Without Learnable Parameters
|
276 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
277 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
278 |
+
>>> output = m(input)
|
279 |
+
"""
|
280 |
+
|
281 |
+
def _check_input_dim(self, input):
|
282 |
+
if input.dim() != 4:
|
283 |
+
raise ValueError('expected 4D input (got {}D input)'
|
284 |
+
.format(input.dim()))
|
285 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
286 |
+
|
287 |
+
|
288 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
289 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
290 |
+
of 4d inputs
|
291 |
+
|
292 |
+
.. math::
|
293 |
+
|
294 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
295 |
+
|
296 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
297 |
+
standard-deviation are reduced across all devices during training.
|
298 |
+
|
299 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
300 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
301 |
+
the statistics only on that device, which accelerated the computation and
|
302 |
+
is also easy to implement, but the statistics might be inaccurate.
|
303 |
+
Instead, in this synchronized version, the statistics will be computed
|
304 |
+
over all training samples distributed on multiple devices.
|
305 |
+
|
306 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
307 |
+
as the built-in PyTorch implementation.
|
308 |
+
|
309 |
+
The mean and standard-deviation are calculated per-dimension over
|
310 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
311 |
+
of size C (where C is the input size).
|
312 |
+
|
313 |
+
During training, this layer keeps a running estimate of its computed mean
|
314 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
315 |
+
|
316 |
+
During evaluation, this running mean/variance is used for normalization.
|
317 |
+
|
318 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
319 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
320 |
+
or Spatio-temporal BatchNorm
|
321 |
+
|
322 |
+
Args:
|
323 |
+
num_features: num_features from an expected input of
|
324 |
+
size batch_size x num_features x depth x height x width
|
325 |
+
eps: a value added to the denominator for numerical stability.
|
326 |
+
Default: 1e-5
|
327 |
+
momentum: the value used for the running_mean and running_var
|
328 |
+
computation. Default: 0.1
|
329 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
330 |
+
affine parameters. Default: ``True``
|
331 |
+
|
332 |
+
Shape:
|
333 |
+
- Input: :math:`(N, C, D, H, W)`
|
334 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
335 |
+
|
336 |
+
Examples:
|
337 |
+
>>> # With Learnable Parameters
|
338 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
339 |
+
>>> # Without Learnable Parameters
|
340 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
341 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
342 |
+
>>> output = m(input)
|
343 |
+
"""
|
344 |
+
|
345 |
+
def _check_input_dim(self, input):
|
346 |
+
if input.dim() != 5:
|
347 |
+
raise ValueError('expected 5D input (got {}D input)'
|
348 |
+
.format(input.dim()))
|
349 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
src/models/big/sync_batchnorm/batchnorm_reimpl.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : batchnorm_reimpl.py
|
4 |
+
# Author : acgtyrant
|
5 |
+
# Date : 11/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.init as init
|
14 |
+
|
15 |
+
__all__ = ['BatchNormReimpl']
|
16 |
+
|
17 |
+
|
18 |
+
class BatchNorm2dReimpl(nn.Module):
|
19 |
+
"""
|
20 |
+
A re-implementation of batch normalization, used for testing the numerical
|
21 |
+
stability.
|
22 |
+
|
23 |
+
Author: acgtyrant
|
24 |
+
See also:
|
25 |
+
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
26 |
+
"""
|
27 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.num_features = num_features
|
31 |
+
self.eps = eps
|
32 |
+
self.momentum = momentum
|
33 |
+
self.weight = nn.Parameter(torch.empty(num_features))
|
34 |
+
self.bias = nn.Parameter(torch.empty(num_features))
|
35 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
36 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
37 |
+
self.reset_parameters()
|
38 |
+
|
39 |
+
def reset_running_stats(self):
|
40 |
+
self.running_mean.zero_()
|
41 |
+
self.running_var.fill_(1)
|
42 |
+
|
43 |
+
def reset_parameters(self):
|
44 |
+
self.reset_running_stats()
|
45 |
+
init.uniform_(self.weight)
|
46 |
+
init.zeros_(self.bias)
|
47 |
+
|
48 |
+
def forward(self, input_):
|
49 |
+
batchsize, channels, height, width = input_.size()
|
50 |
+
numel = batchsize * height * width
|
51 |
+
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
52 |
+
sum_ = input_.sum(1)
|
53 |
+
sum_of_square = input_.pow(2).sum(1)
|
54 |
+
mean = sum_ / numel
|
55 |
+
sumvar = sum_of_square - sum_ * mean
|
56 |
+
|
57 |
+
self.running_mean = (
|
58 |
+
(1 - self.momentum) * self.running_mean
|
59 |
+
+ self.momentum * mean.detach()
|
60 |
+
)
|
61 |
+
unbias_var = sumvar / (numel - 1)
|
62 |
+
self.running_var = (
|
63 |
+
(1 - self.momentum) * self.running_var
|
64 |
+
+ self.momentum * unbias_var.detach()
|
65 |
+
)
|
66 |
+
|
67 |
+
bias_var = sumvar / numel
|
68 |
+
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
69 |
+
output = (
|
70 |
+
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
71 |
+
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
72 |
+
|
73 |
+
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
74 |
+
|
src/models/big/sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
return {'master_callback': self._master_callback}
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
self.__init__(state['master_callback'])
|
83 |
+
|
84 |
+
def register_slave(self, identifier):
|
85 |
+
"""
|
86 |
+
Register an slave device.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
identifier: an identifier, usually is the device id.
|
90 |
+
|
91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
92 |
+
|
93 |
+
"""
|
94 |
+
if self._activated:
|
95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
96 |
+
self._activated = False
|
97 |
+
self._registry.clear()
|
98 |
+
future = FutureResult()
|
99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
100 |
+
return SlavePipe(identifier, self._queue, future)
|
101 |
+
|
102 |
+
def run_master(self, master_msg):
|
103 |
+
"""
|
104 |
+
Main entry for the master device in each forward pass.
|
105 |
+
The messages were first collected from each devices (including the master device), and then
|
106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
107 |
+
(including the master device).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
112 |
+
|
113 |
+
Returns: the message to be sent back to the master device.
|
114 |
+
|
115 |
+
"""
|
116 |
+
self._activated = True
|
117 |
+
|
118 |
+
intermediates = [(0, master_msg)]
|
119 |
+
for i in range(self.nr_slaves):
|
120 |
+
intermediates.append(self._queue.get())
|
121 |
+
|
122 |
+
results = self._master_callback(intermediates)
|
123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
124 |
+
|
125 |
+
for i, res in results:
|
126 |
+
if i == 0:
|
127 |
+
continue
|
128 |
+
self._registry[i].result.put(res)
|
129 |
+
|
130 |
+
for i in range(self.nr_slaves):
|
131 |
+
assert self._queue.get() is True
|
132 |
+
|
133 |
+
return results[0][1]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def nr_slaves(self):
|
137 |
+
return len(self._registry)
|
src/models/big/sync_batchnorm/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
src/models/big/sync_batchnorm/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
class TorchTestCase(unittest.TestCase):
|
16 |
+
def assertTensorClose(self, x, y):
|
17 |
+
adiff = float((x - y).abs().max())
|
18 |
+
if (y == 0).all():
|
19 |
+
rdiff = 'NaN'
|
20 |
+
else:
|
21 |
+
rdiff = float((adiff / y).abs().max())
|
22 |
+
|
23 |
+
message = (
|
24 |
+
'Tensor close check failed\n'
|
25 |
+
'adiff={}\n'
|
26 |
+
'rdiff={}\n'
|
27 |
+
).format(adiff, rdiff)
|
28 |
+
self.assertTrue(torch.allclose(x, y), message)
|
29 |
+
|
src/models/big/utils.py
ADDED
@@ -0,0 +1,1193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
''' Utilities file
|
5 |
+
This file contains utility functions for bookkeeping, logging, and data loading.
|
6 |
+
Methods which directly affect training should either go in layers, the model,
|
7 |
+
or train_fns.py.
|
8 |
+
'''
|
9 |
+
|
10 |
+
from __future__ import print_function
|
11 |
+
import sys
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
import time
|
15 |
+
import datetime
|
16 |
+
import json
|
17 |
+
import pickle
|
18 |
+
from argparse import ArgumentParser
|
19 |
+
import animal_hash
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torchvision
|
25 |
+
import torchvision.transforms as transforms
|
26 |
+
from torch.utils.data import DataLoader
|
27 |
+
|
28 |
+
import datasets as dset
|
29 |
+
|
30 |
+
def prepare_parser():
|
31 |
+
usage = 'Parser for all scripts.'
|
32 |
+
parser = ArgumentParser(description=usage)
|
33 |
+
|
34 |
+
### Dataset/Dataloader stuff ###
|
35 |
+
parser.add_argument(
|
36 |
+
'--dataset', type=str, default='I128_hdf5',
|
37 |
+
help='Which Dataset to train on, out of I128, I256, C10, C100;'
|
38 |
+
'Append "_hdf5" to use the hdf5 version for ISLVRC '
|
39 |
+
'(default: %(default)s)')
|
40 |
+
parser.add_argument(
|
41 |
+
'--augment', action='store_true', default=False,
|
42 |
+
help='Augment with random crops and flips (default: %(default)s)')
|
43 |
+
parser.add_argument(
|
44 |
+
'--num_workers', type=int, default=8,
|
45 |
+
help='Number of dataloader workers; consider using less for HDF5 '
|
46 |
+
'(default: %(default)s)')
|
47 |
+
parser.add_argument(
|
48 |
+
'--no_pin_memory', action='store_false', dest='pin_memory', default=True,
|
49 |
+
help='Pin data into memory through dataloader? (default: %(default)s)')
|
50 |
+
parser.add_argument(
|
51 |
+
'--shuffle', action='store_true', default=False,
|
52 |
+
help='Shuffle the data (strongly recommended)? (default: %(default)s)')
|
53 |
+
parser.add_argument(
|
54 |
+
'--load_in_mem', action='store_true', default=False,
|
55 |
+
help='Load all data into memory? (default: %(default)s)')
|
56 |
+
parser.add_argument(
|
57 |
+
'--use_multiepoch_sampler', action='store_true', default=False,
|
58 |
+
help='Use the multi-epoch sampler for dataloader? (default: %(default)s)')
|
59 |
+
|
60 |
+
|
61 |
+
### Model stuff ###
|
62 |
+
parser.add_argument(
|
63 |
+
'--model', type=str, default='BigGAN',
|
64 |
+
help='Name of the model module (default: %(default)s)')
|
65 |
+
parser.add_argument(
|
66 |
+
'--G_param', type=str, default='SN',
|
67 |
+
help='Parameterization style to use for G, spectral norm (SN) or SVD (SVD)'
|
68 |
+
' or None (default: %(default)s)')
|
69 |
+
parser.add_argument(
|
70 |
+
'--D_param', type=str, default='SN',
|
71 |
+
help='Parameterization style to use for D, spectral norm (SN) or SVD (SVD)'
|
72 |
+
' or None (default: %(default)s)')
|
73 |
+
parser.add_argument(
|
74 |
+
'--G_ch', type=int, default=64,
|
75 |
+
help='Channel multiplier for G (default: %(default)s)')
|
76 |
+
parser.add_argument(
|
77 |
+
'--D_ch', type=int, default=64,
|
78 |
+
help='Channel multiplier for D (default: %(default)s)')
|
79 |
+
parser.add_argument(
|
80 |
+
'--G_depth', type=int, default=1,
|
81 |
+
help='Number of resblocks per stage in G? (default: %(default)s)')
|
82 |
+
parser.add_argument(
|
83 |
+
'--D_depth', type=int, default=1,
|
84 |
+
help='Number of resblocks per stage in D? (default: %(default)s)')
|
85 |
+
parser.add_argument(
|
86 |
+
'--D_thin', action='store_false', dest='D_wide', default=True,
|
87 |
+
help='Use the SN-GAN channel pattern for D? (default: %(default)s)')
|
88 |
+
parser.add_argument(
|
89 |
+
'--G_shared', action='store_true', default=False,
|
90 |
+
help='Use shared embeddings in G? (default: %(default)s)')
|
91 |
+
parser.add_argument(
|
92 |
+
'--shared_dim', type=int, default=0,
|
93 |
+
help='G''s shared embedding dimensionality; if 0, will be equal to dim_z. '
|
94 |
+
'(default: %(default)s)')
|
95 |
+
parser.add_argument(
|
96 |
+
'--dim_z', type=int, default=128,
|
97 |
+
help='Noise dimensionality: %(default)s)')
|
98 |
+
parser.add_argument(
|
99 |
+
'--z_var', type=float, default=1.0,
|
100 |
+
help='Noise variance: %(default)s)')
|
101 |
+
parser.add_argument(
|
102 |
+
'--hier', action='store_true', default=False,
|
103 |
+
help='Use hierarchical z in G? (default: %(default)s)')
|
104 |
+
parser.add_argument(
|
105 |
+
'--cross_replica', action='store_true', default=False,
|
106 |
+
help='Cross_replica batchnorm in G?(default: %(default)s)')
|
107 |
+
parser.add_argument(
|
108 |
+
'--mybn', action='store_true', default=False,
|
109 |
+
help='Use my batchnorm (which supports standing stats?) %(default)s)')
|
110 |
+
parser.add_argument(
|
111 |
+
'--G_nl', type=str, default='relu',
|
112 |
+
help='Activation function for G (default: %(default)s)')
|
113 |
+
parser.add_argument(
|
114 |
+
'--D_nl', type=str, default='relu',
|
115 |
+
help='Activation function for D (default: %(default)s)')
|
116 |
+
parser.add_argument(
|
117 |
+
'--G_attn', type=str, default='64',
|
118 |
+
help='What resolutions to use attention on for G (underscore separated) '
|
119 |
+
'(default: %(default)s)')
|
120 |
+
parser.add_argument(
|
121 |
+
'--D_attn', type=str, default='64',
|
122 |
+
help='What resolutions to use attention on for D (underscore separated) '
|
123 |
+
'(default: %(default)s)')
|
124 |
+
parser.add_argument(
|
125 |
+
'--norm_style', type=str, default='bn',
|
126 |
+
help='Normalizer style for G, one of bn [batchnorm], in [instancenorm], '
|
127 |
+
'ln [layernorm], gn [groupnorm] (default: %(default)s)')
|
128 |
+
|
129 |
+
### Model init stuff ###
|
130 |
+
parser.add_argument(
|
131 |
+
'--seed', type=int, default=0,
|
132 |
+
help='Random seed to use; affects both initialization and '
|
133 |
+
' dataloading. (default: %(default)s)')
|
134 |
+
parser.add_argument(
|
135 |
+
'--G_init', type=str, default='ortho',
|
136 |
+
help='Init style to use for G (default: %(default)s)')
|
137 |
+
parser.add_argument(
|
138 |
+
'--D_init', type=str, default='ortho',
|
139 |
+
help='Init style to use for D(default: %(default)s)')
|
140 |
+
parser.add_argument(
|
141 |
+
'--skip_init', action='store_true', default=False,
|
142 |
+
help='Skip initialization, ideal for testing when ortho init was used '
|
143 |
+
'(default: %(default)s)')
|
144 |
+
|
145 |
+
### Optimizer stuff ###
|
146 |
+
parser.add_argument(
|
147 |
+
'--G_lr', type=float, default=5e-5,
|
148 |
+
help='Learning rate to use for Generator (default: %(default)s)')
|
149 |
+
parser.add_argument(
|
150 |
+
'--D_lr', type=float, default=2e-4,
|
151 |
+
help='Learning rate to use for Discriminator (default: %(default)s)')
|
152 |
+
parser.add_argument(
|
153 |
+
'--G_B1', type=float, default=0.0,
|
154 |
+
help='Beta1 to use for Generator (default: %(default)s)')
|
155 |
+
parser.add_argument(
|
156 |
+
'--D_B1', type=float, default=0.0,
|
157 |
+
help='Beta1 to use for Discriminator (default: %(default)s)')
|
158 |
+
parser.add_argument(
|
159 |
+
'--G_B2', type=float, default=0.999,
|
160 |
+
help='Beta2 to use for Generator (default: %(default)s)')
|
161 |
+
parser.add_argument(
|
162 |
+
'--D_B2', type=float, default=0.999,
|
163 |
+
help='Beta2 to use for Discriminator (default: %(default)s)')
|
164 |
+
|
165 |
+
### Batch size, parallel, and precision stuff ###
|
166 |
+
parser.add_argument(
|
167 |
+
'--batch_size', type=int, default=64,
|
168 |
+
help='Default overall batchsize (default: %(default)s)')
|
169 |
+
parser.add_argument(
|
170 |
+
'--G_batch_size', type=int, default=0,
|
171 |
+
help='Batch size to use for G; if 0, same as D (default: %(default)s)')
|
172 |
+
parser.add_argument(
|
173 |
+
'--num_G_accumulations', type=int, default=1,
|
174 |
+
help='Number of passes to accumulate G''s gradients over '
|
175 |
+
'(default: %(default)s)')
|
176 |
+
parser.add_argument(
|
177 |
+
'--num_D_steps', type=int, default=2,
|
178 |
+
help='Number of D steps per G step (default: %(default)s)')
|
179 |
+
parser.add_argument(
|
180 |
+
'--num_D_accumulations', type=int, default=1,
|
181 |
+
help='Number of passes to accumulate D''s gradients over '
|
182 |
+
'(default: %(default)s)')
|
183 |
+
parser.add_argument(
|
184 |
+
'--split_D', action='store_true', default=False,
|
185 |
+
help='Run D twice rather than concatenating inputs? (default: %(default)s)')
|
186 |
+
parser.add_argument(
|
187 |
+
'--num_epochs', type=int, default=100,
|
188 |
+
help='Number of epochs to train for (default: %(default)s)')
|
189 |
+
parser.add_argument(
|
190 |
+
'--parallel', action='store_true', default=False,
|
191 |
+
help='Train with multiple GPUs (default: %(default)s)')
|
192 |
+
parser.add_argument(
|
193 |
+
'--G_fp16', action='store_true', default=False,
|
194 |
+
help='Train with half-precision in G? (default: %(default)s)')
|
195 |
+
parser.add_argument(
|
196 |
+
'--D_fp16', action='store_true', default=False,
|
197 |
+
help='Train with half-precision in D? (default: %(default)s)')
|
198 |
+
parser.add_argument(
|
199 |
+
'--D_mixed_precision', action='store_true', default=False,
|
200 |
+
help='Train with half-precision activations but fp32 params in D? '
|
201 |
+
'(default: %(default)s)')
|
202 |
+
parser.add_argument(
|
203 |
+
'--G_mixed_precision', action='store_true', default=False,
|
204 |
+
help='Train with half-precision activations but fp32 params in G? '
|
205 |
+
'(default: %(default)s)')
|
206 |
+
parser.add_argument(
|
207 |
+
'--accumulate_stats', action='store_true', default=False,
|
208 |
+
help='Accumulate "standing" batchnorm stats? (default: %(default)s)')
|
209 |
+
parser.add_argument(
|
210 |
+
'--num_standing_accumulations', type=int, default=16,
|
211 |
+
help='Number of forward passes to use in accumulating standing stats? '
|
212 |
+
'(default: %(default)s)')
|
213 |
+
|
214 |
+
### Bookkeping stuff ###
|
215 |
+
parser.add_argument(
|
216 |
+
'--G_eval_mode', action='store_true', default=False,
|
217 |
+
help='Run G in eval mode (running/standing stats?) at sample/test time? '
|
218 |
+
'(default: %(default)s)')
|
219 |
+
parser.add_argument(
|
220 |
+
'--save_every', type=int, default=2000,
|
221 |
+
help='Save every X iterations (default: %(default)s)')
|
222 |
+
parser.add_argument(
|
223 |
+
'--num_save_copies', type=int, default=2,
|
224 |
+
help='How many copies to save (default: %(default)s)')
|
225 |
+
parser.add_argument(
|
226 |
+
'--num_best_copies', type=int, default=2,
|
227 |
+
help='How many previous best checkpoints to save (default: %(default)s)')
|
228 |
+
parser.add_argument(
|
229 |
+
'--which_best', type=str, default='IS',
|
230 |
+
help='Which metric to use to determine when to save new "best"'
|
231 |
+
'checkpoints, one of IS or FID (default: %(default)s)')
|
232 |
+
parser.add_argument(
|
233 |
+
'--no_fid', action='store_true', default=False,
|
234 |
+
help='Calculate IS only, not FID? (default: %(default)s)')
|
235 |
+
parser.add_argument(
|
236 |
+
'--test_every', type=int, default=5000,
|
237 |
+
help='Test every X iterations (default: %(default)s)')
|
238 |
+
parser.add_argument(
|
239 |
+
'--num_inception_images', type=int, default=50000,
|
240 |
+
help='Number of samples to compute inception metrics with '
|
241 |
+
'(default: %(default)s)')
|
242 |
+
parser.add_argument(
|
243 |
+
'--hashname', action='store_true', default=False,
|
244 |
+
help='Use a hash of the experiment name instead of the full config '
|
245 |
+
'(default: %(default)s)')
|
246 |
+
parser.add_argument(
|
247 |
+
'--base_root', type=str, default='',
|
248 |
+
help='Default location to store all weights, samples, data, and logs '
|
249 |
+
' (default: %(default)s)')
|
250 |
+
parser.add_argument(
|
251 |
+
'--data_root', type=str, default='data',
|
252 |
+
help='Default location where data is stored (default: %(default)s)')
|
253 |
+
parser.add_argument(
|
254 |
+
'--weights_root', type=str, default='weights',
|
255 |
+
help='Default location to store weights (default: %(default)s)')
|
256 |
+
parser.add_argument(
|
257 |
+
'--logs_root', type=str, default='logs',
|
258 |
+
help='Default location to store logs (default: %(default)s)')
|
259 |
+
parser.add_argument(
|
260 |
+
'--samples_root', type=str, default='samples',
|
261 |
+
help='Default location to store samples (default: %(default)s)')
|
262 |
+
parser.add_argument(
|
263 |
+
'--pbar', type=str, default='mine',
|
264 |
+
help='Type of progressbar to use; one of "mine" or "tqdm" '
|
265 |
+
'(default: %(default)s)')
|
266 |
+
parser.add_argument(
|
267 |
+
'--name_suffix', type=str, default='',
|
268 |
+
help='Suffix for experiment name for loading weights for sampling '
|
269 |
+
'(consider "best0") (default: %(default)s)')
|
270 |
+
parser.add_argument(
|
271 |
+
'--experiment_name', type=str, default='',
|
272 |
+
help='Optionally override the automatic experiment naming with this arg. '
|
273 |
+
'(default: %(default)s)')
|
274 |
+
parser.add_argument(
|
275 |
+
'--config_from_name', action='store_true', default=False,
|
276 |
+
help='Use a hash of the experiment name instead of the full config '
|
277 |
+
'(default: %(default)s)')
|
278 |
+
|
279 |
+
### EMA Stuff ###
|
280 |
+
parser.add_argument(
|
281 |
+
'--ema', action='store_true', default=False,
|
282 |
+
help='Keep an ema of G''s weights? (default: %(default)s)')
|
283 |
+
parser.add_argument(
|
284 |
+
'--ema_decay', type=float, default=0.9999,
|
285 |
+
help='EMA decay rate (default: %(default)s)')
|
286 |
+
parser.add_argument(
|
287 |
+
'--use_ema', action='store_true', default=False,
|
288 |
+
help='Use the EMA parameters of G for evaluation? (default: %(default)s)')
|
289 |
+
parser.add_argument(
|
290 |
+
'--ema_start', type=int, default=0,
|
291 |
+
help='When to start updating the EMA weights (default: %(default)s)')
|
292 |
+
|
293 |
+
### Numerical precision and SV stuff ###
|
294 |
+
parser.add_argument(
|
295 |
+
'--adam_eps', type=float, default=1e-8,
|
296 |
+
help='epsilon value to use for Adam (default: %(default)s)')
|
297 |
+
parser.add_argument(
|
298 |
+
'--BN_eps', type=float, default=1e-5,
|
299 |
+
help='epsilon value to use for BatchNorm (default: %(default)s)')
|
300 |
+
parser.add_argument(
|
301 |
+
'--SN_eps', type=float, default=1e-8,
|
302 |
+
help='epsilon value to use for Spectral Norm(default: %(default)s)')
|
303 |
+
parser.add_argument(
|
304 |
+
'--num_G_SVs', type=int, default=1,
|
305 |
+
help='Number of SVs to track in G (default: %(default)s)')
|
306 |
+
parser.add_argument(
|
307 |
+
'--num_D_SVs', type=int, default=1,
|
308 |
+
help='Number of SVs to track in D (default: %(default)s)')
|
309 |
+
parser.add_argument(
|
310 |
+
'--num_G_SV_itrs', type=int, default=1,
|
311 |
+
help='Number of SV itrs in G (default: %(default)s)')
|
312 |
+
parser.add_argument(
|
313 |
+
'--num_D_SV_itrs', type=int, default=1,
|
314 |
+
help='Number of SV itrs in D (default: %(default)s)')
|
315 |
+
|
316 |
+
### Ortho reg stuff ###
|
317 |
+
parser.add_argument(
|
318 |
+
'--G_ortho', type=float, default=0.0, # 1e-4 is default for BigGAN
|
319 |
+
help='Modified ortho reg coefficient in G(default: %(default)s)')
|
320 |
+
parser.add_argument(
|
321 |
+
'--D_ortho', type=float, default=0.0,
|
322 |
+
help='Modified ortho reg coefficient in D (default: %(default)s)')
|
323 |
+
parser.add_argument(
|
324 |
+
'--toggle_grads', action='store_true', default=True,
|
325 |
+
help='Toggle D and G''s "requires_grad" settings when not training them? '
|
326 |
+
' (default: %(default)s)')
|
327 |
+
|
328 |
+
### Which train function ###
|
329 |
+
parser.add_argument(
|
330 |
+
'--which_train_fn', type=str, default='GAN',
|
331 |
+
help='How2trainyourbois (default: %(default)s)')
|
332 |
+
|
333 |
+
### Resume training stuff
|
334 |
+
parser.add_argument(
|
335 |
+
'--load_weights', type=str, default='',
|
336 |
+
help='Suffix for which weights to load (e.g. best0, copy0) '
|
337 |
+
'(default: %(default)s)')
|
338 |
+
parser.add_argument(
|
339 |
+
'--resume', action='store_true', default=False,
|
340 |
+
help='Resume training? (default: %(default)s)')
|
341 |
+
|
342 |
+
### Log stuff ###
|
343 |
+
parser.add_argument(
|
344 |
+
'--logstyle', type=str, default='%3.3e',
|
345 |
+
help='What style to use when logging training metrics?'
|
346 |
+
'One of: %#.#f/ %#.#e (float/exp, text),'
|
347 |
+
'pickle (python pickle),'
|
348 |
+
'npz (numpy zip),'
|
349 |
+
'mat (MATLAB .mat file) (default: %(default)s)')
|
350 |
+
parser.add_argument(
|
351 |
+
'--log_G_spectra', action='store_true', default=False,
|
352 |
+
help='Log the top 3 singular values in each SN layer in G? '
|
353 |
+
'(default: %(default)s)')
|
354 |
+
parser.add_argument(
|
355 |
+
'--log_D_spectra', action='store_true', default=False,
|
356 |
+
help='Log the top 3 singular values in each SN layer in D? '
|
357 |
+
'(default: %(default)s)')
|
358 |
+
parser.add_argument(
|
359 |
+
'--sv_log_interval', type=int, default=10,
|
360 |
+
help='Iteration interval for logging singular values '
|
361 |
+
' (default: %(default)s)')
|
362 |
+
|
363 |
+
return parser
|
364 |
+
|
365 |
+
# Arguments for sample.py; not presently used in train.py
|
366 |
+
def add_sample_parser(parser):
|
367 |
+
parser.add_argument(
|
368 |
+
'--sample_npz', action='store_true', default=False,
|
369 |
+
help='Sample "sample_num_npz" images and save to npz? '
|
370 |
+
'(default: %(default)s)')
|
371 |
+
parser.add_argument(
|
372 |
+
'--sample_num_npz', type=int, default=50000,
|
373 |
+
help='Number of images to sample when sampling NPZs '
|
374 |
+
'(default: %(default)s)')
|
375 |
+
parser.add_argument(
|
376 |
+
'--sample_sheets', action='store_true', default=False,
|
377 |
+
help='Produce class-conditional sample sheets and stick them in '
|
378 |
+
'the samples root? (default: %(default)s)')
|
379 |
+
parser.add_argument(
|
380 |
+
'--sample_interps', action='store_true', default=False,
|
381 |
+
help='Produce interpolation sheets and stick them in '
|
382 |
+
'the samples root? (default: %(default)s)')
|
383 |
+
parser.add_argument(
|
384 |
+
'--sample_sheet_folder_num', type=int, default=-1,
|
385 |
+
help='Number to use for the folder for these sample sheets '
|
386 |
+
'(default: %(default)s)')
|
387 |
+
parser.add_argument(
|
388 |
+
'--sample_random', action='store_true', default=False,
|
389 |
+
help='Produce a single random sheet? (default: %(default)s)')
|
390 |
+
parser.add_argument(
|
391 |
+
'--sample_trunc_curves', type=str, default='',
|
392 |
+
help='Get inception metrics with a range of variances?'
|
393 |
+
'To use this, specify a startpoint, step, and endpoint, e.g. '
|
394 |
+
'--sample_trunc_curves 0.2_0.1_1.0 for a startpoint of 0.2, '
|
395 |
+
'endpoint of 1.0, and stepsize of 1.0. Note that this is '
|
396 |
+
'not exactly identical to using tf.truncated_normal, but should '
|
397 |
+
'have approximately the same effect. (default: %(default)s)')
|
398 |
+
parser.add_argument(
|
399 |
+
'--sample_inception_metrics', action='store_true', default=False,
|
400 |
+
help='Calculate Inception metrics with sample.py? (default: %(default)s)')
|
401 |
+
return parser
|
402 |
+
|
403 |
+
# Convenience dicts
|
404 |
+
dset_dict = {'I32': dset.ImageFolder, 'I64': dset.ImageFolder,
|
405 |
+
'I128': dset.ImageFolder, 'I256': dset.ImageFolder,
|
406 |
+
'I32_hdf5': dset.ILSVRC_HDF5, 'I64_hdf5': dset.ILSVRC_HDF5,
|
407 |
+
'I128_hdf5': dset.ILSVRC_HDF5, 'I256_hdf5': dset.ILSVRC_HDF5,
|
408 |
+
'C10': dset.CIFAR10, 'C100': dset.CIFAR100}
|
409 |
+
imsize_dict = {'I32': 32, 'I32_hdf5': 32,
|
410 |
+
'I64': 64, 'I64_hdf5': 64,
|
411 |
+
'I128': 128, 'I128_hdf5': 128,
|
412 |
+
'I256': 256, 'I256_hdf5': 256,
|
413 |
+
'C10': 32, 'C100': 32}
|
414 |
+
root_dict = {'I32': 'ImageNet', 'I32_hdf5': 'ILSVRC32.hdf5',
|
415 |
+
'I64': 'ImageNet', 'I64_hdf5': 'ILSVRC64.hdf5',
|
416 |
+
'I128': 'ImageNet', 'I128_hdf5': 'ILSVRC128.hdf5',
|
417 |
+
'I256': 'ImageNet', 'I256_hdf5': 'ILSVRC256.hdf5',
|
418 |
+
'C10': 'cifar', 'C100': 'cifar'}
|
419 |
+
nclass_dict = {'I32': 1000, 'I32_hdf5': 1000,
|
420 |
+
'I64': 1000, 'I64_hdf5': 1000,
|
421 |
+
'I128': 1000, 'I128_hdf5': 1000,
|
422 |
+
'I256': 1000, 'I256_hdf5': 1000,
|
423 |
+
'C10': 10, 'C100': 100}
|
424 |
+
# Number of classes to put per sample sheet
|
425 |
+
classes_per_sheet_dict = {'I32': 50, 'I32_hdf5': 50,
|
426 |
+
'I64': 50, 'I64_hdf5': 50,
|
427 |
+
'I128': 20, 'I128_hdf5': 20,
|
428 |
+
'I256': 20, 'I256_hdf5': 20,
|
429 |
+
'C10': 10, 'C100': 100}
|
430 |
+
activation_dict = {'inplace_relu': nn.ReLU(inplace=True),
|
431 |
+
'relu': nn.ReLU(inplace=False),
|
432 |
+
'ir': nn.ReLU(inplace=True),}
|
433 |
+
|
434 |
+
class CenterCropLongEdge(object):
|
435 |
+
"""Crops the given PIL Image on the long edge.
|
436 |
+
Args:
|
437 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
438 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
439 |
+
made.
|
440 |
+
"""
|
441 |
+
def __call__(self, img):
|
442 |
+
"""
|
443 |
+
Args:
|
444 |
+
img (PIL Image): Image to be cropped.
|
445 |
+
Returns:
|
446 |
+
PIL Image: Cropped image.
|
447 |
+
"""
|
448 |
+
return transforms.functional.center_crop(img, min(img.size))
|
449 |
+
|
450 |
+
def __repr__(self):
|
451 |
+
return self.__class__.__name__
|
452 |
+
|
453 |
+
class RandomCropLongEdge(object):
|
454 |
+
"""Crops the given PIL Image on the long edge with a random start point.
|
455 |
+
Args:
|
456 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
457 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
458 |
+
made.
|
459 |
+
"""
|
460 |
+
def __call__(self, img):
|
461 |
+
"""
|
462 |
+
Args:
|
463 |
+
img (PIL Image): Image to be cropped.
|
464 |
+
Returns:
|
465 |
+
PIL Image: Cropped image.
|
466 |
+
"""
|
467 |
+
size = (min(img.size), min(img.size))
|
468 |
+
# Only step forward along this edge if it's the long edge
|
469 |
+
i = (0 if size[0] == img.size[0]
|
470 |
+
else np.random.randint(low=0,high=img.size[0] - size[0]))
|
471 |
+
j = (0 if size[1] == img.size[1]
|
472 |
+
else np.random.randint(low=0,high=img.size[1] - size[1]))
|
473 |
+
return transforms.functional.crop(img, i, j, size[0], size[1])
|
474 |
+
|
475 |
+
def __repr__(self):
|
476 |
+
return self.__class__.__name__
|
477 |
+
|
478 |
+
|
479 |
+
# multi-epoch Dataset sampler to avoid memory leakage and enable resumption of
|
480 |
+
# training from the same sample regardless of if we stop mid-epoch
|
481 |
+
class MultiEpochSampler(torch.utils.data.Sampler):
|
482 |
+
r"""Samples elements randomly over multiple epochs
|
483 |
+
|
484 |
+
Arguments:
|
485 |
+
data_source (Dataset): dataset to sample from
|
486 |
+
num_epochs (int) : Number of times to loop over the dataset
|
487 |
+
start_itr (int) : which iteration to begin from
|
488 |
+
"""
|
489 |
+
|
490 |
+
def __init__(self, data_source, num_epochs, start_itr=0, batch_size=128):
|
491 |
+
self.data_source = data_source
|
492 |
+
self.num_samples = len(self.data_source)
|
493 |
+
self.num_epochs = num_epochs
|
494 |
+
self.start_itr = start_itr
|
495 |
+
self.batch_size = batch_size
|
496 |
+
|
497 |
+
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
498 |
+
raise ValueError("num_samples should be a positive integeral "
|
499 |
+
"value, but got num_samples={}".format(self.num_samples))
|
500 |
+
|
501 |
+
def __iter__(self):
|
502 |
+
n = len(self.data_source)
|
503 |
+
# Determine number of epochs
|
504 |
+
num_epochs = int(np.ceil((n * self.num_epochs
|
505 |
+
- (self.start_itr * self.batch_size)) / float(n)))
|
506 |
+
# Sample all the indices, and then grab the last num_epochs index sets;
|
507 |
+
# This ensures if we're starting at epoch 4, we're still grabbing epoch 4's
|
508 |
+
# indices
|
509 |
+
out = [torch.randperm(n) for epoch in range(self.num_epochs)][-num_epochs:]
|
510 |
+
# Ignore the first start_itr % n indices of the first epoch
|
511 |
+
out[0] = out[0][(self.start_itr * self.batch_size % n):]
|
512 |
+
# if self.replacement:
|
513 |
+
# return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
|
514 |
+
# return iter(.tolist())
|
515 |
+
output = torch.cat(out).tolist()
|
516 |
+
print('Length dataset output is %d' % len(output))
|
517 |
+
return iter(output)
|
518 |
+
|
519 |
+
def __len__(self):
|
520 |
+
return len(self.data_source) * self.num_epochs - self.start_itr * self.batch_size
|
521 |
+
|
522 |
+
|
523 |
+
# Convenience function to centralize all data loaders
|
524 |
+
def get_data_loaders(dataset, data_root=None, augment=False, batch_size=64,
|
525 |
+
num_workers=8, shuffle=True, load_in_mem=False, hdf5=False,
|
526 |
+
pin_memory=True, drop_last=True, start_itr=0,
|
527 |
+
num_epochs=500, use_multiepoch_sampler=False,
|
528 |
+
**kwargs):
|
529 |
+
|
530 |
+
# Append /FILENAME.hdf5 to root if using hdf5
|
531 |
+
data_root += '/%s' % root_dict[dataset]
|
532 |
+
print('Using dataset root location %s' % data_root)
|
533 |
+
|
534 |
+
which_dataset = dset_dict[dataset]
|
535 |
+
norm_mean = [0.5,0.5,0.5]
|
536 |
+
norm_std = [0.5,0.5,0.5]
|
537 |
+
image_size = imsize_dict[dataset]
|
538 |
+
# For image folder datasets, name of the file where we store the precomputed
|
539 |
+
# image locations to avoid having to walk the dirs every time we load.
|
540 |
+
dataset_kwargs = {'index_filename': '%s_imgs.npz' % dataset}
|
541 |
+
|
542 |
+
# HDF5 datasets have their own inbuilt transform, no need to train_transform
|
543 |
+
if 'hdf5' in dataset:
|
544 |
+
train_transform = None
|
545 |
+
else:
|
546 |
+
if augment:
|
547 |
+
print('Data will be augmented...')
|
548 |
+
if dataset in ['C10', 'C100']:
|
549 |
+
train_transform = [transforms.RandomCrop(32, padding=4),
|
550 |
+
transforms.RandomHorizontalFlip()]
|
551 |
+
else:
|
552 |
+
train_transform = [RandomCropLongEdge(),
|
553 |
+
transforms.Resize(image_size),
|
554 |
+
transforms.RandomHorizontalFlip()]
|
555 |
+
else:
|
556 |
+
print('Data will not be augmented...')
|
557 |
+
if dataset in ['C10', 'C100']:
|
558 |
+
train_transform = []
|
559 |
+
else:
|
560 |
+
train_transform = [CenterCropLongEdge(), transforms.Resize(image_size)]
|
561 |
+
# train_transform = [transforms.Resize(image_size), transforms.CenterCrop]
|
562 |
+
train_transform = transforms.Compose(train_transform + [
|
563 |
+
transforms.ToTensor(),
|
564 |
+
transforms.Normalize(norm_mean, norm_std)])
|
565 |
+
train_set = which_dataset(root=data_root, transform=train_transform,
|
566 |
+
load_in_mem=load_in_mem, **dataset_kwargs)
|
567 |
+
|
568 |
+
# Prepare loader; the loaders list is for forward compatibility with
|
569 |
+
# using validation / test splits.
|
570 |
+
loaders = []
|
571 |
+
if use_multiepoch_sampler:
|
572 |
+
print('Using multiepoch sampler from start_itr %d...' % start_itr)
|
573 |
+
loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory}
|
574 |
+
sampler = MultiEpochSampler(train_set, num_epochs, start_itr, batch_size)
|
575 |
+
train_loader = DataLoader(train_set, batch_size=batch_size,
|
576 |
+
sampler=sampler, **loader_kwargs)
|
577 |
+
else:
|
578 |
+
loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory,
|
579 |
+
'drop_last': drop_last} # Default, drop last incomplete batch
|
580 |
+
train_loader = DataLoader(train_set, batch_size=batch_size,
|
581 |
+
shuffle=shuffle, **loader_kwargs)
|
582 |
+
loaders.append(train_loader)
|
583 |
+
return loaders
|
584 |
+
|
585 |
+
|
586 |
+
# Utility file to seed rngs
|
587 |
+
def seed_rng(seed):
|
588 |
+
torch.manual_seed(seed)
|
589 |
+
torch.cuda.manual_seed(seed)
|
590 |
+
np.random.seed(seed)
|
591 |
+
|
592 |
+
|
593 |
+
# Utility to peg all roots to a base root
|
594 |
+
# If a base root folder is provided, peg all other root folders to it.
|
595 |
+
def update_config_roots(config):
|
596 |
+
if config['base_root']:
|
597 |
+
print('Pegging all root folders to base root %s' % config['base_root'])
|
598 |
+
for key in ['data', 'weights', 'logs', 'samples']:
|
599 |
+
config['%s_root' % key] = '%s/%s' % (config['base_root'], key)
|
600 |
+
return config
|
601 |
+
|
602 |
+
|
603 |
+
# Utility to prepare root folders if they don't exist; parent folder must exist
|
604 |
+
def prepare_root(config):
|
605 |
+
for key in ['weights_root', 'logs_root', 'samples_root']:
|
606 |
+
if not os.path.exists(config[key]):
|
607 |
+
print('Making directory %s for %s...' % (config[key], key))
|
608 |
+
os.mkdir(config[key])
|
609 |
+
|
610 |
+
|
611 |
+
# Simple wrapper that applies EMA to a model. COuld be better done in 1.0 using
|
612 |
+
# the parameters() and buffers() module functions, but for now this works
|
613 |
+
# with state_dicts using .copy_
|
614 |
+
class ema(object):
|
615 |
+
def __init__(self, source, target, decay=0.9999, start_itr=0):
|
616 |
+
self.source = source
|
617 |
+
self.target = target
|
618 |
+
self.decay = decay
|
619 |
+
# Optional parameter indicating what iteration to start the decay at
|
620 |
+
self.start_itr = start_itr
|
621 |
+
# Initialize target's params to be source's
|
622 |
+
self.source_dict = self.source.state_dict()
|
623 |
+
self.target_dict = self.target.state_dict()
|
624 |
+
print('Initializing EMA parameters to be source parameters...')
|
625 |
+
with torch.no_grad():
|
626 |
+
for key in self.source_dict:
|
627 |
+
self.target_dict[key].data.copy_(self.source_dict[key].data)
|
628 |
+
# target_dict[key].data = source_dict[key].data # Doesn't work!
|
629 |
+
|
630 |
+
def update(self, itr=None):
|
631 |
+
# If an iteration counter is provided and itr is less than the start itr,
|
632 |
+
# peg the ema weights to the underlying weights.
|
633 |
+
if itr and itr < self.start_itr:
|
634 |
+
decay = 0.0
|
635 |
+
else:
|
636 |
+
decay = self.decay
|
637 |
+
with torch.no_grad():
|
638 |
+
for key in self.source_dict:
|
639 |
+
self.target_dict[key].data.copy_(self.target_dict[key].data * decay
|
640 |
+
+ self.source_dict[key].data * (1 - decay))
|
641 |
+
|
642 |
+
|
643 |
+
# Apply modified ortho reg to a model
|
644 |
+
# This function is an optimized version that directly computes the gradient,
|
645 |
+
# instead of computing and then differentiating the loss.
|
646 |
+
def ortho(model, strength=1e-4, blacklist=[]):
|
647 |
+
with torch.no_grad():
|
648 |
+
for param in model.parameters():
|
649 |
+
# Only apply this to parameters with at least 2 axes, and not in the blacklist
|
650 |
+
if len(param.shape) < 2 or any([param is item for item in blacklist]):
|
651 |
+
continue
|
652 |
+
w = param.view(param.shape[0], -1)
|
653 |
+
grad = (2 * torch.mm(torch.mm(w, w.t())
|
654 |
+
* (1. - torch.eye(w.shape[0], device=w.device)), w))
|
655 |
+
param.grad.data += strength * grad.view(param.shape)
|
656 |
+
|
657 |
+
|
658 |
+
# Default ortho reg
|
659 |
+
# This function is an optimized version that directly computes the gradient,
|
660 |
+
# instead of computing and then differentiating the loss.
|
661 |
+
def default_ortho(model, strength=1e-4, blacklist=[]):
|
662 |
+
with torch.no_grad():
|
663 |
+
for param in model.parameters():
|
664 |
+
# Only apply this to parameters with at least 2 axes & not in blacklist
|
665 |
+
if len(param.shape) < 2 or param in blacklist:
|
666 |
+
continue
|
667 |
+
w = param.view(param.shape[0], -1)
|
668 |
+
grad = (2 * torch.mm(torch.mm(w, w.t())
|
669 |
+
- torch.eye(w.shape[0], device=w.device), w))
|
670 |
+
param.grad.data += strength * grad.view(param.shape)
|
671 |
+
|
672 |
+
|
673 |
+
# Convenience utility to switch off requires_grad
|
674 |
+
def toggle_grad(model, on_or_off):
|
675 |
+
for param in model.parameters():
|
676 |
+
param.requires_grad = on_or_off
|
677 |
+
|
678 |
+
|
679 |
+
# Function to join strings or ignore them
|
680 |
+
# Base string is the string to link "strings," while strings
|
681 |
+
# is a list of strings or Nones.
|
682 |
+
def join_strings(base_string, strings):
|
683 |
+
return base_string.join([item for item in strings if item])
|
684 |
+
|
685 |
+
|
686 |
+
# Save a model's weights, optimizer, and the state_dict
|
687 |
+
def save_weights(G, D, state_dict, weights_root, experiment_name,
|
688 |
+
name_suffix=None, G_ema=None):
|
689 |
+
root = '/'.join([weights_root, experiment_name])
|
690 |
+
if not os.path.exists(root):
|
691 |
+
os.mkdir(root)
|
692 |
+
if name_suffix:
|
693 |
+
print('Saving weights to %s/%s...' % (root, name_suffix))
|
694 |
+
else:
|
695 |
+
print('Saving weights to %s...' % root)
|
696 |
+
torch.save(G.state_dict(),
|
697 |
+
'%s/%s.pth' % (root, join_strings('_', ['G', name_suffix])))
|
698 |
+
torch.save(G.optim.state_dict(),
|
699 |
+
'%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix])))
|
700 |
+
torch.save(D.state_dict(),
|
701 |
+
'%s/%s.pth' % (root, join_strings('_', ['D', name_suffix])))
|
702 |
+
torch.save(D.optim.state_dict(),
|
703 |
+
'%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix])))
|
704 |
+
torch.save(state_dict,
|
705 |
+
'%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))
|
706 |
+
if G_ema is not None:
|
707 |
+
torch.save(G_ema.state_dict(),
|
708 |
+
'%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix])))
|
709 |
+
|
710 |
+
|
711 |
+
# Load a model's weights, optimizer, and the state_dict
|
712 |
+
def load_weights(G, D, state_dict, weights_root, experiment_name,
|
713 |
+
name_suffix=None, G_ema=None, strict=True, load_optim=True):
|
714 |
+
root = '/'.join([weights_root, experiment_name])
|
715 |
+
if name_suffix:
|
716 |
+
print('Loading %s weights from %s...' % (name_suffix, root))
|
717 |
+
else:
|
718 |
+
print('Loading weights from %s...' % root)
|
719 |
+
if G is not None:
|
720 |
+
G.load_state_dict(
|
721 |
+
torch.load('%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))),
|
722 |
+
strict=strict)
|
723 |
+
if load_optim:
|
724 |
+
G.optim.load_state_dict(
|
725 |
+
torch.load('%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix]))))
|
726 |
+
if D is not None:
|
727 |
+
D.load_state_dict(
|
728 |
+
torch.load('%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))),
|
729 |
+
strict=strict)
|
730 |
+
if load_optim:
|
731 |
+
D.optim.load_state_dict(
|
732 |
+
torch.load('%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix]))))
|
733 |
+
# Load state dict
|
734 |
+
for item in state_dict:
|
735 |
+
state_dict[item] = torch.load('%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))[item]
|
736 |
+
if G_ema is not None:
|
737 |
+
G_ema.load_state_dict(
|
738 |
+
torch.load('%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))),
|
739 |
+
strict=strict)
|
740 |
+
|
741 |
+
|
742 |
+
''' MetricsLogger originally stolen from VoxNet source code.
|
743 |
+
Used for logging inception metrics'''
|
744 |
+
class MetricsLogger(object):
|
745 |
+
def __init__(self, fname, reinitialize=False):
|
746 |
+
self.fname = fname
|
747 |
+
self.reinitialize = reinitialize
|
748 |
+
if os.path.exists(self.fname):
|
749 |
+
if self.reinitialize:
|
750 |
+
print('{} exists, deleting...'.format(self.fname))
|
751 |
+
os.remove(self.fname)
|
752 |
+
|
753 |
+
def log(self, record=None, **kwargs):
|
754 |
+
"""
|
755 |
+
Assumption: no newlines in the input.
|
756 |
+
"""
|
757 |
+
if record is None:
|
758 |
+
record = {}
|
759 |
+
record.update(kwargs)
|
760 |
+
record['_stamp'] = time.time()
|
761 |
+
with open(self.fname, 'a') as f:
|
762 |
+
f.write(json.dumps(record, ensure_ascii=True) + '\n')
|
763 |
+
|
764 |
+
|
765 |
+
# Logstyle is either:
|
766 |
+
# '%#.#f' for floating point representation in text
|
767 |
+
# '%#.#e' for exponent representation in text
|
768 |
+
# 'npz' for output to npz # NOT YET SUPPORTED
|
769 |
+
# 'pickle' for output to a python pickle # NOT YET SUPPORTED
|
770 |
+
# 'mat' for output to a MATLAB .mat file # NOT YET SUPPORTED
|
771 |
+
class MyLogger(object):
|
772 |
+
def __init__(self, fname, reinitialize=False, logstyle='%3.3f'):
|
773 |
+
self.root = fname
|
774 |
+
if not os.path.exists(self.root):
|
775 |
+
os.mkdir(self.root)
|
776 |
+
self.reinitialize = reinitialize
|
777 |
+
self.metrics = []
|
778 |
+
self.logstyle = logstyle # One of '%3.3f' or like '%3.3e'
|
779 |
+
|
780 |
+
# Delete log if re-starting and log already exists
|
781 |
+
def reinit(self, item):
|
782 |
+
if os.path.exists('%s/%s.log' % (self.root, item)):
|
783 |
+
if self.reinitialize:
|
784 |
+
# Only print the removal mess
|
785 |
+
if 'sv' in item :
|
786 |
+
if not any('sv' in item for item in self.metrics):
|
787 |
+
print('Deleting singular value logs...')
|
788 |
+
else:
|
789 |
+
print('{} exists, deleting...'.format('%s_%s.log' % (self.root, item)))
|
790 |
+
os.remove('%s/%s.log' % (self.root, item))
|
791 |
+
|
792 |
+
# Log in plaintext; this is designed for being read in MATLAB(sorry not sorry)
|
793 |
+
def log(self, itr, **kwargs):
|
794 |
+
for arg in kwargs:
|
795 |
+
if arg not in self.metrics:
|
796 |
+
if self.reinitialize:
|
797 |
+
self.reinit(arg)
|
798 |
+
self.metrics += [arg]
|
799 |
+
if self.logstyle == 'pickle':
|
800 |
+
print('Pickle not currently supported...')
|
801 |
+
# with open('%s/%s.log' % (self.root, arg), 'a') as f:
|
802 |
+
# pickle.dump(kwargs[arg], f)
|
803 |
+
elif self.logstyle == 'mat':
|
804 |
+
print('.mat logstyle not currently supported...')
|
805 |
+
else:
|
806 |
+
with open('%s/%s.log' % (self.root, arg), 'a') as f:
|
807 |
+
f.write('%d: %s\n' % (itr, self.logstyle % kwargs[arg]))
|
808 |
+
|
809 |
+
|
810 |
+
# Write some metadata to the logs directory
|
811 |
+
def write_metadata(logs_root, experiment_name, config, state_dict):
|
812 |
+
with open(('%s/%s/metalog.txt' %
|
813 |
+
(logs_root, experiment_name)), 'w') as writefile:
|
814 |
+
writefile.write('datetime: %s\n' % str(datetime.datetime.now()))
|
815 |
+
writefile.write('config: %s\n' % str(config))
|
816 |
+
writefile.write('state: %s\n' %str(state_dict))
|
817 |
+
|
818 |
+
|
819 |
+
"""
|
820 |
+
Very basic progress indicator to wrap an iterable in.
|
821 |
+
|
822 |
+
Author: Jan SchlΓΌter
|
823 |
+
Andy's adds: time elapsed in addition to ETA, makes it possible to add
|
824 |
+
estimated time to 1k iters instead of estimated time to completion.
|
825 |
+
"""
|
826 |
+
def progress(items, desc='', total=None, min_delay=0.1, displaytype='s1k'):
|
827 |
+
"""
|
828 |
+
Returns a generator over `items`, printing the number and percentage of
|
829 |
+
items processed and the estimated remaining processing time before yielding
|
830 |
+
the next item. `total` gives the total number of items (required if `items`
|
831 |
+
has no length), and `min_delay` gives the minimum time in seconds between
|
832 |
+
subsequent prints. `desc` gives an optional prefix text (end with a space).
|
833 |
+
"""
|
834 |
+
total = total or len(items)
|
835 |
+
t_start = time.time()
|
836 |
+
t_last = 0
|
837 |
+
for n, item in enumerate(items):
|
838 |
+
t_now = time.time()
|
839 |
+
if t_now - t_last > min_delay:
|
840 |
+
print("\r%s%d/%d (%6.2f%%)" % (
|
841 |
+
desc, n+1, total, n / float(total) * 100), end=" ")
|
842 |
+
if n > 0:
|
843 |
+
|
844 |
+
if displaytype == 's1k': # minutes/seconds for 1000 iters
|
845 |
+
next_1000 = n + (1000 - n%1000)
|
846 |
+
t_done = t_now - t_start
|
847 |
+
t_1k = t_done / n * next_1000
|
848 |
+
outlist = list(divmod(t_done, 60)) + list(divmod(t_1k - t_done, 60))
|
849 |
+
print("(TE/ET1k: %d:%02d / %d:%02d)" % tuple(outlist), end=" ")
|
850 |
+
else:# displaytype == 'eta':
|
851 |
+
t_done = t_now - t_start
|
852 |
+
t_total = t_done / n * total
|
853 |
+
outlist = list(divmod(t_done, 60)) + list(divmod(t_total - t_done, 60))
|
854 |
+
print("(TE/ETA: %d:%02d / %d:%02d)" % tuple(outlist), end=" ")
|
855 |
+
|
856 |
+
sys.stdout.flush()
|
857 |
+
t_last = t_now
|
858 |
+
yield item
|
859 |
+
t_total = time.time() - t_start
|
860 |
+
print("\r%s%d/%d (100.00%%) (took %d:%02d)" % ((desc, total, total) +
|
861 |
+
divmod(t_total, 60)))
|
862 |
+
|
863 |
+
|
864 |
+
# Sample function for use with inception metrics
|
865 |
+
def sample(G, z_, y_, config):
|
866 |
+
with torch.no_grad():
|
867 |
+
z_.sample_()
|
868 |
+
y_.sample_()
|
869 |
+
if config['parallel']:
|
870 |
+
G_z = nn.parallel.data_parallel(G, (z_, G.shared(y_)))
|
871 |
+
else:
|
872 |
+
G_z = G(z_, G.shared(y_))
|
873 |
+
return G_z, y_
|
874 |
+
|
875 |
+
|
876 |
+
# Sample function for sample sheets
|
877 |
+
def sample_sheet(G, classes_per_sheet, num_classes, samples_per_class, parallel,
|
878 |
+
samples_root, experiment_name, folder_number, z_=None):
|
879 |
+
# Prepare sample directory
|
880 |
+
if not os.path.isdir('%s/%s' % (samples_root, experiment_name)):
|
881 |
+
os.mkdir('%s/%s' % (samples_root, experiment_name))
|
882 |
+
if not os.path.isdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)):
|
883 |
+
os.mkdir('%s/%s/%d' % (samples_root, experiment_name, folder_number))
|
884 |
+
# loop over total number of sheets
|
885 |
+
for i in range(num_classes // classes_per_sheet):
|
886 |
+
ims = []
|
887 |
+
y = torch.arange(i * classes_per_sheet, (i + 1) * classes_per_sheet, device='cuda')
|
888 |
+
for j in range(samples_per_class):
|
889 |
+
if (z_ is not None) and hasattr(z_, 'sample_') and classes_per_sheet <= z_.size(0):
|
890 |
+
z_.sample_()
|
891 |
+
else:
|
892 |
+
z_ = torch.randn(classes_per_sheet, G.dim_z, device='cuda')
|
893 |
+
with torch.no_grad():
|
894 |
+
if parallel:
|
895 |
+
o = nn.parallel.data_parallel(G, (z_[:classes_per_sheet], G.shared(y)))
|
896 |
+
else:
|
897 |
+
o = G(z_[:classes_per_sheet], G.shared(y))
|
898 |
+
|
899 |
+
ims += [o.data.cpu()]
|
900 |
+
# This line should properly unroll the images
|
901 |
+
out_ims = torch.stack(ims, 1).view(-1, ims[0].shape[1], ims[0].shape[2],
|
902 |
+
ims[0].shape[3]).data.float().cpu()
|
903 |
+
# The path for the samples
|
904 |
+
image_filename = '%s/%s/%d/samples%d.jpg' % (samples_root, experiment_name,
|
905 |
+
folder_number, i)
|
906 |
+
torchvision.utils.save_image(out_ims, image_filename,
|
907 |
+
nrow=samples_per_class, normalize=True)
|
908 |
+
|
909 |
+
|
910 |
+
# Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..)
|
911 |
+
def interp(x0, x1, num_midpoints):
|
912 |
+
lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype)
|
913 |
+
return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1)))
|
914 |
+
|
915 |
+
|
916 |
+
# interp sheet function
|
917 |
+
# Supports full, class-wise and intra-class interpolation
|
918 |
+
def interp_sheet(G, num_per_sheet, num_midpoints, num_classes, parallel,
|
919 |
+
samples_root, experiment_name, folder_number, sheet_number=0,
|
920 |
+
fix_z=False, fix_y=False, device='cuda'):
|
921 |
+
# Prepare zs and ys
|
922 |
+
if fix_z: # If fix Z, only sample 1 z per row
|
923 |
+
zs = torch.randn(num_per_sheet, 1, G.dim_z, device=device)
|
924 |
+
zs = zs.repeat(1, num_midpoints + 2, 1).view(-1, G.dim_z)
|
925 |
+
else:
|
926 |
+
zs = interp(torch.randn(num_per_sheet, 1, G.dim_z, device=device),
|
927 |
+
torch.randn(num_per_sheet, 1, G.dim_z, device=device),
|
928 |
+
num_midpoints).view(-1, G.dim_z)
|
929 |
+
if fix_y: # If fix y, only sample 1 z per row
|
930 |
+
ys = sample_1hot(num_per_sheet, num_classes)
|
931 |
+
ys = G.shared(ys).view(num_per_sheet, 1, -1)
|
932 |
+
ys = ys.repeat(1, num_midpoints + 2, 1).view(num_per_sheet * (num_midpoints + 2), -1)
|
933 |
+
else:
|
934 |
+
ys = interp(G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1),
|
935 |
+
G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1),
|
936 |
+
num_midpoints).view(num_per_sheet * (num_midpoints + 2), -1)
|
937 |
+
# Run the net--note that we've already passed y through G.shared.
|
938 |
+
if G.fp16:
|
939 |
+
zs = zs.half()
|
940 |
+
with torch.no_grad():
|
941 |
+
if parallel:
|
942 |
+
out_ims = nn.parallel.data_parallel(G, (zs, ys)).data.cpu()
|
943 |
+
else:
|
944 |
+
out_ims = G(zs, ys).data.cpu()
|
945 |
+
interp_style = '' + ('Z' if not fix_z else '') + ('Y' if not fix_y else '')
|
946 |
+
image_filename = '%s/%s/%d/interp%s%d.jpg' % (samples_root, experiment_name,
|
947 |
+
folder_number, interp_style,
|
948 |
+
sheet_number)
|
949 |
+
torchvision.utils.save_image(out_ims, image_filename,
|
950 |
+
nrow=num_midpoints + 2, normalize=True)
|
951 |
+
|
952 |
+
|
953 |
+
# Convenience debugging function to print out gradnorms and shape from each layer
|
954 |
+
# May need to rewrite this so we can actually see which parameter is which
|
955 |
+
def print_grad_norms(net):
|
956 |
+
gradsums = [[float(torch.norm(param.grad).item()),
|
957 |
+
float(torch.norm(param).item()), param.shape]
|
958 |
+
for param in net.parameters()]
|
959 |
+
order = np.argsort([item[0] for item in gradsums])
|
960 |
+
print(['%3.3e,%3.3e, %s' % (gradsums[item_index][0],
|
961 |
+
gradsums[item_index][1],
|
962 |
+
str(gradsums[item_index][2]))
|
963 |
+
for item_index in order])
|
964 |
+
|
965 |
+
|
966 |
+
# Get singular values to log. This will use the state dict to find them
|
967 |
+
# and substitute underscores for dots.
|
968 |
+
def get_SVs(net, prefix):
|
969 |
+
d = net.state_dict()
|
970 |
+
return {('%s_%s' % (prefix, key)).replace('.', '_') :
|
971 |
+
float(d[key].item())
|
972 |
+
for key in d if 'sv' in key}
|
973 |
+
|
974 |
+
|
975 |
+
# Name an experiment based on its config
|
976 |
+
def name_from_config(config):
|
977 |
+
name = '_'.join([
|
978 |
+
item for item in [
|
979 |
+
'Big%s' % config['which_train_fn'],
|
980 |
+
config['dataset'],
|
981 |
+
config['model'] if config['model'] != 'BigGAN' else None,
|
982 |
+
'seed%d' % config['seed'],
|
983 |
+
'Gch%d' % config['G_ch'],
|
984 |
+
'Dch%d' % config['D_ch'],
|
985 |
+
'Gd%d' % config['G_depth'] if config['G_depth'] > 1 else None,
|
986 |
+
'Dd%d' % config['D_depth'] if config['D_depth'] > 1 else None,
|
987 |
+
'bs%d' % config['batch_size'],
|
988 |
+
'Gfp16' if config['G_fp16'] else None,
|
989 |
+
'Dfp16' if config['D_fp16'] else None,
|
990 |
+
'nDs%d' % config['num_D_steps'] if config['num_D_steps'] > 1 else None,
|
991 |
+
'nDa%d' % config['num_D_accumulations'] if config['num_D_accumulations'] > 1 else None,
|
992 |
+
'nGa%d' % config['num_G_accumulations'] if config['num_G_accumulations'] > 1 else None,
|
993 |
+
'Glr%2.1e' % config['G_lr'],
|
994 |
+
'Dlr%2.1e' % config['D_lr'],
|
995 |
+
'GB%3.3f' % config['G_B1'] if config['G_B1'] !=0.0 else None,
|
996 |
+
'GBB%3.3f' % config['G_B2'] if config['G_B2'] !=0.999 else None,
|
997 |
+
'DB%3.3f' % config['D_B1'] if config['D_B1'] !=0.0 else None,
|
998 |
+
'DBB%3.3f' % config['D_B2'] if config['D_B2'] !=0.999 else None,
|
999 |
+
'Gnl%s' % config['G_nl'],
|
1000 |
+
'Dnl%s' % config['D_nl'],
|
1001 |
+
'Ginit%s' % config['G_init'],
|
1002 |
+
'Dinit%s' % config['D_init'],
|
1003 |
+
'G%s' % config['G_param'] if config['G_param'] != 'SN' else None,
|
1004 |
+
'D%s' % config['D_param'] if config['D_param'] != 'SN' else None,
|
1005 |
+
'Gattn%s' % config['G_attn'] if config['G_attn'] != '0' else None,
|
1006 |
+
'Dattn%s' % config['D_attn'] if config['D_attn'] != '0' else None,
|
1007 |
+
'Gortho%2.1e' % config['G_ortho'] if config['G_ortho'] > 0.0 else None,
|
1008 |
+
'Dortho%2.1e' % config['D_ortho'] if config['D_ortho'] > 0.0 else None,
|
1009 |
+
config['norm_style'] if config['norm_style'] != 'bn' else None,
|
1010 |
+
'cr' if config['cross_replica'] else None,
|
1011 |
+
'Gshared' if config['G_shared'] else None,
|
1012 |
+
'hier' if config['hier'] else None,
|
1013 |
+
'ema' if config['ema'] else None,
|
1014 |
+
config['name_suffix'] if config['name_suffix'] else None,
|
1015 |
+
]
|
1016 |
+
if item is not None])
|
1017 |
+
# dogball
|
1018 |
+
if config['hashname']:
|
1019 |
+
return hashname(name)
|
1020 |
+
else:
|
1021 |
+
return name
|
1022 |
+
|
1023 |
+
|
1024 |
+
# A simple function to produce a unique experiment name from the animal hashes.
|
1025 |
+
def hashname(name):
|
1026 |
+
h = hash(name)
|
1027 |
+
a = h % len(animal_hash.a)
|
1028 |
+
h = h // len(animal_hash.a)
|
1029 |
+
b = h % len(animal_hash.b)
|
1030 |
+
h = h // len(animal_hash.c)
|
1031 |
+
c = h % len(animal_hash.c)
|
1032 |
+
return animal_hash.a[a] + animal_hash.b[b] + animal_hash.c[c]
|
1033 |
+
|
1034 |
+
|
1035 |
+
# Get GPU memory, -i is the index
|
1036 |
+
def query_gpu(indices):
|
1037 |
+
os.system('nvidia-smi -i 0 --query-gpu=memory.free --format=csv')
|
1038 |
+
|
1039 |
+
|
1040 |
+
# Convenience function to count the number of parameters in a module
|
1041 |
+
def count_parameters(module):
|
1042 |
+
print('Number of parameters: {}'.format(
|
1043 |
+
sum([p.data.nelement() for p in module.parameters()])))
|
1044 |
+
|
1045 |
+
|
1046 |
+
# Convenience function to sample an index, not actually a 1-hot
|
1047 |
+
def sample_1hot(batch_size, num_classes, device='cuda'):
|
1048 |
+
return torch.randint(low=0, high=num_classes, size=(batch_size,),
|
1049 |
+
device=device, dtype=torch.int64, requires_grad=False)
|
1050 |
+
|
1051 |
+
|
1052 |
+
# A highly simplified convenience class for sampling from distributions
|
1053 |
+
# One could also use PyTorch's inbuilt distributions package.
|
1054 |
+
# Note that this class requires initialization to proceed as
|
1055 |
+
# x = Distribution(torch.randn(size))
|
1056 |
+
# x.init_distribution(dist_type, **dist_kwargs)
|
1057 |
+
# x = x.to(device,dtype)
|
1058 |
+
# This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2
|
1059 |
+
class Distribution(torch.Tensor):
|
1060 |
+
# Init the params of the distribution
|
1061 |
+
def init_distribution(self, dist_type, **kwargs):
|
1062 |
+
self.dist_type = dist_type
|
1063 |
+
self.dist_kwargs = kwargs
|
1064 |
+
if self.dist_type == 'normal':
|
1065 |
+
self.mean, self.var = kwargs['mean'], kwargs['var']
|
1066 |
+
elif self.dist_type == 'categorical':
|
1067 |
+
self.num_categories = kwargs['num_categories']
|
1068 |
+
|
1069 |
+
def sample_(self):
|
1070 |
+
if self.dist_type == 'normal':
|
1071 |
+
self.normal_(self.mean, self.var)
|
1072 |
+
elif self.dist_type == 'categorical':
|
1073 |
+
self.random_(0, self.num_categories)
|
1074 |
+
# return self.variable
|
1075 |
+
|
1076 |
+
# Silly hack: overwrite the to() method to wrap the new object
|
1077 |
+
# in a distribution as well
|
1078 |
+
def to(self, *args, **kwargs):
|
1079 |
+
new_obj = Distribution(self)
|
1080 |
+
new_obj.init_distribution(self.dist_type, **self.dist_kwargs)
|
1081 |
+
new_obj.data = super().to(*args, **kwargs)
|
1082 |
+
return new_obj
|
1083 |
+
|
1084 |
+
|
1085 |
+
# Convenience function to prepare a z and y vector
|
1086 |
+
def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda',
|
1087 |
+
fp16=False,z_var=1.0):
|
1088 |
+
z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False))
|
1089 |
+
z_.init_distribution('normal', mean=0, var=z_var)
|
1090 |
+
z_ = z_.to(device,torch.float16 if fp16 else torch.float32)
|
1091 |
+
|
1092 |
+
if fp16:
|
1093 |
+
z_ = z_.half()
|
1094 |
+
|
1095 |
+
y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False))
|
1096 |
+
y_.init_distribution('categorical',num_categories=nclasses)
|
1097 |
+
y_ = y_.to(device, torch.int64)
|
1098 |
+
return z_, y_
|
1099 |
+
|
1100 |
+
|
1101 |
+
def initiate_standing_stats(net):
|
1102 |
+
for module in net.modules():
|
1103 |
+
if hasattr(module, 'accumulate_standing'):
|
1104 |
+
module.reset_stats()
|
1105 |
+
module.accumulate_standing = True
|
1106 |
+
|
1107 |
+
|
1108 |
+
def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16):
|
1109 |
+
initiate_standing_stats(net)
|
1110 |
+
net.train()
|
1111 |
+
for i in range(num_accumulations):
|
1112 |
+
with torch.no_grad():
|
1113 |
+
z.normal_()
|
1114 |
+
y.random_(0, nclasses)
|
1115 |
+
x = net(z, net.shared(y)) # No need to parallelize here unless using syncbn
|
1116 |
+
# Set to eval mode
|
1117 |
+
net.eval()
|
1118 |
+
|
1119 |
+
|
1120 |
+
# This version of Adam keeps an fp32 copy of the parameters and
|
1121 |
+
# does all of the parameter updates in fp32, while still doing the
|
1122 |
+
# forwards and backwards passes using fp16 (i.e. fp16 copies of the
|
1123 |
+
# parameters and fp16 activations).
|
1124 |
+
#
|
1125 |
+
# Note that this calls .float().cuda() on the params.
|
1126 |
+
import math
|
1127 |
+
from torch.optim.optimizer import Optimizer
|
1128 |
+
class Adam16(Optimizer):
|
1129 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0):
|
1130 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
1131 |
+
weight_decay=weight_decay)
|
1132 |
+
params = list(params)
|
1133 |
+
super(Adam16, self).__init__(params, defaults)
|
1134 |
+
|
1135 |
+
# Safety modification to make sure we floatify our state
|
1136 |
+
def load_state_dict(self, state_dict):
|
1137 |
+
super(Adam16, self).load_state_dict(state_dict)
|
1138 |
+
for group in self.param_groups:
|
1139 |
+
for p in group['params']:
|
1140 |
+
self.state[p]['exp_avg'] = self.state[p]['exp_avg'].float()
|
1141 |
+
self.state[p]['exp_avg_sq'] = self.state[p]['exp_avg_sq'].float()
|
1142 |
+
self.state[p]['fp32_p'] = self.state[p]['fp32_p'].float()
|
1143 |
+
|
1144 |
+
def step(self, closure=None):
|
1145 |
+
"""Performs a single optimization step.
|
1146 |
+
Arguments:
|
1147 |
+
closure (callable, optional): A closure that reevaluates the model
|
1148 |
+
and returns the loss.
|
1149 |
+
"""
|
1150 |
+
loss = None
|
1151 |
+
if closure is not None:
|
1152 |
+
loss = closure()
|
1153 |
+
|
1154 |
+
for group in self.param_groups:
|
1155 |
+
for p in group['params']:
|
1156 |
+
if p.grad is None:
|
1157 |
+
continue
|
1158 |
+
|
1159 |
+
grad = p.grad.data.float()
|
1160 |
+
state = self.state[p]
|
1161 |
+
|
1162 |
+
# State initialization
|
1163 |
+
if len(state) == 0:
|
1164 |
+
state['step'] = 0
|
1165 |
+
# Exponential moving average of gradient values
|
1166 |
+
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
|
1167 |
+
# Exponential moving average of squared gradient values
|
1168 |
+
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
|
1169 |
+
# Fp32 copy of the weights
|
1170 |
+
state['fp32_p'] = p.data.float()
|
1171 |
+
|
1172 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
1173 |
+
beta1, beta2 = group['betas']
|
1174 |
+
|
1175 |
+
state['step'] += 1
|
1176 |
+
|
1177 |
+
if group['weight_decay'] != 0:
|
1178 |
+
grad = grad.add(group['weight_decay'], state['fp32_p'])
|
1179 |
+
|
1180 |
+
# Decay the first and second moment running average coefficient
|
1181 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
1182 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
1183 |
+
|
1184 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
1185 |
+
|
1186 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
1187 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
1188 |
+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
1189 |
+
|
1190 |
+
state['fp32_p'].addcdiv_(-step_size, exp_avg, denom)
|
1191 |
+
p.data = state['fp32_p'].half()
|
1192 |
+
|
1193 |
+
return loss
|
src/models/cvae.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import cat
|
2 |
+
from torch.optim import Adam
|
3 |
+
from torch.nn import Sequential, ModuleList, \
|
4 |
+
Conv2d, Linear, \
|
5 |
+
LeakyReLU, Tanh, \
|
6 |
+
BatchNorm1d, BatchNorm2d, \
|
7 |
+
ConvTranspose2d, UpsamplingBilinear2d
|
8 |
+
|
9 |
+
from .neuralnetwork import NeuralNetwork
|
10 |
+
|
11 |
+
|
12 |
+
# parameters for cVAE
|
13 |
+
colors_dim = 3
|
14 |
+
labels_dim = 37
|
15 |
+
momentum = 0.99 # Batchnorm
|
16 |
+
negative_slope = 0.2 # LeakyReLU
|
17 |
+
optimizer = Adam
|
18 |
+
betas = (0.5, 0.999)
|
19 |
+
|
20 |
+
# hyperparameters
|
21 |
+
learning_rate = 2e-4
|
22 |
+
latent_dim = 128
|
23 |
+
|
24 |
+
|
25 |
+
def genUpsample(input_channels, output_channels, stride, pad):
|
26 |
+
return Sequential(
|
27 |
+
ConvTranspose2d(input_channels, output_channels, 4, stride, pad, bias=False),
|
28 |
+
BatchNorm2d(output_channels),
|
29 |
+
LeakyReLU(negative_slope=negative_slope))
|
30 |
+
|
31 |
+
|
32 |
+
def genUpsample2(input_channels, output_channels, kernel_size):
|
33 |
+
return Sequential(
|
34 |
+
Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=1, padding= (kernel_size-1) // 2),
|
35 |
+
BatchNorm2d(output_channels),
|
36 |
+
LeakyReLU(negative_slope=negative_slope),
|
37 |
+
Conv2d(output_channels, output_channels, kernel_size=kernel_size, stride=1, padding= (kernel_size-1) // 2),
|
38 |
+
BatchNorm2d(output_channels),
|
39 |
+
LeakyReLU(negative_slope=negative_slope),
|
40 |
+
UpsamplingBilinear2d(scale_factor=2))
|
41 |
+
|
42 |
+
|
43 |
+
class ConditionalDecoder(NeuralNetwork):
|
44 |
+
def __init__(self, ll_scaling=1.0, dim_z=latent_dim):
|
45 |
+
super(ConditionalDecoder, self).__init__()
|
46 |
+
self.dim_z = dim_z
|
47 |
+
ngf = 32
|
48 |
+
self.init = genUpsample(self.dim_z, ngf * 16, 1, 0)
|
49 |
+
self.embedding = Sequential(
|
50 |
+
Linear(labels_dim, self.dim_z),
|
51 |
+
BatchNorm1d(self.dim_z, momentum=momentum),
|
52 |
+
LeakyReLU(negative_slope=negative_slope),
|
53 |
+
)
|
54 |
+
self.dense_init = Sequential(
|
55 |
+
Linear(self.dim_z*2, self.dim_z),
|
56 |
+
BatchNorm1d(self.dim_z, momentum=momentum),
|
57 |
+
LeakyReLU(negative_slope=negative_slope),
|
58 |
+
)
|
59 |
+
self.m_modules = ModuleList() # to 4x4
|
60 |
+
self.c_modules = ModuleList()
|
61 |
+
for i in range(4):
|
62 |
+
self.m_modules.append(genUpsample2(ngf * 2**(4-i), ngf * 2**(3-i), 3))
|
63 |
+
self.c_modules.append(Sequential(Conv2d(ngf * 2**(3-i), colors_dim, 3, 1, 1, bias=False), Tanh()))
|
64 |
+
self.set_optimizer(optimizer, lr=learning_rate*ll_scaling, betas=betas)
|
65 |
+
|
66 |
+
def forward(self, latent, labels, step=3):
|
67 |
+
y = self.embedding(labels)
|
68 |
+
out = cat((latent, y), dim=1)
|
69 |
+
out = self.dense_init(out)
|
70 |
+
out = out.unsqueeze(2).unsqueeze(3)
|
71 |
+
out = self.init(out)
|
72 |
+
for i in range(step):
|
73 |
+
out = self.m_modules[i](out)
|
74 |
+
out = self.c_modules[step](self.m_modules[step](out))
|
75 |
+
return out
|
src/models/infoscc_gan.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict
|
2 |
+
from functools import partial
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
def get_activation(activation: str = "lrelu"):
|
10 |
+
actv_layers = {
|
11 |
+
"relu": nn.ReLU,
|
12 |
+
"lrelu": partial(nn.LeakyReLU, 0.2),
|
13 |
+
}
|
14 |
+
assert activation in actv_layers, f"activation [{activation}] not implemented"
|
15 |
+
return actv_layers[activation]
|
16 |
+
|
17 |
+
|
18 |
+
def get_normalization(normalization: str = "batch_norm"):
|
19 |
+
norm_layers = {
|
20 |
+
"instance_norm": nn.InstanceNorm2d,
|
21 |
+
"batch_norm": nn.BatchNorm2d,
|
22 |
+
"group_norm": partial(nn.GroupNorm, num_groups=8),
|
23 |
+
"layer_norm": partial(nn.GroupNorm, num_groups=1),
|
24 |
+
}
|
25 |
+
assert normalization in norm_layers, f"normalization [{normalization}] not implemented"
|
26 |
+
return norm_layers[normalization]
|
27 |
+
|
28 |
+
|
29 |
+
class ConvLayer(nn.Sequential):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
in_channels: int,
|
33 |
+
out_channels: int,
|
34 |
+
kernel_size: int = 3,
|
35 |
+
stride: int = 1,
|
36 |
+
padding: Optional[int] = 1,
|
37 |
+
padding_mode: str = "zeros",
|
38 |
+
groups: int = 1,
|
39 |
+
bias: bool = True,
|
40 |
+
transposed: bool = False,
|
41 |
+
normalization: Optional[str] = None,
|
42 |
+
activation: Optional[str] = "lrelu",
|
43 |
+
pre_activate: bool = False,
|
44 |
+
):
|
45 |
+
if transposed:
|
46 |
+
conv = partial(nn.ConvTranspose2d, output_padding=stride-1)
|
47 |
+
padding_mode = "zeros"
|
48 |
+
else:
|
49 |
+
conv = nn.Conv2d
|
50 |
+
layers = [
|
51 |
+
conv(
|
52 |
+
in_channels,
|
53 |
+
out_channels,
|
54 |
+
kernel_size=kernel_size,
|
55 |
+
stride=stride,
|
56 |
+
padding=padding,
|
57 |
+
padding_mode=padding_mode,
|
58 |
+
groups=groups,
|
59 |
+
bias=bias,
|
60 |
+
)
|
61 |
+
]
|
62 |
+
|
63 |
+
norm_actv = []
|
64 |
+
if normalization is not None:
|
65 |
+
norm_actv.append(
|
66 |
+
get_normalization(normalization)(
|
67 |
+
num_channels=in_channels if pre_activate else out_channels
|
68 |
+
)
|
69 |
+
)
|
70 |
+
if activation is not None:
|
71 |
+
norm_actv.append(
|
72 |
+
get_activation(activation)(inplace=True)
|
73 |
+
)
|
74 |
+
|
75 |
+
if pre_activate:
|
76 |
+
layers = norm_actv + layers
|
77 |
+
else:
|
78 |
+
layers = layers + norm_actv
|
79 |
+
|
80 |
+
super().__init__(
|
81 |
+
*layers
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
class SubspaceLayer(nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
dim: int,
|
89 |
+
n_basis: int,
|
90 |
+
):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
self.U = nn.Parameter(torch.empty(n_basis, dim))
|
94 |
+
nn.init.orthogonal_(self.U)
|
95 |
+
self.L = nn.Parameter(torch.FloatTensor([3 * i for i in range(n_basis, 0, -1)]))
|
96 |
+
self.mu = nn.Parameter(torch.zeros(dim))
|
97 |
+
|
98 |
+
def forward(self, z):
|
99 |
+
return (self.L * z) @ self.U + self.mu
|
100 |
+
|
101 |
+
|
102 |
+
class EigenBlock(nn.Module):
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
width: int,
|
106 |
+
height: int,
|
107 |
+
in_channels: int,
|
108 |
+
out_channels: int,
|
109 |
+
n_basis: int,
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
|
113 |
+
self.projection = SubspaceLayer(dim=width*height*in_channels, n_basis=n_basis)
|
114 |
+
self.subspace_conv1 = ConvLayer(
|
115 |
+
in_channels,
|
116 |
+
in_channels,
|
117 |
+
kernel_size=1,
|
118 |
+
stride=1,
|
119 |
+
padding=0,
|
120 |
+
transposed=True,
|
121 |
+
activation=None,
|
122 |
+
normalization=None,
|
123 |
+
)
|
124 |
+
self.subspace_conv2 = ConvLayer(
|
125 |
+
in_channels,
|
126 |
+
out_channels,
|
127 |
+
kernel_size=3,
|
128 |
+
stride=2,
|
129 |
+
padding=1,
|
130 |
+
transposed=True,
|
131 |
+
activation=None,
|
132 |
+
normalization=None,
|
133 |
+
)
|
134 |
+
|
135 |
+
self.feature_conv1 = ConvLayer(
|
136 |
+
in_channels,
|
137 |
+
out_channels,
|
138 |
+
kernel_size=3,
|
139 |
+
stride=2,
|
140 |
+
transposed=True,
|
141 |
+
pre_activate=True,
|
142 |
+
)
|
143 |
+
self.feature_conv2 = ConvLayer(
|
144 |
+
out_channels,
|
145 |
+
out_channels,
|
146 |
+
kernel_size=3,
|
147 |
+
stride=1,
|
148 |
+
transposed=True,
|
149 |
+
pre_activate=True,
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(self, z, h):
|
153 |
+
phi = self.projection(z).view(h.shape)
|
154 |
+
h = self.feature_conv1(h + self.subspace_conv1(phi))
|
155 |
+
h = self.feature_conv2(h + self.subspace_conv2(phi))
|
156 |
+
return h
|
157 |
+
|
158 |
+
|
159 |
+
class ConditionalGenerator(nn.Module):
|
160 |
+
|
161 |
+
"""Conditional generator
|
162 |
+
It generates images from one hot label + noise sampled from N(0, 1) with explorable z injection space
|
163 |
+
Based on EigenGAN
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(self,
|
167 |
+
size: int,
|
168 |
+
y_size: int,
|
169 |
+
z_size: int,
|
170 |
+
out_channels: int = 3,
|
171 |
+
n_basis: int = 6,
|
172 |
+
noise_dim: int = 512,
|
173 |
+
base_channels: int = 16,
|
174 |
+
max_channels: int = 512,
|
175 |
+
y_type: str = 'one_hot'):
|
176 |
+
|
177 |
+
if y_type not in ['one_hot', 'multi_label', 'mixed', 'real']:
|
178 |
+
raise ValueError('Unsupported `y_type`')
|
179 |
+
|
180 |
+
super(ConditionalGenerator, self).__init__()
|
181 |
+
|
182 |
+
assert (size & (size - 1) == 0) and size != 0, "img size should be a power of 2"
|
183 |
+
|
184 |
+
self.y_type = y_type
|
185 |
+
self.y_size = y_size
|
186 |
+
self.eps_size = z_size
|
187 |
+
|
188 |
+
self.noise_dim = noise_dim
|
189 |
+
self.n_basis = n_basis
|
190 |
+
self.n_blocks = int(math.log(size, 2)) - 2
|
191 |
+
|
192 |
+
def get_channels(i_block):
|
193 |
+
return min(max_channels, base_channels * (2 ** (self.n_blocks - i_block)))
|
194 |
+
|
195 |
+
self.y_fc = nn.Linear(self.y_size, self.y_size)
|
196 |
+
self.concat_fc = nn.Linear(self.y_size + self.eps_size, self.noise_dim)
|
197 |
+
|
198 |
+
self.fc = nn.Linear(self.noise_dim, 4 * 4 * get_channels(0))
|
199 |
+
|
200 |
+
self.blocks = nn.ModuleList()
|
201 |
+
for i in range(self.n_blocks):
|
202 |
+
self.blocks.append(
|
203 |
+
EigenBlock(
|
204 |
+
width=4 * (2 ** i),
|
205 |
+
height=4 * (2 ** i),
|
206 |
+
in_channels=get_channels(i),
|
207 |
+
out_channels=get_channels(i + 1),
|
208 |
+
n_basis=self.n_basis,
|
209 |
+
)
|
210 |
+
)
|
211 |
+
|
212 |
+
self.out = nn.Sequential(
|
213 |
+
ConvLayer(base_channels, out_channels, kernel_size=7, stride=1, padding=3, pre_activate=True),
|
214 |
+
nn.Tanh(),
|
215 |
+
)
|
216 |
+
|
217 |
+
def forward(self,
|
218 |
+
y: torch.Tensor,
|
219 |
+
eps: Optional[torch.Tensor] = None,
|
220 |
+
zs: Optional[torch.Tensor] = None,
|
221 |
+
return_eps: bool = False):
|
222 |
+
|
223 |
+
bs = y.size(0)
|
224 |
+
|
225 |
+
if eps is None:
|
226 |
+
eps = self.sample_eps(bs)
|
227 |
+
|
228 |
+
if zs is None:
|
229 |
+
zs = self.sample_zs(bs)
|
230 |
+
|
231 |
+
y_out = self.y_fc(y)
|
232 |
+
concat = torch.cat((y_out, eps), dim=1)
|
233 |
+
concat = self.concat_fc(concat)
|
234 |
+
|
235 |
+
out = self.fc(concat).view(len(eps), -1, 4, 4)
|
236 |
+
for block, z in zip(self.blocks, zs.permute(1, 0, 2)):
|
237 |
+
out = block(z, out)
|
238 |
+
out = self.out(out)
|
239 |
+
|
240 |
+
if return_eps:
|
241 |
+
return out, concat
|
242 |
+
|
243 |
+
return out
|
244 |
+
|
245 |
+
def sample_zs(self, batch: int, truncation: float = 1.):
|
246 |
+
device = self.get_device()
|
247 |
+
zs = torch.randn(batch, self.n_blocks, self.n_basis, device=device)
|
248 |
+
|
249 |
+
if truncation < 1.:
|
250 |
+
zs = torch.zeros_like(zs) * (1 - truncation) + zs * truncation
|
251 |
+
return zs
|
252 |
+
|
253 |
+
def sample_eps(self, batch: int, truncation: float = 1.):
|
254 |
+
device = self.get_device()
|
255 |
+
eps = torch.randn(batch, self.eps_size, device=device)
|
256 |
+
|
257 |
+
if truncation < 1.:
|
258 |
+
eps = torch.zeros_like(eps) * (1 - truncation) + eps * truncation
|
259 |
+
return eps
|
260 |
+
|
261 |
+
def get_device(self):
|
262 |
+
return self.fc.weight.device
|
263 |
+
|
264 |
+
def orthogonal_regularizer(self):
|
265 |
+
reg = []
|
266 |
+
for layer in self.modules():
|
267 |
+
if isinstance(layer, SubspaceLayer):
|
268 |
+
UUT = layer.U @ layer.U.t()
|
269 |
+
reg.append(
|
270 |
+
((UUT - torch.eye(UUT.shape[0], device=UUT.device)) ** 2).mean()
|
271 |
+
)
|
272 |
+
return sum(reg) / len(reg)
|
src/models/neuralnetwork.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class NeuralNetwork(torch.nn.Module):
|
5 |
+
""" base class with convenient procedures used by all NN"""
|
6 |
+
def __init__(self):
|
7 |
+
super(NeuralNetwork, self).__init__()
|
8 |
+
self.parameter_file = f"parameter_state_dict_{self._get_name()}.pth"
|
9 |
+
# self.cuda() ## all NN shall run on cuda ### doesnt seem to work
|
10 |
+
|
11 |
+
def save(self) -> None:
|
12 |
+
""" save learned parameters to parameter_file """
|
13 |
+
torch.save(self.state_dict(), self.parameter_file)
|
14 |
+
|
15 |
+
def load(self) -> None:
|
16 |
+
""" load learned parameters from parameter_file """
|
17 |
+
self.load_state_dict(torch.load(self.parameter_file))
|
18 |
+
self.eval()
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def same_padding(kernel_size=1) -> float:
|
22 |
+
""" return padding required to mimic 'same' padding in tensorflow """
|
23 |
+
return (kernel_size-1) // 2
|
24 |
+
|
25 |
+
def set_optimizer(self, optimizer, **kwargs) -> None:
|
26 |
+
self.optimizer = optimizer(self.parameters(), **kwargs)
|
27 |
+
|
28 |
+
def get_total_number_parameters(self) -> float:
|
29 |
+
""" return total number of parameters """
|
30 |
+
return sum([p.numel() for p in classifier.parameters()])
|
31 |
+
|
32 |
+
def zero_grad(self):
|
33 |
+
""" faster implementation of zero_grad """
|
34 |
+
for p in self.parameters():
|
35 |
+
p.grad = None
|
36 |
+
# self.zero_grad(set_to_none=True)
|
37 |
+
|
38 |
+
|
39 |
+
def update_networks_on_loss(loss: torch.Tensor, *networks) -> None:
|
40 |
+
if not loss:
|
41 |
+
return
|
42 |
+
for network in networks:
|
43 |
+
network.zero_grad()
|
44 |
+
loss.backward()
|
45 |
+
for network in networks:
|
46 |
+
network.optimizer.step()
|
src/models/parameter.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" hardcoded parameter
|
2 |
+
|
3 |
+
these can be changed in a jupyter notebook during runtime via
|
4 |
+
|
5 |
+
>>> import parameter
|
6 |
+
>>> parameter.parameter = new_value
|
7 |
+
|
8 |
+
"""
|
9 |
+
|
10 |
+
from torch.optim import Adam
|
11 |
+
|
12 |
+
###############
|
13 |
+
## hardcoded ##
|
14 |
+
###############
|
15 |
+
|
16 |
+
|
17 |
+
# Input
|
18 |
+
image_dim = 64
|
19 |
+
colors_dim = 3
|
20 |
+
labels_dim = 37 #3
|
21 |
+
input_size = (colors_dim,image_dim,image_dim)
|
22 |
+
|
23 |
+
|
24 |
+
#############
|
25 |
+
## mutable ##
|
26 |
+
#############
|
27 |
+
|
28 |
+
class Parameter:
|
29 |
+
""" container for hyperparameters"""
|
30 |
+
|
31 |
+
def __init__(self):
|
32 |
+
# Encoder/Decoder
|
33 |
+
self.latent_dim = 8
|
34 |
+
self.decoder_dim = self.latent_dim # differs from latent_dim if PCA applied before decoder
|
35 |
+
|
36 |
+
# General
|
37 |
+
self.learning_rate = 0.0002
|
38 |
+
self.betas = (0.5,0.999) ## 0.999 is default beta2 in tensorflow
|
39 |
+
self.optimizer = Adam
|
40 |
+
self.negative_slope = 0.2 # for LeakyReLU
|
41 |
+
self.momentum = 0.99 # for BatchNorm
|
42 |
+
|
43 |
+
# Loss weights
|
44 |
+
self.alpha = 1 # switch VAE (1) / AE (0)
|
45 |
+
self.beta = 1 # weight for KL-loss
|
46 |
+
self.gamma = 1024 # weight for learned-metric-loss (https://arxiv.org/pdf/1512.09300.pdf)
|
47 |
+
self.delta = 1 # weight for class-loss
|
48 |
+
self.zeta = 0.5 # weight for MSE-loss
|
49 |
+
|
50 |
+
def return_parameter_dict(self):
|
51 |
+
return(self.__dict__)
|
52 |
+
|
53 |
+
parameter = Parameter()
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .utils import download_file
|
2 |
+
from .utils import sample_labels
|
src/utils/utils.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gdown
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def download_file(file_id: str, output_path: str):
|
8 |
+
gdown.download(f'https://drive.google.com/uc?id={file_id}', output_path)
|
9 |
+
|
10 |
+
|
11 |
+
def sample_labels(labels: torch.Tensor, n: int) -> torch.Tensor:
|
12 |
+
high = labels.shape[0]
|
13 |
+
idx = np.random.randint(0, high, size=n)
|
14 |
+
return labels[idx]
|