elite code init
Browse files- .gitignore +21 -0
- 2_gpu.json +11 -0
- 3_gpu.json +11 -0
- 4_gpu.json +11 -0
- LICENSE +201 -0
- README.md +3 -3
- datasets.py +541 -0
- elite.yaml +147 -0
- inference_global.py +214 -0
- inference_global.sh +12 -0
- inference_local.py +247 -0
- inference_local.sh +12 -0
- test_datasets/1.jpg +0 -0
- test_datasets/10.jpg +0 -0
- test_datasets/10_bg.png +0 -0
- test_datasets/11.jpg +0 -0
- test_datasets/11_bg.png +0 -0
- test_datasets/15.jpg +0 -0
- test_datasets/15_bg.png +0 -0
- test_datasets/16.jpg +0 -0
- test_datasets/16_bg.png +0 -0
- test_datasets/17.jpg +0 -0
- test_datasets/17_bg.png +0 -0
- test_datasets/1_bg.png +0 -0
- test_datasets/2.jpg +0 -0
- test_datasets/20.jpg +0 -0
- test_datasets/20_bg.png +0 -0
- test_datasets/2_bg.png +0 -0
- test_datasets/3.jpg +0 -0
- test_datasets/3_bg.png +0 -0
- test_datasets/4.png +0 -0
- test_datasets/4_bg.png +0 -0
- test_datasets/7.jpg +0 -0
- test_datasets/7_bg.png +0 -0
- train_global.py +715 -0
- train_global.sh +15 -0
- train_local.py +709 -0
- train_local.sh +16 -0
.gitignore
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_debug*
|
2 |
+
.env
|
3 |
+
__pycache__
|
4 |
+
_sc.py
|
5 |
+
*.ckpt
|
6 |
+
*.bin
|
7 |
+
|
8 |
+
checkpoints
|
9 |
+
.idea
|
10 |
+
.idea/workspace.xml
|
11 |
+
.DS_Store
|
12 |
+
*/__pycache__git
|
13 |
+
.pyc
|
14 |
+
.iml
|
15 |
+
__pycache__/
|
16 |
+
*/__pycache__/
|
17 |
+
*/*/__pycache__/
|
18 |
+
*/*/*/__pycache__/
|
19 |
+
*/*/*/*/__pycache__/
|
20 |
+
*/*/*/*/*/__pycache__/
|
21 |
+
*/*/*/*/*/*/__pycache__/
|
2_gpu.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compute_environment": "LOCAL_MACHINE",
|
3 |
+
"distributed_type": "MULTI_GPU",
|
4 |
+
"fp16": false,
|
5 |
+
"machine_rank": 0,
|
6 |
+
"main_process_ip": null,
|
7 |
+
"main_process_port": null,
|
8 |
+
"main_training_function": "main",
|
9 |
+
"num_machines": 1,
|
10 |
+
"num_processes": 2
|
11 |
+
}
|
3_gpu.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compute_environment": "LOCAL_MACHINE",
|
3 |
+
"distributed_type": "MULTI_GPU",
|
4 |
+
"fp16": false,
|
5 |
+
"machine_rank": 0,
|
6 |
+
"main_process_ip": null,
|
7 |
+
"main_process_port": null,
|
8 |
+
"main_training_function": "main",
|
9 |
+
"num_machines": 1,
|
10 |
+
"num_processes": 3
|
11 |
+
}
|
4_gpu.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compute_environment": "LOCAL_MACHINE",
|
3 |
+
"distributed_type": "MULTI_GPU",
|
4 |
+
"fp16": false,
|
5 |
+
"machine_rank": 0,
|
6 |
+
"main_process_ip": null,
|
7 |
+
"main_process_port": null,
|
8 |
+
"main_training_function": "main",
|
9 |
+
"num_machines": 1,
|
10 |
+
"num_processes": 4
|
11 |
+
}
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
1 |
+
# ELITE
|
2 |
+
|
3 |
+
The detailed README is coming soom.
|
datasets.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
from PIL import Image
|
3 |
+
from torchvision import transforms
|
4 |
+
import os
|
5 |
+
import PIL
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import torchvision
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
import albumentations as A
|
12 |
+
import copy
|
13 |
+
import cv2
|
14 |
+
import pandas as pd
|
15 |
+
|
16 |
+
|
17 |
+
imagenet_templates_small = [
|
18 |
+
"a photo of a {}",
|
19 |
+
"a rendering of a {}",
|
20 |
+
"a cropped photo of the {}",
|
21 |
+
"the photo of a {}",
|
22 |
+
"a photo of a clean {}",
|
23 |
+
"a photo of a dirty {}",
|
24 |
+
"a dark photo of the {}",
|
25 |
+
"a photo of my {}",
|
26 |
+
"a photo of the cool {}",
|
27 |
+
"a close-up photo of a {}",
|
28 |
+
"a bright photo of the {}",
|
29 |
+
"a cropped photo of a {}",
|
30 |
+
"a photo of the {}",
|
31 |
+
"a good photo of the {}",
|
32 |
+
"a photo of one {}",
|
33 |
+
"a close-up photo of the {}",
|
34 |
+
"a rendition of the {}",
|
35 |
+
"a photo of the clean {}",
|
36 |
+
"a rendition of a {}",
|
37 |
+
"a photo of a nice {}",
|
38 |
+
"a good photo of a {}",
|
39 |
+
"a photo of the nice {}",
|
40 |
+
"a photo of the small {}",
|
41 |
+
"a photo of the weird {}",
|
42 |
+
"a photo of the large {}",
|
43 |
+
"a photo of a cool {}",
|
44 |
+
"a photo of a small {}",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
49 |
+
PIL_INTERPOLATION = {
|
50 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
51 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
52 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
53 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
54 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
55 |
+
}
|
56 |
+
else:
|
57 |
+
PIL_INTERPOLATION = {
|
58 |
+
"linear": PIL.Image.LINEAR,
|
59 |
+
"bilinear": PIL.Image.BILINEAR,
|
60 |
+
"bicubic": PIL.Image.BICUBIC,
|
61 |
+
"lanczos": PIL.Image.LANCZOS,
|
62 |
+
"nearest": PIL.Image.NEAREST,
|
63 |
+
}
|
64 |
+
|
65 |
+
def is_image(file):
|
66 |
+
return 'jpg' in file.lower() or 'png' in file.lower() or 'jpeg' in file.lower()
|
67 |
+
|
68 |
+
class CustomDatasetWithBG(Dataset):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
data_root,
|
72 |
+
tokenizer,
|
73 |
+
size=512,
|
74 |
+
interpolation="bicubic",
|
75 |
+
placeholder_token="*",
|
76 |
+
template="a photo of a {}",
|
77 |
+
):
|
78 |
+
self.data_root = data_root
|
79 |
+
self.tokenizer = tokenizer
|
80 |
+
self.size = size
|
81 |
+
self.placeholder_token = placeholder_token
|
82 |
+
|
83 |
+
self.image_paths = []
|
84 |
+
self.image_paths += [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if is_image(file_path) and not 'bg' in file_path]
|
85 |
+
|
86 |
+
self.image_paths = sorted(self.image_paths)
|
87 |
+
|
88 |
+
self.num_images = len(self.image_paths)
|
89 |
+
self._length = self.num_images
|
90 |
+
|
91 |
+
self.interpolation = {
|
92 |
+
"linear": PIL_INTERPOLATION["linear"],
|
93 |
+
"bilinear": PIL_INTERPOLATION["bilinear"],
|
94 |
+
"bicubic": PIL_INTERPOLATION["bicubic"],
|
95 |
+
"lanczos": PIL_INTERPOLATION["lanczos"],
|
96 |
+
}[interpolation]
|
97 |
+
|
98 |
+
self.template = template
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return self._length
|
102 |
+
|
103 |
+
def get_tensor_clip(self, normalize=True, toTensor=True):
|
104 |
+
transform_list = []
|
105 |
+
if toTensor:
|
106 |
+
transform_list += [torchvision.transforms.ToTensor()]
|
107 |
+
if normalize:
|
108 |
+
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
109 |
+
(0.26862954, 0.26130258, 0.27577711))]
|
110 |
+
return torchvision.transforms.Compose(transform_list)
|
111 |
+
|
112 |
+
def process(self, image):
|
113 |
+
img = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
|
114 |
+
img = np.array(img).astype(np.float32)
|
115 |
+
img = img / 127.5 - 1.0
|
116 |
+
return torch.from_numpy(img).permute(2, 0, 1)
|
117 |
+
|
118 |
+
def __getitem__(self, i):
|
119 |
+
example = {}
|
120 |
+
|
121 |
+
placeholder_string = self.placeholder_token
|
122 |
+
text = self.template.format(placeholder_string)
|
123 |
+
example["text"] = text
|
124 |
+
|
125 |
+
placeholder_index = 0
|
126 |
+
words = text.strip().split(' ')
|
127 |
+
for idx, word in enumerate(words):
|
128 |
+
if word == placeholder_string:
|
129 |
+
placeholder_index = idx + 1
|
130 |
+
|
131 |
+
example["index"] = torch.tensor(placeholder_index)
|
132 |
+
|
133 |
+
example["input_ids"] = self.tokenizer(
|
134 |
+
text,
|
135 |
+
padding="max_length",
|
136 |
+
truncation=True,
|
137 |
+
max_length=self.tokenizer.model_max_length,
|
138 |
+
return_tensors="pt",
|
139 |
+
).input_ids[0]
|
140 |
+
|
141 |
+
image = Image.open(self.image_paths[i % self.num_images])
|
142 |
+
|
143 |
+
mask_path = self.image_paths[i % self.num_images].replace('.jpeg', '.png').replace('.jpg', '.png').replace('.JPEG', '.png')[:-4] + '_bg.png'
|
144 |
+
mask = np.array(Image.open(mask_path)) / 255.0
|
145 |
+
|
146 |
+
if not image.mode == "RGB":
|
147 |
+
image = image.convert("RGB")
|
148 |
+
|
149 |
+
image_np = np.array(image)
|
150 |
+
object_tensor = image_np * mask
|
151 |
+
example["pixel_values"] = self.process(image_np)
|
152 |
+
|
153 |
+
|
154 |
+
ref_object_tensor = Image.fromarray(object_tensor.astype('uint8')).resize((224, 224), resample=self.interpolation)
|
155 |
+
ref_image_tenser = Image.fromarray(image_np.astype('uint8')).resize((224, 224), resample=self.interpolation)
|
156 |
+
example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor)
|
157 |
+
example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tenser)
|
158 |
+
|
159 |
+
ref_seg_tensor = Image.fromarray(mask.astype('uint8') * 255)
|
160 |
+
ref_seg_tensor = self.get_tensor_clip(normalize=False)(ref_seg_tensor)
|
161 |
+
example["pixel_values_seg"] = torch.nn.functional.interpolate(ref_seg_tensor.unsqueeze(0), size=(128, 128), mode='nearest').squeeze(0)
|
162 |
+
|
163 |
+
return example
|
164 |
+
|
165 |
+
|
166 |
+
class OpenImagesDataset(Dataset):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
data_root,
|
170 |
+
tokenizer,
|
171 |
+
size=512,
|
172 |
+
interpolation="bicubic",
|
173 |
+
set="train",
|
174 |
+
placeholder_token="*",
|
175 |
+
):
|
176 |
+
self.data_root = data_root
|
177 |
+
self.tokenizer = tokenizer
|
178 |
+
self.size = size
|
179 |
+
self.placeholder_token = placeholder_token
|
180 |
+
self.set_type = set
|
181 |
+
|
182 |
+
self.random_trans = A.Compose([
|
183 |
+
A.Resize(height=224, width=224),
|
184 |
+
A.HorizontalFlip(p=0.5),
|
185 |
+
A.Rotate(limit=20),
|
186 |
+
A.Blur(p=0.3),
|
187 |
+
A.ElasticTransform(p=0.3)
|
188 |
+
])
|
189 |
+
|
190 |
+
self.bbox_path_list = []
|
191 |
+
if set == "train":
|
192 |
+
bboxs_path = os.path.join(data_root, 'annotations', f'oidv6-train-annotations-bbox.csv')
|
193 |
+
elif set == "validation":
|
194 |
+
bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-bbox.csv')
|
195 |
+
else:
|
196 |
+
bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-bbox.csv')
|
197 |
+
|
198 |
+
df_val_bbox = pd.read_csv(bboxs_path)
|
199 |
+
bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName)
|
200 |
+
|
201 |
+
bbox_full = []
|
202 |
+
for label_name in df_val_bbox['LabelName'].unique():
|
203 |
+
bboxs = bbox_groups.get_group(label_name)[
|
204 |
+
['XMin', 'XMax', 'YMin', 'YMax', 'LabelName', 'ImageID',
|
205 |
+
'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsInside']].values.tolist()
|
206 |
+
bboxs_new = []
|
207 |
+
for bbox in bboxs:
|
208 |
+
if not ((bbox[1] - bbox[0]) * (bbox[3] - bbox[2]) > 0.8 or (bbox[1] - bbox[0]) * (
|
209 |
+
bbox[3] - bbox[2]) < 0.02):
|
210 |
+
bboxs_new.append([bbox[0], bbox[1], bbox[2], bbox[3], bbox[4], bbox[5]])
|
211 |
+
bbox_full.extend(bboxs_new)
|
212 |
+
|
213 |
+
self.bboxs_full = bbox_full
|
214 |
+
|
215 |
+
self.num_images = len(bbox_full)
|
216 |
+
|
217 |
+
print('{}: total {} images ...'.format(set, self.num_images))
|
218 |
+
|
219 |
+
self._length = self.num_images
|
220 |
+
|
221 |
+
self.interpolation = {
|
222 |
+
"linear": PIL_INTERPOLATION["linear"],
|
223 |
+
"bilinear": PIL_INTERPOLATION["bilinear"],
|
224 |
+
"bicubic": PIL_INTERPOLATION["bicubic"],
|
225 |
+
"lanczos": PIL_INTERPOLATION["lanczos"],
|
226 |
+
}[interpolation]
|
227 |
+
|
228 |
+
self.templates = imagenet_templates_small
|
229 |
+
|
230 |
+
|
231 |
+
def __len__(self):
|
232 |
+
return self._length
|
233 |
+
|
234 |
+
def get_tensor_clip(self, normalize=True, toTensor=True):
|
235 |
+
transform_list = []
|
236 |
+
if toTensor:
|
237 |
+
transform_list += [torchvision.transforms.ToTensor()]
|
238 |
+
if normalize:
|
239 |
+
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
240 |
+
(0.26862954, 0.26130258, 0.27577711))]
|
241 |
+
return torchvision.transforms.Compose(transform_list)
|
242 |
+
|
243 |
+
def process(self, image):
|
244 |
+
img = np.array(image)
|
245 |
+
img = cv2.resize(img, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
|
246 |
+
img = np.array(img).astype(np.float32)
|
247 |
+
img = img / 127.5 - 1.0
|
248 |
+
return torch.from_numpy(img).permute(2, 0, 1)
|
249 |
+
|
250 |
+
def obtain_text(self, add_caption, object_category=None):
|
251 |
+
|
252 |
+
if object_category is None:
|
253 |
+
placeholder_string = self.placeholder_token
|
254 |
+
else:
|
255 |
+
placeholder_string = object_category
|
256 |
+
|
257 |
+
text = random.choice(self.templates).format(placeholder_string)
|
258 |
+
text = add_caption + text[1:]
|
259 |
+
|
260 |
+
placeholder_index = 0
|
261 |
+
words = text.strip().split(' ')
|
262 |
+
for idx, word in enumerate(words):
|
263 |
+
if word == placeholder_string:
|
264 |
+
placeholder_index = idx + 1
|
265 |
+
|
266 |
+
index = torch.tensor(placeholder_index)
|
267 |
+
|
268 |
+
input_ids = self.tokenizer(
|
269 |
+
text,
|
270 |
+
padding="max_length",
|
271 |
+
truncation=True,
|
272 |
+
max_length=self.tokenizer.model_max_length,
|
273 |
+
return_tensors="pt",
|
274 |
+
).input_ids[0]
|
275 |
+
return input_ids, index, text
|
276 |
+
|
277 |
+
def __getitem__(self, i):
|
278 |
+
example = {}
|
279 |
+
|
280 |
+
input_ids, index, text = self.obtain_text('a')
|
281 |
+
example["input_ids"] = input_ids
|
282 |
+
example["index"] = index
|
283 |
+
example["text"] = text
|
284 |
+
|
285 |
+
bbox_sample = self.bboxs_full[i % self.num_images]
|
286 |
+
bbox_sample = copy.copy(bbox_sample)
|
287 |
+
|
288 |
+
file_name = bbox_sample[-1] + '.jpg'
|
289 |
+
img_path = os.path.join(self.data_root, 'images', self.set_type, file_name)
|
290 |
+
|
291 |
+
try:
|
292 |
+
img_p = Image.open(img_path).convert("RGB")
|
293 |
+
img_p_np = np.array(img_p)
|
294 |
+
bbox_sample[0] *= int(img_p_np.shape[1])
|
295 |
+
bbox_sample[1] *= int(img_p_np.shape[1])
|
296 |
+
bbox_sample[2] *= int(img_p_np.shape[0])
|
297 |
+
bbox_sample[3] *= int(img_p_np.shape[0])
|
298 |
+
|
299 |
+
bbox_pad = copy.copy(bbox_sample)
|
300 |
+
bbox_pad[0] = int(bbox_sample[0] - min(10, bbox_sample[0] - 0))
|
301 |
+
bbox_pad[1] = int(bbox_sample[1] + min(10, img_p.size[0] - bbox_sample[1]))
|
302 |
+
bbox_pad[2] = int(bbox_sample[2] - min(10, bbox_sample[2] - 0))
|
303 |
+
bbox_pad[3] = int(bbox_sample[3] + min(10, img_p.size[1] - bbox_sample[3]))
|
304 |
+
|
305 |
+
image_tensor = img_p_np[bbox_pad[2]:bbox_pad[3], bbox_pad[0]:bbox_pad[1], :]
|
306 |
+
example["pixel_values"] = self.process(image_tensor)
|
307 |
+
|
308 |
+
ref_image_tensor = self.random_trans(image=image_tensor)
|
309 |
+
ref_image_tensor = Image.fromarray(ref_image_tensor["image"])
|
310 |
+
example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tensor)
|
311 |
+
|
312 |
+
except Exception as e:
|
313 |
+
example["pixel_values"] = torch.zeros((3, 512, 512))
|
314 |
+
example["pixel_values_clip"] = torch.zeros((3, 224, 224))
|
315 |
+
with open('error.txt', 'a+') as f:
|
316 |
+
f.write(str(e) + '\n')
|
317 |
+
|
318 |
+
return example
|
319 |
+
|
320 |
+
|
321 |
+
class OpenImagesDatasetWithMask(OpenImagesDataset):
|
322 |
+
def __init__(self,
|
323 |
+
data_root,
|
324 |
+
tokenizer,
|
325 |
+
size=512,
|
326 |
+
interpolation="bicubic",
|
327 |
+
set="train",
|
328 |
+
placeholder_token="*"):
|
329 |
+
|
330 |
+
# super().__init__(data_root, tokenizer, size, interpolation, set, placeholder_token)
|
331 |
+
self.data_root = data_root
|
332 |
+
self.tokenizer = tokenizer
|
333 |
+
self.size = size
|
334 |
+
self.placeholder_token = placeholder_token
|
335 |
+
self.set = set
|
336 |
+
|
337 |
+
class_anno_path = os.path.join(data_root, 'annotations', f'oidv6-class-descriptions.csv')
|
338 |
+
anno_files = pd.read_csv(class_anno_path)
|
339 |
+
class_groups = anno_files.groupby(anno_files.LabelName)
|
340 |
+
|
341 |
+
if set == "train":
|
342 |
+
bboxs_path = os.path.join(data_root, 'annotations', f'train-annotations-object-segmentation.csv')
|
343 |
+
dict_path = os.path.join(data_root, 'segs', f'train_bbox_dict.npy')
|
344 |
+
elif set == "validation":
|
345 |
+
bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-object-segmentation.csv')
|
346 |
+
dict_path = os.path.join(data_root, 'segs', f'validation_bbox_dict.npy')
|
347 |
+
else:
|
348 |
+
bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-object-segmentation.csv')
|
349 |
+
dict_path = os.path.join(data_root, 'segs', f'test_bbox_dict.npy')
|
350 |
+
|
351 |
+
bbox_dict = np.load(dict_path, allow_pickle=True).item()
|
352 |
+
|
353 |
+
df_val_bbox = pd.read_csv(bboxs_path)
|
354 |
+
bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName)
|
355 |
+
bboxes_full = []
|
356 |
+
for label_name in df_val_bbox['LabelName'].unique():
|
357 |
+
bboxs = bbox_groups.get_group(label_name)[
|
358 |
+
['BoxXMin', 'BoxXMax', 'BoxYMin', 'BoxYMax', 'LabelName', 'MaskPath']].values.tolist()
|
359 |
+
bboxes_new = []
|
360 |
+
for box in bboxs:
|
361 |
+
if not box[-1] in bbox_dict:
|
362 |
+
continue
|
363 |
+
bbox_data = bbox_dict[box[-1]]
|
364 |
+
|
365 |
+
if (bbox_data[2] - bbox_data[1]) < 100 or (bbox_data[4] - bbox_data[3]) < 100:
|
366 |
+
continue
|
367 |
+
if not ((bbox_data[2] - bbox_data[1]) / (bbox_data[4] - bbox_data[3]) < 0.5 or (
|
368 |
+
bbox_data[4] - bbox_data[3]) / ( bbox_data[2] - bbox_data[1]) < 0.5):
|
369 |
+
class_name = class_groups.get_group(box[4])[['DisplayName']].values.tolist()[0][0]
|
370 |
+
bboxes_new.append([box[-1], bbox_data[1], bbox_data[2], bbox_data[3], bbox_data[4], class_name])
|
371 |
+
|
372 |
+
bboxes_full.extend(bboxes_new)
|
373 |
+
|
374 |
+
self.bboxes_full = bboxes_full
|
375 |
+
self.num_images = len(bboxes_full)
|
376 |
+
|
377 |
+
print('{}: total {} images ...'.format(set, self.num_images))
|
378 |
+
|
379 |
+
self._length = self.num_images
|
380 |
+
self.interpolation = {
|
381 |
+
"linear": PIL_INTERPOLATION["linear"],
|
382 |
+
"bilinear": PIL_INTERPOLATION["bilinear"],
|
383 |
+
"bicubic": PIL_INTERPOLATION["bicubic"],
|
384 |
+
"lanczos": PIL_INTERPOLATION["lanczos"],
|
385 |
+
}[interpolation]
|
386 |
+
|
387 |
+
self.templates = imagenet_templates_small
|
388 |
+
|
389 |
+
|
390 |
+
def __len__(self):
|
391 |
+
return self._length
|
392 |
+
|
393 |
+
## borrowed from custom diffusion
|
394 |
+
def custom_aug(self, instance_image):
|
395 |
+
instance_image = Image.fromarray(instance_image)
|
396 |
+
#### apply augmentation and create a valid image regions mask ####
|
397 |
+
if np.random.randint(0, 3) < 2:
|
398 |
+
random_scale = np.random.randint(self.size // 3, self.size + 1)
|
399 |
+
else:
|
400 |
+
random_scale = np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
|
401 |
+
|
402 |
+
if random_scale % 2 == 1:
|
403 |
+
random_scale += 1
|
404 |
+
|
405 |
+
if random_scale < 0.6 * self.size:
|
406 |
+
add_to_caption = np.random.choice(["a far away", "very small"])
|
407 |
+
cx = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1)
|
408 |
+
cy = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1)
|
409 |
+
|
410 |
+
instance_image1 = instance_image.resize((random_scale, random_scale), resample=self.interpolation)
|
411 |
+
instance_image1 = np.array(instance_image1).astype(np.uint8)
|
412 |
+
instance_image1 = (instance_image1 / 127.5 - 1.0).astype(np.float32)
|
413 |
+
|
414 |
+
instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
|
415 |
+
instance_image[cx - random_scale // 2: cx + random_scale // 2,
|
416 |
+
cy - random_scale // 2: cy + random_scale // 2, :] = instance_image1
|
417 |
+
|
418 |
+
mask = np.zeros((self.size // 8, self.size // 8))
|
419 |
+
mask[(cx - random_scale // 2) // 8 + 1: (cx + random_scale // 2) // 8 - 1,
|
420 |
+
(cy - random_scale // 2) // 8 + 1: (cy + random_scale // 2) // 8 - 1] = 1.
|
421 |
+
|
422 |
+
elif random_scale > self.size:
|
423 |
+
add_to_caption = np.random.choice(["zoomed in", "close up"])
|
424 |
+
cx = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1)
|
425 |
+
cy = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1)
|
426 |
+
|
427 |
+
instance_image = instance_image.resize((random_scale, random_scale), resample=self.interpolation)
|
428 |
+
instance_image = np.array(instance_image).astype(np.uint8)
|
429 |
+
instance_image = (instance_image / 127.5 - 1.0).astype(np.float32)
|
430 |
+
instance_image = instance_image[cx - self.size // 2: cx + self.size // 2,
|
431 |
+
cy - self.size // 2: cy + self.size // 2, :]
|
432 |
+
mask = np.ones((self.size // 8, self.size // 8))
|
433 |
+
else:
|
434 |
+
add_to_caption = "a"
|
435 |
+
if self.size is not None:
|
436 |
+
instance_image = instance_image.resize((self.size, self.size), resample=self.interpolation)
|
437 |
+
instance_image = np.array(instance_image).astype(np.uint8)
|
438 |
+
instance_image = (instance_image / 127.5 - 1.0).astype(np.float32)
|
439 |
+
mask = np.ones((self.size // 8, self.size // 8))
|
440 |
+
|
441 |
+
return torch.from_numpy(instance_image).permute(2, 0, 1), torch.from_numpy(mask[:, :, None]).permute(2, 0, 1), add_to_caption
|
442 |
+
|
443 |
+
def aug_cv2(self, img, seg):
|
444 |
+
|
445 |
+
img_auged = np.array(img).copy()
|
446 |
+
seg_auged = np.array(seg).copy()
|
447 |
+
# resize and crop
|
448 |
+
if random.choice([0, 1]) == 0:
|
449 |
+
new_size = random.randint(224, 256)
|
450 |
+
img_auged = cv2.resize(img_auged, (new_size, new_size), interpolation=cv2.INTER_CUBIC)
|
451 |
+
seg_auged = cv2.resize(seg_auged, (new_size, new_size), interpolation=cv2.INTER_NEAREST)
|
452 |
+
|
453 |
+
start_x, start_y = random.randint(0, new_size - 224), random.randint(0, new_size - 224)
|
454 |
+
img_auged = img_auged[start_x:start_x + 224, start_y:start_y + 224, :]
|
455 |
+
seg_auged = seg_auged[start_x:start_x + 224, start_y:start_y + 224, :]
|
456 |
+
|
457 |
+
h, w = img_auged.shape[:2]
|
458 |
+
# rotate
|
459 |
+
if random.choice([0, 1]) == 0:
|
460 |
+
# print('rotate')
|
461 |
+
angle = random.randint(-30, 30)
|
462 |
+
M = cv2.getRotationMatrix2D((112, 112), angle, 1)
|
463 |
+
img_auged = cv2.warpAffine(img_auged, M, (w, h), flags=cv2.INTER_CUBIC)
|
464 |
+
seg_auged = cv2.warpAffine(seg_auged, M, (w, h), flags=cv2.INTER_NEAREST)
|
465 |
+
|
466 |
+
# translation
|
467 |
+
if random.choice([0, 1]) == 0:
|
468 |
+
trans_x = random.randint(-60, 60)
|
469 |
+
trans_y = random.randint(-60, 60)
|
470 |
+
H = np.float32([[1, 0, trans_x],
|
471 |
+
[0, 1, trans_y]])
|
472 |
+
img_auged = cv2.warpAffine(img_auged, H, (w, h), flags=cv2.INTER_CUBIC)
|
473 |
+
seg_auged = cv2.warpAffine(seg_auged, H, (w, h), flags=cv2.INTER_NEAREST)
|
474 |
+
|
475 |
+
img_auged = Image.fromarray(img_auged)
|
476 |
+
seg_auged = Image.fromarray(seg_auged)
|
477 |
+
|
478 |
+
return img_auged, seg_auged
|
479 |
+
|
480 |
+
|
481 |
+
def __getitem__(self, i):
|
482 |
+
example = {}
|
483 |
+
|
484 |
+
seg_name = self.bboxes_full[i % self.num_images][0]
|
485 |
+
file_name = seg_name.split('_')[0] + '.jpg'
|
486 |
+
img_path = os.path.join(self.data_root, 'images', self.set, file_name)
|
487 |
+
seg_path = os.path.join(self.data_root, 'segs', self.set, seg_name)
|
488 |
+
|
489 |
+
try:
|
490 |
+
# crop image and mask
|
491 |
+
bbox_sample = self.bboxes_full[i % self.num_images][1:]
|
492 |
+
img_p_np = cv2.imread(img_path)
|
493 |
+
img_p_np = cv2.cvtColor(img_p_np, cv2.COLOR_BGR2RGB)
|
494 |
+
seg_p_np = cv2.imread(seg_path).astype('float')
|
495 |
+
seg_p_np = cv2.resize(seg_p_np, img_p_np.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
|
496 |
+
|
497 |
+
bbox_pad = copy.copy(bbox_sample)
|
498 |
+
pad_size = random.choice(list(range(10, 20)))
|
499 |
+
bbox_pad[0] = int(bbox_pad[0] - min(pad_size, bbox_pad[0] - 0))
|
500 |
+
bbox_pad[1] = int(bbox_pad[1] + pad_size)
|
501 |
+
bbox_pad[2] = int(bbox_pad[2] - min(pad_size, bbox_pad[2] - 0))
|
502 |
+
bbox_pad[3] = int(bbox_pad[3] + pad_size)
|
503 |
+
|
504 |
+
image_tensor = img_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :]
|
505 |
+
seg_tensor = seg_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :]
|
506 |
+
|
507 |
+
# augmentation for input image
|
508 |
+
augged_image, augged_mask, add_caption = self.custom_aug(image_tensor)
|
509 |
+
input_ids, index, text = self.obtain_text(add_caption)
|
510 |
+
|
511 |
+
example["pixel_values"] = augged_image
|
512 |
+
example["mask_values"] = augged_mask
|
513 |
+
example["input_ids"] = input_ids
|
514 |
+
example["index"] = index
|
515 |
+
example["text"] = text
|
516 |
+
|
517 |
+
object_tensor = image_tensor * (seg_tensor / 255)
|
518 |
+
ref_object_tensor = cv2.resize(object_tensor, (224, 224), interpolation=cv2.INTER_CUBIC)
|
519 |
+
ref_image_tenser = cv2.resize(image_tensor, (224, 224), interpolation=cv2.INTER_CUBIC)
|
520 |
+
ref_seg_tensor = cv2.resize(seg_tensor, (224, 224), interpolation=cv2.INTER_NEAREST)
|
521 |
+
|
522 |
+
ref_object_tensor, ref_seg_tensor = self.aug_cv2(ref_object_tensor.astype('uint8'), ref_seg_tensor.astype('uint8'))
|
523 |
+
example["pixel_values_clip"] = self.get_tensor_clip()(Image.fromarray(ref_image_tenser))
|
524 |
+
example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor)
|
525 |
+
example["pixel_values_seg"] = self.get_tensor_clip(normalize=False)(ref_seg_tensor)
|
526 |
+
|
527 |
+
except Exception as e:
|
528 |
+
example["pixel_values"] = torch.zeros((3, 512, 512))
|
529 |
+
example["pixel_values_obj"] = torch.zeros((3, 224, 224))
|
530 |
+
example["pixel_values_clip"] = torch.zeros((3, 224, 224))
|
531 |
+
example["pixel_values_seg"] = torch.zeros((3, 224, 224))
|
532 |
+
|
533 |
+
input_ids, index, text = self.obtain_text("a")
|
534 |
+
example["input_ids"] = input_ids
|
535 |
+
example["index"] = index
|
536 |
+
example["text"] = text
|
537 |
+
|
538 |
+
with open('error.txt', 'a+') as f:
|
539 |
+
f.write(str(e) + '\n')
|
540 |
+
|
541 |
+
return example
|
elite.yaml
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: elite
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- _libgcc_mutex=0.1=main
|
6 |
+
- ca-certificates=2022.10.11=h06a4308_0
|
7 |
+
- certifi=2022.9.24=py39h06a4308_0
|
8 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
9 |
+
- libffi=3.3=he6710b0_2
|
10 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
11 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
12 |
+
- ncurses=6.3=h7f8727e_2
|
13 |
+
- openssl=1.1.1s=h7f8727e_0
|
14 |
+
- pip=22.2.2=py39h06a4308_0
|
15 |
+
- python=3.9.12=h12debd9_1
|
16 |
+
- readline=8.1.2=h7f8727e_1
|
17 |
+
- sqlite=3.38.5=hc218d9a_0
|
18 |
+
- tk=8.6.12=h1ccaba5_0
|
19 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
20 |
+
- xz=5.2.5=h7f8727e_1
|
21 |
+
- zlib=1.2.12=h7f8727e_2
|
22 |
+
- pip:
|
23 |
+
- absl-py==1.3.0
|
24 |
+
- accelerate==0.15.0
|
25 |
+
- aiohttp==3.8.3
|
26 |
+
- aiosignal==1.3.1
|
27 |
+
- albumentations==1.1.0
|
28 |
+
- altair==4.2.0
|
29 |
+
- antlr4-python3-runtime==4.8
|
30 |
+
- async-timeout==4.0.2
|
31 |
+
- attrs==22.1.0
|
32 |
+
- blinker==1.5
|
33 |
+
- cachetools==5.2.0
|
34 |
+
- charset-normalizer==2.1.1
|
35 |
+
- click==8.1.3
|
36 |
+
- commonmark==0.9.1
|
37 |
+
- contourpy==1.0.6
|
38 |
+
- cycler==0.11.0
|
39 |
+
- cython==0.29.33
|
40 |
+
- decorator==5.1.1
|
41 |
+
- diffusers==0.11.1
|
42 |
+
- einops==0.4.1
|
43 |
+
- emoji==2.2.0
|
44 |
+
- entrypoints==0.4
|
45 |
+
- faiss-gpu==1.7.2
|
46 |
+
- filelock==3.8.0
|
47 |
+
- fonttools==4.38.0
|
48 |
+
- frozenlist==1.3.3
|
49 |
+
- fsspec==2022.11.0
|
50 |
+
- ftfy==6.1.1
|
51 |
+
- future==0.18.2
|
52 |
+
- gitdb==4.0.9
|
53 |
+
- gitpython==3.1.29
|
54 |
+
- google-auth==2.14.1
|
55 |
+
- google-auth-oauthlib==0.4.6
|
56 |
+
- grpcio==1.50.0
|
57 |
+
- huggingface-hub==0.11.0
|
58 |
+
- idna==3.4
|
59 |
+
- imageio==2.14.1
|
60 |
+
- imageio-ffmpeg==0.4.7
|
61 |
+
- importlib-metadata==5.0.0
|
62 |
+
- jinja2==3.1.2
|
63 |
+
- joblib==1.2.0
|
64 |
+
- jsonschema==4.17.0
|
65 |
+
- kiwisolver==1.4.4
|
66 |
+
- kornia==0.6.0
|
67 |
+
- markdown==3.4.1
|
68 |
+
- markupsafe==2.1.1
|
69 |
+
- matplotlib==3.6.2
|
70 |
+
- multidict==6.0.2
|
71 |
+
- networkx==2.8.8
|
72 |
+
- nltk==3.7
|
73 |
+
- numpy==1.23.4
|
74 |
+
- oauthlib==3.2.2
|
75 |
+
- omegaconf==2.1.1
|
76 |
+
- opencv-python==4.6.0.66
|
77 |
+
- opencv-python-headless==4.6.0.66
|
78 |
+
- packaging==21.3
|
79 |
+
- pandas==1.5.1
|
80 |
+
- pillow==9.0.1
|
81 |
+
- protobuf==3.20.1
|
82 |
+
- psutil==5.9.4
|
83 |
+
- pudb==2019.2
|
84 |
+
- pyarrow==10.0.0
|
85 |
+
- pyasn1==0.4.8
|
86 |
+
- pyasn1-modules==0.2.8
|
87 |
+
- pycocotools==2.0.6
|
88 |
+
- pydeck==0.8.0
|
89 |
+
- pydensecrf==1.0rc2
|
90 |
+
- pydeprecate==0.3.2
|
91 |
+
- pygments==2.13.0
|
92 |
+
- pympler==1.0.1
|
93 |
+
- pyparsing==3.0.9
|
94 |
+
- pyrsistent==0.19.2
|
95 |
+
- python-dateutil==2.8.2
|
96 |
+
- python-dotenv==0.21.0
|
97 |
+
- pytorch-lightning==1.6.5
|
98 |
+
- pytz==2022.6
|
99 |
+
- pytz-deprecation-shim==0.1.0.post0
|
100 |
+
- pywavelets==1.4.1
|
101 |
+
- pyyaml==6.0
|
102 |
+
- qudida==0.0.4
|
103 |
+
- regex==2022.10.31
|
104 |
+
- requests==2.28.1
|
105 |
+
- requests-oauthlib==1.3.1
|
106 |
+
- rich==12.6.0
|
107 |
+
- rsa==4.9
|
108 |
+
- sacremoses==0.0.53
|
109 |
+
- scikit-image==0.19.3
|
110 |
+
- scikit-learn==1.1.3
|
111 |
+
- scipy==1.9.3
|
112 |
+
- semver==2.13.0
|
113 |
+
- setuptools==59.5.0
|
114 |
+
- six==1.16.0
|
115 |
+
- smmap==5.0.0
|
116 |
+
- stanza==1.4.2
|
117 |
+
- streamlit==1.15.0
|
118 |
+
- tensorboard==2.11.0
|
119 |
+
- tensorboard-data-server==0.6.1
|
120 |
+
- tensorboard-plugin-wit==1.8.1
|
121 |
+
- test-tube==0.7.5
|
122 |
+
- threadpoolctl==3.1.0
|
123 |
+
- tifffile==2022.10.10
|
124 |
+
- timm==0.6.12
|
125 |
+
- tokenizers==0.12.1
|
126 |
+
- toml==0.10.2
|
127 |
+
- toolz==0.12.0
|
128 |
+
- torch==1.12.1+cu116
|
129 |
+
- torch-fidelity==0.3.0
|
130 |
+
- torchaudio==0.12.1+cu116
|
131 |
+
- torchmetrics==0.6.0
|
132 |
+
- torchvision==0.13.1+cu116
|
133 |
+
- tornado==6.2
|
134 |
+
- tqdm==4.64.1
|
135 |
+
- transformers==4.25.1
|
136 |
+
- typing-extensions==4.4.0
|
137 |
+
- tzdata==2022.6
|
138 |
+
- tzlocal==4.2
|
139 |
+
- urllib3==1.26.12
|
140 |
+
- urwid==2.1.2
|
141 |
+
- validators==0.20.0
|
142 |
+
- watchdog==2.1.9
|
143 |
+
- wcwidth==0.2.5
|
144 |
+
- werkzeug==2.2.2
|
145 |
+
- yarl==1.8.1
|
146 |
+
- zipp==3.10.0
|
147 |
+
|
inference_global.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
8 |
+
from train_global import Mapper, th2image
|
9 |
+
from train_global import inj_forward_text, inj_forward_crossattention, validation
|
10 |
+
import torch.nn as nn
|
11 |
+
from datasets import CustomDatasetWithBG
|
12 |
+
|
13 |
+
def _pil_from_latents(vae, latents):
|
14 |
+
_latents = 1 / 0.18215 * latents.clone()
|
15 |
+
image = vae.decode(_latents).sample
|
16 |
+
|
17 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
18 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
19 |
+
images = (image * 255).round().astype("uint8")
|
20 |
+
ret_pil_images = [Image.fromarray(image) for image in images]
|
21 |
+
|
22 |
+
return ret_pil_images
|
23 |
+
|
24 |
+
|
25 |
+
def pww_load_tools(
|
26 |
+
device: str = "cuda:0",
|
27 |
+
scheduler_type=LMSDiscreteScheduler,
|
28 |
+
mapper_model_path: Optional[str] = None,
|
29 |
+
diffusion_model_path: Optional[str] = None,
|
30 |
+
model_token: Optional[str] = None,
|
31 |
+
) -> Tuple[
|
32 |
+
UNet2DConditionModel,
|
33 |
+
CLIPTextModel,
|
34 |
+
CLIPTokenizer,
|
35 |
+
AutoencoderKL,
|
36 |
+
CLIPVisionModel,
|
37 |
+
Mapper,
|
38 |
+
LMSDiscreteScheduler,
|
39 |
+
]:
|
40 |
+
|
41 |
+
# 'CompVis/stable-diffusion-v1-4'
|
42 |
+
local_path_only = diffusion_model_path is not None
|
43 |
+
vae = AutoencoderKL.from_pretrained(
|
44 |
+
diffusion_model_path,
|
45 |
+
subfolder="vae",
|
46 |
+
use_auth_token=model_token,
|
47 |
+
torch_dtype=torch.float16,
|
48 |
+
local_files_only=local_path_only,
|
49 |
+
)
|
50 |
+
|
51 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
|
52 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
|
53 |
+
image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
|
54 |
+
|
55 |
+
|
56 |
+
# Load models and create wrapper for stable diffusion
|
57 |
+
for _module in text_encoder.modules():
|
58 |
+
if _module.__class__.__name__ == "CLIPTextTransformer":
|
59 |
+
_module.__class__.__call__ = inj_forward_text
|
60 |
+
|
61 |
+
unet = UNet2DConditionModel.from_pretrained(
|
62 |
+
diffusion_model_path,
|
63 |
+
subfolder="unet",
|
64 |
+
use_auth_token=model_token,
|
65 |
+
torch_dtype=torch.float16,
|
66 |
+
local_files_only=local_path_only,
|
67 |
+
)
|
68 |
+
|
69 |
+
mapper = Mapper(input_dim=1024, output_dim=768)
|
70 |
+
|
71 |
+
for _name, _module in unet.named_modules():
|
72 |
+
if _module.__class__.__name__ == "CrossAttention":
|
73 |
+
if 'attn1' in _name: continue
|
74 |
+
_module.__class__.__call__ = inj_forward_crossattention
|
75 |
+
|
76 |
+
shape = _module.to_k.weight.shape
|
77 |
+
to_k_global = nn.Linear(shape[1], shape[0], bias=False)
|
78 |
+
mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global)
|
79 |
+
|
80 |
+
shape = _module.to_v.weight.shape
|
81 |
+
to_v_global = nn.Linear(shape[1], shape[0], bias=False)
|
82 |
+
mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global)
|
83 |
+
|
84 |
+
mapper.load_state_dict(torch.load(mapper_model_path, map_location='cpu'))
|
85 |
+
mapper.half()
|
86 |
+
|
87 |
+
for _name, _module in unet.named_modules():
|
88 |
+
if 'attn1' in _name: continue
|
89 |
+
if _module.__class__.__name__ == "CrossAttention":
|
90 |
+
_module.add_module('to_k_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_k'))
|
91 |
+
_module.add_module('to_v_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_v'))
|
92 |
+
|
93 |
+
vae.to(device), unet.to(device), text_encoder.to(device), image_encoder.to(device), mapper.to(device)
|
94 |
+
|
95 |
+
scheduler = scheduler_type(
|
96 |
+
beta_start=0.00085,
|
97 |
+
beta_end=0.012,
|
98 |
+
beta_schedule="scaled_linear",
|
99 |
+
num_train_timesteps=1000,
|
100 |
+
)
|
101 |
+
vae.eval()
|
102 |
+
unet.eval()
|
103 |
+
image_encoder.eval()
|
104 |
+
text_encoder.eval()
|
105 |
+
mapper.eval()
|
106 |
+
return vae, unet, text_encoder, tokenizer, image_encoder, mapper, scheduler
|
107 |
+
|
108 |
+
|
109 |
+
def parse_args():
|
110 |
+
import argparse
|
111 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
112 |
+
parser.add_argument(
|
113 |
+
"--token_index",
|
114 |
+
type=str,
|
115 |
+
default="full",
|
116 |
+
help="Selected index for word embedding.",
|
117 |
+
)
|
118 |
+
|
119 |
+
parser.add_argument(
|
120 |
+
"--global_mapper_path",
|
121 |
+
type=str,
|
122 |
+
required=True,
|
123 |
+
help="Path to pretrained global mapping network.",
|
124 |
+
)
|
125 |
+
|
126 |
+
parser.add_argument(
|
127 |
+
"--output_dir",
|
128 |
+
type=str,
|
129 |
+
default='outputs',
|
130 |
+
help="The output directory where the model predictions will be written.",
|
131 |
+
)
|
132 |
+
|
133 |
+
parser.add_argument(
|
134 |
+
"--placeholder_token",
|
135 |
+
type=str,
|
136 |
+
default="S",
|
137 |
+
help="A token to use as a placeholder for the concept.",
|
138 |
+
)
|
139 |
+
|
140 |
+
parser.add_argument(
|
141 |
+
"--template",
|
142 |
+
type=str,
|
143 |
+
default="a photo of a {}",
|
144 |
+
help="Text template for customized genetation.",
|
145 |
+
)
|
146 |
+
|
147 |
+
parser.add_argument(
|
148 |
+
"--test_data_dir", type=str, default=None, required=True, help="A folder containing the testing data."
|
149 |
+
)
|
150 |
+
|
151 |
+
parser.add_argument(
|
152 |
+
"--pretrained_model_name_or_path",
|
153 |
+
type=str,
|
154 |
+
default=None,
|
155 |
+
required=True,
|
156 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
157 |
+
)
|
158 |
+
|
159 |
+
parser.add_argument(
|
160 |
+
"--suffix",
|
161 |
+
type=str,
|
162 |
+
default="object",
|
163 |
+
help="Suffix of save directory.",
|
164 |
+
)
|
165 |
+
|
166 |
+
parser.add_argument(
|
167 |
+
"--selected_data",
|
168 |
+
type=int,
|
169 |
+
default=-1,
|
170 |
+
help="Data index. -1 for all.",
|
171 |
+
)
|
172 |
+
|
173 |
+
args = parser.parse_args()
|
174 |
+
return args
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
args = parse_args()
|
179 |
+
|
180 |
+
save_dir = os.path.join(args.output_dir, f'{args.suffix}_token{args.token_index}')
|
181 |
+
os.makedirs(save_dir, exist_ok=True)
|
182 |
+
|
183 |
+
vae, unet, text_encoder, tokenizer, image_encoder, mapper, scheduler = pww_load_tools(
|
184 |
+
"cuda:0",
|
185 |
+
LMSDiscreteScheduler,
|
186 |
+
diffusion_model_path=args.pretrained_model_name_or_path,
|
187 |
+
mapper_model_path=args.global_mapper_path,
|
188 |
+
)
|
189 |
+
|
190 |
+
train_dataset = CustomDatasetWithBG(
|
191 |
+
data_root=args.test_data_dir,
|
192 |
+
tokenizer=tokenizer,
|
193 |
+
size=512,
|
194 |
+
placeholder_token=args.placeholder_token,
|
195 |
+
template=args.template,
|
196 |
+
)
|
197 |
+
|
198 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False)
|
199 |
+
for step, batch in enumerate(train_dataloader):
|
200 |
+
if args.selected_data > -1 and step != args.selected_data:
|
201 |
+
continue
|
202 |
+
batch["pixel_values"] = batch["pixel_values"].to("cuda:0")
|
203 |
+
batch["pixel_values_clip"] = batch["pixel_values_clip"].to("cuda:0").half()
|
204 |
+
batch["input_ids"] = batch["input_ids"].to("cuda:0")
|
205 |
+
batch["index"] = batch["index"].to("cuda:0").long()
|
206 |
+
print(step, batch['text'])
|
207 |
+
seeds = [0, 42, 10086, 777, 555, 222, 111, 999, 327, 283, 190, 218, 2371, 9329, 2938, 2073, 27367, 293,
|
208 |
+
8269, 87367, 29379, 4658, 39, 598]
|
209 |
+
seeds = sorted(seeds)
|
210 |
+
for seed in seeds:
|
211 |
+
syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, vae, batch["pixel_values_clip"].device, 5,
|
212 |
+
token_index=args.token_index, seed=seed)
|
213 |
+
concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
|
214 |
+
Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(seed).zfill(5)}.jpg'))
|
inference_global.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
+
export DATA_DIR='./test_datasets/'
|
3 |
+
|
4 |
+
CUDA_VISIBLE_DEVICES=6 python inference_global.py \
|
5 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
6 |
+
--test_data_dir=$DATA_DIR \
|
7 |
+
--output_dir="./outputs/global_mapping" \
|
8 |
+
--suffix="object" \
|
9 |
+
--token_index="0" \
|
10 |
+
--template="a photo of a {}" \
|
11 |
+
--global_mapper_path="./checkpoints/global_mapper.pt"
|
12 |
+
|
inference_local.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
11 |
+
from train_local import Mapper, th2image, MapperLocal
|
12 |
+
from train_local import inj_forward_text, inj_forward_crossattention, validation
|
13 |
+
import torch.nn as nn
|
14 |
+
from datasets import CustomDatasetWithBG
|
15 |
+
|
16 |
+
def _pil_from_latents(vae, latents):
|
17 |
+
_latents = 1 / 0.18215 * latents.clone()
|
18 |
+
image = vae.decode(_latents).sample
|
19 |
+
|
20 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
21 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
22 |
+
images = (image * 255).round().astype("uint8")
|
23 |
+
ret_pil_images = [Image.fromarray(image) for image in images]
|
24 |
+
|
25 |
+
return ret_pil_images
|
26 |
+
|
27 |
+
|
28 |
+
def pww_load_tools(
|
29 |
+
device: str = "cuda:0",
|
30 |
+
scheduler_type=LMSDiscreteScheduler,
|
31 |
+
mapper_model_path: Optional[str] = None,
|
32 |
+
mapper_local_model_path: Optional[str] = None,
|
33 |
+
diffusion_model_path: Optional[str] = None,
|
34 |
+
model_token: Optional[str] = None,
|
35 |
+
) -> Tuple[
|
36 |
+
UNet2DConditionModel,
|
37 |
+
CLIPTextModel,
|
38 |
+
CLIPTokenizer,
|
39 |
+
AutoencoderKL,
|
40 |
+
CLIPVisionModel,
|
41 |
+
Mapper,
|
42 |
+
MapperLocal,
|
43 |
+
LMSDiscreteScheduler,
|
44 |
+
]:
|
45 |
+
|
46 |
+
# 'CompVis/stable-diffusion-v1-4'
|
47 |
+
local_path_only = diffusion_model_path is not None
|
48 |
+
vae = AutoencoderKL.from_pretrained(
|
49 |
+
diffusion_model_path,
|
50 |
+
subfolder="vae",
|
51 |
+
use_auth_token=model_token,
|
52 |
+
torch_dtype=torch.float16,
|
53 |
+
local_files_only=local_path_only,
|
54 |
+
)
|
55 |
+
|
56 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
|
57 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
|
58 |
+
image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
|
59 |
+
|
60 |
+
|
61 |
+
# Load models and create wrapper for stable diffusion
|
62 |
+
for _module in text_encoder.modules():
|
63 |
+
if _module.__class__.__name__ == "CLIPTextTransformer":
|
64 |
+
_module.__class__.__call__ = inj_forward_text
|
65 |
+
|
66 |
+
unet = UNet2DConditionModel.from_pretrained(
|
67 |
+
diffusion_model_path,
|
68 |
+
subfolder="unet",
|
69 |
+
use_auth_token=model_token,
|
70 |
+
torch_dtype=torch.float16,
|
71 |
+
local_files_only=local_path_only,
|
72 |
+
)
|
73 |
+
inj_forward_crossattention
|
74 |
+
mapper = Mapper(input_dim=1024, output_dim=768)
|
75 |
+
|
76 |
+
mapper_local = MapperLocal(input_dim=1024, output_dim=768)
|
77 |
+
|
78 |
+
for _name, _module in unet.named_modules():
|
79 |
+
if _module.__class__.__name__ == "CrossAttention":
|
80 |
+
if 'attn1' in _name: continue
|
81 |
+
_module.__class__.__call__ = inj_forward_crossattention
|
82 |
+
|
83 |
+
shape = _module.to_k.weight.shape
|
84 |
+
to_k_global = nn.Linear(shape[1], shape[0], bias=False)
|
85 |
+
mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global)
|
86 |
+
|
87 |
+
shape = _module.to_v.weight.shape
|
88 |
+
to_v_global = nn.Linear(shape[1], shape[0], bias=False)
|
89 |
+
mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global)
|
90 |
+
|
91 |
+
to_v_local = nn.Linear(shape[1], shape[0], bias=False)
|
92 |
+
mapper_local.add_module(f'{_name.replace(".", "_")}_to_v', to_v_local)
|
93 |
+
|
94 |
+
to_k_local = nn.Linear(shape[1], shape[0], bias=False)
|
95 |
+
mapper_local.add_module(f'{_name.replace(".", "_")}_to_k', to_k_local)
|
96 |
+
|
97 |
+
mapper.load_state_dict(torch.load(mapper_model_path, map_location='cpu'))
|
98 |
+
mapper.half()
|
99 |
+
|
100 |
+
mapper_local.load_state_dict(torch.load(mapper_local_model_path, map_location='cpu'))
|
101 |
+
mapper_local.half()
|
102 |
+
|
103 |
+
for _name, _module in unet.named_modules():
|
104 |
+
if 'attn1' in _name: continue
|
105 |
+
if _module.__class__.__name__ == "CrossAttention":
|
106 |
+
_module.add_module('to_k_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_k'))
|
107 |
+
_module.add_module('to_v_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_v'))
|
108 |
+
_module.add_module('to_v_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_v'))
|
109 |
+
_module.add_module('to_k_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_k'))
|
110 |
+
|
111 |
+
vae.to(device), unet.to(device), text_encoder.to(device), image_encoder.to(device), mapper.to(device), mapper_local.to(device)
|
112 |
+
|
113 |
+
scheduler = scheduler_type(
|
114 |
+
beta_start=0.00085,
|
115 |
+
beta_end=0.012,
|
116 |
+
beta_schedule="scaled_linear",
|
117 |
+
num_train_timesteps=1000,
|
118 |
+
)
|
119 |
+
vae.eval()
|
120 |
+
unet.eval()
|
121 |
+
image_encoder.eval()
|
122 |
+
text_encoder.eval()
|
123 |
+
mapper.eval()
|
124 |
+
mapper_local.eval()
|
125 |
+
return vae, unet, text_encoder, tokenizer, image_encoder, mapper, mapper_local, scheduler
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
def parse_args():
|
130 |
+
|
131 |
+
import argparse
|
132 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
133 |
+
|
134 |
+
parser.add_argument(
|
135 |
+
"--global_mapper_path",
|
136 |
+
type=str,
|
137 |
+
required=True,
|
138 |
+
help="Path to pretrained global mapping network.",
|
139 |
+
)
|
140 |
+
|
141 |
+
parser.add_argument(
|
142 |
+
"--local_mapper_path",
|
143 |
+
type=str,
|
144 |
+
required=True,
|
145 |
+
help="Path to pretrained local mapping network.",
|
146 |
+
)
|
147 |
+
|
148 |
+
parser.add_argument(
|
149 |
+
"--output_dir",
|
150 |
+
type=str,
|
151 |
+
default='outputs',
|
152 |
+
help="The output directory where the model predictions will be written.",
|
153 |
+
)
|
154 |
+
|
155 |
+
parser.add_argument(
|
156 |
+
"--placeholder_token",
|
157 |
+
type=str,
|
158 |
+
default="S",
|
159 |
+
help="A token to use as a placeholder for the concept.",
|
160 |
+
)
|
161 |
+
|
162 |
+
parser.add_argument(
|
163 |
+
"--template",
|
164 |
+
type=str,
|
165 |
+
default="a photo of a {}",
|
166 |
+
help="Text template for customized genetation.",
|
167 |
+
)
|
168 |
+
|
169 |
+
parser.add_argument(
|
170 |
+
"--test_data_dir", type=str, default=None, required=True, help="A folder containing the testing data."
|
171 |
+
)
|
172 |
+
|
173 |
+
parser.add_argument(
|
174 |
+
"--pretrained_model_name_or_path",
|
175 |
+
type=str,
|
176 |
+
default=None,
|
177 |
+
required=True,
|
178 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
179 |
+
)
|
180 |
+
|
181 |
+
parser.add_argument(
|
182 |
+
"--suffix",
|
183 |
+
type=str,
|
184 |
+
default="object",
|
185 |
+
help="Suffix of save directory.",
|
186 |
+
)
|
187 |
+
|
188 |
+
parser.add_argument(
|
189 |
+
"--selected_data",
|
190 |
+
type=int,
|
191 |
+
default=-1,
|
192 |
+
help="Data index. -1 for all.",
|
193 |
+
)
|
194 |
+
|
195 |
+
parser.add_argument(
|
196 |
+
"--llambda",
|
197 |
+
type=str,
|
198 |
+
default="0.8",
|
199 |
+
help="Lambda for fuse the global and local feature.",
|
200 |
+
)
|
201 |
+
|
202 |
+
args = parser.parse_args()
|
203 |
+
return args
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
args = parse_args()
|
208 |
+
|
209 |
+
save_dir = os.path.join(args.output_dir, f'{args.suffix}_l{args.llambda.replace(".", "p")}')
|
210 |
+
os.makedirs(save_dir, exist_ok=True)
|
211 |
+
|
212 |
+
vae, unet, text_encoder, tokenizer, image_encoder, mapper, mapper_local, scheduler = pww_load_tools(
|
213 |
+
"cuda:0",
|
214 |
+
LMSDiscreteScheduler,
|
215 |
+
diffusion_model_path=args.pretrained_model_name_or_path,
|
216 |
+
mapper_model_path=args.global_mapper_path,
|
217 |
+
mapper_local_model_path=args.local_mapper_path,
|
218 |
+
)
|
219 |
+
|
220 |
+
train_dataset = CustomDatasetWithBG(
|
221 |
+
data_root=args.test_data_dir,
|
222 |
+
tokenizer=tokenizer,
|
223 |
+
size=512,
|
224 |
+
placeholder_token=args.placeholder_token,
|
225 |
+
template=args.template,
|
226 |
+
)
|
227 |
+
|
228 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False)
|
229 |
+
for step, batch in enumerate(train_dataloader):
|
230 |
+
if args.selected_data > -1 and step != args.selected_data:
|
231 |
+
continue
|
232 |
+
batch["pixel_values"] = batch["pixel_values"].to("cuda:0")
|
233 |
+
batch["pixel_values_clip"] = batch["pixel_values_clip"].to("cuda:0").half()
|
234 |
+
batch["pixel_values_obj"] = batch["pixel_values_obj"].to("cuda:0").half()
|
235 |
+
batch["pixel_values_seg"] = batch["pixel_values_seg"].to("cuda:0").half()
|
236 |
+
batch["input_ids"] = batch["input_ids"].to("cuda:0")
|
237 |
+
batch["index"] = batch["index"].to("cuda:0").long()
|
238 |
+
print(step, batch['text'])
|
239 |
+
seeds = [0, 42, 10086, 777, 555, 222, 111, 999, 327, 283, 190, 218, 2371, 9329, 2938, 2073, 27367, 293,
|
240 |
+
8269, 87367, 29379, 4658, 39, 598]
|
241 |
+
seeds = sorted(seeds)
|
242 |
+
for seed in seeds:
|
243 |
+
syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae,
|
244 |
+
batch["pixel_values_clip"].device, 5,
|
245 |
+
seed=seed, llambda=float(args.llambda))
|
246 |
+
concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
|
247 |
+
Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(seed).zfill(5)}.jpg'))
|
inference_local.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
+
export DATA_DIR='./test_datasets/'
|
3 |
+
CUDA_VISIBLE_DEVICES=7 python inference_local.py \
|
4 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
5 |
+
--test_data_dir=$DATA_DIR \
|
6 |
+
--output_dir="./outputs/local_mapping" \
|
7 |
+
--suffix="object" \
|
8 |
+
--template="a photo of a {}" \
|
9 |
+
--llambda="0.8" \
|
10 |
+
--global_mapper_path="./checkpoints/global_mapper.pt" \
|
11 |
+
--local_mapper_path="./checkpoints/local_mapper.pt"
|
12 |
+
|
test_datasets/1.jpg
ADDED
test_datasets/10.jpg
ADDED
test_datasets/10_bg.png
ADDED
test_datasets/11.jpg
ADDED
test_datasets/11_bg.png
ADDED
test_datasets/15.jpg
ADDED
test_datasets/15_bg.png
ADDED
test_datasets/16.jpg
ADDED
test_datasets/16_bg.png
ADDED
test_datasets/17.jpg
ADDED
test_datasets/17_bg.png
ADDED
test_datasets/1_bg.png
ADDED
test_datasets/2.jpg
ADDED
test_datasets/20.jpg
ADDED
test_datasets/20_bg.png
ADDED
test_datasets/2_bg.png
ADDED
test_datasets/3.jpg
ADDED
test_datasets/3_bg.png
ADDED
test_datasets/4.png
ADDED
test_datasets/4_bg.png
ADDED
test_datasets/7.jpg
ADDED
test_datasets/7_bg.png
ADDED
train_global.py
ADDED
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
|
14 |
+
import PIL
|
15 |
+
from accelerate import Accelerator
|
16 |
+
from accelerate.logging import get_logger
|
17 |
+
from accelerate.utils import set_seed
|
18 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, LMSDiscreteScheduler
|
19 |
+
from diffusers.optimization import get_scheduler
|
20 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
21 |
+
|
22 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
23 |
+
from transformers.utils import (
|
24 |
+
add_start_docstrings_to_model_forward,
|
25 |
+
replace_return_docstrings,
|
26 |
+
)
|
27 |
+
from transformers.models.clip.configuration_clip import CLIPTextConfig
|
28 |
+
from transformers.models.clip.modeling_clip import CLIP_TEXT_INPUTS_DOCSTRING, _expand_mask
|
29 |
+
|
30 |
+
from PIL import Image
|
31 |
+
from tqdm.auto import tqdm
|
32 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
33 |
+
|
34 |
+
from typing import Any, Optional, Tuple, Union
|
35 |
+
from datasets import OpenImagesDataset
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class Mapper(nn.Module):
|
40 |
+
def __init__(self,
|
41 |
+
input_dim: int,
|
42 |
+
output_dim: int,
|
43 |
+
):
|
44 |
+
super(Mapper, self).__init__()
|
45 |
+
|
46 |
+
for i in range(5):
|
47 |
+
setattr(self, f'mapping_{i}', nn.Sequential(nn.Linear(input_dim, 1024),
|
48 |
+
nn.LayerNorm(1024),
|
49 |
+
nn.LeakyReLU(),
|
50 |
+
nn.Linear(1024, 1024),
|
51 |
+
nn.LayerNorm(1024),
|
52 |
+
nn.LeakyReLU(),
|
53 |
+
nn.Linear(1024, output_dim)))
|
54 |
+
|
55 |
+
setattr(self, f'mapping_patch_{i}', nn.Sequential(nn.Linear(input_dim, 1024),
|
56 |
+
nn.LayerNorm(1024),
|
57 |
+
nn.LeakyReLU(),
|
58 |
+
nn.Linear(1024, 1024),
|
59 |
+
nn.LayerNorm(1024),
|
60 |
+
nn.LeakyReLU(),
|
61 |
+
nn.Linear(1024, output_dim)))
|
62 |
+
|
63 |
+
def forward(self, embs):
|
64 |
+
hidden_states = ()
|
65 |
+
for i, emb in enumerate(embs):
|
66 |
+
hidden_state = getattr(self, f'mapping_{i}')(emb[:, :1]) + getattr(self, f'mapping_patch_{i}')(emb[:, 1:]).mean(dim=1, keepdim=True)
|
67 |
+
hidden_states += (hidden_state, )
|
68 |
+
hidden_states = torch.cat(hidden_states, dim=1)
|
69 |
+
return hidden_states
|
70 |
+
|
71 |
+
|
72 |
+
def _build_causal_attention_mask(bsz, seq_len, dtype):
|
73 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
74 |
+
# pytorch uses additive attention mask; fill with -inf
|
75 |
+
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
|
76 |
+
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
77 |
+
mask.triu_(1) # zero out the lower diagonal
|
78 |
+
mask = mask.unsqueeze(1) # expand mask
|
79 |
+
return mask
|
80 |
+
|
81 |
+
|
82 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
83 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
84 |
+
def inj_forward_text(
|
85 |
+
self,
|
86 |
+
input_ids: Optional[torch.Tensor] = None,
|
87 |
+
attention_mask: Optional[torch.Tensor] = None,
|
88 |
+
position_ids: Optional[torch.Tensor] = None,
|
89 |
+
output_attentions: Optional[bool] = None,
|
90 |
+
output_hidden_states: Optional[bool] = None,
|
91 |
+
return_dict: Optional[bool] = None,
|
92 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
93 |
+
r"""
|
94 |
+
Returns:
|
95 |
+
"""
|
96 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
97 |
+
output_hidden_states = (
|
98 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
99 |
+
)
|
100 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
101 |
+
|
102 |
+
if input_ids is None:
|
103 |
+
raise ValueError("You have to specify either input_ids")
|
104 |
+
|
105 |
+
r_input_ids = input_ids['input_ids']
|
106 |
+
if 'inj_embedding' in input_ids:
|
107 |
+
inj_embedding = input_ids['inj_embedding']
|
108 |
+
inj_index = input_ids['inj_index']
|
109 |
+
else:
|
110 |
+
inj_embedding = None
|
111 |
+
inj_index = None
|
112 |
+
|
113 |
+
input_shape = r_input_ids.size()
|
114 |
+
r_input_ids = r_input_ids.view(-1, input_shape[-1])
|
115 |
+
|
116 |
+
|
117 |
+
inputs_embeds = self.embeddings.token_embedding(r_input_ids)
|
118 |
+
new_inputs_embeds = inputs_embeds.clone()
|
119 |
+
if inj_embedding is not None:
|
120 |
+
emb_length = inj_embedding.shape[1]
|
121 |
+
for bsz, idx in enumerate(inj_index):
|
122 |
+
lll = new_inputs_embeds[bsz, idx+emb_length:].shape[0]
|
123 |
+
new_inputs_embeds[bsz, idx+emb_length:] = inputs_embeds[bsz, idx+1:idx+1+lll]
|
124 |
+
new_inputs_embeds[bsz, idx:idx+emb_length] = inj_embedding[bsz]
|
125 |
+
|
126 |
+
hidden_states = self.embeddings(input_ids=r_input_ids, position_ids=position_ids, inputs_embeds=new_inputs_embeds)
|
127 |
+
|
128 |
+
bsz, seq_len = input_shape
|
129 |
+
# CLIP's text model uses causal mask, prepare it here.
|
130 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
131 |
+
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
132 |
+
hidden_states.device
|
133 |
+
)
|
134 |
+
# expand attention_mask
|
135 |
+
if attention_mask is not None:
|
136 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
137 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
138 |
+
|
139 |
+
encoder_outputs = self.encoder(
|
140 |
+
inputs_embeds=hidden_states,
|
141 |
+
attention_mask=attention_mask,
|
142 |
+
causal_attention_mask=causal_attention_mask,
|
143 |
+
output_attentions=output_attentions,
|
144 |
+
output_hidden_states=output_hidden_states,
|
145 |
+
return_dict=return_dict,
|
146 |
+
)
|
147 |
+
|
148 |
+
last_hidden_state = encoder_outputs[0]
|
149 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
150 |
+
|
151 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
152 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
153 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
154 |
+
pooled_output = last_hidden_state[
|
155 |
+
torch.arange(last_hidden_state.shape[0], device=r_input_ids.device), r_input_ids.to(torch.int).argmax(dim=-1)
|
156 |
+
]
|
157 |
+
|
158 |
+
if not return_dict:
|
159 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
160 |
+
|
161 |
+
return BaseModelOutputWithPooling(
|
162 |
+
last_hidden_state=last_hidden_state,
|
163 |
+
pooler_output=pooled_output,
|
164 |
+
hidden_states=encoder_outputs.hidden_states,
|
165 |
+
attentions=encoder_outputs.attentions,
|
166 |
+
)
|
167 |
+
|
168 |
+
def inj_forward_crossattention(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
169 |
+
context = encoder_hidden_states
|
170 |
+
if context is not None:
|
171 |
+
context_tensor = context["CONTEXT_TENSOR"]
|
172 |
+
else:
|
173 |
+
context_tensor = hidden_states
|
174 |
+
|
175 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
176 |
+
|
177 |
+
query = self.to_q(hidden_states)
|
178 |
+
if context is not None:
|
179 |
+
key = self.to_k_global(context_tensor)
|
180 |
+
value = self.to_v_global(context_tensor)
|
181 |
+
else:
|
182 |
+
key = self.to_k(context_tensor)
|
183 |
+
value = self.to_v(context_tensor)
|
184 |
+
|
185 |
+
dim = query.shape[-1]
|
186 |
+
|
187 |
+
query = self.reshape_heads_to_batch_dim(query)
|
188 |
+
key = self.reshape_heads_to_batch_dim(key)
|
189 |
+
value = self.reshape_heads_to_batch_dim(value)
|
190 |
+
|
191 |
+
attention_scores = torch.matmul(query, key.transpose(-1, -2))
|
192 |
+
attention_scores = attention_scores * self.scale
|
193 |
+
|
194 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
195 |
+
|
196 |
+
hidden_states = torch.matmul(attention_probs, value)
|
197 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
198 |
+
|
199 |
+
# linear proj
|
200 |
+
hidden_states = self.to_out[0](hidden_states)
|
201 |
+
# dropout
|
202 |
+
hidden_states = self.to_out[1](hidden_states)
|
203 |
+
|
204 |
+
return hidden_states
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
logger = get_logger(__name__)
|
209 |
+
|
210 |
+
|
211 |
+
def save_progress(mapper, accelerator, args, step=None):
|
212 |
+
logger.info("Saving embeddings")
|
213 |
+
|
214 |
+
state_dict = accelerator.unwrap_model(mapper).state_dict()
|
215 |
+
|
216 |
+
if step is not None:
|
217 |
+
torch.save(state_dict, os.path.join(args.output_dir, f"mapper_{str(step).zfill(6)}.pt"))
|
218 |
+
else:
|
219 |
+
torch.save(state_dict, os.path.join(args.output_dir, "mapper.pt"))
|
220 |
+
|
221 |
+
|
222 |
+
def parse_args():
|
223 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
224 |
+
parser.add_argument(
|
225 |
+
"--save_steps",
|
226 |
+
type=int,
|
227 |
+
default=500,
|
228 |
+
help="Save learned_embeds.bin every X updates steps.",
|
229 |
+
)
|
230 |
+
parser.add_argument(
|
231 |
+
"--pretrained_model_name_or_path",
|
232 |
+
type=str,
|
233 |
+
default=None,
|
234 |
+
required=True,
|
235 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
236 |
+
)
|
237 |
+
parser.add_argument(
|
238 |
+
"--tokenizer_name",
|
239 |
+
type=str,
|
240 |
+
default=None,
|
241 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
|
245 |
+
)
|
246 |
+
parser.add_argument(
|
247 |
+
"--global_mapper_path", type=str, default=None, help="If not none, the training will start from the given checkpoints."
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--placeholder_token",
|
251 |
+
type=str,
|
252 |
+
default=None,
|
253 |
+
required=True,
|
254 |
+
help="A token to use as a placeholder for the concept.",
|
255 |
+
)
|
256 |
+
parser.add_argument(
|
257 |
+
"--output_dir",
|
258 |
+
type=str,
|
259 |
+
default="text-inversion-model",
|
260 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
261 |
+
)
|
262 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
263 |
+
parser.add_argument(
|
264 |
+
"--resolution",
|
265 |
+
type=int,
|
266 |
+
default=512,
|
267 |
+
help=(
|
268 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
269 |
+
" resolution"
|
270 |
+
),
|
271 |
+
)
|
272 |
+
parser.add_argument(
|
273 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
274 |
+
)
|
275 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
276 |
+
parser.add_argument(
|
277 |
+
"--max_train_steps",
|
278 |
+
type=int,
|
279 |
+
default=5000,
|
280 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
281 |
+
)
|
282 |
+
parser.add_argument(
|
283 |
+
"--gradient_accumulation_steps",
|
284 |
+
type=int,
|
285 |
+
default=1,
|
286 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
287 |
+
)
|
288 |
+
parser.add_argument(
|
289 |
+
"--learning_rate",
|
290 |
+
type=float,
|
291 |
+
default=1e-4,
|
292 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
293 |
+
)
|
294 |
+
parser.add_argument(
|
295 |
+
"--scale_lr",
|
296 |
+
action="store_true",
|
297 |
+
default=True,
|
298 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
299 |
+
)
|
300 |
+
parser.add_argument(
|
301 |
+
"--lr_scheduler",
|
302 |
+
type=str,
|
303 |
+
default="constant",
|
304 |
+
help=(
|
305 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
306 |
+
' "constant", "constant_with_warmup"]'
|
307 |
+
),
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
311 |
+
)
|
312 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
313 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
314 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
315 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
316 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
317 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
318 |
+
parser.add_argument(
|
319 |
+
"--hub_model_id",
|
320 |
+
type=str,
|
321 |
+
default=None,
|
322 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
323 |
+
)
|
324 |
+
parser.add_argument(
|
325 |
+
"--logging_dir",
|
326 |
+
type=str,
|
327 |
+
default="logs",
|
328 |
+
help=(
|
329 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
330 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
331 |
+
),
|
332 |
+
)
|
333 |
+
parser.add_argument(
|
334 |
+
"--mixed_precision",
|
335 |
+
type=str,
|
336 |
+
default="no",
|
337 |
+
choices=["no", "fp16", "bf16"],
|
338 |
+
help=(
|
339 |
+
"Whether to use mixed precision. Choose"
|
340 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
341 |
+
"and an Nvidia Ampere GPU."
|
342 |
+
),
|
343 |
+
)
|
344 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
345 |
+
|
346 |
+
args = parser.parse_args()
|
347 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
348 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
349 |
+
args.local_rank = env_local_rank
|
350 |
+
|
351 |
+
if args.train_data_dir is None:
|
352 |
+
raise ValueError("You must specify a train data directory.")
|
353 |
+
|
354 |
+
return args
|
355 |
+
|
356 |
+
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
357 |
+
if token is None:
|
358 |
+
token = HfFolder.get_token()
|
359 |
+
if organization is None:
|
360 |
+
username = whoami(token)["name"]
|
361 |
+
return f"{username}/{model_id}"
|
362 |
+
else:
|
363 |
+
return f"{organization}/{model_id}"
|
364 |
+
|
365 |
+
|
366 |
+
def freeze_params(params):
|
367 |
+
for param in params:
|
368 |
+
param.requires_grad = False
|
369 |
+
|
370 |
+
def unfreeze_params(params):
|
371 |
+
for param in params:
|
372 |
+
param.requires_grad = True
|
373 |
+
|
374 |
+
def th2image(image):
|
375 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
376 |
+
image = image.detach().cpu().permute(1, 2, 0).numpy()
|
377 |
+
image = (image * 255).round().astype("uint8")
|
378 |
+
return Image.fromarray(image)
|
379 |
+
|
380 |
+
|
381 |
+
@torch.no_grad()
|
382 |
+
def validation(example, tokenizer, image_encoder, text_encoder, unet, mapper, vae, device, guidance_scale, token_index='full', seed=None):
|
383 |
+
scheduler = LMSDiscreteScheduler(
|
384 |
+
beta_start=0.00085,
|
385 |
+
beta_end=0.012,
|
386 |
+
beta_schedule="scaled_linear",
|
387 |
+
num_train_timesteps=1000,
|
388 |
+
)
|
389 |
+
|
390 |
+
uncond_input = tokenizer(
|
391 |
+
[''] * example["pixel_values"].shape[0],
|
392 |
+
padding="max_length",
|
393 |
+
max_length=tokenizer.model_max_length,
|
394 |
+
return_tensors="pt",
|
395 |
+
)
|
396 |
+
uncond_embeddings = text_encoder({'input_ids':uncond_input.input_ids.to(device)})[0]
|
397 |
+
|
398 |
+
if seed is None:
|
399 |
+
latents = torch.randn(
|
400 |
+
(example["pixel_values"].shape[0], unet.in_channels, 64, 64)
|
401 |
+
)
|
402 |
+
else:
|
403 |
+
generator = torch.manual_seed(seed)
|
404 |
+
latents = torch.randn(
|
405 |
+
(example["pixel_values"].shape[0], unet.in_channels, 64, 64), generator=generator,
|
406 |
+
)
|
407 |
+
|
408 |
+
latents = latents.to(example["pixel_values_clip"])
|
409 |
+
scheduler.set_timesteps(100)
|
410 |
+
latents = latents * scheduler.init_noise_sigma
|
411 |
+
|
412 |
+
placeholder_idx = example["index"]
|
413 |
+
image = F.interpolate(example["pixel_values_clip"], (224, 224), mode='bilinear')
|
414 |
+
|
415 |
+
image_features = image_encoder(image, output_hidden_states=True)
|
416 |
+
image_embeddings = [image_features[0], image_features[2][4], image_features[2][8], image_features[2][12],
|
417 |
+
image_features[2][16]]
|
418 |
+
image_embeddings = [emb.detach() for emb in image_embeddings]
|
419 |
+
inj_embedding = mapper(image_embeddings)
|
420 |
+
|
421 |
+
if token_index != 'full':
|
422 |
+
token_index = int(token_index)
|
423 |
+
inj_embedding = inj_embedding[:, token_index:token_index + 1, :]
|
424 |
+
|
425 |
+
encoder_hidden_states = text_encoder({'input_ids': example["input_ids"],
|
426 |
+
"inj_embedding": inj_embedding,
|
427 |
+
"inj_index": placeholder_idx})[0]
|
428 |
+
|
429 |
+
for t in tqdm(scheduler.timesteps):
|
430 |
+
latent_model_input = scheduler.scale_model_input(latents, t)
|
431 |
+
noise_pred_text = unet(
|
432 |
+
latent_model_input,
|
433 |
+
t,
|
434 |
+
encoder_hidden_states={
|
435 |
+
"CONTEXT_TENSOR": encoder_hidden_states,
|
436 |
+
}
|
437 |
+
).sample
|
438 |
+
|
439 |
+
latent_model_input = scheduler.scale_model_input(latents, t)
|
440 |
+
|
441 |
+
noise_pred_uncond = unet(
|
442 |
+
latent_model_input,
|
443 |
+
t,
|
444 |
+
encoder_hidden_states={
|
445 |
+
"CONTEXT_TENSOR": uncond_embeddings,
|
446 |
+
}
|
447 |
+
).sample
|
448 |
+
|
449 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
450 |
+
noise_pred_text - noise_pred_uncond
|
451 |
+
)
|
452 |
+
|
453 |
+
# compute the previous noisy sample x_t -> x_t-1
|
454 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
455 |
+
|
456 |
+
_latents = 1 / 0.18215 * latents.clone()
|
457 |
+
images = vae.decode(_latents).sample
|
458 |
+
ret_pil_images = [th2image(image) for image in images]
|
459 |
+
|
460 |
+
return ret_pil_images
|
461 |
+
|
462 |
+
def main():
|
463 |
+
args = parse_args()
|
464 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
465 |
+
|
466 |
+
accelerator = Accelerator(
|
467 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
468 |
+
mixed_precision=args.mixed_precision,
|
469 |
+
log_with="tensorboard",
|
470 |
+
logging_dir=logging_dir,
|
471 |
+
)
|
472 |
+
|
473 |
+
# If passed along, set the training seed now.
|
474 |
+
if args.seed is not None:
|
475 |
+
set_seed(args.seed)
|
476 |
+
|
477 |
+
# Handle the repository creation
|
478 |
+
if accelerator.is_main_process:
|
479 |
+
if args.push_to_hub:
|
480 |
+
if args.hub_model_id is None:
|
481 |
+
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
482 |
+
else:
|
483 |
+
repo_name = args.hub_model_id
|
484 |
+
repo = Repository(args.output_dir, clone_from=repo_name)
|
485 |
+
|
486 |
+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
487 |
+
if "step_*" not in gitignore:
|
488 |
+
gitignore.write("step_*\n")
|
489 |
+
if "epoch_*" not in gitignore:
|
490 |
+
gitignore.write("epoch_*\n")
|
491 |
+
elif args.output_dir is not None:
|
492 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
493 |
+
|
494 |
+
# Load the tokenizer and add the placeholder token as a additional special token
|
495 |
+
if args.tokenizer_name:
|
496 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
497 |
+
elif args.pretrained_model_name_or_path:
|
498 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
499 |
+
|
500 |
+
# Load models and create wrapper for stable diffusion
|
501 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
502 |
+
|
503 |
+
# replace the forward method of the text encoder to inject the word embedding
|
504 |
+
for _module in text_encoder.modules():
|
505 |
+
if _module.__class__.__name__ == "CLIPTextTransformer":
|
506 |
+
_module.__class__.__call__ = inj_forward_text
|
507 |
+
|
508 |
+
image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
|
509 |
+
|
510 |
+
mapper = Mapper(input_dim=1024, output_dim=768)
|
511 |
+
|
512 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
513 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
514 |
+
|
515 |
+
# replace the forward method of the crossattention to finetune the to_k and to_v layers
|
516 |
+
for _name, _module in unet.named_modules():
|
517 |
+
if _module.__class__.__name__ == "CrossAttention":
|
518 |
+
if 'attn1' in _name: continue
|
519 |
+
_module.__class__.__call__ = inj_forward_crossattention
|
520 |
+
|
521 |
+
shape = _module.to_k.weight.shape
|
522 |
+
to_k_global = nn.Linear(shape[1], shape[0], bias=False)
|
523 |
+
to_k_global.weight.data = _module.to_k.weight.data.clone()
|
524 |
+
mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global)
|
525 |
+
|
526 |
+
shape = _module.to_v.weight.shape
|
527 |
+
to_v_global = nn.Linear(shape[1], shape[0], bias=False)
|
528 |
+
to_v_global.weight.data = _module.to_v.weight.data.clone()
|
529 |
+
mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global)
|
530 |
+
|
531 |
+
if args.global_mapper_path is None:
|
532 |
+
_module.add_module('to_k_global', to_k_global)
|
533 |
+
_module.add_module('to_v_global', to_v_global)
|
534 |
+
|
535 |
+
if args.global_mapper_path is not None:
|
536 |
+
mapper.load_state_dict(torch.load(args.global_mapper_path, map_location='cpu'))
|
537 |
+
for _name, _module in unet.named_modules():
|
538 |
+
if _module.__class__.__name__ == "CrossAttention":
|
539 |
+
if 'attn1' in _name: continue
|
540 |
+
_module.add_module('to_k_global', getattr(mapper, f'{_name.replace(".", "_")}_to_k'))
|
541 |
+
_module.add_module('to_v_global', getattr(mapper, f'{_name.replace(".", "_")}_to_v'))
|
542 |
+
|
543 |
+
# Freeze vae and unet, encoder
|
544 |
+
freeze_params(vae.parameters())
|
545 |
+
freeze_params(unet.parameters())
|
546 |
+
freeze_params(text_encoder.parameters())
|
547 |
+
freeze_params(image_encoder.parameters())
|
548 |
+
|
549 |
+
# Unfreeze the mapper
|
550 |
+
unfreeze_params(mapper.parameters())
|
551 |
+
|
552 |
+
if args.scale_lr:
|
553 |
+
args.learning_rate = (
|
554 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
555 |
+
)
|
556 |
+
|
557 |
+
# Initialize the optimizer
|
558 |
+
optimizer = torch.optim.AdamW(
|
559 |
+
itertools.chain(mapper.parameters()), # only optimize the embeddings
|
560 |
+
lr=args.learning_rate,
|
561 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
562 |
+
weight_decay=args.adam_weight_decay,
|
563 |
+
eps=args.adam_epsilon,
|
564 |
+
)
|
565 |
+
|
566 |
+
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
|
567 |
+
|
568 |
+
train_dataset = OpenImagesDataset(
|
569 |
+
data_root=args.train_data_dir,
|
570 |
+
tokenizer=tokenizer,
|
571 |
+
size=args.resolution,
|
572 |
+
placeholder_token=args.placeholder_token,
|
573 |
+
set="test",
|
574 |
+
)
|
575 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
|
576 |
+
|
577 |
+
# Scheduler and math around the number of training steps.
|
578 |
+
overrode_max_train_steps = False
|
579 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
580 |
+
if args.max_train_steps is None:
|
581 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
582 |
+
overrode_max_train_steps = True
|
583 |
+
|
584 |
+
lr_scheduler = get_scheduler(
|
585 |
+
args.lr_scheduler,
|
586 |
+
optimizer=optimizer,
|
587 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
588 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
589 |
+
)
|
590 |
+
|
591 |
+
mapper, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
592 |
+
mapper, optimizer, train_dataloader, lr_scheduler
|
593 |
+
)
|
594 |
+
|
595 |
+
# Move vae, unet, and encoders to device
|
596 |
+
vae.to(accelerator.device)
|
597 |
+
unet.to(accelerator.device)
|
598 |
+
image_encoder.to(accelerator.device)
|
599 |
+
text_encoder.to(accelerator.device)
|
600 |
+
# Keep vae, unet and image_encoder in eval model as we don't train these
|
601 |
+
vae.eval()
|
602 |
+
unet.eval()
|
603 |
+
image_encoder.eval()
|
604 |
+
|
605 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
606 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
607 |
+
if overrode_max_train_steps:
|
608 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
609 |
+
# Afterwards we recalculate our number of training epochs
|
610 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
611 |
+
|
612 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
613 |
+
# The trackers initialize automatically on the main process.
|
614 |
+
if accelerator.is_main_process:
|
615 |
+
accelerator.init_trackers("elite", config=vars(args))
|
616 |
+
|
617 |
+
# Train!
|
618 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
619 |
+
|
620 |
+
logger.info("***** Running training *****")
|
621 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
622 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
623 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
624 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
625 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
626 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
627 |
+
# Only show the progress bar once on each machine.
|
628 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
629 |
+
progress_bar.set_description("Steps")
|
630 |
+
global_step = 0
|
631 |
+
|
632 |
+
for epoch in range(args.num_train_epochs):
|
633 |
+
mapper.train()
|
634 |
+
for step, batch in enumerate(train_dataloader):
|
635 |
+
with accelerator.accumulate(mapper):
|
636 |
+
# Convert images to latent space
|
637 |
+
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
638 |
+
latents = latents * 0.18215
|
639 |
+
|
640 |
+
# Sample noise that we'll add to the latents
|
641 |
+
noise = torch.randn(latents.shape).to(latents.device)
|
642 |
+
bsz = latents.shape[0]
|
643 |
+
# Sample a random timestep for each image
|
644 |
+
timesteps = torch.randint(
|
645 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
646 |
+
).long()
|
647 |
+
|
648 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
649 |
+
# (this is the forward diffusion process)
|
650 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
651 |
+
|
652 |
+
placeholder_idx = batch["index"]
|
653 |
+
image = F.interpolate(batch["pixel_values_clip"], (224, 224), mode='bilinear')
|
654 |
+
|
655 |
+
image_features = image_encoder(image, output_hidden_states=True)
|
656 |
+
image_embeddings = [image_features[0], image_features[2][4], image_features[2][8], image_features[2][12], image_features[2][16]]
|
657 |
+
image_embeddings = [emb.detach() for emb in image_embeddings]
|
658 |
+
inj_embedding = mapper(image_embeddings)
|
659 |
+
|
660 |
+
# Get the text embedding for conditioning
|
661 |
+
encoder_hidden_states = text_encoder({'input_ids': batch["input_ids"],
|
662 |
+
"inj_embedding": inj_embedding,
|
663 |
+
"inj_index": placeholder_idx.detach()})[0]
|
664 |
+
|
665 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states={
|
666 |
+
"CONTEXT_TENSOR": encoder_hidden_states,
|
667 |
+
}).sample
|
668 |
+
|
669 |
+
loss_mle = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
670 |
+
|
671 |
+
loss_reg = torch.mean(torch.abs(inj_embedding)) * 0.01
|
672 |
+
|
673 |
+
loss = loss_mle + loss_reg
|
674 |
+
|
675 |
+
accelerator.backward(loss)
|
676 |
+
|
677 |
+
if accelerator.sync_gradients:
|
678 |
+
accelerator.clip_grad_norm_(mapper.parameters(), 1)
|
679 |
+
|
680 |
+
optimizer.step()
|
681 |
+
lr_scheduler.step()
|
682 |
+
optimizer.zero_grad()
|
683 |
+
|
684 |
+
|
685 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
686 |
+
if accelerator.sync_gradients:
|
687 |
+
progress_bar.update(1)
|
688 |
+
global_step += 1
|
689 |
+
if global_step % args.save_steps == 0:
|
690 |
+
save_progress(mapper, accelerator, args, global_step)
|
691 |
+
syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, vae, batch["pixel_values_clip"].device, 5)
|
692 |
+
gt_images = [th2image(img) for img in batch["pixel_values"]]
|
693 |
+
img_list = []
|
694 |
+
for syn, gt in zip(syn_images, gt_images):
|
695 |
+
img_list.append(np.concatenate((np.array(syn), np.array(gt)), axis=1))
|
696 |
+
img_list = np.concatenate(img_list, axis=0)
|
697 |
+
Image.fromarray(img_list).save(os.path.join(args.output_dir, f"{str(global_step).zfill(5)}.jpg"))
|
698 |
+
|
699 |
+
logs = {"loss_mle": loss_mle.detach().item(), "loss_reg": loss_reg.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
700 |
+
progress_bar.set_postfix(**logs)
|
701 |
+
accelerator.log(logs, step=global_step)
|
702 |
+
|
703 |
+
if global_step >= args.max_train_steps:
|
704 |
+
break
|
705 |
+
|
706 |
+
accelerator.wait_for_everyone()
|
707 |
+
|
708 |
+
if accelerator.is_main_process:
|
709 |
+
save_progress(mapper, accelerator, args)
|
710 |
+
|
711 |
+
accelerator.end_training()
|
712 |
+
|
713 |
+
|
714 |
+
if __name__ == "__main__":
|
715 |
+
main()
|
train_global.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
+
export DATA_DIR='/home/weiyuxiang/datasets/Open_Images/'
|
3 |
+
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config_file 4_gpu.json --main_process_port 25656 train_global.py \
|
4 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
5 |
+
--train_data_dir=$DATA_DIR \
|
6 |
+
--placeholder_token="S" \
|
7 |
+
--resolution=512 \
|
8 |
+
--train_batch_size=4 \
|
9 |
+
--gradient_accumulation_steps=4 \
|
10 |
+
--max_train_steps=200000 \
|
11 |
+
--learning_rate=1e-06 --scale_lr \
|
12 |
+
--lr_scheduler="constant" \
|
13 |
+
--lr_warmup_steps=0 \
|
14 |
+
--output_dir="./elite_experiments/global_mapping" \
|
15 |
+
--save_steps 200
|
train_local.py
ADDED
@@ -0,0 +1,709 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
import itertools
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.utils.checkpoint
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
|
15 |
+
import PIL
|
16 |
+
from accelerate import Accelerator
|
17 |
+
from accelerate.logging import get_logger
|
18 |
+
from accelerate.utils import set_seed
|
19 |
+
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
|
20 |
+
from diffusers.optimization import get_scheduler
|
21 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
22 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
23 |
+
|
24 |
+
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
25 |
+
from PIL import Image
|
26 |
+
from tqdm.auto import tqdm
|
27 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
28 |
+
|
29 |
+
|
30 |
+
from typing import Optional
|
31 |
+
from train_global import inj_forward_text, th2image, Mapper
|
32 |
+
from datasets import OpenImagesDatasetWithMask
|
33 |
+
|
34 |
+
|
35 |
+
class MapperLocal(nn.Module):
|
36 |
+
def __init__(self,
|
37 |
+
input_dim: int,
|
38 |
+
output_dim: int,
|
39 |
+
):
|
40 |
+
super(MapperLocal, self).__init__()
|
41 |
+
|
42 |
+
for i in range(5):
|
43 |
+
setattr(self, f'mapping_{i}', nn.Sequential(nn.Linear(input_dim, 1024),
|
44 |
+
nn.LayerNorm(1024),
|
45 |
+
nn.LeakyReLU(),
|
46 |
+
nn.Linear(1024, 1024),
|
47 |
+
nn.LayerNorm(1024),
|
48 |
+
nn.LeakyReLU(),
|
49 |
+
nn.Linear(1024, output_dim)))
|
50 |
+
|
51 |
+
setattr(self, f'mapping_patch_{i}', nn.Sequential(nn.Linear(input_dim, 1024),
|
52 |
+
nn.LayerNorm(1024),
|
53 |
+
nn.LeakyReLU(),
|
54 |
+
nn.Linear(1024, 1024),
|
55 |
+
nn.LayerNorm(1024),
|
56 |
+
nn.LeakyReLU(),
|
57 |
+
nn.Linear(1024, output_dim)))
|
58 |
+
|
59 |
+
def forward(self, embs):
|
60 |
+
hidden_states = ()
|
61 |
+
for i, emb in enumerate(embs):
|
62 |
+
hidden_state = getattr(self, f'mapping_{i}')(emb[:, :1]) + getattr(self, f'mapping_patch_{i}')(emb[:, 1:])
|
63 |
+
hidden_states += (hidden_state.unsqueeze(0),)
|
64 |
+
hidden_states = torch.cat(hidden_states, dim=0).mean(dim=0)
|
65 |
+
return hidden_states
|
66 |
+
|
67 |
+
value_local_list = []
|
68 |
+
|
69 |
+
def inj_forward_crossattention(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
70 |
+
|
71 |
+
context = encoder_hidden_states
|
72 |
+
hidden_states_local = hidden_states.clone()
|
73 |
+
if context is not None:
|
74 |
+
context_tensor = context["CONTEXT_TENSOR"]
|
75 |
+
else:
|
76 |
+
context_tensor = hidden_states
|
77 |
+
|
78 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
79 |
+
|
80 |
+
query = self.to_q(hidden_states)
|
81 |
+
|
82 |
+
if context is not None:
|
83 |
+
key = self.to_k_global(context_tensor)
|
84 |
+
value = self.to_v_global(context_tensor)
|
85 |
+
else:
|
86 |
+
key = self.to_k(context_tensor)
|
87 |
+
value = self.to_v(context_tensor)
|
88 |
+
|
89 |
+
dim = query.shape[-1]
|
90 |
+
|
91 |
+
query = self.reshape_heads_to_batch_dim(query)
|
92 |
+
key = self.reshape_heads_to_batch_dim(key)
|
93 |
+
value = self.reshape_heads_to_batch_dim(value)
|
94 |
+
|
95 |
+
|
96 |
+
attention_scores = torch.matmul(query, key.transpose(-1, -2))
|
97 |
+
attention_scores = attention_scores * self.scale
|
98 |
+
|
99 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
100 |
+
|
101 |
+
hidden_states = torch.matmul(attention_probs, value)
|
102 |
+
|
103 |
+
if context is not None and "LOCAL" in context:
|
104 |
+
# Perform cross attention with the local context
|
105 |
+
query_local = self.to_q(hidden_states_local)
|
106 |
+
key_local = self.to_k_local(context["LOCAL"])
|
107 |
+
value_local = self.to_v_local(context["LOCAL"])
|
108 |
+
|
109 |
+
query_local = self.reshape_heads_to_batch_dim(query_local)
|
110 |
+
key_local = self.reshape_heads_to_batch_dim(key_local)
|
111 |
+
value_local = self.reshape_heads_to_batch_dim(value_local)
|
112 |
+
|
113 |
+
attention_scores_local = torch.matmul(query_local, key_local.transpose(-1, -2))
|
114 |
+
attention_scores_local = attention_scores_local * self.scale
|
115 |
+
attention_probs_local = attention_scores_local.softmax(dim=-1)
|
116 |
+
|
117 |
+
# To extract the attmap of learned [w]
|
118 |
+
index_local = context["LOCAL_INDEX"]
|
119 |
+
index_local = index_local.reshape(index_local.shape[0], 1).repeat((1, self.heads)).reshape(-1)
|
120 |
+
attention_probs_clone = attention_probs.clone().permute((0, 2, 1))
|
121 |
+
attention_probs_mask = attention_probs_clone[torch.arange(index_local.shape[0]), index_local]
|
122 |
+
# Normalize the attention map
|
123 |
+
attention_probs_mask = attention_probs_mask.unsqueeze(2) / attention_probs_mask.max()
|
124 |
+
|
125 |
+
if "LAMBDA" in context:
|
126 |
+
_lambda = context["LAMBDA"]
|
127 |
+
else:
|
128 |
+
_lambda = 1
|
129 |
+
|
130 |
+
attention_probs_local = attention_probs_local * attention_probs_mask * _lambda
|
131 |
+
hidden_states += torch.matmul(attention_probs_local, value_local)
|
132 |
+
value_local_list.append(value_local)
|
133 |
+
|
134 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
135 |
+
|
136 |
+
# linear proj
|
137 |
+
hidden_states = self.to_out[0](hidden_states)
|
138 |
+
# dropout
|
139 |
+
hidden_states = self.to_out[1](hidden_states)
|
140 |
+
|
141 |
+
return hidden_states
|
142 |
+
|
143 |
+
# ------------------------------------------------------------------------------
|
144 |
+
|
145 |
+
logger = get_logger(__name__)
|
146 |
+
|
147 |
+
|
148 |
+
def save_progress(mapper, accelerator, args, step=None):
|
149 |
+
logger.info("Saving embeddings")
|
150 |
+
|
151 |
+
state_dict = accelerator.unwrap_model(mapper).state_dict()
|
152 |
+
|
153 |
+
if step is not None:
|
154 |
+
torch.save(state_dict, os.path.join(args.output_dir, f"local_mapper_{str(step).zfill(6)}.pt"))
|
155 |
+
else:
|
156 |
+
torch.save(state_dict, os.path.join(args.output_dir, "local_mapper.pt"))
|
157 |
+
|
158 |
+
|
159 |
+
def parse_args():
|
160 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
161 |
+
parser.add_argument(
|
162 |
+
"--save_steps",
|
163 |
+
type=int,
|
164 |
+
default=500,
|
165 |
+
help="Save learned_embeds.bin every X updates steps.",
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--pretrained_model_name_or_path",
|
169 |
+
type=str,
|
170 |
+
default=None,
|
171 |
+
required=True,
|
172 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--tokenizer_name",
|
176 |
+
type=str,
|
177 |
+
default=None,
|
178 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--global_mapper_path", type=str, default=None,
|
185 |
+
help="If not none, the training will start from the given checkpoints."
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--local_mapper_path", type=str, default=None,
|
189 |
+
help="If not none, the training will start from the given checkpoints."
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--placeholder_token",
|
193 |
+
type=str,
|
194 |
+
default=None,
|
195 |
+
required=True,
|
196 |
+
help="A token to use as a placeholder for the concept.",
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--output_dir",
|
200 |
+
type=str,
|
201 |
+
default="text-inversion-model",
|
202 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
203 |
+
)
|
204 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
205 |
+
parser.add_argument(
|
206 |
+
"--resolution",
|
207 |
+
type=int,
|
208 |
+
default=512,
|
209 |
+
help=(
|
210 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
211 |
+
" resolution"
|
212 |
+
),
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
216 |
+
)
|
217 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
218 |
+
parser.add_argument(
|
219 |
+
"--max_train_steps",
|
220 |
+
type=int,
|
221 |
+
default=5000,
|
222 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--gradient_accumulation_steps",
|
226 |
+
type=int,
|
227 |
+
default=1,
|
228 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
229 |
+
)
|
230 |
+
parser.add_argument(
|
231 |
+
"--learning_rate",
|
232 |
+
type=float,
|
233 |
+
default=1e-4,
|
234 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--scale_lr",
|
238 |
+
action="store_true",
|
239 |
+
default=True,
|
240 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--lr_scheduler",
|
244 |
+
type=str,
|
245 |
+
default="constant",
|
246 |
+
help=(
|
247 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
248 |
+
' "constant", "constant_with_warmup"]'
|
249 |
+
),
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
253 |
+
)
|
254 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
255 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
256 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
257 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
258 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
259 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
260 |
+
parser.add_argument(
|
261 |
+
"--hub_model_id",
|
262 |
+
type=str,
|
263 |
+
default=None,
|
264 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
265 |
+
)
|
266 |
+
parser.add_argument(
|
267 |
+
"--logging_dir",
|
268 |
+
type=str,
|
269 |
+
default="logs",
|
270 |
+
help=(
|
271 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
272 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
273 |
+
),
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--mixed_precision",
|
277 |
+
type=str,
|
278 |
+
default="no",
|
279 |
+
choices=["no", "fp16", "bf16"],
|
280 |
+
help=(
|
281 |
+
"Whether to use mixed precision. Choose"
|
282 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
283 |
+
"and an Nvidia Ampere GPU."
|
284 |
+
),
|
285 |
+
)
|
286 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
287 |
+
|
288 |
+
args = parser.parse_args()
|
289 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
290 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
291 |
+
args.local_rank = env_local_rank
|
292 |
+
|
293 |
+
if args.train_data_dir is None:
|
294 |
+
raise ValueError("You must specify a train data directory.")
|
295 |
+
|
296 |
+
return args
|
297 |
+
|
298 |
+
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
299 |
+
if token is None:
|
300 |
+
token = HfFolder.get_token()
|
301 |
+
if organization is None:
|
302 |
+
username = whoami(token)["name"]
|
303 |
+
return f"{username}/{model_id}"
|
304 |
+
else:
|
305 |
+
return f"{organization}/{model_id}"
|
306 |
+
|
307 |
+
|
308 |
+
def freeze_params(params):
|
309 |
+
for param in params:
|
310 |
+
param.requires_grad = False
|
311 |
+
|
312 |
+
def unfreeze_params(params):
|
313 |
+
for param in params:
|
314 |
+
param.requires_grad = True
|
315 |
+
|
316 |
+
|
317 |
+
@torch.no_grad()
|
318 |
+
def validation(example, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae, device, guidance_scale, seed=None, llambda=1):
|
319 |
+
scheduler = LMSDiscreteScheduler(
|
320 |
+
beta_start=0.00085,
|
321 |
+
beta_end=0.012,
|
322 |
+
beta_schedule="scaled_linear",
|
323 |
+
num_train_timesteps=1000,
|
324 |
+
)
|
325 |
+
|
326 |
+
uncond_input = tokenizer(
|
327 |
+
[''] * example["pixel_values"].shape[0],
|
328 |
+
padding="max_length",
|
329 |
+
max_length=tokenizer.model_max_length,
|
330 |
+
return_tensors="pt",
|
331 |
+
)
|
332 |
+
uncond_embeddings = text_encoder({'input_ids':uncond_input.input_ids.to(device)})[0]
|
333 |
+
|
334 |
+
if seed is None:
|
335 |
+
latents = torch.randn(
|
336 |
+
(example["pixel_values"].shape[0], unet.in_channels, 64, 64)
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
generator = torch.manual_seed(seed)
|
340 |
+
latents = torch.randn(
|
341 |
+
(example["pixel_values"].shape[0], unet.in_channels, 64, 64), generator=generator,
|
342 |
+
)
|
343 |
+
|
344 |
+
latents = latents.to(example["pixel_values_clip"])
|
345 |
+
scheduler.set_timesteps(100)
|
346 |
+
latents = latents * scheduler.init_noise_sigma
|
347 |
+
|
348 |
+
placeholder_idx = example["index"]
|
349 |
+
|
350 |
+
image = F.interpolate(example["pixel_values_clip"], (224, 224), mode='bilinear')
|
351 |
+
image_features = image_encoder(image, output_hidden_states=True)
|
352 |
+
image_embeddings = [image_features[0], image_features[2][4], image_features[2][8], image_features[2][12], image_features[2][16]]
|
353 |
+
image_embeddings = [emb.detach() for emb in image_embeddings]
|
354 |
+
inj_embedding = mapper(image_embeddings)
|
355 |
+
|
356 |
+
inj_embedding = inj_embedding[:, 0:1, :]
|
357 |
+
encoder_hidden_states = text_encoder({'input_ids': example["input_ids"],
|
358 |
+
"inj_embedding": inj_embedding,
|
359 |
+
"inj_index": placeholder_idx})[0]
|
360 |
+
|
361 |
+
image_obj = F.interpolate(example["pixel_values_obj"], (224, 224), mode='bilinear')
|
362 |
+
image_features_obj = image_encoder(image_obj, output_hidden_states=True)
|
363 |
+
image_embeddings_obj = [image_features_obj[0], image_features_obj[2][4], image_features_obj[2][8],
|
364 |
+
image_features_obj[2][12], image_features_obj[2][16]]
|
365 |
+
image_embeddings_obj = [emb.detach() for emb in image_embeddings_obj]
|
366 |
+
|
367 |
+
inj_embedding_local = mapper_local(image_embeddings_obj)
|
368 |
+
mask = F.interpolate(example["pixel_values_seg"], (16, 16), mode='nearest')
|
369 |
+
mask = mask[:, 0].reshape(mask.shape[0], -1, 1)
|
370 |
+
inj_embedding_local = inj_embedding_local * mask
|
371 |
+
|
372 |
+
|
373 |
+
for t in tqdm(scheduler.timesteps):
|
374 |
+
latent_model_input = scheduler.scale_model_input(latents, t)
|
375 |
+
noise_pred_text = unet(
|
376 |
+
latent_model_input,
|
377 |
+
t,
|
378 |
+
encoder_hidden_states={
|
379 |
+
"CONTEXT_TENSOR": encoder_hidden_states,
|
380 |
+
"LOCAL": inj_embedding_local,
|
381 |
+
"LOCAL_INDEX": placeholder_idx.detach(),
|
382 |
+
"LAMBDA": llambda
|
383 |
+
}
|
384 |
+
).sample
|
385 |
+
value_local_list.clear()
|
386 |
+
latent_model_input = scheduler.scale_model_input(latents, t)
|
387 |
+
|
388 |
+
noise_pred_uncond = unet(
|
389 |
+
latent_model_input,
|
390 |
+
t,
|
391 |
+
encoder_hidden_states={
|
392 |
+
"CONTEXT_TENSOR": uncond_embeddings,
|
393 |
+
}
|
394 |
+
).sample
|
395 |
+
value_local_list.clear()
|
396 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
397 |
+
noise_pred_text - noise_pred_uncond
|
398 |
+
)
|
399 |
+
|
400 |
+
# compute the previous noisy sample x_t -> x_t-1
|
401 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
402 |
+
|
403 |
+
_latents = 1 / 0.18215 * latents.clone()
|
404 |
+
images = vae.decode(_latents).sample
|
405 |
+
ret_pil_images = [th2image(image) for image in images]
|
406 |
+
|
407 |
+
return ret_pil_images
|
408 |
+
|
409 |
+
def main():
|
410 |
+
args = parse_args()
|
411 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
412 |
+
|
413 |
+
accelerator = Accelerator(
|
414 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
415 |
+
mixed_precision=args.mixed_precision,
|
416 |
+
log_with="tensorboard",
|
417 |
+
logging_dir=logging_dir,
|
418 |
+
)
|
419 |
+
|
420 |
+
# If passed along, set the training seed now.
|
421 |
+
if args.seed is not None:
|
422 |
+
set_seed(args.seed)
|
423 |
+
|
424 |
+
# Handle the repository creation
|
425 |
+
if accelerator.is_main_process:
|
426 |
+
if args.push_to_hub:
|
427 |
+
if args.hub_model_id is None:
|
428 |
+
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
429 |
+
else:
|
430 |
+
repo_name = args.hub_model_id
|
431 |
+
repo = Repository(args.output_dir, clone_from=repo_name)
|
432 |
+
|
433 |
+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
434 |
+
if "step_*" not in gitignore:
|
435 |
+
gitignore.write("step_*\n")
|
436 |
+
if "epoch_*" not in gitignore:
|
437 |
+
gitignore.write("epoch_*\n")
|
438 |
+
elif args.output_dir is not None:
|
439 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
440 |
+
|
441 |
+
# Load the tokenizer and add the placeholder token as a additional special token
|
442 |
+
|
443 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
444 |
+
# Load models and create wrapper for stable diffusion
|
445 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
446 |
+
|
447 |
+
for _module in text_encoder.modules():
|
448 |
+
if _module.__class__.__name__ == "CLIPTextTransformer":
|
449 |
+
_module.__class__.__call__ = inj_forward_text
|
450 |
+
|
451 |
+
image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
|
452 |
+
|
453 |
+
mapper = Mapper(input_dim=1024, output_dim=768)
|
454 |
+
mapper_local = MapperLocal(input_dim=1024, output_dim=768)
|
455 |
+
|
456 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
457 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
458 |
+
|
459 |
+
# replace the forward method of the crossattention to finetune the to_k and to_v layers
|
460 |
+
for _name, _module in unet.named_modules():
|
461 |
+
if _module.__class__.__name__ == "CrossAttention":
|
462 |
+
if 'attn1' in _name: continue
|
463 |
+
_module.__class__.__call__ = inj_forward_crossattention
|
464 |
+
|
465 |
+
shape = _module.to_k.weight.shape
|
466 |
+
to_k_global = nn.Linear(shape[1], shape[0], bias=False)
|
467 |
+
to_k_global.weight.data = _module.to_k.weight.data.clone()
|
468 |
+
mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global)
|
469 |
+
|
470 |
+
shape = _module.to_v.weight.shape
|
471 |
+
to_v_global = nn.Linear(shape[1], shape[0], bias=False)
|
472 |
+
to_v_global.weight.data = _module.to_v.weight.data.clone()
|
473 |
+
mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global)
|
474 |
+
|
475 |
+
to_k_local = nn.Linear(shape[1], shape[0], bias=False)
|
476 |
+
to_k_local.weight.data = _module.to_k.weight.data.clone()
|
477 |
+
mapper_local.add_module(f'{_name.replace(".", "_")}_to_k', to_k_local)
|
478 |
+
_module.add_module('to_k_local', to_k_local)
|
479 |
+
|
480 |
+
to_v_local = nn.Linear(shape[1], shape[0], bias=False)
|
481 |
+
to_v_local.weight.data = _module.to_v.weight.data.clone()
|
482 |
+
mapper_local.add_module(f'{_name.replace(".", "_")}_to_v', to_v_local)
|
483 |
+
_module.add_module('to_v_local', to_v_local)
|
484 |
+
|
485 |
+
if args.global_mapper_path is None:
|
486 |
+
_module.add_module('to_k_global', to_k_global)
|
487 |
+
_module.add_module('to_v_global', to_v_global)
|
488 |
+
|
489 |
+
if args.local_mapper_path is None:
|
490 |
+
_module.add_module('to_k_local', to_k_local)
|
491 |
+
_module.add_module('to_v_local', to_v_local)
|
492 |
+
|
493 |
+
if args.global_mapper_path is not None:
|
494 |
+
mapper.load_state_dict(torch.load(args.global_mapper_path, map_location='cpu'))
|
495 |
+
for _name, _module in unet.named_modules():
|
496 |
+
if _module.__class__.__name__ == "CrossAttention":
|
497 |
+
if 'attn1' in _name: continue
|
498 |
+
_module.add_module('to_k_global', getattr(mapper, f'{_name.replace(".", "_")}_to_k'))
|
499 |
+
_module.add_module('to_v_global', getattr(mapper, f'{_name.replace(".", "_")}_to_v'))
|
500 |
+
|
501 |
+
if args.local_mapper_path is not None:
|
502 |
+
mapper_local.load_state_dict(torch.load(args.local_mapper_path, map_location='cpu'))
|
503 |
+
for _name, _module in unet.named_modules():
|
504 |
+
if _module.__class__.__name__ == "CrossAttention":
|
505 |
+
if 'attn1' in _name: continue
|
506 |
+
_module.add_module('to_k_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_k'))
|
507 |
+
_module.add_module('to_v_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_v'))
|
508 |
+
|
509 |
+
# Freeze vae and unet
|
510 |
+
freeze_params(vae.parameters())
|
511 |
+
freeze_params(unet.parameters())
|
512 |
+
freeze_params(text_encoder.parameters())
|
513 |
+
freeze_params(image_encoder.parameters())
|
514 |
+
unfreeze_params(mapper_local.parameters())
|
515 |
+
|
516 |
+
if args.scale_lr:
|
517 |
+
args.learning_rate = (
|
518 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
519 |
+
)
|
520 |
+
|
521 |
+
# Initialize the optimizer
|
522 |
+
optimizer = torch.optim.AdamW(
|
523 |
+
itertools.chain(mapper_local.parameters()), # only optimize the embeddings
|
524 |
+
lr=args.learning_rate,
|
525 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
526 |
+
weight_decay=args.adam_weight_decay,
|
527 |
+
eps=args.adam_epsilon,
|
528 |
+
)
|
529 |
+
|
530 |
+
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
|
531 |
+
|
532 |
+
train_dataset = OpenImagesDatasetWithMask(
|
533 |
+
data_root=args.train_data_dir,
|
534 |
+
tokenizer=tokenizer,
|
535 |
+
size=args.resolution,
|
536 |
+
placeholder_token=args.placeholder_token,
|
537 |
+
set="test"
|
538 |
+
)
|
539 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
|
540 |
+
|
541 |
+
# Scheduler and math around the number of training steps.
|
542 |
+
overrode_max_train_steps = False
|
543 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
544 |
+
if args.max_train_steps is None:
|
545 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
546 |
+
overrode_max_train_steps = True
|
547 |
+
|
548 |
+
lr_scheduler = get_scheduler(
|
549 |
+
args.lr_scheduler,
|
550 |
+
optimizer=optimizer,
|
551 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
552 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
553 |
+
)
|
554 |
+
|
555 |
+
mapper_local, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
556 |
+
mapper_local, optimizer, train_dataloader, lr_scheduler
|
557 |
+
)
|
558 |
+
|
559 |
+
# Move vae and unet to device
|
560 |
+
vae.to(accelerator.device)
|
561 |
+
unet.to(accelerator.device)
|
562 |
+
image_encoder.to(accelerator.device)
|
563 |
+
text_encoder.to(accelerator.device)
|
564 |
+
mapper.to(accelerator.device)
|
565 |
+
# Keep vae and unet in eval model as we don't train these
|
566 |
+
vae.eval()
|
567 |
+
unet.eval()
|
568 |
+
image_encoder.eval()
|
569 |
+
mapper.eval()
|
570 |
+
|
571 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
572 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
573 |
+
if overrode_max_train_steps:
|
574 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
575 |
+
# Afterwards we recalculate our number of training epochs
|
576 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
577 |
+
|
578 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
579 |
+
# The trackers initialize automatically on the main process.
|
580 |
+
if accelerator.is_main_process:
|
581 |
+
accelerator.init_trackers("elite", config=vars(args))
|
582 |
+
|
583 |
+
# Train!
|
584 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
585 |
+
|
586 |
+
logger.info("***** Running training *****")
|
587 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
588 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
589 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
590 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
591 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
592 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
593 |
+
# Only show the progress bar once on each machine.
|
594 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
595 |
+
progress_bar.set_description("Steps")
|
596 |
+
global_step = 0
|
597 |
+
|
598 |
+
for epoch in range(args.num_train_epochs):
|
599 |
+
mapper_local.train()
|
600 |
+
for step, batch in enumerate(train_dataloader):
|
601 |
+
with accelerator.accumulate(mapper_local):
|
602 |
+
# Convert images to latent space
|
603 |
+
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
604 |
+
latents = latents * 0.18215
|
605 |
+
|
606 |
+
# Sample noise that we'll add to the latents
|
607 |
+
noise = torch.randn(latents.shape).to(latents.device)
|
608 |
+
bsz = latents.shape[0]
|
609 |
+
# Sample a random timestep for each image
|
610 |
+
timesteps = torch.randint(
|
611 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
612 |
+
).long()
|
613 |
+
|
614 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
615 |
+
# (this is the forward diffusion process)
|
616 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
617 |
+
|
618 |
+
placeholder_idx = batch["index"]
|
619 |
+
image = F.interpolate(batch["pixel_values_clip"], (224, 224), mode='bilinear')
|
620 |
+
image_obj = F.interpolate(batch["pixel_values_obj"], (224, 224), mode='bilinear')
|
621 |
+
|
622 |
+
mask = F.interpolate(batch["pixel_values_seg"], (16, 16), mode='nearest')
|
623 |
+
mask = mask[:, 0].reshape(mask.shape[0], -1, 1)
|
624 |
+
|
625 |
+
image_features = image_encoder(image, output_hidden_states=True)
|
626 |
+
image_embeddings = [image_features[0], image_features[2][4], image_features[2][8], image_features[2][12], image_features[2][16]]
|
627 |
+
image_embeddings = [emb.detach() for emb in image_embeddings]
|
628 |
+
inj_embedding = mapper(image_embeddings)
|
629 |
+
|
630 |
+
# only use the first word
|
631 |
+
inj_embedding = inj_embedding[:, 0:1, :]
|
632 |
+
|
633 |
+
# Get the text embedding for conditioning
|
634 |
+
encoder_hidden_states = text_encoder({'input_ids': batch["input_ids"],
|
635 |
+
"inj_embedding": inj_embedding,
|
636 |
+
"inj_index": placeholder_idx.detach()})[0]
|
637 |
+
|
638 |
+
image_features_obj = image_encoder(image_obj, output_hidden_states=True)
|
639 |
+
image_embeddings_obj = [image_features_obj[0], image_features_obj[2][4], image_features_obj[2][8], image_features_obj[2][12], image_features_obj[2][16]]
|
640 |
+
image_embeddings_obj = [emb.detach() for emb in image_embeddings_obj]
|
641 |
+
|
642 |
+
inj_embedding_local = mapper_local(image_embeddings_obj)
|
643 |
+
inj_embedding_local = inj_embedding_local * mask
|
644 |
+
|
645 |
+
|
646 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states={
|
647 |
+
"CONTEXT_TENSOR": encoder_hidden_states,
|
648 |
+
"LOCAL": inj_embedding_local,
|
649 |
+
"LOCAL_INDEX": placeholder_idx.detach()
|
650 |
+
}).sample
|
651 |
+
|
652 |
+
mask_values = batch["mask_values"]
|
653 |
+
loss_mle = F.mse_loss(noise_pred, noise, reduction="none")
|
654 |
+
loss_mle = ((loss_mle*mask_values).sum([1, 2, 3])/mask_values.sum([1, 2, 3])).mean()
|
655 |
+
|
656 |
+
loss_reg = 0
|
657 |
+
for vvv in value_local_list:
|
658 |
+
loss_reg += torch.mean(torch.abs(vvv))
|
659 |
+
loss_reg = loss_reg / len(value_local_list) * 0.0001
|
660 |
+
|
661 |
+
loss = loss_mle + loss_reg
|
662 |
+
|
663 |
+
accelerator.backward(loss)
|
664 |
+
|
665 |
+
if accelerator.sync_gradients:
|
666 |
+
accelerator.clip_grad_norm_(mapper_local.parameters(), 1)
|
667 |
+
|
668 |
+
optimizer.step()
|
669 |
+
lr_scheduler.step()
|
670 |
+
optimizer.zero_grad()
|
671 |
+
value_local_list.clear()
|
672 |
+
|
673 |
+
|
674 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
675 |
+
if accelerator.sync_gradients:
|
676 |
+
progress_bar.update(1)
|
677 |
+
global_step += 1
|
678 |
+
if global_step % args.save_steps == 0:
|
679 |
+
save_progress(mapper_local, accelerator, args, global_step)
|
680 |
+
syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae, batch["pixel_values_clip"].device, 5)
|
681 |
+
input_images = [th2image(img) for img in batch["pixel_values"]]
|
682 |
+
clip_images = [th2image(img).resize((512, 512)) for img in batch["pixel_values_clip"]]
|
683 |
+
obj_images = [th2image(img).resize((512, 512)) for img in batch["pixel_values_obj"]]
|
684 |
+
input_masks = torch.cat([mask_values, mask_values, mask_values], dim=1)
|
685 |
+
input_masks = [th2image(img).resize((512, 512)) for img in input_masks]
|
686 |
+
obj_masks = [th2image(img).resize((512, 512)) for img in batch["pixel_values_seg"]]
|
687 |
+
img_list = []
|
688 |
+
for syn, input_img, input_mask, clip_image, obj_image, obj_mask in zip(syn_images, input_images, input_masks, clip_images, obj_images, obj_masks):
|
689 |
+
img_list.append(np.concatenate((np.array(syn), np.array(input_img), np.array(input_mask), np.array(clip_image), np.array(obj_image), np.array(obj_mask)), axis=1))
|
690 |
+
img_list = np.concatenate(img_list, axis=0)
|
691 |
+
Image.fromarray(img_list).save(os.path.join(args.output_dir, f"{str(global_step).zfill(5)}.jpg"))
|
692 |
+
|
693 |
+
logs = {"loss_mle": loss_mle.detach().item(), "loss_reg": loss_reg.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
694 |
+
progress_bar.set_postfix(**logs)
|
695 |
+
accelerator.log(logs, step=global_step)
|
696 |
+
|
697 |
+
if global_step >= args.max_train_steps:
|
698 |
+
break
|
699 |
+
|
700 |
+
accelerator.wait_for_everyone()
|
701 |
+
|
702 |
+
if accelerator.is_main_process:
|
703 |
+
save_progress(mapper_local, accelerator, args)
|
704 |
+
|
705 |
+
accelerator.end_training()
|
706 |
+
|
707 |
+
|
708 |
+
if __name__ == "__main__":
|
709 |
+
main()
|
train_local.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
+
export DATA_DIR='/home/weiyuxiang/datasets/Open_Images/'
|
3 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 4_gpu.json --main_process_port 25657 train_local.py \
|
4 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
5 |
+
--train_data_dir=$DATA_DIR \
|
6 |
+
--placeholder_token="S" \
|
7 |
+
--resolution=512 \
|
8 |
+
--train_batch_size=2 \
|
9 |
+
--gradient_accumulation_steps=4 \
|
10 |
+
--max_train_steps=200000 \
|
11 |
+
--learning_rate=1e-5 --scale_lr \
|
12 |
+
--lr_scheduler="constant" \
|
13 |
+
--lr_warmup_steps=0 \
|
14 |
+
--global_mapper_path "./elite_experiments/global_mapping/mapper_070000.pt" \
|
15 |
+
--output_dir="./elite_experiments/local_mapping" \
|
16 |
+
--save_steps 200
|