Spaces:
Running
on
Zero
Running
on
Zero
邬彦泽
commited on
Commit
•
aa8012e
1
Parent(s):
dbb55f6
This view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +201 -0
- README.md +2 -2
- app.py +298 -0
- eva_clip/__init__.py +10 -0
- eva_clip/constants.py +2 -0
- eva_clip/eva_vit_model.py +633 -0
- eva_clip/factory.py +517 -0
- eva_clip/hf_configs.py +57 -0
- eva_clip/hf_model.py +248 -0
- eva_clip/loss.py +138 -0
- eva_clip/model.py +440 -0
- eva_clip/model_configs/EVA01-CLIP-B-16.json +19 -0
- eva_clip/model_configs/EVA01-CLIP-g-14-plus.json +24 -0
- eva_clip/model_configs/EVA01-CLIP-g-14.json +24 -0
- eva_clip/model_configs/EVA02-CLIP-B-16.json +29 -0
- eva_clip/model_configs/EVA02-CLIP-L-14-336.json +29 -0
- eva_clip/model_configs/EVA02-CLIP-L-14.json +29 -0
- eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json +25 -0
- eva_clip/model_configs/EVA02-CLIP-bigE-14.json +25 -0
- eva_clip/modified_resnet.py +181 -0
- eva_clip/openai.py +144 -0
- eva_clip/pretrained.py +332 -0
- eva_clip/rope.py +137 -0
- eva_clip/timm_model.py +123 -0
- eva_clip/tokenizer.py +201 -0
- eva_clip/transform.py +103 -0
- eva_clip/transformer.py +792 -0
- eva_clip/utils.py +326 -0
- example_inputs/hinton.jpeg +0 -0
- example_inputs/lecun.jpg +0 -0
- example_inputs/lifeifei.jpg +0 -0
- example_inputs/liuyifei.png +0 -0
- example_inputs/rihanna.webp +0 -0
- example_inputs/zcy.webp +0 -0
- flux/__init__.py +11 -0
- flux/__main__.py +4 -0
- flux/api.py +194 -0
- flux/cli.py +261 -0
- flux/math.py +31 -0
- flux/model.py +135 -0
- flux/modules/__init__.py +0 -0
- flux/modules/autoencoder.py +312 -0
- flux/modules/conditioner.py +37 -0
- flux/modules/layers.py +253 -0
- flux/sampling.py +161 -0
- flux/util.py +201 -0
- models/.gitkeep +0 -0
- pulid/attention_processor.py +422 -0
- pulid/encoders.py +64 -0
- pulid/encoders_flux.py +207 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title: PuLID
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
+
title: PuLID-FLUX
|
3 |
+
emoji: 🤗
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
from einops import rearrange
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from flux.cli import SamplingOptions
|
11 |
+
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
12 |
+
from flux.util import load_ae, load_clip, load_flow_model, load_t5
|
13 |
+
from pulid.pipeline_flux import PuLIDPipeline
|
14 |
+
from pulid.utils import resize_numpy_image_long
|
15 |
+
|
16 |
+
|
17 |
+
def get_models(name: str, device: torch.device, offload: bool):
|
18 |
+
t5 = load_t5(device, max_length=128)
|
19 |
+
clip = load_clip(device)
|
20 |
+
model = load_flow_model(name, device="cpu" if offload else device)
|
21 |
+
model.eval()
|
22 |
+
ae = load_ae(name, device="cpu" if offload else device)
|
23 |
+
return model, ae, t5, clip
|
24 |
+
|
25 |
+
|
26 |
+
class FluxGenerator:
|
27 |
+
def __init__(self, model_name: str, device: str, offload: bool, args):
|
28 |
+
self.device = torch.device(device)
|
29 |
+
self.offload = offload
|
30 |
+
self.model_name = model_name
|
31 |
+
self.model, self.ae, self.t5, self.clip = get_models(
|
32 |
+
model_name,
|
33 |
+
device=self.device,
|
34 |
+
offload=self.offload,
|
35 |
+
)
|
36 |
+
self.pulid_model = PuLIDPipeline(self.model, device, weight_dtype=torch.bfloat16)
|
37 |
+
self.pulid_model.load_pretrain(args.pretrained_model)
|
38 |
+
|
39 |
+
@spaces.GPU
|
40 |
+
@torch.inference_mode()
|
41 |
+
def generate_image(
|
42 |
+
self,
|
43 |
+
width,
|
44 |
+
height,
|
45 |
+
num_steps,
|
46 |
+
start_step,
|
47 |
+
guidance,
|
48 |
+
seed,
|
49 |
+
prompt,
|
50 |
+
id_image=None,
|
51 |
+
id_weight=1.0,
|
52 |
+
neg_prompt="",
|
53 |
+
true_cfg=1.0,
|
54 |
+
timestep_to_start_cfg=1,
|
55 |
+
max_sequence_length=128,
|
56 |
+
):
|
57 |
+
self.t5.max_length = max_sequence_length
|
58 |
+
|
59 |
+
seed = int(seed)
|
60 |
+
if seed == -1:
|
61 |
+
seed = None
|
62 |
+
|
63 |
+
opts = SamplingOptions(
|
64 |
+
prompt=prompt,
|
65 |
+
width=width,
|
66 |
+
height=height,
|
67 |
+
num_steps=num_steps,
|
68 |
+
guidance=guidance,
|
69 |
+
seed=seed,
|
70 |
+
)
|
71 |
+
|
72 |
+
if opts.seed is None:
|
73 |
+
opts.seed = torch.Generator(device="cpu").seed()
|
74 |
+
print(f"Generating '{opts.prompt}' with seed {opts.seed}")
|
75 |
+
t0 = time.perf_counter()
|
76 |
+
|
77 |
+
use_true_cfg = abs(true_cfg - 1.0) > 1e-2
|
78 |
+
|
79 |
+
if id_image is not None:
|
80 |
+
id_image = resize_numpy_image_long(id_image, 1024)
|
81 |
+
id_embeddings, uncond_id_embeddings = self.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
|
82 |
+
else:
|
83 |
+
id_embeddings = None
|
84 |
+
uncond_id_embeddings = None
|
85 |
+
|
86 |
+
# prepare input
|
87 |
+
x = get_noise(
|
88 |
+
1,
|
89 |
+
opts.height,
|
90 |
+
opts.width,
|
91 |
+
device=self.device,
|
92 |
+
dtype=torch.bfloat16,
|
93 |
+
seed=opts.seed,
|
94 |
+
)
|
95 |
+
timesteps = get_schedule(
|
96 |
+
opts.num_steps,
|
97 |
+
x.shape[-1] * x.shape[-2] // 4,
|
98 |
+
shift=True,
|
99 |
+
)
|
100 |
+
|
101 |
+
if self.offload:
|
102 |
+
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
|
103 |
+
inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt)
|
104 |
+
inp_neg = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
|
105 |
+
|
106 |
+
# offload TEs to CPU, load model to gpu
|
107 |
+
if self.offload:
|
108 |
+
self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
self.model = self.model.to(self.device)
|
111 |
+
|
112 |
+
# denoise initial noise
|
113 |
+
x = denoise(
|
114 |
+
self.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
|
115 |
+
start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
|
116 |
+
timestep_to_start_cfg=timestep_to_start_cfg,
|
117 |
+
neg_txt=inp_neg["txt"] if use_true_cfg else None,
|
118 |
+
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
|
119 |
+
neg_vec=inp_neg["vec"] if use_true_cfg else None,
|
120 |
+
)
|
121 |
+
|
122 |
+
# offload model, load autoencoder to gpu
|
123 |
+
if self.offload:
|
124 |
+
self.model.cpu()
|
125 |
+
torch.cuda.empty_cache()
|
126 |
+
self.ae.decoder.to(x.device)
|
127 |
+
|
128 |
+
# decode latents to pixel space
|
129 |
+
x = unpack(x.float(), opts.height, opts.width)
|
130 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
|
131 |
+
x = self.ae.decode(x)
|
132 |
+
|
133 |
+
if self.offload:
|
134 |
+
self.ae.decoder.cpu()
|
135 |
+
torch.cuda.empty_cache()
|
136 |
+
|
137 |
+
t1 = time.perf_counter()
|
138 |
+
|
139 |
+
print(f"Done in {t1 - t0:.1f}s.")
|
140 |
+
# bring into PIL format
|
141 |
+
x = x.clamp(-1, 1)
|
142 |
+
# x = embed_watermark(x.float())
|
143 |
+
x = rearrange(x[0], "c h w -> h w c")
|
144 |
+
|
145 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
146 |
+
return img, str(opts.seed), self.pulid_model.debug_img_list
|
147 |
+
|
148 |
+
_HEADER_ = '''
|
149 |
+
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
150 |
+
<h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">PuLID for FLUX</h1>
|
151 |
+
<p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p>
|
152 |
+
</div>
|
153 |
+
|
154 |
+
❗️❗️❗️**Tips:**
|
155 |
+
- `timestep to start inserting ID:` The smaller the value, the higher the fidelity, but the lower the editability; the higher the value, the lower the fidelity, but the higher the editability. **The recommended range for this value is between 0 and 4**. For photorealistic scenes, we recommend using 4; for stylized scenes, we recommend using 0-1. If you are not satisfied with the similarity, you can lower this value; conversely, if you are not satisfied with the editability, you can increase this value.
|
156 |
+
- `true CFG scale:` In most scenarios, it is recommended to use a fake CFG, i.e., setting the true CFG scale to 1, and just adjusting the guidance scale. This is also more efficiency. However, in a few cases, utilizing a true CFG can yield better results. For more detaileds, please refer to XX.
|
157 |
+
- please refer to the <a href='URL_ADDRESS' target='_blank'>github doc</a> for more details and info about the model, we provide the detail explanation about the above two parameters in the doc.
|
158 |
+
- we provide some examples in the bottom, you can try these example prompts first
|
159 |
+
|
160 |
+
''' # noqa E501
|
161 |
+
|
162 |
+
_CITE_ = r"""
|
163 |
+
If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'> Github Repo</a>. Thanks!
|
164 |
+
---
|
165 |
+
|
166 |
+
📧 **Contact**
|
167 |
+
If you have any questions or feedbacks, feel free to open a discussion or contact <b>[email protected]</b>.
|
168 |
+
""" # noqa E501
|
169 |
+
|
170 |
+
|
171 |
+
def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
172 |
+
offload: bool = False):
|
173 |
+
generator = FluxGenerator(model_name, device, offload, args)
|
174 |
+
|
175 |
+
with gr.Blocks() as demo:
|
176 |
+
gr.Markdown(_HEADER_)
|
177 |
+
|
178 |
+
with gr.Row():
|
179 |
+
with gr.Column():
|
180 |
+
prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
|
181 |
+
id_image = gr.Image(label="ID Image")
|
182 |
+
id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight")
|
183 |
+
|
184 |
+
width = gr.Slider(256, 1536, 896, step=16, label="Width")
|
185 |
+
height = gr.Slider(256, 1536, 1152, step=16, label="Height")
|
186 |
+
num_steps = gr.Slider(1, 20, 20, step=1, label="Number of steps")
|
187 |
+
start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
|
188 |
+
guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance")
|
189 |
+
seed = gr.Textbox(-1, label="Seed (-1 for random)")
|
190 |
+
max_sequence_length = gr.Slider(128, 512, 128, step=128,
|
191 |
+
label="max_sequence_length for prompt (T5), small will be faster")
|
192 |
+
|
193 |
+
with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)", open=False): # noqa E501
|
194 |
+
neg_prompt = gr.Textbox(
|
195 |
+
label="Negative Prompt",
|
196 |
+
value="bad quality, worst quality, text, signature, watermark, extra limbs")
|
197 |
+
true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
|
198 |
+
timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
|
199 |
+
|
200 |
+
generate_btn = gr.Button("Generate")
|
201 |
+
|
202 |
+
with gr.Column():
|
203 |
+
output_image = gr.Image(label="Generated Image")
|
204 |
+
seed_output = gr.Textbox(label="Used Seed")
|
205 |
+
intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev)
|
206 |
+
gr.Markdown(_CITE_)
|
207 |
+
|
208 |
+
with gr.Row(), gr.Column():
|
209 |
+
gr.Markdown("## Examples")
|
210 |
+
example_inps = [
|
211 |
+
[
|
212 |
+
'a woman holding sign with glowing green text \"PuLID for FLUX\"',
|
213 |
+
'example_inputs/liuyifei.png',
|
214 |
+
4, 4, 2680261499100305976, 1
|
215 |
+
],
|
216 |
+
[
|
217 |
+
'portrait, side view',
|
218 |
+
'example_inputs/liuyifei.png',
|
219 |
+
4, 4, 1205240166692517553, 1
|
220 |
+
],
|
221 |
+
[
|
222 |
+
'white-haired woman with vr technology atmosphere, revolutionary exceptional magnum with remarkable details', # noqa E501
|
223 |
+
'example_inputs/liuyifei.png',
|
224 |
+
4, 4, 6349424134217931066, 1
|
225 |
+
],
|
226 |
+
[
|
227 |
+
'a young child is eating Icecream',
|
228 |
+
'example_inputs/liuyifei.png',
|
229 |
+
4, 4, 10606046113565776207, 1
|
230 |
+
],
|
231 |
+
[
|
232 |
+
'a man is holding a sign with text \"PuLID for FLUX\", winter, snowing, top of the mountain',
|
233 |
+
'example_inputs/pengwei.jpg',
|
234 |
+
4, 4, 2410129802683836089, 1
|
235 |
+
],
|
236 |
+
[
|
237 |
+
'portrait, candle light',
|
238 |
+
'example_inputs/pengwei.jpg',
|
239 |
+
4, 4, 17522759474323955700, 1
|
240 |
+
],
|
241 |
+
[
|
242 |
+
'profile shot dark photo of a 25-year-old male with smoke escaping from his mouth, the backlit smoke gives the image an ephemeral quality, natural face, natural eyebrows, natural skin texture, award winning photo, highly detailed face, atmospheric lighting, film grain, monochrome', # noqa E501
|
243 |
+
'example_inputs/pengwei.jpg',
|
244 |
+
4, 4, 17733156847328193625, 1
|
245 |
+
],
|
246 |
+
[
|
247 |
+
'American Comics, 1boy',
|
248 |
+
'example_inputs/pengwei.jpg',
|
249 |
+
1, 4, 13223174453874179686, 1
|
250 |
+
],
|
251 |
+
[
|
252 |
+
'portrait, pixar',
|
253 |
+
'example_inputs/pengwei.jpg',
|
254 |
+
1, 4, 9445036702517583939, 1
|
255 |
+
],
|
256 |
+
]
|
257 |
+
gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
|
258 |
+
label='fake CFG')
|
259 |
+
|
260 |
+
example_inps = [
|
261 |
+
[
|
262 |
+
'portrait, made of ice sculpture',
|
263 |
+
'example_inputs/lecun.jpg',
|
264 |
+
1, 1, 3811899118709451814, 5
|
265 |
+
],
|
266 |
+
]
|
267 |
+
gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
|
268 |
+
label='true CFG')
|
269 |
+
|
270 |
+
generate_btn.click(
|
271 |
+
fn=generator.generate_image,
|
272 |
+
inputs=[width, height, num_steps, start_step, guidance, seed, prompt, id_image, id_weight, neg_prompt,
|
273 |
+
true_cfg, timestep_to_start_cfg, max_sequence_length],
|
274 |
+
outputs=[output_image, seed_output, intermediate_output],
|
275 |
+
)
|
276 |
+
|
277 |
+
return demo
|
278 |
+
|
279 |
+
|
280 |
+
if __name__ == "__main__":
|
281 |
+
import argparse
|
282 |
+
|
283 |
+
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
|
284 |
+
parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
|
285 |
+
help="currently only support flux-dev")
|
286 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
|
287 |
+
help="Device to use")
|
288 |
+
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
|
289 |
+
parser.add_argument("--port", type=int, default=8080, help="Port to use")
|
290 |
+
parser.add_argument("--dev", action='store_true', help="Development mode")
|
291 |
+
parser.add_argument("--pretrained_model", type=str, help='for development')
|
292 |
+
args = parser.parse_args()
|
293 |
+
|
294 |
+
import huggingface_hub
|
295 |
+
huggingface_hub.login(os.getenv('HF_TOKEN'))
|
296 |
+
|
297 |
+
demo = create_demo(args, args.name, args.device, args.offload)
|
298 |
+
demo.launch()
|
eva_clip/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
2 |
+
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms
|
3 |
+
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
4 |
+
from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
|
5 |
+
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
6 |
+
from .openai import load_openai_model, list_openai_models
|
7 |
+
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
|
8 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
9 |
+
from .tokenizer import SimpleTokenizer, tokenize
|
10 |
+
from .transform import image_transform
|
eva_clip/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
eva_clip/eva_vit_model.py
ADDED
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
|
3 |
+
# --------------------------------------------------------
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
from functools import partial
|
7 |
+
from itertools import repeat
|
8 |
+
import collections.abc
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import warnings
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from .transformer import PatchDropout
|
15 |
+
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
16 |
+
|
17 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
18 |
+
try:
|
19 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
20 |
+
except:
|
21 |
+
from torch.utils.checkpoint import checkpoint
|
22 |
+
else:
|
23 |
+
from torch.utils.checkpoint import checkpoint
|
24 |
+
|
25 |
+
try:
|
26 |
+
import xformers
|
27 |
+
import xformers.ops as xops
|
28 |
+
XFORMERS_IS_AVAILBLE = True
|
29 |
+
except:
|
30 |
+
XFORMERS_IS_AVAILBLE = False
|
31 |
+
|
32 |
+
|
33 |
+
def _ntuple(n):
|
34 |
+
def parse(x):
|
35 |
+
if isinstance(x, collections.abc.Iterable):
|
36 |
+
return x
|
37 |
+
return tuple(repeat(x, n))
|
38 |
+
return parse
|
39 |
+
|
40 |
+
to_2tuple = _ntuple(2)
|
41 |
+
|
42 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
43 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
44 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
45 |
+
def norm_cdf(x):
|
46 |
+
# Computes standard normal cumulative distribution function
|
47 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
48 |
+
|
49 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
50 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
51 |
+
"The distribution of values may be incorrect.",
|
52 |
+
stacklevel=2)
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
# Values are generated by using a truncated uniform distribution and
|
56 |
+
# then using the inverse CDF for the normal distribution.
|
57 |
+
# Get upper and lower cdf values
|
58 |
+
l = norm_cdf((a - mean) / std)
|
59 |
+
u = norm_cdf((b - mean) / std)
|
60 |
+
|
61 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
62 |
+
# [2l-1, 2u-1].
|
63 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
64 |
+
|
65 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
66 |
+
# standard normal
|
67 |
+
tensor.erfinv_()
|
68 |
+
|
69 |
+
# Transform to proper mean, std
|
70 |
+
tensor.mul_(std * math.sqrt(2.))
|
71 |
+
tensor.add_(mean)
|
72 |
+
|
73 |
+
# Clamp to ensure it's in the proper range
|
74 |
+
tensor.clamp_(min=a, max=b)
|
75 |
+
return tensor
|
76 |
+
|
77 |
+
|
78 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
79 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
80 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
81 |
+
normal distribution. The values are effectively drawn from the
|
82 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
83 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
84 |
+
the bounds. The method used for generating the random values works
|
85 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
86 |
+
Args:
|
87 |
+
tensor: an n-dimensional `torch.Tensor`
|
88 |
+
mean: the mean of the normal distribution
|
89 |
+
std: the standard deviation of the normal distribution
|
90 |
+
a: the minimum cutoff value
|
91 |
+
b: the maximum cutoff value
|
92 |
+
Examples:
|
93 |
+
>>> w = torch.empty(3, 5)
|
94 |
+
>>> nn.init.trunc_normal_(w)
|
95 |
+
"""
|
96 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
97 |
+
|
98 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
99 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
100 |
+
|
101 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
102 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
103 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
104 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
105 |
+
'survival rate' as the argument.
|
106 |
+
|
107 |
+
"""
|
108 |
+
if drop_prob == 0. or not training:
|
109 |
+
return x
|
110 |
+
keep_prob = 1 - drop_prob
|
111 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
112 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
113 |
+
if keep_prob > 0.0 and scale_by_keep:
|
114 |
+
random_tensor.div_(keep_prob)
|
115 |
+
return x * random_tensor
|
116 |
+
|
117 |
+
|
118 |
+
class DropPath(nn.Module):
|
119 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
120 |
+
"""
|
121 |
+
def __init__(self, drop_prob=None):
|
122 |
+
super(DropPath, self).__init__()
|
123 |
+
self.drop_prob = drop_prob
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
return drop_path(x, self.drop_prob, self.training)
|
127 |
+
|
128 |
+
def extra_repr(self) -> str:
|
129 |
+
return 'p={}'.format(self.drop_prob)
|
130 |
+
|
131 |
+
|
132 |
+
class Mlp(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
in_features,
|
136 |
+
hidden_features=None,
|
137 |
+
out_features=None,
|
138 |
+
act_layer=nn.GELU,
|
139 |
+
norm_layer=nn.LayerNorm,
|
140 |
+
drop=0.,
|
141 |
+
subln=False,
|
142 |
+
|
143 |
+
):
|
144 |
+
super().__init__()
|
145 |
+
out_features = out_features or in_features
|
146 |
+
hidden_features = hidden_features or in_features
|
147 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
148 |
+
self.act = act_layer()
|
149 |
+
|
150 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
151 |
+
|
152 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
153 |
+
self.drop = nn.Dropout(drop)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
x = self.fc1(x)
|
157 |
+
x = self.act(x)
|
158 |
+
# x = self.drop(x)
|
159 |
+
# commit this for the orignal BERT implement
|
160 |
+
x = self.ffn_ln(x)
|
161 |
+
|
162 |
+
x = self.fc2(x)
|
163 |
+
x = self.drop(x)
|
164 |
+
return x
|
165 |
+
|
166 |
+
class SwiGLU(nn.Module):
|
167 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
168 |
+
norm_layer=nn.LayerNorm, subln=False):
|
169 |
+
super().__init__()
|
170 |
+
out_features = out_features or in_features
|
171 |
+
hidden_features = hidden_features or in_features
|
172 |
+
|
173 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
174 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
175 |
+
|
176 |
+
self.act = act_layer()
|
177 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
178 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
179 |
+
|
180 |
+
self.drop = nn.Dropout(drop)
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
x1 = self.w1(x)
|
184 |
+
x2 = self.w2(x)
|
185 |
+
hidden = self.act(x1) * x2
|
186 |
+
x = self.ffn_ln(hidden)
|
187 |
+
x = self.w3(x)
|
188 |
+
x = self.drop(x)
|
189 |
+
return x
|
190 |
+
|
191 |
+
class Attention(nn.Module):
|
192 |
+
def __init__(
|
193 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
194 |
+
proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
|
195 |
+
super().__init__()
|
196 |
+
self.num_heads = num_heads
|
197 |
+
head_dim = dim // num_heads
|
198 |
+
if attn_head_dim is not None:
|
199 |
+
head_dim = attn_head_dim
|
200 |
+
all_head_dim = head_dim * self.num_heads
|
201 |
+
self.scale = qk_scale or head_dim ** -0.5
|
202 |
+
|
203 |
+
self.subln = subln
|
204 |
+
if self.subln:
|
205 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
206 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
207 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
208 |
+
else:
|
209 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
210 |
+
|
211 |
+
if qkv_bias:
|
212 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
213 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
214 |
+
else:
|
215 |
+
self.q_bias = None
|
216 |
+
self.v_bias = None
|
217 |
+
|
218 |
+
if window_size:
|
219 |
+
self.window_size = window_size
|
220 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
221 |
+
self.relative_position_bias_table = nn.Parameter(
|
222 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
223 |
+
# cls to token & token 2 cls & cls to cls
|
224 |
+
|
225 |
+
# get pair-wise relative position index for each token inside the window
|
226 |
+
coords_h = torch.arange(window_size[0])
|
227 |
+
coords_w = torch.arange(window_size[1])
|
228 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
229 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
230 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
231 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
232 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
233 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
234 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
235 |
+
relative_position_index = \
|
236 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
237 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
238 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
239 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
240 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
241 |
+
|
242 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
243 |
+
else:
|
244 |
+
self.window_size = None
|
245 |
+
self.relative_position_bias_table = None
|
246 |
+
self.relative_position_index = None
|
247 |
+
|
248 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
249 |
+
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
250 |
+
# self.proj = nn.Linear(all_head_dim, all_head_dim)
|
251 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
252 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
253 |
+
self.xattn = xattn
|
254 |
+
self.xattn_drop = attn_drop
|
255 |
+
|
256 |
+
self.rope = rope
|
257 |
+
|
258 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
259 |
+
B, N, C = x.shape
|
260 |
+
if self.subln:
|
261 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
262 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
263 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
264 |
+
|
265 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
266 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
267 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
268 |
+
else:
|
269 |
+
|
270 |
+
qkv_bias = None
|
271 |
+
if self.q_bias is not None:
|
272 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
273 |
+
|
274 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
275 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
|
276 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
277 |
+
|
278 |
+
if self.rope:
|
279 |
+
# slightly fast impl
|
280 |
+
q_t = q[:, :, 1:, :]
|
281 |
+
ro_q_t = self.rope(q_t)
|
282 |
+
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
|
283 |
+
|
284 |
+
k_t = k[:, :, 1:, :]
|
285 |
+
ro_k_t = self.rope(k_t)
|
286 |
+
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
|
287 |
+
|
288 |
+
if self.xattn:
|
289 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
290 |
+
k = k.permute(0, 2, 1, 3)
|
291 |
+
v = v.permute(0, 2, 1, 3)
|
292 |
+
|
293 |
+
x = xops.memory_efficient_attention(
|
294 |
+
q, k, v,
|
295 |
+
p=self.xattn_drop,
|
296 |
+
scale=self.scale,
|
297 |
+
)
|
298 |
+
x = x.reshape(B, N, -1)
|
299 |
+
x = self.inner_attn_ln(x)
|
300 |
+
x = self.proj(x)
|
301 |
+
x = self.proj_drop(x)
|
302 |
+
else:
|
303 |
+
q = q * self.scale
|
304 |
+
attn = (q @ k.transpose(-2, -1))
|
305 |
+
|
306 |
+
if self.relative_position_bias_table is not None:
|
307 |
+
relative_position_bias = \
|
308 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
309 |
+
self.window_size[0] * self.window_size[1] + 1,
|
310 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
311 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
312 |
+
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
|
313 |
+
|
314 |
+
if rel_pos_bias is not None:
|
315 |
+
attn = attn + rel_pos_bias.type_as(attn)
|
316 |
+
|
317 |
+
if attn_mask is not None:
|
318 |
+
attn_mask = attn_mask.bool()
|
319 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
320 |
+
|
321 |
+
attn = attn.softmax(dim=-1)
|
322 |
+
attn = self.attn_drop(attn)
|
323 |
+
|
324 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
325 |
+
x = self.inner_attn_ln(x)
|
326 |
+
x = self.proj(x)
|
327 |
+
x = self.proj_drop(x)
|
328 |
+
return x
|
329 |
+
|
330 |
+
|
331 |
+
class Block(nn.Module):
|
332 |
+
|
333 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
334 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
335 |
+
window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
|
336 |
+
subln=False, naiveswiglu=False):
|
337 |
+
super().__init__()
|
338 |
+
self.norm1 = norm_layer(dim)
|
339 |
+
self.attn = Attention(
|
340 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
341 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
|
342 |
+
xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
|
343 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
344 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
345 |
+
self.norm2 = norm_layer(dim)
|
346 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
347 |
+
|
348 |
+
if naiveswiglu:
|
349 |
+
self.mlp = SwiGLU(
|
350 |
+
in_features=dim,
|
351 |
+
hidden_features=mlp_hidden_dim,
|
352 |
+
subln=subln,
|
353 |
+
norm_layer=norm_layer,
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
self.mlp = Mlp(
|
357 |
+
in_features=dim,
|
358 |
+
hidden_features=mlp_hidden_dim,
|
359 |
+
act_layer=act_layer,
|
360 |
+
subln=subln,
|
361 |
+
drop=drop
|
362 |
+
)
|
363 |
+
|
364 |
+
if init_values is not None and init_values > 0:
|
365 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
366 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
367 |
+
else:
|
368 |
+
self.gamma_1, self.gamma_2 = None, None
|
369 |
+
|
370 |
+
self.postnorm = postnorm
|
371 |
+
|
372 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
373 |
+
if self.gamma_1 is None:
|
374 |
+
if self.postnorm:
|
375 |
+
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
376 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
377 |
+
else:
|
378 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
379 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
380 |
+
else:
|
381 |
+
if self.postnorm:
|
382 |
+
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
383 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
384 |
+
else:
|
385 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
386 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
387 |
+
return x
|
388 |
+
|
389 |
+
|
390 |
+
class PatchEmbed(nn.Module):
|
391 |
+
""" Image to Patch Embedding
|
392 |
+
"""
|
393 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
394 |
+
super().__init__()
|
395 |
+
img_size = to_2tuple(img_size)
|
396 |
+
patch_size = to_2tuple(patch_size)
|
397 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
398 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
399 |
+
self.img_size = img_size
|
400 |
+
self.patch_size = patch_size
|
401 |
+
self.num_patches = num_patches
|
402 |
+
|
403 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
404 |
+
|
405 |
+
def forward(self, x, **kwargs):
|
406 |
+
B, C, H, W = x.shape
|
407 |
+
# FIXME look at relaxing size constraints
|
408 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
409 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
410 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
411 |
+
return x
|
412 |
+
|
413 |
+
|
414 |
+
class RelativePositionBias(nn.Module):
|
415 |
+
|
416 |
+
def __init__(self, window_size, num_heads):
|
417 |
+
super().__init__()
|
418 |
+
self.window_size = window_size
|
419 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
420 |
+
self.relative_position_bias_table = nn.Parameter(
|
421 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
422 |
+
# cls to token & token 2 cls & cls to cls
|
423 |
+
|
424 |
+
# get pair-wise relative position index for each token inside the window
|
425 |
+
coords_h = torch.arange(window_size[0])
|
426 |
+
coords_w = torch.arange(window_size[1])
|
427 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
428 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
429 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
430 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
431 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
432 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
433 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
434 |
+
relative_position_index = \
|
435 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
436 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
437 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
438 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
439 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
440 |
+
|
441 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
442 |
+
|
443 |
+
def forward(self):
|
444 |
+
relative_position_bias = \
|
445 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
446 |
+
self.window_size[0] * self.window_size[1] + 1,
|
447 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
448 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
449 |
+
|
450 |
+
|
451 |
+
class EVAVisionTransformer(nn.Module):
|
452 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
453 |
+
"""
|
454 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
455 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
456 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
|
457 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
|
458 |
+
use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
|
459 |
+
pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
|
460 |
+
super().__init__()
|
461 |
+
|
462 |
+
if not XFORMERS_IS_AVAILBLE:
|
463 |
+
xattn = False
|
464 |
+
|
465 |
+
self.image_size = img_size
|
466 |
+
self.num_classes = num_classes
|
467 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
468 |
+
|
469 |
+
self.patch_embed = PatchEmbed(
|
470 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
471 |
+
num_patches = self.patch_embed.num_patches
|
472 |
+
|
473 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
474 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
475 |
+
if use_abs_pos_emb:
|
476 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
477 |
+
else:
|
478 |
+
self.pos_embed = None
|
479 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
480 |
+
|
481 |
+
if use_shared_rel_pos_bias:
|
482 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
483 |
+
else:
|
484 |
+
self.rel_pos_bias = None
|
485 |
+
|
486 |
+
if rope:
|
487 |
+
half_head_dim = embed_dim // num_heads // 2
|
488 |
+
hw_seq_len = img_size // patch_size
|
489 |
+
self.rope = VisionRotaryEmbeddingFast(
|
490 |
+
dim=half_head_dim,
|
491 |
+
pt_seq_len=pt_hw_seq_len,
|
492 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
493 |
+
# patch_dropout=patch_dropout
|
494 |
+
)
|
495 |
+
else:
|
496 |
+
self.rope = None
|
497 |
+
|
498 |
+
self.naiveswiglu = naiveswiglu
|
499 |
+
|
500 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
501 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
502 |
+
self.blocks = nn.ModuleList([
|
503 |
+
Block(
|
504 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
505 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
506 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
|
507 |
+
xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
|
508 |
+
for i in range(depth)])
|
509 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
510 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
511 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
512 |
+
|
513 |
+
if self.pos_embed is not None:
|
514 |
+
trunc_normal_(self.pos_embed, std=.02)
|
515 |
+
|
516 |
+
trunc_normal_(self.cls_token, std=.02)
|
517 |
+
# trunc_normal_(self.mask_token, std=.02)
|
518 |
+
|
519 |
+
self.apply(self._init_weights)
|
520 |
+
self.fix_init_weight()
|
521 |
+
|
522 |
+
if isinstance(self.head, nn.Linear):
|
523 |
+
trunc_normal_(self.head.weight, std=.02)
|
524 |
+
self.head.weight.data.mul_(init_scale)
|
525 |
+
self.head.bias.data.mul_(init_scale)
|
526 |
+
|
527 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
528 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
529 |
+
|
530 |
+
self.grad_checkpointing = grad_checkpointing
|
531 |
+
|
532 |
+
def fix_init_weight(self):
|
533 |
+
def rescale(param, layer_id):
|
534 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
535 |
+
|
536 |
+
for layer_id, layer in enumerate(self.blocks):
|
537 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
538 |
+
if self.naiveswiglu:
|
539 |
+
rescale(layer.mlp.w3.weight.data, layer_id + 1)
|
540 |
+
else:
|
541 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
542 |
+
|
543 |
+
def get_cast_dtype(self) -> torch.dtype:
|
544 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
545 |
+
|
546 |
+
def _init_weights(self, m):
|
547 |
+
if isinstance(m, nn.Linear):
|
548 |
+
trunc_normal_(m.weight, std=.02)
|
549 |
+
if m.bias is not None:
|
550 |
+
nn.init.constant_(m.bias, 0)
|
551 |
+
elif isinstance(m, nn.LayerNorm):
|
552 |
+
nn.init.constant_(m.bias, 0)
|
553 |
+
nn.init.constant_(m.weight, 1.0)
|
554 |
+
|
555 |
+
def get_num_layers(self):
|
556 |
+
return len(self.blocks)
|
557 |
+
|
558 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
559 |
+
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
560 |
+
for param in self.parameters():
|
561 |
+
param.requires_grad = False
|
562 |
+
|
563 |
+
@torch.jit.ignore
|
564 |
+
def set_grad_checkpointing(self, enable=True):
|
565 |
+
self.grad_checkpointing = enable
|
566 |
+
|
567 |
+
@torch.jit.ignore
|
568 |
+
def no_weight_decay(self):
|
569 |
+
return {'pos_embed', 'cls_token'}
|
570 |
+
|
571 |
+
def get_classifier(self):
|
572 |
+
return self.head
|
573 |
+
|
574 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
575 |
+
self.num_classes = num_classes
|
576 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
577 |
+
|
578 |
+
def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):
|
579 |
+
|
580 |
+
x = self.patch_embed(x)
|
581 |
+
batch_size, seq_len, _ = x.size()
|
582 |
+
|
583 |
+
if shuffle:
|
584 |
+
idx = torch.randperm(x.shape[1]) + 1
|
585 |
+
zero = torch.LongTensor([0, ])
|
586 |
+
idx = torch.cat([zero, idx])
|
587 |
+
pos_embed = self.pos_embed[:, idx]
|
588 |
+
|
589 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
590 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
591 |
+
if shuffle:
|
592 |
+
x = x + pos_embed
|
593 |
+
elif self.pos_embed is not None:
|
594 |
+
x = x + self.pos_embed
|
595 |
+
x = self.pos_drop(x)
|
596 |
+
|
597 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
598 |
+
if os.getenv('RoPE') == '1':
|
599 |
+
if self.training and not isinstance(self.patch_dropout, nn.Identity):
|
600 |
+
x, patch_indices_keep = self.patch_dropout(x)
|
601 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
|
602 |
+
else:
|
603 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
|
604 |
+
x = self.patch_dropout(x)
|
605 |
+
else:
|
606 |
+
x = self.patch_dropout(x)
|
607 |
+
|
608 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
609 |
+
hidden_states = []
|
610 |
+
for idx, blk in enumerate(self.blocks):
|
611 |
+
if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden:
|
612 |
+
hidden_states.append(x)
|
613 |
+
if self.grad_checkpointing:
|
614 |
+
x = checkpoint(blk, x, (rel_pos_bias,))
|
615 |
+
else:
|
616 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
617 |
+
|
618 |
+
if not return_all_features:
|
619 |
+
x = self.norm(x)
|
620 |
+
if self.fc_norm is not None:
|
621 |
+
return self.fc_norm(x.mean(1)), hidden_states
|
622 |
+
else:
|
623 |
+
return x[:, 0], hidden_states
|
624 |
+
return x
|
625 |
+
|
626 |
+
def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False):
|
627 |
+
if return_all_features:
|
628 |
+
return self.forward_features(x, return_all_features, return_hidden, shuffle)
|
629 |
+
x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle)
|
630 |
+
x = self.head(x)
|
631 |
+
if return_hidden:
|
632 |
+
return x, hidden_states
|
633 |
+
return x
|
eva_clip/factory.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Optional, Tuple, Union, Dict, Any
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
12 |
+
from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
13 |
+
get_cast_dtype
|
14 |
+
from .openai import load_openai_model
|
15 |
+
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
|
16 |
+
from .transform import image_transform
|
17 |
+
from .tokenizer import HFTokenizer, tokenize
|
18 |
+
from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
|
19 |
+
|
20 |
+
|
21 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
22 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
23 |
+
|
24 |
+
|
25 |
+
def _natural_key(string_):
|
26 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
27 |
+
|
28 |
+
|
29 |
+
def _rescan_model_configs():
|
30 |
+
global _MODEL_CONFIGS
|
31 |
+
|
32 |
+
config_ext = ('.json',)
|
33 |
+
config_files = []
|
34 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
35 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
36 |
+
config_files.append(config_path)
|
37 |
+
elif config_path.is_dir():
|
38 |
+
for ext in config_ext:
|
39 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
40 |
+
|
41 |
+
for cf in config_files:
|
42 |
+
with open(cf, "r", encoding="utf8") as f:
|
43 |
+
model_cfg = json.load(f)
|
44 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
45 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
46 |
+
|
47 |
+
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
|
48 |
+
|
49 |
+
|
50 |
+
_rescan_model_configs() # initial populate of model config registry
|
51 |
+
|
52 |
+
|
53 |
+
def list_models():
|
54 |
+
""" enumerate available model architectures based on config files """
|
55 |
+
return list(_MODEL_CONFIGS.keys())
|
56 |
+
|
57 |
+
|
58 |
+
def add_model_config(path):
|
59 |
+
""" add model config path or file and update registry """
|
60 |
+
if not isinstance(path, Path):
|
61 |
+
path = Path(path)
|
62 |
+
_MODEL_CONFIG_PATHS.append(path)
|
63 |
+
_rescan_model_configs()
|
64 |
+
|
65 |
+
|
66 |
+
def get_model_config(model_name):
|
67 |
+
if model_name in _MODEL_CONFIGS:
|
68 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
69 |
+
else:
|
70 |
+
return None
|
71 |
+
|
72 |
+
|
73 |
+
def get_tokenizer(model_name):
|
74 |
+
config = get_model_config(model_name)
|
75 |
+
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
|
76 |
+
return tokenizer
|
77 |
+
|
78 |
+
|
79 |
+
# loading openai CLIP weights when is_openai=True for training
|
80 |
+
def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
|
81 |
+
if is_openai:
|
82 |
+
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
|
83 |
+
state_dict = model.state_dict()
|
84 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
85 |
+
state_dict.pop(key, None)
|
86 |
+
else:
|
87 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
88 |
+
for mk in model_key.split('|'):
|
89 |
+
if isinstance(checkpoint, dict) and mk in checkpoint:
|
90 |
+
state_dict = checkpoint[mk]
|
91 |
+
break
|
92 |
+
else:
|
93 |
+
state_dict = checkpoint
|
94 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
95 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
96 |
+
|
97 |
+
for k in skip_list:
|
98 |
+
if k in list(state_dict.keys()):
|
99 |
+
logging.info(f"Removing key {k} from pretrained checkpoint")
|
100 |
+
del state_dict[k]
|
101 |
+
|
102 |
+
if os.getenv('RoPE') == '1':
|
103 |
+
for k in list(state_dict.keys()):
|
104 |
+
if 'freqs_cos' in k or 'freqs_sin' in k:
|
105 |
+
del state_dict[k]
|
106 |
+
return state_dict
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
|
111 |
+
state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
|
112 |
+
# detect old format and make compatible with new format
|
113 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
114 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
115 |
+
if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
|
116 |
+
state_dict['logit_scale'] = state_dict['text.logit_scale']
|
117 |
+
del state_dict['text.logit_scale']
|
118 |
+
|
119 |
+
# resize_clip_pos_embed for CLIP and open CLIP
|
120 |
+
if 'visual.positional_embedding' in state_dict:
|
121 |
+
resize_clip_pos_embed(state_dict, model)
|
122 |
+
# specified to eva_vit_model
|
123 |
+
elif 'visual.pos_embed' in state_dict:
|
124 |
+
resize_evaclip_pos_embed(state_dict, model)
|
125 |
+
|
126 |
+
# resize_clip_pos_embed(state_dict, model)
|
127 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
128 |
+
logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
|
129 |
+
return incompatible_keys
|
130 |
+
|
131 |
+
def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
|
132 |
+
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
|
133 |
+
|
134 |
+
for k in list(state_dict.keys()):
|
135 |
+
if not k.startswith('visual.'):
|
136 |
+
del state_dict[k]
|
137 |
+
for k in list(state_dict.keys()):
|
138 |
+
if k.startswith('visual.'):
|
139 |
+
new_k = k[7:]
|
140 |
+
state_dict[new_k] = state_dict[k]
|
141 |
+
del state_dict[k]
|
142 |
+
return state_dict
|
143 |
+
|
144 |
+
def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
|
145 |
+
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
|
146 |
+
|
147 |
+
for k in list(state_dict.keys()):
|
148 |
+
if k.startswith('visual.'):
|
149 |
+
del state_dict[k]
|
150 |
+
return state_dict
|
151 |
+
|
152 |
+
def get_pretrained_tag(pretrained_model):
|
153 |
+
pretrained_model = pretrained_model.lower()
|
154 |
+
if "laion" in pretrained_model or "open_clip" in pretrained_model:
|
155 |
+
return "open_clip"
|
156 |
+
elif "openai" in pretrained_model:
|
157 |
+
return "clip"
|
158 |
+
elif "eva" in pretrained_model and "clip" in pretrained_model:
|
159 |
+
return "eva_clip"
|
160 |
+
else:
|
161 |
+
return "other"
|
162 |
+
|
163 |
+
def load_pretrained_checkpoint(
|
164 |
+
model,
|
165 |
+
visual_checkpoint_path,
|
166 |
+
text_checkpoint_path,
|
167 |
+
strict=True,
|
168 |
+
visual_model=None,
|
169 |
+
text_model=None,
|
170 |
+
model_key="model|module|state_dict",
|
171 |
+
skip_list=[]):
|
172 |
+
visual_tag = get_pretrained_tag(visual_model)
|
173 |
+
text_tag = get_pretrained_tag(text_model)
|
174 |
+
|
175 |
+
logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
|
176 |
+
visual_incompatible_keys, text_incompatible_keys = None, None
|
177 |
+
if visual_checkpoint_path:
|
178 |
+
if visual_tag == "eva_clip" or visual_tag == "open_clip":
|
179 |
+
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
|
180 |
+
elif visual_tag == "clip":
|
181 |
+
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
|
182 |
+
else:
|
183 |
+
visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
|
184 |
+
|
185 |
+
# resize_clip_pos_embed for CLIP and open CLIP
|
186 |
+
if 'positional_embedding' in visual_state_dict:
|
187 |
+
resize_visual_pos_embed(visual_state_dict, model)
|
188 |
+
# specified to EVA model
|
189 |
+
elif 'pos_embed' in visual_state_dict:
|
190 |
+
resize_eva_pos_embed(visual_state_dict, model)
|
191 |
+
|
192 |
+
visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
|
193 |
+
logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
|
194 |
+
logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
|
195 |
+
|
196 |
+
if text_checkpoint_path:
|
197 |
+
if text_tag == "eva_clip" or text_tag == "open_clip":
|
198 |
+
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
|
199 |
+
elif text_tag == "clip":
|
200 |
+
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
|
201 |
+
else:
|
202 |
+
text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
|
203 |
+
|
204 |
+
text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
|
205 |
+
|
206 |
+
logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
|
207 |
+
logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
|
208 |
+
|
209 |
+
return visual_incompatible_keys, text_incompatible_keys
|
210 |
+
|
211 |
+
def create_model(
|
212 |
+
model_name: str,
|
213 |
+
pretrained: Optional[str] = None,
|
214 |
+
precision: str = 'fp32',
|
215 |
+
device: Union[str, torch.device] = 'cpu',
|
216 |
+
jit: bool = False,
|
217 |
+
force_quick_gelu: bool = False,
|
218 |
+
force_custom_clip: bool = False,
|
219 |
+
force_patch_dropout: Optional[float] = None,
|
220 |
+
pretrained_image: str = '',
|
221 |
+
pretrained_text: str = '',
|
222 |
+
pretrained_hf: bool = True,
|
223 |
+
pretrained_visual_model: str = None,
|
224 |
+
pretrained_text_model: str = None,
|
225 |
+
cache_dir: Optional[str] = None,
|
226 |
+
skip_list: list = [],
|
227 |
+
):
|
228 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
229 |
+
if isinstance(device, str):
|
230 |
+
device = torch.device(device)
|
231 |
+
|
232 |
+
if pretrained and pretrained.lower() == 'openai':
|
233 |
+
logging.info(f'Loading pretrained {model_name} from OpenAI.')
|
234 |
+
model = load_openai_model(
|
235 |
+
model_name,
|
236 |
+
precision=precision,
|
237 |
+
device=device,
|
238 |
+
jit=jit,
|
239 |
+
cache_dir=cache_dir,
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
model_cfg = get_model_config(model_name)
|
243 |
+
if model_cfg is not None:
|
244 |
+
logging.info(f'Loaded {model_name} model config.')
|
245 |
+
else:
|
246 |
+
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
247 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
248 |
+
|
249 |
+
if 'rope' in model_cfg.get('vision_cfg', {}):
|
250 |
+
if model_cfg['vision_cfg']['rope']:
|
251 |
+
os.environ['RoPE'] = "1"
|
252 |
+
else:
|
253 |
+
os.environ['RoPE'] = "0"
|
254 |
+
|
255 |
+
if force_quick_gelu:
|
256 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
257 |
+
model_cfg["quick_gelu"] = True
|
258 |
+
|
259 |
+
if force_patch_dropout is not None:
|
260 |
+
# override the default patch dropout value
|
261 |
+
model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
|
262 |
+
|
263 |
+
cast_dtype = get_cast_dtype(precision)
|
264 |
+
custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
|
265 |
+
|
266 |
+
|
267 |
+
if custom_clip:
|
268 |
+
if 'hf_model_name' in model_cfg.get('text_cfg', {}):
|
269 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
270 |
+
model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
|
271 |
+
else:
|
272 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
273 |
+
|
274 |
+
pretrained_cfg = {}
|
275 |
+
if pretrained:
|
276 |
+
checkpoint_path = ''
|
277 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
278 |
+
if pretrained_cfg:
|
279 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
280 |
+
elif os.path.exists(pretrained):
|
281 |
+
checkpoint_path = pretrained
|
282 |
+
|
283 |
+
if checkpoint_path:
|
284 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
285 |
+
load_checkpoint(model,
|
286 |
+
checkpoint_path,
|
287 |
+
model_key="model|module|state_dict",
|
288 |
+
strict=False
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
error_str = (
|
292 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
293 |
+
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
294 |
+
logging.warning(error_str)
|
295 |
+
raise RuntimeError(error_str)
|
296 |
+
else:
|
297 |
+
visual_checkpoint_path = ''
|
298 |
+
text_checkpoint_path = ''
|
299 |
+
|
300 |
+
if pretrained_image:
|
301 |
+
pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
|
302 |
+
pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
|
303 |
+
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
|
304 |
+
# pretrained weight loading for timm models set via vision_cfg
|
305 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
306 |
+
elif pretrained_image_cfg:
|
307 |
+
visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
|
308 |
+
elif os.path.exists(pretrained_image):
|
309 |
+
visual_checkpoint_path = pretrained_image
|
310 |
+
else:
|
311 |
+
logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
|
312 |
+
raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
|
313 |
+
|
314 |
+
if pretrained_text:
|
315 |
+
pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
|
316 |
+
pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
|
317 |
+
if pretrained_image_cfg:
|
318 |
+
text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
|
319 |
+
elif os.path.exists(pretrained_text):
|
320 |
+
text_checkpoint_path = pretrained_text
|
321 |
+
else:
|
322 |
+
logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
|
323 |
+
raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
|
324 |
+
|
325 |
+
if visual_checkpoint_path:
|
326 |
+
logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
|
327 |
+
if text_checkpoint_path:
|
328 |
+
logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
|
329 |
+
|
330 |
+
if visual_checkpoint_path or text_checkpoint_path:
|
331 |
+
load_pretrained_checkpoint(
|
332 |
+
model,
|
333 |
+
visual_checkpoint_path,
|
334 |
+
text_checkpoint_path,
|
335 |
+
strict=False,
|
336 |
+
visual_model=pretrained_visual_model,
|
337 |
+
text_model=pretrained_text_model,
|
338 |
+
model_key="model|module|state_dict",
|
339 |
+
skip_list=skip_list
|
340 |
+
)
|
341 |
+
|
342 |
+
if "fp16" in precision or "bf16" in precision:
|
343 |
+
logging.info(f'convert precision to {precision}')
|
344 |
+
model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
|
345 |
+
|
346 |
+
model.to(device=device)
|
347 |
+
|
348 |
+
# set image / mean metadata from pretrained_cfg if available, or use default
|
349 |
+
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
350 |
+
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
351 |
+
|
352 |
+
if jit:
|
353 |
+
model = torch.jit.script(model)
|
354 |
+
|
355 |
+
return model
|
356 |
+
|
357 |
+
|
358 |
+
def create_model_and_transforms(
|
359 |
+
model_name: str,
|
360 |
+
pretrained: Optional[str] = None,
|
361 |
+
precision: str = 'fp32',
|
362 |
+
device: Union[str, torch.device] = 'cpu',
|
363 |
+
jit: bool = False,
|
364 |
+
force_quick_gelu: bool = False,
|
365 |
+
force_custom_clip: bool = False,
|
366 |
+
force_patch_dropout: Optional[float] = None,
|
367 |
+
pretrained_image: str = '',
|
368 |
+
pretrained_text: str = '',
|
369 |
+
pretrained_hf: bool = True,
|
370 |
+
pretrained_visual_model: str = None,
|
371 |
+
pretrained_text_model: str = None,
|
372 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
373 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
374 |
+
cache_dir: Optional[str] = None,
|
375 |
+
skip_list: list = [],
|
376 |
+
):
|
377 |
+
model = create_model(
|
378 |
+
model_name,
|
379 |
+
pretrained,
|
380 |
+
precision=precision,
|
381 |
+
device=device,
|
382 |
+
jit=jit,
|
383 |
+
force_quick_gelu=force_quick_gelu,
|
384 |
+
force_custom_clip=force_custom_clip,
|
385 |
+
force_patch_dropout=force_patch_dropout,
|
386 |
+
pretrained_image=pretrained_image,
|
387 |
+
pretrained_text=pretrained_text,
|
388 |
+
pretrained_hf=pretrained_hf,
|
389 |
+
pretrained_visual_model=pretrained_visual_model,
|
390 |
+
pretrained_text_model=pretrained_text_model,
|
391 |
+
cache_dir=cache_dir,
|
392 |
+
skip_list=skip_list,
|
393 |
+
)
|
394 |
+
|
395 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
396 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
397 |
+
preprocess_train = image_transform(
|
398 |
+
model.visual.image_size,
|
399 |
+
is_train=True,
|
400 |
+
mean=image_mean,
|
401 |
+
std=image_std
|
402 |
+
)
|
403 |
+
preprocess_val = image_transform(
|
404 |
+
model.visual.image_size,
|
405 |
+
is_train=False,
|
406 |
+
mean=image_mean,
|
407 |
+
std=image_std
|
408 |
+
)
|
409 |
+
|
410 |
+
return model, preprocess_train, preprocess_val
|
411 |
+
|
412 |
+
|
413 |
+
def create_transforms(
|
414 |
+
model_name: str,
|
415 |
+
pretrained: Optional[str] = None,
|
416 |
+
precision: str = 'fp32',
|
417 |
+
device: Union[str, torch.device] = 'cpu',
|
418 |
+
jit: bool = False,
|
419 |
+
force_quick_gelu: bool = False,
|
420 |
+
force_custom_clip: bool = False,
|
421 |
+
force_patch_dropout: Optional[float] = None,
|
422 |
+
pretrained_image: str = '',
|
423 |
+
pretrained_text: str = '',
|
424 |
+
pretrained_hf: bool = True,
|
425 |
+
pretrained_visual_model: str = None,
|
426 |
+
pretrained_text_model: str = None,
|
427 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
428 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
429 |
+
cache_dir: Optional[str] = None,
|
430 |
+
skip_list: list = [],
|
431 |
+
):
|
432 |
+
model = create_model(
|
433 |
+
model_name,
|
434 |
+
pretrained,
|
435 |
+
precision=precision,
|
436 |
+
device=device,
|
437 |
+
jit=jit,
|
438 |
+
force_quick_gelu=force_quick_gelu,
|
439 |
+
force_custom_clip=force_custom_clip,
|
440 |
+
force_patch_dropout=force_patch_dropout,
|
441 |
+
pretrained_image=pretrained_image,
|
442 |
+
pretrained_text=pretrained_text,
|
443 |
+
pretrained_hf=pretrained_hf,
|
444 |
+
pretrained_visual_model=pretrained_visual_model,
|
445 |
+
pretrained_text_model=pretrained_text_model,
|
446 |
+
cache_dir=cache_dir,
|
447 |
+
skip_list=skip_list,
|
448 |
+
)
|
449 |
+
|
450 |
+
|
451 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
452 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
453 |
+
preprocess_train = image_transform(
|
454 |
+
model.visual.image_size,
|
455 |
+
is_train=True,
|
456 |
+
mean=image_mean,
|
457 |
+
std=image_std
|
458 |
+
)
|
459 |
+
preprocess_val = image_transform(
|
460 |
+
model.visual.image_size,
|
461 |
+
is_train=False,
|
462 |
+
mean=image_mean,
|
463 |
+
std=image_std
|
464 |
+
)
|
465 |
+
del model
|
466 |
+
|
467 |
+
return preprocess_train, preprocess_val
|
468 |
+
|
469 |
+
def create_model_from_pretrained(
|
470 |
+
model_name: str,
|
471 |
+
pretrained: str,
|
472 |
+
precision: str = 'fp32',
|
473 |
+
device: Union[str, torch.device] = 'cpu',
|
474 |
+
jit: bool = False,
|
475 |
+
force_quick_gelu: bool = False,
|
476 |
+
force_custom_clip: bool = False,
|
477 |
+
force_patch_dropout: Optional[float] = None,
|
478 |
+
return_transform: bool = True,
|
479 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
480 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
481 |
+
cache_dir: Optional[str] = None,
|
482 |
+
is_frozen: bool = False,
|
483 |
+
):
|
484 |
+
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
|
485 |
+
raise RuntimeError(
|
486 |
+
f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
|
487 |
+
f' Use open_clip.list_pretrained() to find one.')
|
488 |
+
|
489 |
+
model = create_model(
|
490 |
+
model_name,
|
491 |
+
pretrained,
|
492 |
+
precision=precision,
|
493 |
+
device=device,
|
494 |
+
jit=jit,
|
495 |
+
force_quick_gelu=force_quick_gelu,
|
496 |
+
force_custom_clip=force_custom_clip,
|
497 |
+
force_patch_dropout=force_patch_dropout,
|
498 |
+
cache_dir=cache_dir,
|
499 |
+
)
|
500 |
+
|
501 |
+
if is_frozen:
|
502 |
+
for param in model.parameters():
|
503 |
+
param.requires_grad = False
|
504 |
+
|
505 |
+
if not return_transform:
|
506 |
+
return model
|
507 |
+
|
508 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
509 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
510 |
+
preprocess = image_transform(
|
511 |
+
model.visual.image_size,
|
512 |
+
is_train=False,
|
513 |
+
mean=image_mean,
|
514 |
+
std=image_std
|
515 |
+
)
|
516 |
+
|
517 |
+
return model, preprocess
|
eva_clip/hf_configs.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HF architecture dict:
|
2 |
+
arch_dict = {
|
3 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
4 |
+
"roberta": {
|
5 |
+
"config_names": {
|
6 |
+
"context_length": "max_position_embeddings",
|
7 |
+
"vocab_size": "vocab_size",
|
8 |
+
"width": "hidden_size",
|
9 |
+
"heads": "num_attention_heads",
|
10 |
+
"layers": "num_hidden_layers",
|
11 |
+
"layer_attr": "layer",
|
12 |
+
"token_embeddings_attr": "embeddings"
|
13 |
+
},
|
14 |
+
"pooler": "mean_pooler",
|
15 |
+
},
|
16 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
17 |
+
"xlm-roberta": {
|
18 |
+
"config_names": {
|
19 |
+
"context_length": "max_position_embeddings",
|
20 |
+
"vocab_size": "vocab_size",
|
21 |
+
"width": "hidden_size",
|
22 |
+
"heads": "num_attention_heads",
|
23 |
+
"layers": "num_hidden_layers",
|
24 |
+
"layer_attr": "layer",
|
25 |
+
"token_embeddings_attr": "embeddings"
|
26 |
+
},
|
27 |
+
"pooler": "mean_pooler",
|
28 |
+
},
|
29 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
30 |
+
"mt5": {
|
31 |
+
"config_names": {
|
32 |
+
# unlimited seqlen
|
33 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
34 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
35 |
+
"context_length": "",
|
36 |
+
"vocab_size": "vocab_size",
|
37 |
+
"width": "d_model",
|
38 |
+
"heads": "num_heads",
|
39 |
+
"layers": "num_layers",
|
40 |
+
"layer_attr": "block",
|
41 |
+
"token_embeddings_attr": "embed_tokens"
|
42 |
+
},
|
43 |
+
"pooler": "mean_pooler",
|
44 |
+
},
|
45 |
+
"bert": {
|
46 |
+
"config_names": {
|
47 |
+
"context_length": "max_position_embeddings",
|
48 |
+
"vocab_size": "vocab_size",
|
49 |
+
"width": "hidden_size",
|
50 |
+
"heads": "num_attention_heads",
|
51 |
+
"layers": "num_hidden_layers",
|
52 |
+
"layer_attr": "layer",
|
53 |
+
"token_embeddings_attr": "embeddings"
|
54 |
+
},
|
55 |
+
"pooler": "mean_pooler",
|
56 |
+
}
|
57 |
+
}
|
eva_clip/hf_model.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
|
3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from torch import TensorType
|
12 |
+
try:
|
13 |
+
import transformers
|
14 |
+
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
|
15 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
16 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
17 |
+
except ImportError as e:
|
18 |
+
transformers = None
|
19 |
+
|
20 |
+
|
21 |
+
class BaseModelOutput:
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class PretrainedConfig:
|
26 |
+
pass
|
27 |
+
|
28 |
+
from .hf_configs import arch_dict
|
29 |
+
|
30 |
+
# utils
|
31 |
+
def _camel2snake(s):
|
32 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
33 |
+
|
34 |
+
# TODO: ?last - for gpt-like models
|
35 |
+
_POOLERS = {}
|
36 |
+
|
37 |
+
def register_pooler(cls):
|
38 |
+
"""Decorator registering pooler class"""
|
39 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
40 |
+
return cls
|
41 |
+
|
42 |
+
|
43 |
+
@register_pooler
|
44 |
+
class MeanPooler(nn.Module):
|
45 |
+
"""Mean pooling"""
|
46 |
+
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
|
47 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
48 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
49 |
+
|
50 |
+
@register_pooler
|
51 |
+
class MaxPooler(nn.Module):
|
52 |
+
"""Max pooling"""
|
53 |
+
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
|
54 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
55 |
+
return masked_output.max(1).values
|
56 |
+
|
57 |
+
@register_pooler
|
58 |
+
class ClsPooler(nn.Module):
|
59 |
+
"""CLS token pooling"""
|
60 |
+
def __init__(self, use_pooler_output=True):
|
61 |
+
super().__init__()
|
62 |
+
self.cls_token_position = 0
|
63 |
+
self.use_pooler_output = use_pooler_output
|
64 |
+
|
65 |
+
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
|
66 |
+
|
67 |
+
if (self.use_pooler_output and
|
68 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
69 |
+
(x.pooler_output is not None)
|
70 |
+
):
|
71 |
+
return x.pooler_output
|
72 |
+
|
73 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
74 |
+
|
75 |
+
class HFTextEncoder(nn.Module):
|
76 |
+
"""HuggingFace model adapter"""
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
model_name_or_path: str,
|
80 |
+
output_dim: int,
|
81 |
+
tokenizer_name: str = None,
|
82 |
+
config: PretrainedConfig = None,
|
83 |
+
pooler_type: str = None,
|
84 |
+
proj: str = None,
|
85 |
+
pretrained: bool = True,
|
86 |
+
masked_language_modeling: bool = False):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self.output_dim = output_dim
|
90 |
+
|
91 |
+
# TODO: find better way to get this information
|
92 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
93 |
+
|
94 |
+
if transformers is None:
|
95 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
96 |
+
if config is None:
|
97 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
98 |
+
if masked_language_modeling:
|
99 |
+
create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
|
100 |
+
AutoModelForMaskedLM.from_config, self.config)
|
101 |
+
else:
|
102 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
103 |
+
AutoModel.from_config, self.config)
|
104 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
105 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
106 |
+
self.transformer = create_func(model_args)
|
107 |
+
self.transformer = self.transformer.encoder
|
108 |
+
else:
|
109 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
110 |
+
else:
|
111 |
+
self.config = config
|
112 |
+
if masked_language_modeling:
|
113 |
+
self.transformer = AutoModelForMaskedLM.from_config(config)
|
114 |
+
else:
|
115 |
+
self.transformer = AutoModel.from_config(config)
|
116 |
+
|
117 |
+
if pooler_type is None: # get default arch pooler
|
118 |
+
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
|
119 |
+
else:
|
120 |
+
self.pooler = _POOLERS[pooler_type]()
|
121 |
+
|
122 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
123 |
+
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
124 |
+
self.proj = nn.Identity()
|
125 |
+
elif proj == 'linear':
|
126 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
127 |
+
elif proj == 'mlp':
|
128 |
+
hidden_size = (d_model + output_dim) // 2
|
129 |
+
self.proj = nn.Sequential(
|
130 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
131 |
+
nn.GELU(),
|
132 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
133 |
+
)
|
134 |
+
|
135 |
+
# self.itm_proj = nn.Linear(d_model, 2, bias=False)
|
136 |
+
# self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
|
137 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
138 |
+
|
139 |
+
# def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
|
140 |
+
# image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
|
141 |
+
# attn_mask = (x != self.config.pad_token_id).long()
|
142 |
+
# out = self.transformer(
|
143 |
+
# input_ids=x,
|
144 |
+
# attention_mask=attn_mask,
|
145 |
+
# encoder_hidden_states = image_embeds,
|
146 |
+
# encoder_attention_mask = image_atts,
|
147 |
+
# )
|
148 |
+
# pooled_out = self.pooler(out, attn_mask)
|
149 |
+
|
150 |
+
# return self.itm_proj(pooled_out)
|
151 |
+
|
152 |
+
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
|
153 |
+
if masked_indices is None:
|
154 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
155 |
+
|
156 |
+
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
|
157 |
+
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
|
158 |
+
|
159 |
+
if targets is not None:
|
160 |
+
targets[~masked_indices] = -100 # We only compute loss on masked tokens
|
161 |
+
|
162 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
163 |
+
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
|
164 |
+
input_ids[indices_replaced] = self.tokenizer.mask_token_id
|
165 |
+
|
166 |
+
# 10% of the time, we replace masked input tokens with random word
|
167 |
+
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
168 |
+
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
|
169 |
+
input_ids[indices_random] = random_words[indices_random]
|
170 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
171 |
+
|
172 |
+
if targets is not None:
|
173 |
+
return input_ids, targets
|
174 |
+
else:
|
175 |
+
return input_ids
|
176 |
+
|
177 |
+
def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
|
178 |
+
labels = input_ids.clone()
|
179 |
+
attn_mask = (input_ids != self.config.pad_token_id).long()
|
180 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
|
181 |
+
vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
|
182 |
+
probability_matrix = torch.full(labels.shape, mlm_probability)
|
183 |
+
input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
|
184 |
+
probability_matrix = probability_matrix)
|
185 |
+
mlm_output = self.transformer(input_ids,
|
186 |
+
attention_mask = attn_mask,
|
187 |
+
encoder_hidden_states = image_embeds,
|
188 |
+
encoder_attention_mask = image_atts,
|
189 |
+
return_dict = True,
|
190 |
+
labels = labels,
|
191 |
+
)
|
192 |
+
return mlm_output.loss
|
193 |
+
# mlm_output = self.transformer(input_ids,
|
194 |
+
# attention_mask = attn_mask,
|
195 |
+
# encoder_hidden_states = image_embeds,
|
196 |
+
# encoder_attention_mask = image_atts,
|
197 |
+
# return_dict = True,
|
198 |
+
# ).last_hidden_state
|
199 |
+
# logits = self.mlm_proj(mlm_output)
|
200 |
+
|
201 |
+
# # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
|
202 |
+
# logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
|
203 |
+
# labels = labels[:, 1:].contiguous().view(-1)
|
204 |
+
|
205 |
+
# mlm_loss = F.cross_entropy(
|
206 |
+
# logits,
|
207 |
+
# labels,
|
208 |
+
# # label_smoothing=0.1,
|
209 |
+
# )
|
210 |
+
# return mlm_loss
|
211 |
+
|
212 |
+
|
213 |
+
def forward(self, x:TensorType) -> TensorType:
|
214 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
215 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
216 |
+
pooled_out = self.pooler(out, attn_mask)
|
217 |
+
|
218 |
+
return self.proj(pooled_out)
|
219 |
+
|
220 |
+
def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
|
221 |
+
if not unlocked_layers: # full freezing
|
222 |
+
for n, p in self.transformer.named_parameters():
|
223 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
224 |
+
return
|
225 |
+
|
226 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
227 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
228 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
229 |
+
embeddings = getattr(
|
230 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
231 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
232 |
+
# freeze layers
|
233 |
+
for module in modules:
|
234 |
+
for n, p in module.named_parameters():
|
235 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
236 |
+
|
237 |
+
|
238 |
+
@torch.jit.ignore
|
239 |
+
def set_grad_checkpointing(self, enable=True):
|
240 |
+
self.transformer.gradient_checkpointing_enable()
|
241 |
+
|
242 |
+
def get_num_layers(self):
|
243 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
244 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
245 |
+
return len(layer_list)
|
246 |
+
|
247 |
+
def init_parameters(self):
|
248 |
+
pass
|
eva_clip/loss.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
try:
|
7 |
+
import torch.distributed.nn
|
8 |
+
from torch import distributed as dist
|
9 |
+
has_distributed = True
|
10 |
+
except ImportError:
|
11 |
+
has_distributed = False
|
12 |
+
|
13 |
+
try:
|
14 |
+
import horovod.torch as hvd
|
15 |
+
except ImportError:
|
16 |
+
hvd = None
|
17 |
+
|
18 |
+
from timm.loss import LabelSmoothingCrossEntropy
|
19 |
+
|
20 |
+
|
21 |
+
def gather_features(
|
22 |
+
image_features,
|
23 |
+
text_features,
|
24 |
+
local_loss=False,
|
25 |
+
gather_with_grad=False,
|
26 |
+
rank=0,
|
27 |
+
world_size=1,
|
28 |
+
use_horovod=False
|
29 |
+
):
|
30 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
31 |
+
if use_horovod:
|
32 |
+
assert hvd is not None, 'Please install horovod'
|
33 |
+
if gather_with_grad:
|
34 |
+
all_image_features = hvd.allgather(image_features)
|
35 |
+
all_text_features = hvd.allgather(text_features)
|
36 |
+
else:
|
37 |
+
with torch.no_grad():
|
38 |
+
all_image_features = hvd.allgather(image_features)
|
39 |
+
all_text_features = hvd.allgather(text_features)
|
40 |
+
if not local_loss:
|
41 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
42 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
43 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
44 |
+
gathered_image_features[rank] = image_features
|
45 |
+
gathered_text_features[rank] = text_features
|
46 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
47 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
48 |
+
else:
|
49 |
+
# We gather tensors from all gpus
|
50 |
+
if gather_with_grad:
|
51 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
52 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
53 |
+
# all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
|
54 |
+
# all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
|
55 |
+
else:
|
56 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
57 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
58 |
+
dist.all_gather(gathered_image_features, image_features)
|
59 |
+
dist.all_gather(gathered_text_features, text_features)
|
60 |
+
if not local_loss:
|
61 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
62 |
+
gathered_image_features[rank] = image_features
|
63 |
+
gathered_text_features[rank] = text_features
|
64 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
65 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
66 |
+
|
67 |
+
return all_image_features, all_text_features
|
68 |
+
|
69 |
+
|
70 |
+
class ClipLoss(nn.Module):
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
local_loss=False,
|
75 |
+
gather_with_grad=False,
|
76 |
+
cache_labels=False,
|
77 |
+
rank=0,
|
78 |
+
world_size=1,
|
79 |
+
use_horovod=False,
|
80 |
+
smoothing=0.,
|
81 |
+
):
|
82 |
+
super().__init__()
|
83 |
+
self.local_loss = local_loss
|
84 |
+
self.gather_with_grad = gather_with_grad
|
85 |
+
self.cache_labels = cache_labels
|
86 |
+
self.rank = rank
|
87 |
+
self.world_size = world_size
|
88 |
+
self.use_horovod = use_horovod
|
89 |
+
self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
|
90 |
+
|
91 |
+
# cache state
|
92 |
+
self.prev_num_logits = 0
|
93 |
+
self.labels = {}
|
94 |
+
|
95 |
+
def forward(self, image_features, text_features, logit_scale=1.):
|
96 |
+
device = image_features.device
|
97 |
+
if self.world_size > 1:
|
98 |
+
all_image_features, all_text_features = gather_features(
|
99 |
+
image_features, text_features,
|
100 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
101 |
+
|
102 |
+
if self.local_loss:
|
103 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
104 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
105 |
+
else:
|
106 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
107 |
+
logits_per_text = logits_per_image.T
|
108 |
+
else:
|
109 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
110 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
111 |
+
# calculated ground-truth and cache if enabled
|
112 |
+
num_logits = logits_per_image.shape[0]
|
113 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
114 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
115 |
+
if self.world_size > 1 and self.local_loss:
|
116 |
+
labels = labels + num_logits * self.rank
|
117 |
+
if self.cache_labels:
|
118 |
+
self.labels[device] = labels
|
119 |
+
self.prev_num_logits = num_logits
|
120 |
+
else:
|
121 |
+
labels = self.labels[device]
|
122 |
+
|
123 |
+
if self.label_smoothing_cross_entropy:
|
124 |
+
total_loss = (
|
125 |
+
self.label_smoothing_cross_entropy(logits_per_image, labels) +
|
126 |
+
self.label_smoothing_cross_entropy(logits_per_text, labels)
|
127 |
+
) / 2
|
128 |
+
else:
|
129 |
+
total_loss = (
|
130 |
+
F.cross_entropy(logits_per_image, labels) +
|
131 |
+
F.cross_entropy(logits_per_text, labels)
|
132 |
+
) / 2
|
133 |
+
|
134 |
+
acc = None
|
135 |
+
i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
|
136 |
+
t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
|
137 |
+
acc = {"i2t": i2t_acc, "t2i": t2i_acc}
|
138 |
+
return total_loss, acc
|
eva_clip/model.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Optional, Tuple, Union
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
try:
|
16 |
+
from .hf_model import HFTextEncoder
|
17 |
+
except:
|
18 |
+
HFTextEncoder = None
|
19 |
+
from .modified_resnet import ModifiedResNet
|
20 |
+
# from .timm_model import TimmModel
|
21 |
+
from .eva_vit_model import EVAVisionTransformer
|
22 |
+
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
23 |
+
|
24 |
+
try:
|
25 |
+
from apex.normalization import FusedLayerNorm
|
26 |
+
except:
|
27 |
+
FusedLayerNorm = LayerNorm
|
28 |
+
print("Please 'pip install apex'")
|
29 |
+
|
30 |
+
try:
|
31 |
+
import xformers.ops as xops
|
32 |
+
except ImportError:
|
33 |
+
xops = None
|
34 |
+
print("Please 'pip install xformers'")
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class CLIPVisionCfg:
|
38 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
39 |
+
width: int = 768
|
40 |
+
head_width: int = 64
|
41 |
+
mlp_ratio: float = 4.0
|
42 |
+
patch_size: int = 16
|
43 |
+
image_size: Union[Tuple[int, int], int] = 224
|
44 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
45 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
46 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
47 |
+
drop_path_rate: Optional[float] = None # drop path rate
|
48 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
49 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
50 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
51 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
52 |
+
timm_proj_bias: bool = False # enable bias final projection
|
53 |
+
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
|
54 |
+
qkv_bias: bool = True
|
55 |
+
fusedLN: bool = False
|
56 |
+
xattn: bool = False
|
57 |
+
postnorm: bool = False
|
58 |
+
rope: bool = False
|
59 |
+
pt_hw_seq_len: int = 16 # 224/14
|
60 |
+
intp_freq: bool = False
|
61 |
+
naiveswiglu: bool = False
|
62 |
+
subln: bool = False
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class CLIPTextCfg:
|
67 |
+
context_length: int = 77
|
68 |
+
vocab_size: int = 49408
|
69 |
+
width: int = 512
|
70 |
+
heads: int = 8
|
71 |
+
layers: int = 12
|
72 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
73 |
+
hf_model_name: str = None
|
74 |
+
hf_tokenizer_name: str = None
|
75 |
+
hf_model_pretrained: bool = True
|
76 |
+
proj: str = 'mlp'
|
77 |
+
pooler_type: str = 'mean_pooler'
|
78 |
+
masked_language_modeling: bool = False
|
79 |
+
fusedLN: bool = False
|
80 |
+
xattn: bool = False
|
81 |
+
attn_mask: bool = True
|
82 |
+
|
83 |
+
def get_cast_dtype(precision: str):
|
84 |
+
cast_dtype = None
|
85 |
+
if precision == 'bf16':
|
86 |
+
cast_dtype = torch.bfloat16
|
87 |
+
elif precision == 'fp16':
|
88 |
+
cast_dtype = torch.float16
|
89 |
+
return cast_dtype
|
90 |
+
|
91 |
+
|
92 |
+
def _build_vision_tower(
|
93 |
+
embed_dim: int,
|
94 |
+
vision_cfg: CLIPVisionCfg,
|
95 |
+
quick_gelu: bool = False,
|
96 |
+
cast_dtype: Optional[torch.dtype] = None
|
97 |
+
):
|
98 |
+
if isinstance(vision_cfg, dict):
|
99 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
100 |
+
|
101 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
102 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
103 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
104 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
105 |
+
|
106 |
+
if vision_cfg.eva_model_name:
|
107 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
108 |
+
norm_layer = LayerNorm
|
109 |
+
|
110 |
+
visual = EVAVisionTransformer(
|
111 |
+
img_size=vision_cfg.image_size,
|
112 |
+
patch_size=vision_cfg.patch_size,
|
113 |
+
num_classes=embed_dim,
|
114 |
+
use_mean_pooling=vision_cfg.global_average_pool, #False
|
115 |
+
init_values=vision_cfg.ls_init_value,
|
116 |
+
patch_dropout=vision_cfg.patch_dropout,
|
117 |
+
embed_dim=vision_cfg.width,
|
118 |
+
depth=vision_cfg.layers,
|
119 |
+
num_heads=vision_heads,
|
120 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
121 |
+
qkv_bias=vision_cfg.qkv_bias,
|
122 |
+
drop_path_rate=vision_cfg.drop_path_rate,
|
123 |
+
norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
|
124 |
+
xattn=vision_cfg.xattn,
|
125 |
+
rope=vision_cfg.rope,
|
126 |
+
postnorm=vision_cfg.postnorm,
|
127 |
+
pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
|
128 |
+
intp_freq= vision_cfg.intp_freq,
|
129 |
+
naiveswiglu= vision_cfg.naiveswiglu,
|
130 |
+
subln= vision_cfg.subln
|
131 |
+
)
|
132 |
+
elif vision_cfg.timm_model_name:
|
133 |
+
# visual = TimmModel(
|
134 |
+
# vision_cfg.timm_model_name,
|
135 |
+
# pretrained=vision_cfg.timm_model_pretrained,
|
136 |
+
# pool=vision_cfg.timm_pool,
|
137 |
+
# proj=vision_cfg.timm_proj,
|
138 |
+
# proj_bias=vision_cfg.timm_proj_bias,
|
139 |
+
# embed_dim=embed_dim,
|
140 |
+
# image_size=vision_cfg.image_size
|
141 |
+
# )
|
142 |
+
# act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
143 |
+
raise ValueError
|
144 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
145 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
146 |
+
visual = ModifiedResNet(
|
147 |
+
layers=vision_cfg.layers,
|
148 |
+
output_dim=embed_dim,
|
149 |
+
heads=vision_heads,
|
150 |
+
image_size=vision_cfg.image_size,
|
151 |
+
width=vision_cfg.width
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
155 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
156 |
+
visual = VisionTransformer(
|
157 |
+
image_size=vision_cfg.image_size,
|
158 |
+
patch_size=vision_cfg.patch_size,
|
159 |
+
width=vision_cfg.width,
|
160 |
+
layers=vision_cfg.layers,
|
161 |
+
heads=vision_heads,
|
162 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
163 |
+
ls_init_value=vision_cfg.ls_init_value,
|
164 |
+
patch_dropout=vision_cfg.patch_dropout,
|
165 |
+
global_average_pool=vision_cfg.global_average_pool,
|
166 |
+
output_dim=embed_dim,
|
167 |
+
act_layer=act_layer,
|
168 |
+
norm_layer=norm_layer,
|
169 |
+
)
|
170 |
+
|
171 |
+
return visual
|
172 |
+
|
173 |
+
|
174 |
+
def _build_text_tower(
|
175 |
+
embed_dim: int,
|
176 |
+
text_cfg: CLIPTextCfg,
|
177 |
+
quick_gelu: bool = False,
|
178 |
+
cast_dtype: Optional[torch.dtype] = None,
|
179 |
+
):
|
180 |
+
if isinstance(text_cfg, dict):
|
181 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
182 |
+
|
183 |
+
if text_cfg.hf_model_name:
|
184 |
+
text = HFTextEncoder(
|
185 |
+
text_cfg.hf_model_name,
|
186 |
+
output_dim=embed_dim,
|
187 |
+
tokenizer_name=text_cfg.hf_tokenizer_name,
|
188 |
+
proj=text_cfg.proj,
|
189 |
+
pooler_type=text_cfg.pooler_type,
|
190 |
+
masked_language_modeling=text_cfg.masked_language_modeling
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
194 |
+
norm_layer = LayerNorm
|
195 |
+
|
196 |
+
text = TextTransformer(
|
197 |
+
context_length=text_cfg.context_length,
|
198 |
+
vocab_size=text_cfg.vocab_size,
|
199 |
+
width=text_cfg.width,
|
200 |
+
heads=text_cfg.heads,
|
201 |
+
layers=text_cfg.layers,
|
202 |
+
ls_init_value=text_cfg.ls_init_value,
|
203 |
+
output_dim=embed_dim,
|
204 |
+
act_layer=act_layer,
|
205 |
+
norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
|
206 |
+
xattn=text_cfg.xattn,
|
207 |
+
attn_mask=text_cfg.attn_mask,
|
208 |
+
)
|
209 |
+
return text
|
210 |
+
|
211 |
+
class CLIP(nn.Module):
|
212 |
+
def __init__(
|
213 |
+
self,
|
214 |
+
embed_dim: int,
|
215 |
+
vision_cfg: CLIPVisionCfg,
|
216 |
+
text_cfg: CLIPTextCfg,
|
217 |
+
quick_gelu: bool = False,
|
218 |
+
cast_dtype: Optional[torch.dtype] = None,
|
219 |
+
):
|
220 |
+
super().__init__()
|
221 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
222 |
+
|
223 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
224 |
+
self.transformer = text.transformer
|
225 |
+
self.vocab_size = text.vocab_size
|
226 |
+
self.token_embedding = text.token_embedding
|
227 |
+
self.positional_embedding = text.positional_embedding
|
228 |
+
self.ln_final = text.ln_final
|
229 |
+
self.text_projection = text.text_projection
|
230 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
231 |
+
|
232 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
233 |
+
|
234 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
235 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
236 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
237 |
+
|
238 |
+
@torch.jit.ignore
|
239 |
+
def set_grad_checkpointing(self, enable=True):
|
240 |
+
self.visual.set_grad_checkpointing(enable)
|
241 |
+
self.transformer.grad_checkpointing = enable
|
242 |
+
|
243 |
+
@torch.jit.ignore
|
244 |
+
def no_weight_decay(self):
|
245 |
+
return {'logit_scale'}
|
246 |
+
|
247 |
+
def encode_image(self, image, normalize: bool = False):
|
248 |
+
features = self.visual(image)
|
249 |
+
return F.normalize(features, dim=-1) if normalize else features
|
250 |
+
|
251 |
+
def encode_text(self, text, normalize: bool = False):
|
252 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
253 |
+
|
254 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
255 |
+
|
256 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
257 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
258 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
259 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
260 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
261 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
262 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
263 |
+
return F.normalize(x, dim=-1) if normalize else x
|
264 |
+
|
265 |
+
def forward(self, image, text):
|
266 |
+
image_features = self.encode_image(image, normalize=True)
|
267 |
+
text_features = self.encode_text(text, normalize=True)
|
268 |
+
return image_features, text_features, self.logit_scale.exp()
|
269 |
+
|
270 |
+
|
271 |
+
class CustomCLIP(nn.Module):
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
embed_dim: int,
|
275 |
+
vision_cfg: CLIPVisionCfg,
|
276 |
+
text_cfg: CLIPTextCfg,
|
277 |
+
quick_gelu: bool = False,
|
278 |
+
cast_dtype: Optional[torch.dtype] = None,
|
279 |
+
itm_task: bool = False,
|
280 |
+
):
|
281 |
+
super().__init__()
|
282 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
283 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
284 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
285 |
+
|
286 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
287 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
288 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
289 |
+
|
290 |
+
def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
|
291 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
292 |
+
|
293 |
+
@torch.jit.ignore
|
294 |
+
def set_grad_checkpointing(self, enable=True):
|
295 |
+
self.visual.set_grad_checkpointing(enable)
|
296 |
+
self.text.set_grad_checkpointing(enable)
|
297 |
+
|
298 |
+
@torch.jit.ignore
|
299 |
+
def no_weight_decay(self):
|
300 |
+
return {'logit_scale'}
|
301 |
+
|
302 |
+
def encode_image(self, image, normalize: bool = False):
|
303 |
+
features = self.visual(image)
|
304 |
+
return F.normalize(features, dim=-1) if normalize else features
|
305 |
+
|
306 |
+
def encode_text(self, text, normalize: bool = False):
|
307 |
+
features = self.text(text)
|
308 |
+
return F.normalize(features, dim=-1) if normalize else features
|
309 |
+
|
310 |
+
def forward(self, image, text):
|
311 |
+
image_features = self.encode_image(image, normalize=True)
|
312 |
+
text_features = self.encode_text(text, normalize=True)
|
313 |
+
return image_features, text_features, self.logit_scale.exp()
|
314 |
+
|
315 |
+
|
316 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
317 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
318 |
+
|
319 |
+
def _convert_weights(l):
|
320 |
+
|
321 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
322 |
+
l.weight.data = l.weight.data.to(dtype)
|
323 |
+
if l.bias is not None:
|
324 |
+
l.bias.data = l.bias.data.to(dtype)
|
325 |
+
|
326 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
327 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
328 |
+
tensor = getattr(l, attr, None)
|
329 |
+
if tensor is not None:
|
330 |
+
tensor.data = tensor.data.to(dtype)
|
331 |
+
|
332 |
+
if isinstance(l, nn.Parameter):
|
333 |
+
l.data = l.data.to(dtype)
|
334 |
+
|
335 |
+
for name in ["text_projection", "proj"]:
|
336 |
+
if hasattr(l, name) and isinstance(l, nn.Parameter):
|
337 |
+
attr = getattr(l, name, None)
|
338 |
+
if attr is not None:
|
339 |
+
attr.data = attr.data.to(dtype)
|
340 |
+
|
341 |
+
model.apply(_convert_weights)
|
342 |
+
|
343 |
+
|
344 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
345 |
+
|
346 |
+
|
347 |
+
# used to maintain checkpoint compatibility
|
348 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
349 |
+
if 'text_projection' in state_dict:
|
350 |
+
# old format state_dict, move text tower -> .text
|
351 |
+
new_state_dict = {}
|
352 |
+
for k, v in state_dict.items():
|
353 |
+
if any(k.startswith(p) for p in (
|
354 |
+
'text_projection',
|
355 |
+
'positional_embedding',
|
356 |
+
'token_embedding',
|
357 |
+
'transformer',
|
358 |
+
'ln_final',
|
359 |
+
'logit_scale'
|
360 |
+
)):
|
361 |
+
k = 'text.' + k
|
362 |
+
new_state_dict[k] = v
|
363 |
+
return new_state_dict
|
364 |
+
return state_dict
|
365 |
+
|
366 |
+
|
367 |
+
def build_model_from_openai_state_dict(
|
368 |
+
state_dict: dict,
|
369 |
+
quick_gelu=True,
|
370 |
+
cast_dtype=torch.float16,
|
371 |
+
):
|
372 |
+
vit = "visual.proj" in state_dict
|
373 |
+
|
374 |
+
if vit:
|
375 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
376 |
+
vision_layers = len(
|
377 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
378 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
379 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
380 |
+
image_size = vision_patch_size * grid_size
|
381 |
+
else:
|
382 |
+
counts: list = [
|
383 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
384 |
+
vision_layers = tuple(counts)
|
385 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
386 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
387 |
+
vision_patch_size = None
|
388 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
389 |
+
image_size = output_width * 32
|
390 |
+
|
391 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
392 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
393 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
394 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
395 |
+
transformer_heads = transformer_width // 64
|
396 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
397 |
+
|
398 |
+
vision_cfg = CLIPVisionCfg(
|
399 |
+
layers=vision_layers,
|
400 |
+
width=vision_width,
|
401 |
+
patch_size=vision_patch_size,
|
402 |
+
image_size=image_size,
|
403 |
+
)
|
404 |
+
text_cfg = CLIPTextCfg(
|
405 |
+
context_length=context_length,
|
406 |
+
vocab_size=vocab_size,
|
407 |
+
width=transformer_width,
|
408 |
+
heads=transformer_heads,
|
409 |
+
layers=transformer_layers
|
410 |
+
)
|
411 |
+
model = CLIP(
|
412 |
+
embed_dim,
|
413 |
+
vision_cfg=vision_cfg,
|
414 |
+
text_cfg=text_cfg,
|
415 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
416 |
+
cast_dtype=cast_dtype,
|
417 |
+
)
|
418 |
+
|
419 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
420 |
+
state_dict.pop(key, None)
|
421 |
+
|
422 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
423 |
+
model.load_state_dict(state_dict)
|
424 |
+
return model.eval()
|
425 |
+
|
426 |
+
|
427 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
428 |
+
model.eval()
|
429 |
+
image_size = model.visual.image_size
|
430 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
431 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
432 |
+
model = torch.jit.trace_module(
|
433 |
+
model,
|
434 |
+
inputs=dict(
|
435 |
+
forward=(example_images, example_text),
|
436 |
+
encode_text=(example_text,),
|
437 |
+
encode_image=(example_images,)
|
438 |
+
))
|
439 |
+
model.visual.image_size = image_size
|
440 |
+
return model
|
eva_clip/model_configs/EVA01-CLIP-B-16.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 16,
|
8 |
+
"eva_model_name": "eva-clip-b-16",
|
9 |
+
"ls_init_value": 0.1,
|
10 |
+
"drop_path_rate": 0.0
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 77,
|
14 |
+
"vocab_size": 49408,
|
15 |
+
"width": 512,
|
16 |
+
"heads": 8,
|
17 |
+
"layers": 12
|
18 |
+
}
|
19 |
+
}
|
eva_clip/model_configs/EVA01-CLIP-g-14-plus.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 40,
|
6 |
+
"width": 1408,
|
7 |
+
"head_width": 88,
|
8 |
+
"mlp_ratio": 4.3637,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-g-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 1024,
|
19 |
+
"heads": 16,
|
20 |
+
"layers": 24,
|
21 |
+
"xattn": false,
|
22 |
+
"fusedLN": true
|
23 |
+
}
|
24 |
+
}
|
eva_clip/model_configs/EVA01-CLIP-g-14.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 40,
|
6 |
+
"width": 1408,
|
7 |
+
"head_width": 88,
|
8 |
+
"mlp_ratio": 4.3637,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-g-14-x",
|
11 |
+
"drop_path_rate": 0.4,
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 768,
|
19 |
+
"heads": 12,
|
20 |
+
"layers": 12,
|
21 |
+
"xattn": false,
|
22 |
+
"fusedLN": true
|
23 |
+
}
|
24 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-B-16.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"head_width": 64,
|
8 |
+
"patch_size": 16,
|
9 |
+
"mlp_ratio": 2.6667,
|
10 |
+
"eva_model_name": "eva-clip-b-16-X",
|
11 |
+
"drop_path_rate": 0.0,
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true,
|
14 |
+
"rope": true,
|
15 |
+
"pt_hw_seq_len": 16,
|
16 |
+
"intp_freq": true,
|
17 |
+
"naiveswiglu": true,
|
18 |
+
"subln": true
|
19 |
+
},
|
20 |
+
"text_cfg": {
|
21 |
+
"context_length": 77,
|
22 |
+
"vocab_size": 49408,
|
23 |
+
"width": 512,
|
24 |
+
"heads": 8,
|
25 |
+
"layers": 12,
|
26 |
+
"xattn": true,
|
27 |
+
"fusedLN": true
|
28 |
+
}
|
29 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-L-14-336.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"drop_path_rate": 0,
|
8 |
+
"head_width": 64,
|
9 |
+
"mlp_ratio": 2.6667,
|
10 |
+
"patch_size": 14,
|
11 |
+
"eva_model_name": "eva-clip-l-14-336",
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true,
|
14 |
+
"rope": true,
|
15 |
+
"pt_hw_seq_len": 16,
|
16 |
+
"intp_freq": true,
|
17 |
+
"naiveswiglu": true,
|
18 |
+
"subln": true
|
19 |
+
},
|
20 |
+
"text_cfg": {
|
21 |
+
"context_length": 77,
|
22 |
+
"vocab_size": 49408,
|
23 |
+
"width": 768,
|
24 |
+
"heads": 12,
|
25 |
+
"layers": 12,
|
26 |
+
"xattn": false,
|
27 |
+
"fusedLN": true
|
28 |
+
}
|
29 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-L-14.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"drop_path_rate": 0,
|
8 |
+
"head_width": 64,
|
9 |
+
"mlp_ratio": 2.6667,
|
10 |
+
"patch_size": 14,
|
11 |
+
"eva_model_name": "eva-clip-l-14",
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true,
|
14 |
+
"rope": true,
|
15 |
+
"pt_hw_seq_len": 16,
|
16 |
+
"intp_freq": true,
|
17 |
+
"naiveswiglu": true,
|
18 |
+
"subln": true
|
19 |
+
},
|
20 |
+
"text_cfg": {
|
21 |
+
"context_length": 77,
|
22 |
+
"vocab_size": 49408,
|
23 |
+
"width": 768,
|
24 |
+
"heads": 12,
|
25 |
+
"layers": 12,
|
26 |
+
"xattn": false,
|
27 |
+
"fusedLN": true
|
28 |
+
}
|
29 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 64,
|
6 |
+
"width": 1792,
|
7 |
+
"head_width": 112,
|
8 |
+
"mlp_ratio": 8.571428571428571,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-4b-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"xattn": true,
|
13 |
+
"postnorm": true,
|
14 |
+
"fusedLN": true
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 1280,
|
20 |
+
"heads": 20,
|
21 |
+
"layers": 32,
|
22 |
+
"xattn": false,
|
23 |
+
"fusedLN": true
|
24 |
+
}
|
25 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-bigE-14.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 64,
|
6 |
+
"width": 1792,
|
7 |
+
"head_width": 112,
|
8 |
+
"mlp_ratio": 8.571428571428571,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-4b-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"xattn": true,
|
13 |
+
"postnorm": true,
|
14 |
+
"fusedLN": true
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 1024,
|
20 |
+
"heads": 16,
|
21 |
+
"layers": 24,
|
22 |
+
"xattn": false,
|
23 |
+
"fusedLN": true
|
24 |
+
}
|
25 |
+
}
|
eva_clip/modified_resnet.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from eva_clip.utils import freeze_batch_norm_2d
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.act1 = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
self.act2 = nn.ReLU(inplace=True)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
self.act3 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
37 |
+
("-1", nn.AvgPool2d(stride)),
|
38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
40 |
+
]))
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor):
|
43 |
+
identity = x
|
44 |
+
|
45 |
+
out = self.act1(self.bn1(self.conv1(x)))
|
46 |
+
out = self.act2(self.bn2(self.conv2(out)))
|
47 |
+
out = self.avgpool(out)
|
48 |
+
out = self.bn3(self.conv3(out))
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
identity = self.downsample(x)
|
52 |
+
|
53 |
+
out += identity
|
54 |
+
out = self.act3(out)
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
class AttentionPool2d(nn.Module):
|
59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
60 |
+
super().__init__()
|
61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
66 |
+
self.num_heads = num_heads
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
72 |
+
x, _ = F.multi_head_attention_forward(
|
73 |
+
query=x, key=x, value=x,
|
74 |
+
embed_dim_to_check=x.shape[-1],
|
75 |
+
num_heads=self.num_heads,
|
76 |
+
q_proj_weight=self.q_proj.weight,
|
77 |
+
k_proj_weight=self.k_proj.weight,
|
78 |
+
v_proj_weight=self.v_proj.weight,
|
79 |
+
in_proj_weight=None,
|
80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
81 |
+
bias_k=None,
|
82 |
+
bias_v=None,
|
83 |
+
add_zero_attn=False,
|
84 |
+
dropout_p=0.,
|
85 |
+
out_proj_weight=self.c_proj.weight,
|
86 |
+
out_proj_bias=self.c_proj.bias,
|
87 |
+
use_separate_proj_weight=True,
|
88 |
+
training=self.training,
|
89 |
+
need_weights=False
|
90 |
+
)
|
91 |
+
|
92 |
+
return x[0]
|
93 |
+
|
94 |
+
|
95 |
+
class ModifiedResNet(nn.Module):
|
96 |
+
"""
|
97 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
98 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
99 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
100 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
104 |
+
super().__init__()
|
105 |
+
self.output_dim = output_dim
|
106 |
+
self.image_size = image_size
|
107 |
+
|
108 |
+
# the 3-layer stem
|
109 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
110 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
111 |
+
self.act1 = nn.ReLU(inplace=True)
|
112 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
113 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
114 |
+
self.act2 = nn.ReLU(inplace=True)
|
115 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
116 |
+
self.bn3 = nn.BatchNorm2d(width)
|
117 |
+
self.act3 = nn.ReLU(inplace=True)
|
118 |
+
self.avgpool = nn.AvgPool2d(2)
|
119 |
+
|
120 |
+
# residual layers
|
121 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
122 |
+
self.layer1 = self._make_layer(width, layers[0])
|
123 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
124 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
125 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
126 |
+
|
127 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
128 |
+
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
129 |
+
|
130 |
+
self.init_parameters()
|
131 |
+
|
132 |
+
def _make_layer(self, planes, blocks, stride=1):
|
133 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
134 |
+
|
135 |
+
self._inplanes = planes * Bottleneck.expansion
|
136 |
+
for _ in range(1, blocks):
|
137 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
138 |
+
|
139 |
+
return nn.Sequential(*layers)
|
140 |
+
|
141 |
+
def init_parameters(self):
|
142 |
+
if self.attnpool is not None:
|
143 |
+
std = self.attnpool.c_proj.in_features ** -0.5
|
144 |
+
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
145 |
+
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
146 |
+
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
147 |
+
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
148 |
+
|
149 |
+
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
150 |
+
for name, param in resnet_block.named_parameters():
|
151 |
+
if name.endswith("bn3.weight"):
|
152 |
+
nn.init.zeros_(param)
|
153 |
+
|
154 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
155 |
+
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
156 |
+
for param in self.parameters():
|
157 |
+
param.requires_grad = False
|
158 |
+
if freeze_bn_stats:
|
159 |
+
freeze_batch_norm_2d(self)
|
160 |
+
|
161 |
+
@torch.jit.ignore
|
162 |
+
def set_grad_checkpointing(self, enable=True):
|
163 |
+
# FIXME support for non-transformer
|
164 |
+
pass
|
165 |
+
|
166 |
+
def stem(self, x):
|
167 |
+
x = self.act1(self.bn1(self.conv1(x)))
|
168 |
+
x = self.act2(self.bn2(self.conv2(x)))
|
169 |
+
x = self.act3(self.bn3(self.conv3(x)))
|
170 |
+
x = self.avgpool(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
x = self.stem(x)
|
175 |
+
x = self.layer1(x)
|
176 |
+
x = self.layer2(x)
|
177 |
+
x = self.layer3(x)
|
178 |
+
x = self.layer4(x)
|
179 |
+
x = self.attnpool(x)
|
180 |
+
|
181 |
+
return x
|
eva_clip/openai.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" OpenAI pretrained model functions
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import warnings
|
8 |
+
from typing import List, Optional, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
|
13 |
+
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
|
14 |
+
|
15 |
+
__all__ = ["list_openai_models", "load_openai_model"]
|
16 |
+
|
17 |
+
|
18 |
+
def list_openai_models() -> List[str]:
|
19 |
+
"""Returns the names of available CLIP models"""
|
20 |
+
return list_pretrained_models_by_tag('openai')
|
21 |
+
|
22 |
+
|
23 |
+
def load_openai_model(
|
24 |
+
name: str,
|
25 |
+
precision: Optional[str] = None,
|
26 |
+
device: Optional[Union[str, torch.device]] = None,
|
27 |
+
jit: bool = True,
|
28 |
+
cache_dir: Optional[str] = None,
|
29 |
+
):
|
30 |
+
"""Load a CLIP model
|
31 |
+
|
32 |
+
Parameters
|
33 |
+
----------
|
34 |
+
name : str
|
35 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
36 |
+
precision: str
|
37 |
+
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
|
38 |
+
device : Union[str, torch.device]
|
39 |
+
The device to put the loaded model
|
40 |
+
jit : bool
|
41 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
42 |
+
cache_dir : Optional[str]
|
43 |
+
The directory to cache the downloaded model weights
|
44 |
+
|
45 |
+
Returns
|
46 |
+
-------
|
47 |
+
model : torch.nn.Module
|
48 |
+
The CLIP model
|
49 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
50 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
51 |
+
"""
|
52 |
+
if device is None:
|
53 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
54 |
+
if precision is None:
|
55 |
+
precision = 'fp32' if device == 'cpu' else 'fp16'
|
56 |
+
|
57 |
+
if get_pretrained_url(name, 'openai'):
|
58 |
+
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
|
59 |
+
elif os.path.isfile(name):
|
60 |
+
model_path = name
|
61 |
+
else:
|
62 |
+
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
|
63 |
+
|
64 |
+
try:
|
65 |
+
# loading JIT archive
|
66 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
67 |
+
state_dict = None
|
68 |
+
except RuntimeError:
|
69 |
+
# loading saved state dict
|
70 |
+
if jit:
|
71 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
72 |
+
jit = False
|
73 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
74 |
+
|
75 |
+
if not jit:
|
76 |
+
# Build a non-jit model from the OpenAI jitted model state dict
|
77 |
+
cast_dtype = get_cast_dtype(precision)
|
78 |
+
try:
|
79 |
+
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
|
80 |
+
except KeyError:
|
81 |
+
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
82 |
+
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
|
83 |
+
|
84 |
+
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
|
85 |
+
model = model.to(device)
|
86 |
+
if precision.startswith('amp') or precision == 'fp32':
|
87 |
+
model.float()
|
88 |
+
elif precision == 'bf16':
|
89 |
+
convert_weights_to_lp(model, dtype=torch.bfloat16)
|
90 |
+
|
91 |
+
return model
|
92 |
+
|
93 |
+
# patch the device names
|
94 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
95 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
96 |
+
|
97 |
+
def patch_device(module):
|
98 |
+
try:
|
99 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
100 |
+
except RuntimeError:
|
101 |
+
graphs = []
|
102 |
+
|
103 |
+
if hasattr(module, "forward1"):
|
104 |
+
graphs.append(module.forward1.graph)
|
105 |
+
|
106 |
+
for graph in graphs:
|
107 |
+
for node in graph.findAllNodes("prim::Constant"):
|
108 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
109 |
+
node.copyAttributes(device_node)
|
110 |
+
|
111 |
+
model.apply(patch_device)
|
112 |
+
patch_device(model.encode_image)
|
113 |
+
patch_device(model.encode_text)
|
114 |
+
|
115 |
+
# patch dtype to float32 (typically for CPU)
|
116 |
+
if precision == 'fp32':
|
117 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
118 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
119 |
+
float_node = float_input.node()
|
120 |
+
|
121 |
+
def patch_float(module):
|
122 |
+
try:
|
123 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
124 |
+
except RuntimeError:
|
125 |
+
graphs = []
|
126 |
+
|
127 |
+
if hasattr(module, "forward1"):
|
128 |
+
graphs.append(module.forward1.graph)
|
129 |
+
|
130 |
+
for graph in graphs:
|
131 |
+
for node in graph.findAllNodes("aten::to"):
|
132 |
+
inputs = list(node.inputs())
|
133 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
134 |
+
if inputs[i].node()["value"] == 5:
|
135 |
+
inputs[i].node().copyAttributes(float_node)
|
136 |
+
|
137 |
+
model.apply(patch_float)
|
138 |
+
patch_float(model.encode_image)
|
139 |
+
patch_float(model.encode_text)
|
140 |
+
model.float()
|
141 |
+
|
142 |
+
# ensure image_size attr available at consistent location for both jit and non-jit
|
143 |
+
model.visual.image_size = model.input_resolution.item()
|
144 |
+
return model
|
eva_clip/pretrained.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from functools import partial
|
6 |
+
from typing import Dict, Union
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
try:
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
_has_hf_hub = True
|
13 |
+
except ImportError:
|
14 |
+
hf_hub_download = None
|
15 |
+
_has_hf_hub = False
|
16 |
+
|
17 |
+
|
18 |
+
def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
|
19 |
+
return dict(
|
20 |
+
url=url,
|
21 |
+
hf_hub=hf_hub,
|
22 |
+
mean=mean,
|
23 |
+
std=std,
|
24 |
+
)
|
25 |
+
|
26 |
+
_VITB32 = dict(
|
27 |
+
openai=_pcfg(
|
28 |
+
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
29 |
+
laion400m_e31=_pcfg(
|
30 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
31 |
+
laion400m_e32=_pcfg(
|
32 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
33 |
+
laion2b_e16=_pcfg(
|
34 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
|
35 |
+
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
|
36 |
+
)
|
37 |
+
|
38 |
+
_VITB32_quickgelu = dict(
|
39 |
+
openai=_pcfg(
|
40 |
+
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
41 |
+
laion400m_e31=_pcfg(
|
42 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
43 |
+
laion400m_e32=_pcfg(
|
44 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
45 |
+
)
|
46 |
+
|
47 |
+
_VITB16 = dict(
|
48 |
+
openai=_pcfg(
|
49 |
+
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
|
50 |
+
laion400m_e31=_pcfg(
|
51 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
|
52 |
+
laion400m_e32=_pcfg(
|
53 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
|
54 |
+
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
|
55 |
+
)
|
56 |
+
|
57 |
+
_EVAB16 = dict(
|
58 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
|
59 |
+
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
|
60 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
|
61 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
|
62 |
+
)
|
63 |
+
|
64 |
+
_VITB16_PLUS_240 = dict(
|
65 |
+
laion400m_e31=_pcfg(
|
66 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
|
67 |
+
laion400m_e32=_pcfg(
|
68 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
|
69 |
+
)
|
70 |
+
|
71 |
+
_VITL14 = dict(
|
72 |
+
openai=_pcfg(
|
73 |
+
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
|
74 |
+
laion400m_e31=_pcfg(
|
75 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
|
76 |
+
laion400m_e32=_pcfg(
|
77 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
|
78 |
+
laion2b_s32b_b82k=_pcfg(
|
79 |
+
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
|
80 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
81 |
+
)
|
82 |
+
|
83 |
+
_EVAL14 = dict(
|
84 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
|
85 |
+
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
|
86 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
|
87 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
|
88 |
+
)
|
89 |
+
|
90 |
+
_VITL14_336 = dict(
|
91 |
+
openai=_pcfg(
|
92 |
+
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
|
93 |
+
)
|
94 |
+
|
95 |
+
_EVAL14_336 = dict(
|
96 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
|
97 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
|
98 |
+
eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
|
99 |
+
eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
|
100 |
+
)
|
101 |
+
|
102 |
+
_VITH14 = dict(
|
103 |
+
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
|
104 |
+
)
|
105 |
+
|
106 |
+
_VITg14 = dict(
|
107 |
+
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
|
108 |
+
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
|
109 |
+
)
|
110 |
+
|
111 |
+
_EVAg14 = dict(
|
112 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
|
113 |
+
eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
|
114 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
|
115 |
+
eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
|
116 |
+
)
|
117 |
+
|
118 |
+
_EVAg14_PLUS = dict(
|
119 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
|
120 |
+
eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
|
121 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
|
122 |
+
eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
|
123 |
+
)
|
124 |
+
|
125 |
+
_VITbigG14 = dict(
|
126 |
+
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
|
127 |
+
)
|
128 |
+
|
129 |
+
_EVAbigE14 = dict(
|
130 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
|
131 |
+
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
|
132 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
|
133 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
|
134 |
+
)
|
135 |
+
|
136 |
+
_EVAbigE14_PLUS = dict(
|
137 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
|
138 |
+
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
|
139 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
|
140 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
_PRETRAINED = {
|
145 |
+
# "ViT-B-32": _VITB32,
|
146 |
+
"OpenaiCLIP-B-32": _VITB32,
|
147 |
+
"OpenCLIP-B-32": _VITB32,
|
148 |
+
|
149 |
+
# "ViT-B-32-quickgelu": _VITB32_quickgelu,
|
150 |
+
"OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
|
151 |
+
"OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
|
152 |
+
|
153 |
+
# "ViT-B-16": _VITB16,
|
154 |
+
"OpenaiCLIP-B-16": _VITB16,
|
155 |
+
"OpenCLIP-B-16": _VITB16,
|
156 |
+
|
157 |
+
"EVA02-B-16": _EVAB16,
|
158 |
+
"EVA02-CLIP-B-16": _EVAB16,
|
159 |
+
|
160 |
+
# "ViT-B-16-plus-240": _VITB16_PLUS_240,
|
161 |
+
"OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
|
162 |
+
|
163 |
+
# "ViT-L-14": _VITL14,
|
164 |
+
"OpenaiCLIP-L-14": _VITL14,
|
165 |
+
"OpenCLIP-L-14": _VITL14,
|
166 |
+
|
167 |
+
"EVA02-L-14": _EVAL14,
|
168 |
+
"EVA02-CLIP-L-14": _EVAL14,
|
169 |
+
|
170 |
+
# "ViT-L-14-336": _VITL14_336,
|
171 |
+
"OpenaiCLIP-L-14-336": _VITL14_336,
|
172 |
+
|
173 |
+
"EVA02-CLIP-L-14-336": _EVAL14_336,
|
174 |
+
|
175 |
+
# "ViT-H-14": _VITH14,
|
176 |
+
# "ViT-g-14": _VITg14,
|
177 |
+
"OpenCLIP-H-14": _VITH14,
|
178 |
+
"OpenCLIP-g-14": _VITg14,
|
179 |
+
|
180 |
+
"EVA01-CLIP-g-14": _EVAg14,
|
181 |
+
"EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
|
182 |
+
|
183 |
+
# "ViT-bigG-14": _VITbigG14,
|
184 |
+
"OpenCLIP-bigG-14": _VITbigG14,
|
185 |
+
|
186 |
+
"EVA02-CLIP-bigE-14": _EVAbigE14,
|
187 |
+
"EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
|
188 |
+
}
|
189 |
+
|
190 |
+
|
191 |
+
def _clean_tag(tag: str):
|
192 |
+
# normalize pretrained tags
|
193 |
+
return tag.lower().replace('-', '_')
|
194 |
+
|
195 |
+
|
196 |
+
def list_pretrained(as_str: bool = False):
|
197 |
+
""" returns list of pretrained models
|
198 |
+
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
|
199 |
+
"""
|
200 |
+
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
|
201 |
+
|
202 |
+
|
203 |
+
def list_pretrained_models_by_tag(tag: str):
|
204 |
+
""" return all models having the specified pretrain tag """
|
205 |
+
models = []
|
206 |
+
tag = _clean_tag(tag)
|
207 |
+
for k in _PRETRAINED.keys():
|
208 |
+
if tag in _PRETRAINED[k]:
|
209 |
+
models.append(k)
|
210 |
+
return models
|
211 |
+
|
212 |
+
|
213 |
+
def list_pretrained_tags_by_model(model: str):
|
214 |
+
""" return all pretrain tags for the specified model architecture """
|
215 |
+
tags = []
|
216 |
+
if model in _PRETRAINED:
|
217 |
+
tags.extend(_PRETRAINED[model].keys())
|
218 |
+
return tags
|
219 |
+
|
220 |
+
|
221 |
+
def is_pretrained_cfg(model: str, tag: str):
|
222 |
+
if model not in _PRETRAINED:
|
223 |
+
return False
|
224 |
+
return _clean_tag(tag) in _PRETRAINED[model]
|
225 |
+
|
226 |
+
|
227 |
+
def get_pretrained_cfg(model: str, tag: str):
|
228 |
+
if model not in _PRETRAINED:
|
229 |
+
return {}
|
230 |
+
model_pretrained = _PRETRAINED[model]
|
231 |
+
return model_pretrained.get(_clean_tag(tag), {})
|
232 |
+
|
233 |
+
|
234 |
+
def get_pretrained_url(model: str, tag: str):
|
235 |
+
cfg = get_pretrained_cfg(model, _clean_tag(tag))
|
236 |
+
return cfg.get('url', '')
|
237 |
+
|
238 |
+
|
239 |
+
def download_pretrained_from_url(
|
240 |
+
url: str,
|
241 |
+
cache_dir: Union[str, None] = None,
|
242 |
+
):
|
243 |
+
if not cache_dir:
|
244 |
+
cache_dir = os.path.expanduser("~/.cache/clip")
|
245 |
+
os.makedirs(cache_dir, exist_ok=True)
|
246 |
+
filename = os.path.basename(url)
|
247 |
+
|
248 |
+
if 'openaipublic' in url:
|
249 |
+
expected_sha256 = url.split("/")[-2]
|
250 |
+
elif 'mlfoundations' in url:
|
251 |
+
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
252 |
+
else:
|
253 |
+
expected_sha256 = ''
|
254 |
+
|
255 |
+
download_target = os.path.join(cache_dir, filename)
|
256 |
+
|
257 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
258 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
259 |
+
|
260 |
+
if os.path.isfile(download_target):
|
261 |
+
if expected_sha256:
|
262 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
263 |
+
return download_target
|
264 |
+
else:
|
265 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
266 |
+
else:
|
267 |
+
return download_target
|
268 |
+
|
269 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
270 |
+
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
271 |
+
while True:
|
272 |
+
buffer = source.read(8192)
|
273 |
+
if not buffer:
|
274 |
+
break
|
275 |
+
|
276 |
+
output.write(buffer)
|
277 |
+
loop.update(len(buffer))
|
278 |
+
|
279 |
+
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
280 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
281 |
+
|
282 |
+
return download_target
|
283 |
+
|
284 |
+
|
285 |
+
def has_hf_hub(necessary=False):
|
286 |
+
if not _has_hf_hub and necessary:
|
287 |
+
# if no HF Hub module installed, and it is necessary to continue, raise error
|
288 |
+
raise RuntimeError(
|
289 |
+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
290 |
+
return _has_hf_hub
|
291 |
+
|
292 |
+
|
293 |
+
def download_pretrained_from_hf(
|
294 |
+
model_id: str,
|
295 |
+
filename: str = 'open_clip_pytorch_model.bin',
|
296 |
+
revision=None,
|
297 |
+
cache_dir: Union[str, None] = None,
|
298 |
+
):
|
299 |
+
has_hf_hub(True)
|
300 |
+
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
|
301 |
+
return cached_file
|
302 |
+
|
303 |
+
|
304 |
+
def download_pretrained(
|
305 |
+
cfg: Dict,
|
306 |
+
force_hf_hub: bool = False,
|
307 |
+
cache_dir: Union[str, None] = None,
|
308 |
+
):
|
309 |
+
target = ''
|
310 |
+
if not cfg:
|
311 |
+
return target
|
312 |
+
|
313 |
+
download_url = cfg.get('url', '')
|
314 |
+
download_hf_hub = cfg.get('hf_hub', '')
|
315 |
+
if download_hf_hub and force_hf_hub:
|
316 |
+
# use HF hub even if url exists
|
317 |
+
download_url = ''
|
318 |
+
|
319 |
+
if download_url:
|
320 |
+
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
|
321 |
+
elif download_hf_hub:
|
322 |
+
has_hf_hub(True)
|
323 |
+
# we assume the hf_hub entries in pretrained config combine model_id + filename in
|
324 |
+
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
|
325 |
+
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
|
326 |
+
model_id, filename = os.path.split(download_hf_hub)
|
327 |
+
if filename:
|
328 |
+
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
|
329 |
+
else:
|
330 |
+
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
331 |
+
|
332 |
+
return target
|
eva_clip/rope.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
import logging
|
6 |
+
|
7 |
+
def broadcat(tensors, dim = -1):
|
8 |
+
num_tensors = len(tensors)
|
9 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
10 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
11 |
+
shape_len = list(shape_lens)[0]
|
12 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
13 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
14 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
15 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
16 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
17 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
18 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
19 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
20 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
21 |
+
return torch.cat(tensors, dim = dim)
|
22 |
+
|
23 |
+
def rotate_half(x):
|
24 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
25 |
+
x1, x2 = x.unbind(dim = -1)
|
26 |
+
x = torch.stack((-x2, x1), dim = -1)
|
27 |
+
return rearrange(x, '... d r -> ... (d r)')
|
28 |
+
|
29 |
+
|
30 |
+
class VisionRotaryEmbedding(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
dim,
|
34 |
+
pt_seq_len,
|
35 |
+
ft_seq_len=None,
|
36 |
+
custom_freqs = None,
|
37 |
+
freqs_for = 'lang',
|
38 |
+
theta = 10000,
|
39 |
+
max_freq = 10,
|
40 |
+
num_freqs = 1,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
if custom_freqs:
|
44 |
+
freqs = custom_freqs
|
45 |
+
elif freqs_for == 'lang':
|
46 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
47 |
+
elif freqs_for == 'pixel':
|
48 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
49 |
+
elif freqs_for == 'constant':
|
50 |
+
freqs = torch.ones(num_freqs).float()
|
51 |
+
else:
|
52 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
53 |
+
|
54 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
55 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
56 |
+
|
57 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
58 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
|
59 |
+
|
60 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
61 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
|
62 |
+
|
63 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
|
64 |
+
|
65 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
66 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
67 |
+
|
68 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
69 |
+
|
70 |
+
def forward(self, t, start_index = 0):
|
71 |
+
rot_dim = self.freqs_cos.shape[-1]
|
72 |
+
end_index = start_index + rot_dim
|
73 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
74 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
75 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
76 |
+
|
77 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
78 |
+
|
79 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
dim,
|
83 |
+
pt_seq_len,
|
84 |
+
ft_seq_len=None,
|
85 |
+
custom_freqs = None,
|
86 |
+
freqs_for = 'lang',
|
87 |
+
theta = 10000,
|
88 |
+
max_freq = 10,
|
89 |
+
num_freqs = 1,
|
90 |
+
patch_dropout = 0.
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
if custom_freqs:
|
94 |
+
freqs = custom_freqs
|
95 |
+
elif freqs_for == 'lang':
|
96 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
97 |
+
elif freqs_for == 'pixel':
|
98 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
99 |
+
elif freqs_for == 'constant':
|
100 |
+
freqs = torch.ones(num_freqs).float()
|
101 |
+
else:
|
102 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
103 |
+
|
104 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
105 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
106 |
+
|
107 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
108 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
109 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
110 |
+
|
111 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
112 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
113 |
+
|
114 |
+
self.patch_dropout = patch_dropout
|
115 |
+
|
116 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
117 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
118 |
+
|
119 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
120 |
+
|
121 |
+
def forward(self, t, patch_indices_keep=None):
|
122 |
+
if patch_indices_keep is not None:
|
123 |
+
batch = t.size()[0]
|
124 |
+
batch_indices = torch.arange(batch)
|
125 |
+
batch_indices = batch_indices[..., None]
|
126 |
+
|
127 |
+
freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
128 |
+
freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
129 |
+
|
130 |
+
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
|
131 |
+
freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
|
132 |
+
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
|
133 |
+
freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
|
134 |
+
|
135 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
136 |
+
|
137 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
eva_clip/timm_model.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" timm model adapter
|
2 |
+
|
3 |
+
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
4 |
+
"""
|
5 |
+
import logging
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
try:
|
12 |
+
import timm
|
13 |
+
from timm.models.layers import Mlp, to_2tuple
|
14 |
+
try:
|
15 |
+
# old timm imports < 0.8.1
|
16 |
+
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
17 |
+
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
|
18 |
+
except ImportError:
|
19 |
+
# new timm imports >= 0.8.1
|
20 |
+
from timm.layers import RotAttentionPool2d
|
21 |
+
from timm.layers import AttentionPool2d as AbsAttentionPool2d
|
22 |
+
except ImportError:
|
23 |
+
timm = None
|
24 |
+
|
25 |
+
from .utils import freeze_batch_norm_2d
|
26 |
+
|
27 |
+
|
28 |
+
class TimmModel(nn.Module):
|
29 |
+
""" timm model adapter
|
30 |
+
# FIXME this adapter is a work in progress, may change in ways that break weight compat
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
model_name,
|
36 |
+
embed_dim,
|
37 |
+
image_size=224,
|
38 |
+
pool='avg',
|
39 |
+
proj='linear',
|
40 |
+
proj_bias=False,
|
41 |
+
drop=0.,
|
42 |
+
pretrained=False):
|
43 |
+
super().__init__()
|
44 |
+
if timm is None:
|
45 |
+
# raise RuntimeError("Please `pip install timm` to use timm models.")
|
46 |
+
return
|
47 |
+
|
48 |
+
self.image_size = to_2tuple(image_size)
|
49 |
+
self.trunk = timm.create_model(model_name, pretrained=pretrained)
|
50 |
+
feat_size = self.trunk.default_cfg.get('pool_size', None)
|
51 |
+
feature_ndim = 1 if not feat_size else 2
|
52 |
+
if pool in ('abs_attn', 'rot_attn'):
|
53 |
+
assert feature_ndim == 2
|
54 |
+
# if attn pooling used, remove both classifier and default pool
|
55 |
+
self.trunk.reset_classifier(0, global_pool='')
|
56 |
+
else:
|
57 |
+
# reset global pool if pool config set, otherwise leave as network default
|
58 |
+
reset_kwargs = dict(global_pool=pool) if pool else {}
|
59 |
+
self.trunk.reset_classifier(0, **reset_kwargs)
|
60 |
+
prev_chs = self.trunk.num_features
|
61 |
+
|
62 |
+
head_layers = OrderedDict()
|
63 |
+
if pool == 'abs_attn':
|
64 |
+
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
|
65 |
+
prev_chs = embed_dim
|
66 |
+
elif pool == 'rot_attn':
|
67 |
+
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
68 |
+
prev_chs = embed_dim
|
69 |
+
else:
|
70 |
+
assert proj, 'projection layer needed if non-attention pooling is used.'
|
71 |
+
|
72 |
+
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
|
73 |
+
if proj == 'linear':
|
74 |
+
head_layers['drop'] = nn.Dropout(drop)
|
75 |
+
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
|
76 |
+
elif proj == 'mlp':
|
77 |
+
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
|
78 |
+
|
79 |
+
self.head = nn.Sequential(head_layers)
|
80 |
+
|
81 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
82 |
+
""" lock modules
|
83 |
+
Args:
|
84 |
+
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
85 |
+
"""
|
86 |
+
if not unlocked_groups:
|
87 |
+
# lock full model
|
88 |
+
for param in self.trunk.parameters():
|
89 |
+
param.requires_grad = False
|
90 |
+
if freeze_bn_stats:
|
91 |
+
freeze_batch_norm_2d(self.trunk)
|
92 |
+
else:
|
93 |
+
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
94 |
+
try:
|
95 |
+
# FIXME import here until API stable and in an official release
|
96 |
+
from timm.models.helpers import group_parameters, group_modules
|
97 |
+
except ImportError:
|
98 |
+
raise RuntimeError(
|
99 |
+
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
|
100 |
+
matcher = self.trunk.group_matcher()
|
101 |
+
gparams = group_parameters(self.trunk, matcher)
|
102 |
+
max_layer_id = max(gparams.keys())
|
103 |
+
max_layer_id = max_layer_id - unlocked_groups
|
104 |
+
for group_idx in range(max_layer_id + 1):
|
105 |
+
group = gparams[group_idx]
|
106 |
+
for param in group:
|
107 |
+
self.trunk.get_parameter(param).requires_grad = False
|
108 |
+
if freeze_bn_stats:
|
109 |
+
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
110 |
+
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
111 |
+
freeze_batch_norm_2d(self.trunk, gmodules)
|
112 |
+
|
113 |
+
@torch.jit.ignore
|
114 |
+
def set_grad_checkpointing(self, enable=True):
|
115 |
+
try:
|
116 |
+
self.trunk.set_grad_checkpointing(enable)
|
117 |
+
except Exception as e:
|
118 |
+
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
x = self.trunk(x)
|
122 |
+
x = self.head(x)
|
123 |
+
return x
|
eva_clip/tokenizer.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP tokenizer
|
2 |
+
|
3 |
+
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
import gzip
|
6 |
+
import html
|
7 |
+
import os
|
8 |
+
from functools import lru_cache
|
9 |
+
from typing import Union, List
|
10 |
+
|
11 |
+
import ftfy
|
12 |
+
import regex as re
|
13 |
+
import torch
|
14 |
+
|
15 |
+
# https://stackoverflow.com/q/62691279
|
16 |
+
import os
|
17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
18 |
+
|
19 |
+
|
20 |
+
@lru_cache()
|
21 |
+
def default_bpe():
|
22 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
23 |
+
|
24 |
+
|
25 |
+
@lru_cache()
|
26 |
+
def bytes_to_unicode():
|
27 |
+
"""
|
28 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
29 |
+
The reversible bpe codes work on unicode strings.
|
30 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
31 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
32 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
33 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
34 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
35 |
+
"""
|
36 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
37 |
+
cs = bs[:]
|
38 |
+
n = 0
|
39 |
+
for b in range(2**8):
|
40 |
+
if b not in bs:
|
41 |
+
bs.append(b)
|
42 |
+
cs.append(2**8+n)
|
43 |
+
n += 1
|
44 |
+
cs = [chr(n) for n in cs]
|
45 |
+
return dict(zip(bs, cs))
|
46 |
+
|
47 |
+
|
48 |
+
def get_pairs(word):
|
49 |
+
"""Return set of symbol pairs in a word.
|
50 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
51 |
+
"""
|
52 |
+
pairs = set()
|
53 |
+
prev_char = word[0]
|
54 |
+
for char in word[1:]:
|
55 |
+
pairs.add((prev_char, char))
|
56 |
+
prev_char = char
|
57 |
+
return pairs
|
58 |
+
|
59 |
+
|
60 |
+
def basic_clean(text):
|
61 |
+
text = ftfy.fix_text(text)
|
62 |
+
text = html.unescape(html.unescape(text))
|
63 |
+
return text.strip()
|
64 |
+
|
65 |
+
|
66 |
+
def whitespace_clean(text):
|
67 |
+
text = re.sub(r'\s+', ' ', text)
|
68 |
+
text = text.strip()
|
69 |
+
return text
|
70 |
+
|
71 |
+
|
72 |
+
class SimpleTokenizer(object):
|
73 |
+
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
74 |
+
self.byte_encoder = bytes_to_unicode()
|
75 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
76 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
77 |
+
merges = merges[1:49152-256-2+1]
|
78 |
+
merges = [tuple(merge.split()) for merge in merges]
|
79 |
+
vocab = list(bytes_to_unicode().values())
|
80 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
81 |
+
for merge in merges:
|
82 |
+
vocab.append(''.join(merge))
|
83 |
+
if not special_tokens:
|
84 |
+
special_tokens = ['<start_of_text>', '<end_of_text>']
|
85 |
+
else:
|
86 |
+
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
87 |
+
vocab.extend(special_tokens)
|
88 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
89 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
90 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
91 |
+
self.cache = {t:t for t in special_tokens}
|
92 |
+
special = "|".join(special_tokens)
|
93 |
+
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
94 |
+
|
95 |
+
self.vocab_size = len(self.encoder)
|
96 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
97 |
+
|
98 |
+
def bpe(self, token):
|
99 |
+
if token in self.cache:
|
100 |
+
return self.cache[token]
|
101 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
102 |
+
pairs = get_pairs(word)
|
103 |
+
|
104 |
+
if not pairs:
|
105 |
+
return token+'</w>'
|
106 |
+
|
107 |
+
while True:
|
108 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
109 |
+
if bigram not in self.bpe_ranks:
|
110 |
+
break
|
111 |
+
first, second = bigram
|
112 |
+
new_word = []
|
113 |
+
i = 0
|
114 |
+
while i < len(word):
|
115 |
+
try:
|
116 |
+
j = word.index(first, i)
|
117 |
+
new_word.extend(word[i:j])
|
118 |
+
i = j
|
119 |
+
except:
|
120 |
+
new_word.extend(word[i:])
|
121 |
+
break
|
122 |
+
|
123 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
124 |
+
new_word.append(first+second)
|
125 |
+
i += 2
|
126 |
+
else:
|
127 |
+
new_word.append(word[i])
|
128 |
+
i += 1
|
129 |
+
new_word = tuple(new_word)
|
130 |
+
word = new_word
|
131 |
+
if len(word) == 1:
|
132 |
+
break
|
133 |
+
else:
|
134 |
+
pairs = get_pairs(word)
|
135 |
+
word = ' '.join(word)
|
136 |
+
self.cache[token] = word
|
137 |
+
return word
|
138 |
+
|
139 |
+
def encode(self, text):
|
140 |
+
bpe_tokens = []
|
141 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
142 |
+
for token in re.findall(self.pat, text):
|
143 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
144 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
145 |
+
return bpe_tokens
|
146 |
+
|
147 |
+
def decode(self, tokens):
|
148 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
149 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
150 |
+
return text
|
151 |
+
|
152 |
+
|
153 |
+
_tokenizer = SimpleTokenizer()
|
154 |
+
|
155 |
+
|
156 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
157 |
+
"""
|
158 |
+
Returns the tokenized representation of given input string(s)
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
texts : Union[str, List[str]]
|
163 |
+
An input string or a list of input strings to tokenize
|
164 |
+
context_length : int
|
165 |
+
The context length to use; all CLIP models use 77 as the context length
|
166 |
+
|
167 |
+
Returns
|
168 |
+
-------
|
169 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
170 |
+
"""
|
171 |
+
if isinstance(texts, str):
|
172 |
+
texts = [texts]
|
173 |
+
|
174 |
+
sot_token = _tokenizer.encoder["<start_of_text>"]
|
175 |
+
eot_token = _tokenizer.encoder["<end_of_text>"]
|
176 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
177 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
178 |
+
|
179 |
+
for i, tokens in enumerate(all_tokens):
|
180 |
+
if len(tokens) > context_length:
|
181 |
+
tokens = tokens[:context_length] # Truncate
|
182 |
+
tokens[-1] = eot_token
|
183 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
184 |
+
|
185 |
+
return result
|
186 |
+
|
187 |
+
|
188 |
+
class HFTokenizer:
|
189 |
+
"HuggingFace tokenizer wrapper"
|
190 |
+
def __init__(self, tokenizer_name:str):
|
191 |
+
from transformers import AutoTokenizer
|
192 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
193 |
+
|
194 |
+
def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
|
195 |
+
# same cleaning as for default tokenizer, except lowercasing
|
196 |
+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
197 |
+
if isinstance(texts, str):
|
198 |
+
texts = [texts]
|
199 |
+
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
200 |
+
input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
|
201 |
+
return input_ids
|
eva_clip/transform.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Sequence, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torchvision.transforms.functional as F
|
6 |
+
|
7 |
+
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
8 |
+
CenterCrop
|
9 |
+
|
10 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
11 |
+
|
12 |
+
|
13 |
+
class ResizeMaxSize(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
16 |
+
super().__init__()
|
17 |
+
if not isinstance(max_size, int):
|
18 |
+
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
19 |
+
self.max_size = max_size
|
20 |
+
self.interpolation = interpolation
|
21 |
+
self.fn = min if fn == 'min' else min
|
22 |
+
self.fill = fill
|
23 |
+
|
24 |
+
def forward(self, img):
|
25 |
+
if isinstance(img, torch.Tensor):
|
26 |
+
height, width = img.shape[:2]
|
27 |
+
else:
|
28 |
+
width, height = img.size
|
29 |
+
scale = self.max_size / float(max(height, width))
|
30 |
+
if scale != 1.0:
|
31 |
+
new_size = tuple(round(dim * scale) for dim in (height, width))
|
32 |
+
img = F.resize(img, new_size, self.interpolation)
|
33 |
+
pad_h = self.max_size - new_size[0]
|
34 |
+
pad_w = self.max_size - new_size[1]
|
35 |
+
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
36 |
+
return img
|
37 |
+
|
38 |
+
|
39 |
+
def _convert_to_rgb(image):
|
40 |
+
return image.convert('RGB')
|
41 |
+
|
42 |
+
|
43 |
+
# class CatGen(nn.Module):
|
44 |
+
# def __init__(self, num=4):
|
45 |
+
# self.num = num
|
46 |
+
# def mixgen_batch(image, text):
|
47 |
+
# batch_size = image.shape[0]
|
48 |
+
# index = np.random.permutation(batch_size)
|
49 |
+
|
50 |
+
# cat_images = []
|
51 |
+
# for i in range(batch_size):
|
52 |
+
# # image mixup
|
53 |
+
# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
|
54 |
+
# # text concat
|
55 |
+
# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
|
56 |
+
# text = torch.stack(text)
|
57 |
+
# return image, text
|
58 |
+
|
59 |
+
|
60 |
+
def image_transform(
|
61 |
+
image_size: int,
|
62 |
+
is_train: bool,
|
63 |
+
mean: Optional[Tuple[float, ...]] = None,
|
64 |
+
std: Optional[Tuple[float, ...]] = None,
|
65 |
+
resize_longest_max: bool = False,
|
66 |
+
fill_color: int = 0,
|
67 |
+
):
|
68 |
+
mean = mean or OPENAI_DATASET_MEAN
|
69 |
+
if not isinstance(mean, (list, tuple)):
|
70 |
+
mean = (mean,) * 3
|
71 |
+
|
72 |
+
std = std or OPENAI_DATASET_STD
|
73 |
+
if not isinstance(std, (list, tuple)):
|
74 |
+
std = (std,) * 3
|
75 |
+
|
76 |
+
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
77 |
+
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
78 |
+
image_size = image_size[0]
|
79 |
+
|
80 |
+
normalize = Normalize(mean=mean, std=std)
|
81 |
+
if is_train:
|
82 |
+
return Compose([
|
83 |
+
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
|
84 |
+
_convert_to_rgb,
|
85 |
+
ToTensor(),
|
86 |
+
normalize,
|
87 |
+
])
|
88 |
+
else:
|
89 |
+
if resize_longest_max:
|
90 |
+
transforms = [
|
91 |
+
ResizeMaxSize(image_size, fill=fill_color)
|
92 |
+
]
|
93 |
+
else:
|
94 |
+
transforms = [
|
95 |
+
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
96 |
+
CenterCrop(image_size),
|
97 |
+
]
|
98 |
+
transforms.extend([
|
99 |
+
_convert_to_rgb,
|
100 |
+
ToTensor(),
|
101 |
+
normalize,
|
102 |
+
])
|
103 |
+
return Compose(transforms)
|
eva_clip/transformer.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from collections import OrderedDict
|
4 |
+
import math
|
5 |
+
import warnings
|
6 |
+
from typing import Callable, Optional, Sequence
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
13 |
+
from .utils import to_2tuple
|
14 |
+
|
15 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
16 |
+
try:
|
17 |
+
import deepspeed
|
18 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
19 |
+
except:
|
20 |
+
print("Please 'pip install deepspeed'")
|
21 |
+
deepspeed = None
|
22 |
+
from torch.utils.checkpoint import checkpoint
|
23 |
+
else:
|
24 |
+
from torch.utils.checkpoint import checkpoint
|
25 |
+
|
26 |
+
try:
|
27 |
+
import xformers.ops as xops
|
28 |
+
except ImportError:
|
29 |
+
xops = None
|
30 |
+
print("Please 'pip install xformers'")
|
31 |
+
|
32 |
+
|
33 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
34 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
35 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
36 |
+
def norm_cdf(x):
|
37 |
+
# Computes standard normal cumulative distribution function
|
38 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
39 |
+
|
40 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
41 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
42 |
+
"The distribution of values may be incorrect.",
|
43 |
+
stacklevel=2)
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
# Values are generated by using a truncated uniform distribution and
|
47 |
+
# then using the inverse CDF for the normal distribution.
|
48 |
+
# Get upper and lower cdf values
|
49 |
+
l = norm_cdf((a - mean) / std)
|
50 |
+
u = norm_cdf((b - mean) / std)
|
51 |
+
|
52 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
53 |
+
# [2l-1, 2u-1].
|
54 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
55 |
+
|
56 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
57 |
+
# standard normal
|
58 |
+
tensor.erfinv_()
|
59 |
+
|
60 |
+
# Transform to proper mean, std
|
61 |
+
tensor.mul_(std * math.sqrt(2.))
|
62 |
+
tensor.add_(mean)
|
63 |
+
|
64 |
+
# Clamp to ensure it's in the proper range
|
65 |
+
tensor.clamp_(min=a, max=b)
|
66 |
+
return tensor
|
67 |
+
|
68 |
+
|
69 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
70 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
71 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
72 |
+
normal distribution. The values are effectively drawn from the
|
73 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
74 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
75 |
+
the bounds. The method used for generating the random values works
|
76 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
77 |
+
Args:
|
78 |
+
tensor: an n-dimensional `torch.Tensor`
|
79 |
+
mean: the mean of the normal distribution
|
80 |
+
std: the standard deviation of the normal distribution
|
81 |
+
a: the minimum cutoff value
|
82 |
+
b: the maximum cutoff value
|
83 |
+
Examples:
|
84 |
+
>>> w = torch.empty(3, 5)
|
85 |
+
>>> nn.init.trunc_normal_(w)
|
86 |
+
"""
|
87 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
class LayerNormFp32(nn.LayerNorm):
|
92 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
93 |
+
def __init__(self, *args, **kwargs):
|
94 |
+
super().__init__(*args, **kwargs)
|
95 |
+
|
96 |
+
def forward(self, x: torch.Tensor):
|
97 |
+
output = F.layer_norm(
|
98 |
+
x.float(),
|
99 |
+
self.normalized_shape,
|
100 |
+
self.weight.float() if self.weight is not None else None,
|
101 |
+
self.bias.float() if self.bias is not None else None,
|
102 |
+
self.eps,
|
103 |
+
)
|
104 |
+
return output.type_as(x)
|
105 |
+
|
106 |
+
|
107 |
+
class LayerNorm(nn.LayerNorm):
|
108 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
109 |
+
|
110 |
+
def forward(self, x: torch.Tensor):
|
111 |
+
orig_type = x.dtype
|
112 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
113 |
+
return x.to(orig_type)
|
114 |
+
|
115 |
+
class QuickGELU(nn.Module):
|
116 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
117 |
+
def forward(self, x: torch.Tensor):
|
118 |
+
return x * torch.sigmoid(1.702 * x)
|
119 |
+
|
120 |
+
|
121 |
+
class LayerScale(nn.Module):
|
122 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
123 |
+
super().__init__()
|
124 |
+
self.inplace = inplace
|
125 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
129 |
+
|
130 |
+
class PatchDropout(nn.Module):
|
131 |
+
"""
|
132 |
+
https://arxiv.org/abs/2212.00794
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, prob, exclude_first_token=True):
|
136 |
+
super().__init__()
|
137 |
+
assert 0 <= prob < 1.
|
138 |
+
self.prob = prob
|
139 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
140 |
+
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
if not self.training or self.prob == 0.:
|
144 |
+
return x
|
145 |
+
|
146 |
+
if self.exclude_first_token:
|
147 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
148 |
+
else:
|
149 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
150 |
+
|
151 |
+
batch = x.size()[0]
|
152 |
+
num_tokens = x.size()[1]
|
153 |
+
|
154 |
+
batch_indices = torch.arange(batch)
|
155 |
+
batch_indices = batch_indices[..., None]
|
156 |
+
|
157 |
+
keep_prob = 1 - self.prob
|
158 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
159 |
+
|
160 |
+
rand = torch.randn(batch, num_tokens)
|
161 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
162 |
+
|
163 |
+
x = x[batch_indices, patch_indices_keep]
|
164 |
+
|
165 |
+
if self.exclude_first_token:
|
166 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
167 |
+
|
168 |
+
if self.training and os.getenv('RoPE') == '1':
|
169 |
+
return x, patch_indices_keep
|
170 |
+
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
def _in_projection_packed(
|
175 |
+
q: torch.Tensor,
|
176 |
+
k: torch.Tensor,
|
177 |
+
v: torch.Tensor,
|
178 |
+
w: torch.Tensor,
|
179 |
+
b: Optional[torch.Tensor] = None,
|
180 |
+
):
|
181 |
+
"""
|
182 |
+
https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
|
183 |
+
"""
|
184 |
+
E = q.size(-1)
|
185 |
+
if k is v:
|
186 |
+
if q is k:
|
187 |
+
# self-attention
|
188 |
+
return F.linear(q, w, b).chunk(3, dim=-1)
|
189 |
+
else:
|
190 |
+
# encoder-decoder attention
|
191 |
+
w_q, w_kv = w.split([E, E * 2])
|
192 |
+
if b is None:
|
193 |
+
b_q = b_kv = None
|
194 |
+
else:
|
195 |
+
b_q, b_kv = b.split([E, E * 2])
|
196 |
+
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
|
197 |
+
else:
|
198 |
+
w_q, w_k, w_v = w.chunk(3)
|
199 |
+
if b is None:
|
200 |
+
b_q = b_k = b_v = None
|
201 |
+
else:
|
202 |
+
b_q, b_k, b_v = b.chunk(3)
|
203 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
204 |
+
|
205 |
+
class Attention(nn.Module):
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
dim,
|
209 |
+
num_heads=8,
|
210 |
+
qkv_bias=True,
|
211 |
+
scaled_cosine=False,
|
212 |
+
scale_heads=False,
|
213 |
+
logit_scale_max=math.log(1. / 0.01),
|
214 |
+
attn_drop=0.,
|
215 |
+
proj_drop=0.,
|
216 |
+
xattn=False,
|
217 |
+
rope=False
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
self.scaled_cosine = scaled_cosine
|
221 |
+
self.scale_heads = scale_heads
|
222 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
223 |
+
self.num_heads = num_heads
|
224 |
+
self.head_dim = dim // num_heads
|
225 |
+
self.scale = self.head_dim ** -0.5
|
226 |
+
self.logit_scale_max = logit_scale_max
|
227 |
+
|
228 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
229 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
230 |
+
if qkv_bias:
|
231 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
232 |
+
else:
|
233 |
+
self.in_proj_bias = None
|
234 |
+
|
235 |
+
if self.scaled_cosine:
|
236 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
237 |
+
else:
|
238 |
+
self.logit_scale = None
|
239 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
240 |
+
if self.scale_heads:
|
241 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
242 |
+
else:
|
243 |
+
self.head_scale = None
|
244 |
+
self.out_proj = nn.Linear(dim, dim)
|
245 |
+
self.out_drop = nn.Dropout(proj_drop)
|
246 |
+
self.xattn = xattn
|
247 |
+
self.xattn_drop = attn_drop
|
248 |
+
self.rope = rope
|
249 |
+
|
250 |
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
251 |
+
L, N, C = x.shape
|
252 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
253 |
+
if self.xattn:
|
254 |
+
q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
255 |
+
k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
256 |
+
v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
257 |
+
|
258 |
+
x = xops.memory_efficient_attention(
|
259 |
+
q, k, v,
|
260 |
+
p=self.xattn_drop,
|
261 |
+
scale=self.scale if self.logit_scale is None else None,
|
262 |
+
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
266 |
+
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
267 |
+
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
268 |
+
|
269 |
+
if self.logit_scale is not None:
|
270 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
271 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
272 |
+
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
273 |
+
attn = attn.view(-1, L, L)
|
274 |
+
else:
|
275 |
+
q = q * self.scale
|
276 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
277 |
+
|
278 |
+
if attn_mask is not None:
|
279 |
+
if attn_mask.dtype == torch.bool:
|
280 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
281 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
282 |
+
attn_mask = new_attn_mask
|
283 |
+
attn += attn_mask
|
284 |
+
|
285 |
+
attn = attn.softmax(dim=-1)
|
286 |
+
attn = self.attn_drop(attn)
|
287 |
+
|
288 |
+
x = torch.bmm(attn, v)
|
289 |
+
|
290 |
+
if self.head_scale is not None:
|
291 |
+
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
292 |
+
x = x.view(-1, L, C)
|
293 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
294 |
+
x = self.out_proj(x)
|
295 |
+
x = self.out_drop(x)
|
296 |
+
return x
|
297 |
+
|
298 |
+
class CustomAttention(nn.Module):
|
299 |
+
def __init__(
|
300 |
+
self,
|
301 |
+
dim,
|
302 |
+
num_heads=8,
|
303 |
+
qkv_bias=True,
|
304 |
+
scaled_cosine=True,
|
305 |
+
scale_heads=False,
|
306 |
+
logit_scale_max=math.log(1. / 0.01),
|
307 |
+
attn_drop=0.,
|
308 |
+
proj_drop=0.,
|
309 |
+
xattn=False
|
310 |
+
):
|
311 |
+
super().__init__()
|
312 |
+
self.scaled_cosine = scaled_cosine
|
313 |
+
self.scale_heads = scale_heads
|
314 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
315 |
+
self.num_heads = num_heads
|
316 |
+
self.head_dim = dim // num_heads
|
317 |
+
self.scale = self.head_dim ** -0.5
|
318 |
+
self.logit_scale_max = logit_scale_max
|
319 |
+
|
320 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
321 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
322 |
+
if qkv_bias:
|
323 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
324 |
+
else:
|
325 |
+
self.in_proj_bias = None
|
326 |
+
|
327 |
+
if self.scaled_cosine:
|
328 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
329 |
+
else:
|
330 |
+
self.logit_scale = None
|
331 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
332 |
+
if self.scale_heads:
|
333 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
334 |
+
else:
|
335 |
+
self.head_scale = None
|
336 |
+
self.out_proj = nn.Linear(dim, dim)
|
337 |
+
self.out_drop = nn.Dropout(proj_drop)
|
338 |
+
self.xattn = xattn
|
339 |
+
self.xattn_drop = attn_drop
|
340 |
+
|
341 |
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
342 |
+
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
|
343 |
+
N_q, B_q, C_q = q.shape
|
344 |
+
N_k, B_k, C_k = k.shape
|
345 |
+
N_v, B_v, C_v = v.shape
|
346 |
+
if self.xattn:
|
347 |
+
# B, N, C -> B, N, num_heads, C
|
348 |
+
q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
|
349 |
+
k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
|
350 |
+
v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
|
351 |
+
|
352 |
+
x = xops.memory_efficient_attention(
|
353 |
+
q, k, v,
|
354 |
+
p=self.xattn_drop,
|
355 |
+
scale=self.scale if self.logit_scale is None else None,
|
356 |
+
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
|
357 |
+
)
|
358 |
+
else:
|
359 |
+
# B*H, L, C
|
360 |
+
q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
|
361 |
+
k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
|
362 |
+
v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
|
363 |
+
|
364 |
+
if self.logit_scale is not None:
|
365 |
+
# B*H, N_q, N_k
|
366 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
367 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
368 |
+
attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
|
369 |
+
attn = attn.view(-1, N_q, N_k)
|
370 |
+
else:
|
371 |
+
q = q * self.scale
|
372 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
373 |
+
|
374 |
+
if attn_mask is not None:
|
375 |
+
if attn_mask.dtype == torch.bool:
|
376 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
377 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
378 |
+
attn_mask = new_attn_mask
|
379 |
+
attn += attn_mask
|
380 |
+
|
381 |
+
attn = attn.softmax(dim=-1)
|
382 |
+
attn = self.attn_drop(attn)
|
383 |
+
|
384 |
+
x = torch.bmm(attn, v)
|
385 |
+
|
386 |
+
if self.head_scale is not None:
|
387 |
+
x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
|
388 |
+
x = x.view(-1, N_q, C_q)
|
389 |
+
x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
|
390 |
+
x = self.out_proj(x)
|
391 |
+
x = self.out_drop(x)
|
392 |
+
return x
|
393 |
+
|
394 |
+
class CustomResidualAttentionBlock(nn.Module):
|
395 |
+
def __init__(
|
396 |
+
self,
|
397 |
+
d_model: int,
|
398 |
+
n_head: int,
|
399 |
+
mlp_ratio: float = 4.0,
|
400 |
+
ls_init_value: float = None,
|
401 |
+
act_layer: Callable = nn.GELU,
|
402 |
+
norm_layer: Callable = LayerNorm,
|
403 |
+
scale_cosine_attn: bool = False,
|
404 |
+
scale_heads: bool = False,
|
405 |
+
scale_attn: bool = False,
|
406 |
+
scale_fc: bool = False,
|
407 |
+
cross_attn: bool = False,
|
408 |
+
xattn: bool = False,
|
409 |
+
):
|
410 |
+
super().__init__()
|
411 |
+
|
412 |
+
self.ln_1 = norm_layer(d_model)
|
413 |
+
self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
|
414 |
+
self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
|
415 |
+
self.attn = CustomAttention(
|
416 |
+
d_model, n_head,
|
417 |
+
qkv_bias=True,
|
418 |
+
attn_drop=0.,
|
419 |
+
proj_drop=0.,
|
420 |
+
scaled_cosine=scale_cosine_attn,
|
421 |
+
scale_heads=scale_heads,
|
422 |
+
xattn=xattn
|
423 |
+
)
|
424 |
+
|
425 |
+
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
426 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
427 |
+
|
428 |
+
self.ln_2 = norm_layer(d_model)
|
429 |
+
mlp_width = int(d_model * mlp_ratio)
|
430 |
+
self.mlp = nn.Sequential(OrderedDict([
|
431 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
432 |
+
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
433 |
+
("gelu", act_layer()),
|
434 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
435 |
+
]))
|
436 |
+
|
437 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
438 |
+
|
439 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
440 |
+
q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
|
441 |
+
q = q + self.ls_2(self.mlp(self.ln_2(q)))
|
442 |
+
return q
|
443 |
+
|
444 |
+
class CustomTransformer(nn.Module):
|
445 |
+
def __init__(
|
446 |
+
self,
|
447 |
+
width: int,
|
448 |
+
layers: int,
|
449 |
+
heads: int,
|
450 |
+
mlp_ratio: float = 4.0,
|
451 |
+
ls_init_value: float = None,
|
452 |
+
act_layer: Callable = nn.GELU,
|
453 |
+
norm_layer: Callable = LayerNorm,
|
454 |
+
scale_cosine_attn: bool = True,
|
455 |
+
scale_heads: bool = False,
|
456 |
+
scale_attn: bool = False,
|
457 |
+
scale_fc: bool = False,
|
458 |
+
cross_attn: bool = False,
|
459 |
+
xattn: bool = False,
|
460 |
+
):
|
461 |
+
super().__init__()
|
462 |
+
self.width = width
|
463 |
+
self.layers = layers
|
464 |
+
self.grad_checkpointing = False
|
465 |
+
self.xattn = xattn
|
466 |
+
|
467 |
+
self.resblocks = nn.ModuleList([
|
468 |
+
CustomResidualAttentionBlock(
|
469 |
+
width,
|
470 |
+
heads,
|
471 |
+
mlp_ratio,
|
472 |
+
ls_init_value=ls_init_value,
|
473 |
+
act_layer=act_layer,
|
474 |
+
norm_layer=norm_layer,
|
475 |
+
scale_cosine_attn=scale_cosine_attn,
|
476 |
+
scale_heads=scale_heads,
|
477 |
+
scale_attn=scale_attn,
|
478 |
+
scale_fc=scale_fc,
|
479 |
+
cross_attn=cross_attn,
|
480 |
+
xattn=xattn)
|
481 |
+
for _ in range(layers)
|
482 |
+
])
|
483 |
+
|
484 |
+
def get_cast_dtype(self) -> torch.dtype:
|
485 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
486 |
+
|
487 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
|
488 |
+
if k is None and v is None:
|
489 |
+
k = v = q
|
490 |
+
for r in self.resblocks:
|
491 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
492 |
+
q = checkpoint(r, q, k, v, attn_mask)
|
493 |
+
else:
|
494 |
+
q = r(q, k, v, attn_mask=attn_mask)
|
495 |
+
return q
|
496 |
+
|
497 |
+
|
498 |
+
class ResidualAttentionBlock(nn.Module):
|
499 |
+
def __init__(
|
500 |
+
self,
|
501 |
+
d_model: int,
|
502 |
+
n_head: int,
|
503 |
+
mlp_ratio: float = 4.0,
|
504 |
+
ls_init_value: float = None,
|
505 |
+
act_layer: Callable = nn.GELU,
|
506 |
+
norm_layer: Callable = LayerNorm,
|
507 |
+
xattn: bool = False,
|
508 |
+
):
|
509 |
+
super().__init__()
|
510 |
+
|
511 |
+
self.ln_1 = norm_layer(d_model)
|
512 |
+
if xattn:
|
513 |
+
self.attn = Attention(d_model, n_head, xattn=True)
|
514 |
+
else:
|
515 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
516 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
517 |
+
|
518 |
+
self.ln_2 = norm_layer(d_model)
|
519 |
+
mlp_width = int(d_model * mlp_ratio)
|
520 |
+
self.mlp = nn.Sequential(OrderedDict([
|
521 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
522 |
+
("gelu", act_layer()),
|
523 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
524 |
+
]))
|
525 |
+
|
526 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
527 |
+
self.xattn = xattn
|
528 |
+
|
529 |
+
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
530 |
+
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
|
531 |
+
if self.xattn:
|
532 |
+
return self.attn(x, attn_mask=attn_mask)
|
533 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
534 |
+
|
535 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
536 |
+
x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
|
537 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
538 |
+
return x
|
539 |
+
|
540 |
+
class Transformer(nn.Module):
|
541 |
+
def __init__(
|
542 |
+
self,
|
543 |
+
width: int,
|
544 |
+
layers: int,
|
545 |
+
heads: int,
|
546 |
+
mlp_ratio: float = 4.0,
|
547 |
+
ls_init_value: float = None,
|
548 |
+
act_layer: Callable = nn.GELU,
|
549 |
+
norm_layer: Callable = LayerNorm,
|
550 |
+
xattn: bool = False,
|
551 |
+
):
|
552 |
+
super().__init__()
|
553 |
+
self.width = width
|
554 |
+
self.layers = layers
|
555 |
+
self.grad_checkpointing = False
|
556 |
+
|
557 |
+
self.resblocks = nn.ModuleList([
|
558 |
+
ResidualAttentionBlock(
|
559 |
+
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
|
560 |
+
for _ in range(layers)
|
561 |
+
])
|
562 |
+
|
563 |
+
def get_cast_dtype(self) -> torch.dtype:
|
564 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
565 |
+
|
566 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
567 |
+
for r in self.resblocks:
|
568 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
569 |
+
x = checkpoint(r, x, attn_mask)
|
570 |
+
else:
|
571 |
+
x = r(x, attn_mask=attn_mask)
|
572 |
+
return x
|
573 |
+
|
574 |
+
|
575 |
+
class VisionTransformer(nn.Module):
|
576 |
+
def __init__(
|
577 |
+
self,
|
578 |
+
image_size: int,
|
579 |
+
patch_size: int,
|
580 |
+
width: int,
|
581 |
+
layers: int,
|
582 |
+
heads: int,
|
583 |
+
mlp_ratio: float,
|
584 |
+
ls_init_value: float = None,
|
585 |
+
patch_dropout: float = 0.,
|
586 |
+
global_average_pool: bool = False,
|
587 |
+
output_dim: int = 512,
|
588 |
+
act_layer: Callable = nn.GELU,
|
589 |
+
norm_layer: Callable = LayerNorm,
|
590 |
+
xattn: bool = False,
|
591 |
+
):
|
592 |
+
super().__init__()
|
593 |
+
self.image_size = to_2tuple(image_size)
|
594 |
+
self.patch_size = to_2tuple(patch_size)
|
595 |
+
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
|
596 |
+
self.output_dim = output_dim
|
597 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
598 |
+
|
599 |
+
scale = width ** -0.5
|
600 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
601 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
602 |
+
|
603 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
604 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
605 |
+
self.ln_pre = norm_layer(width)
|
606 |
+
|
607 |
+
self.transformer = Transformer(
|
608 |
+
width,
|
609 |
+
layers,
|
610 |
+
heads,
|
611 |
+
mlp_ratio,
|
612 |
+
ls_init_value=ls_init_value,
|
613 |
+
act_layer=act_layer,
|
614 |
+
norm_layer=norm_layer,
|
615 |
+
xattn=xattn
|
616 |
+
)
|
617 |
+
|
618 |
+
self.global_average_pool = global_average_pool
|
619 |
+
self.ln_post = norm_layer(width)
|
620 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
621 |
+
|
622 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
623 |
+
for param in self.parameters():
|
624 |
+
param.requires_grad = False
|
625 |
+
|
626 |
+
if unlocked_groups != 0:
|
627 |
+
groups = [
|
628 |
+
[
|
629 |
+
self.conv1,
|
630 |
+
self.class_embedding,
|
631 |
+
self.positional_embedding,
|
632 |
+
self.ln_pre,
|
633 |
+
],
|
634 |
+
*self.transformer.resblocks[:-1],
|
635 |
+
[
|
636 |
+
self.transformer.resblocks[-1],
|
637 |
+
self.ln_post,
|
638 |
+
],
|
639 |
+
self.proj,
|
640 |
+
]
|
641 |
+
|
642 |
+
def _unlock(x):
|
643 |
+
if isinstance(x, Sequence):
|
644 |
+
for g in x:
|
645 |
+
_unlock(g)
|
646 |
+
else:
|
647 |
+
if isinstance(x, torch.nn.Parameter):
|
648 |
+
x.requires_grad = True
|
649 |
+
else:
|
650 |
+
for p in x.parameters():
|
651 |
+
p.requires_grad = True
|
652 |
+
|
653 |
+
_unlock(groups[-unlocked_groups:])
|
654 |
+
|
655 |
+
def get_num_layers(self):
|
656 |
+
return self.transformer.layers
|
657 |
+
|
658 |
+
@torch.jit.ignore
|
659 |
+
def set_grad_checkpointing(self, enable=True):
|
660 |
+
self.transformer.grad_checkpointing = enable
|
661 |
+
|
662 |
+
@torch.jit.ignore
|
663 |
+
def no_weight_decay(self):
|
664 |
+
return {'positional_embedding', 'class_embedding'}
|
665 |
+
|
666 |
+
def forward(self, x: torch.Tensor, return_all_features: bool=False):
|
667 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
668 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
669 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
670 |
+
x = torch.cat(
|
671 |
+
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
672 |
+
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
673 |
+
x = x + self.positional_embedding.to(x.dtype)
|
674 |
+
|
675 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
676 |
+
x = self.patch_dropout(x)
|
677 |
+
x = self.ln_pre(x)
|
678 |
+
|
679 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
680 |
+
x = self.transformer(x)
|
681 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
682 |
+
|
683 |
+
if not return_all_features:
|
684 |
+
if self.global_average_pool:
|
685 |
+
x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)
|
686 |
+
else:
|
687 |
+
x = x[:, 0]
|
688 |
+
|
689 |
+
x = self.ln_post(x)
|
690 |
+
|
691 |
+
if self.proj is not None:
|
692 |
+
x = x @ self.proj
|
693 |
+
|
694 |
+
return x
|
695 |
+
|
696 |
+
|
697 |
+
class TextTransformer(nn.Module):
|
698 |
+
def __init__(
|
699 |
+
self,
|
700 |
+
context_length: int = 77,
|
701 |
+
vocab_size: int = 49408,
|
702 |
+
width: int = 512,
|
703 |
+
heads: int = 8,
|
704 |
+
layers: int = 12,
|
705 |
+
ls_init_value: float = None,
|
706 |
+
output_dim: int = 512,
|
707 |
+
act_layer: Callable = nn.GELU,
|
708 |
+
norm_layer: Callable = LayerNorm,
|
709 |
+
xattn: bool= False,
|
710 |
+
attn_mask: bool = True
|
711 |
+
):
|
712 |
+
super().__init__()
|
713 |
+
self.context_length = context_length
|
714 |
+
self.vocab_size = vocab_size
|
715 |
+
self.width = width
|
716 |
+
self.output_dim = output_dim
|
717 |
+
|
718 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
719 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
|
720 |
+
self.transformer = Transformer(
|
721 |
+
width=width,
|
722 |
+
layers=layers,
|
723 |
+
heads=heads,
|
724 |
+
ls_init_value=ls_init_value,
|
725 |
+
act_layer=act_layer,
|
726 |
+
norm_layer=norm_layer,
|
727 |
+
xattn=xattn
|
728 |
+
)
|
729 |
+
|
730 |
+
self.xattn = xattn
|
731 |
+
self.ln_final = norm_layer(width)
|
732 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
733 |
+
|
734 |
+
if attn_mask:
|
735 |
+
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
736 |
+
else:
|
737 |
+
self.attn_mask = None
|
738 |
+
|
739 |
+
self.init_parameters()
|
740 |
+
|
741 |
+
def init_parameters(self):
|
742 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
743 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
744 |
+
|
745 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
746 |
+
attn_std = self.transformer.width ** -0.5
|
747 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
748 |
+
for block in self.transformer.resblocks:
|
749 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
750 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
751 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
752 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
753 |
+
|
754 |
+
if self.text_projection is not None:
|
755 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
756 |
+
|
757 |
+
@torch.jit.ignore
|
758 |
+
def set_grad_checkpointing(self, enable=True):
|
759 |
+
self.transformer.grad_checkpointing = enable
|
760 |
+
|
761 |
+
@torch.jit.ignore
|
762 |
+
def no_weight_decay(self):
|
763 |
+
# return {'positional_embedding', 'token_embedding'}
|
764 |
+
return {'positional_embedding'}
|
765 |
+
|
766 |
+
def get_num_layers(self):
|
767 |
+
return self.transformer.layers
|
768 |
+
|
769 |
+
def build_attention_mask(self):
|
770 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
771 |
+
# pytorch uses additive attention mask; fill with -inf
|
772 |
+
mask = torch.empty(self.context_length, self.context_length)
|
773 |
+
mask.fill_(float("-inf"))
|
774 |
+
mask.triu_(1) # zero out the lower diagonal
|
775 |
+
return mask
|
776 |
+
|
777 |
+
def forward(self, text, return_all_features: bool=False):
|
778 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
779 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
780 |
+
|
781 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
782 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
783 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
784 |
+
# x = self.transformer(x) # no attention mask is applied
|
785 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
786 |
+
x = self.ln_final(x)
|
787 |
+
|
788 |
+
if not return_all_features:
|
789 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
790 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
791 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
792 |
+
return x
|
eva_clip/utils.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import repeat
|
2 |
+
import collections.abc
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn as nn
|
9 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
# open CLIP
|
13 |
+
def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
14 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
15 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
16 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
17 |
+
return
|
18 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
19 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
20 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
21 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
22 |
+
return
|
23 |
+
|
24 |
+
if extra_tokens:
|
25 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
26 |
+
else:
|
27 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
28 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
29 |
+
|
30 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
31 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
32 |
+
pos_emb_img = F.interpolate(
|
33 |
+
pos_emb_img,
|
34 |
+
size=grid_size,
|
35 |
+
mode=interpolation,
|
36 |
+
align_corners=True,
|
37 |
+
)
|
38 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
39 |
+
if pos_emb_tok is not None:
|
40 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
41 |
+
else:
|
42 |
+
new_pos_embed = pos_emb_img
|
43 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
44 |
+
|
45 |
+
|
46 |
+
def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
47 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
48 |
+
old_pos_embed = state_dict.get('positional_embedding', None)
|
49 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
50 |
+
return
|
51 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
52 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
53 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
54 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
55 |
+
return
|
56 |
+
|
57 |
+
if extra_tokens:
|
58 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
59 |
+
else:
|
60 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
61 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
62 |
+
|
63 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
64 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
65 |
+
pos_emb_img = F.interpolate(
|
66 |
+
pos_emb_img,
|
67 |
+
size=grid_size,
|
68 |
+
mode=interpolation,
|
69 |
+
align_corners=True,
|
70 |
+
)
|
71 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
72 |
+
if pos_emb_tok is not None:
|
73 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
74 |
+
else:
|
75 |
+
new_pos_embed = pos_emb_img
|
76 |
+
state_dict['positional_embedding'] = new_pos_embed
|
77 |
+
|
78 |
+
def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
79 |
+
all_keys = list(state_dict.keys())
|
80 |
+
# interpolate position embedding
|
81 |
+
if 'visual.pos_embed' in state_dict:
|
82 |
+
pos_embed_checkpoint = state_dict['visual.pos_embed']
|
83 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
84 |
+
num_patches = model.visual.patch_embed.num_patches
|
85 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
86 |
+
# height (== width) for the checkpoint position embedding
|
87 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
88 |
+
# height (== width) for the new position embedding
|
89 |
+
new_size = int(num_patches ** 0.5)
|
90 |
+
# class_token and dist_token are kept unchanged
|
91 |
+
if orig_size != new_size:
|
92 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
93 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
94 |
+
# only the position tokens are interpolated
|
95 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
96 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
97 |
+
pos_tokens = torch.nn.functional.interpolate(
|
98 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
99 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
100 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
101 |
+
state_dict['visual.pos_embed'] = new_pos_embed
|
102 |
+
|
103 |
+
patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
|
104 |
+
patch_size = model.visual.patch_embed.patch_size
|
105 |
+
state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
|
106 |
+
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
|
107 |
+
|
108 |
+
|
109 |
+
def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
110 |
+
all_keys = list(state_dict.keys())
|
111 |
+
# interpolate position embedding
|
112 |
+
if 'pos_embed' in state_dict:
|
113 |
+
pos_embed_checkpoint = state_dict['pos_embed']
|
114 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
115 |
+
num_patches = model.visual.patch_embed.num_patches
|
116 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
117 |
+
# height (== width) for the checkpoint position embedding
|
118 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
119 |
+
# height (== width) for the new position embedding
|
120 |
+
new_size = int(num_patches ** 0.5)
|
121 |
+
# class_token and dist_token are kept unchanged
|
122 |
+
if orig_size != new_size:
|
123 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
124 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
125 |
+
# only the position tokens are interpolated
|
126 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
127 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
128 |
+
pos_tokens = torch.nn.functional.interpolate(
|
129 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
130 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
131 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
132 |
+
state_dict['pos_embed'] = new_pos_embed
|
133 |
+
|
134 |
+
patch_embed_proj = state_dict['patch_embed.proj.weight']
|
135 |
+
patch_size = model.visual.patch_embed.patch_size
|
136 |
+
state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
|
137 |
+
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
|
138 |
+
|
139 |
+
|
140 |
+
def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
141 |
+
all_keys = list(state_dict.keys())
|
142 |
+
for key in all_keys:
|
143 |
+
if "relative_position_index" in key:
|
144 |
+
state_dict.pop(key)
|
145 |
+
|
146 |
+
if "relative_position_bias_table" in key:
|
147 |
+
rel_pos_bias = state_dict[key]
|
148 |
+
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
149 |
+
dst_num_pos, _ = model.visual.state_dict()[key].size()
|
150 |
+
dst_patch_shape = model.visual.patch_embed.patch_shape
|
151 |
+
if dst_patch_shape[0] != dst_patch_shape[1]:
|
152 |
+
raise NotImplementedError()
|
153 |
+
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
|
154 |
+
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
155 |
+
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
|
156 |
+
if src_size != dst_size:
|
157 |
+
print("Position interpolate for %s from %dx%d to %dx%d" % (
|
158 |
+
key, src_size, src_size, dst_size, dst_size))
|
159 |
+
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
160 |
+
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
161 |
+
|
162 |
+
def geometric_progression(a, r, n):
|
163 |
+
return a * (1.0 - r ** n) / (1.0 - r)
|
164 |
+
|
165 |
+
left, right = 1.01, 1.5
|
166 |
+
while right - left > 1e-6:
|
167 |
+
q = (left + right) / 2.0
|
168 |
+
gp = geometric_progression(1, q, src_size // 2)
|
169 |
+
if gp > dst_size // 2:
|
170 |
+
right = q
|
171 |
+
else:
|
172 |
+
left = q
|
173 |
+
|
174 |
+
# if q > 1.090307:
|
175 |
+
# q = 1.090307
|
176 |
+
|
177 |
+
dis = []
|
178 |
+
cur = 1
|
179 |
+
for i in range(src_size // 2):
|
180 |
+
dis.append(cur)
|
181 |
+
cur += q ** (i + 1)
|
182 |
+
|
183 |
+
r_ids = [-_ for _ in reversed(dis)]
|
184 |
+
|
185 |
+
x = r_ids + [0] + dis
|
186 |
+
y = r_ids + [0] + dis
|
187 |
+
|
188 |
+
t = dst_size // 2.0
|
189 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
190 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
191 |
+
|
192 |
+
print("Original positions = %s" % str(x))
|
193 |
+
print("Target positions = %s" % str(dx))
|
194 |
+
|
195 |
+
all_rel_pos_bias = []
|
196 |
+
|
197 |
+
for i in range(num_attn_heads):
|
198 |
+
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
199 |
+
f = F.interpolate.interp2d(x, y, z, kind='cubic')
|
200 |
+
all_rel_pos_bias.append(
|
201 |
+
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
|
202 |
+
|
203 |
+
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
204 |
+
|
205 |
+
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
206 |
+
state_dict[key] = new_rel_pos_bias
|
207 |
+
|
208 |
+
# interpolate position embedding
|
209 |
+
if 'pos_embed' in state_dict:
|
210 |
+
pos_embed_checkpoint = state_dict['pos_embed']
|
211 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
212 |
+
num_patches = model.visual.patch_embed.num_patches
|
213 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
214 |
+
# height (== width) for the checkpoint position embedding
|
215 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
216 |
+
# height (== width) for the new position embedding
|
217 |
+
new_size = int(num_patches ** 0.5)
|
218 |
+
# class_token and dist_token are kept unchanged
|
219 |
+
if orig_size != new_size:
|
220 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
221 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
222 |
+
# only the position tokens are interpolated
|
223 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
224 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
225 |
+
pos_tokens = torch.nn.functional.interpolate(
|
226 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
227 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
228 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
229 |
+
state_dict['pos_embed'] = new_pos_embed
|
230 |
+
|
231 |
+
patch_embed_proj = state_dict['patch_embed.proj.weight']
|
232 |
+
patch_size = model.visual.patch_embed.patch_size
|
233 |
+
state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
|
234 |
+
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
|
235 |
+
|
236 |
+
|
237 |
+
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
238 |
+
"""
|
239 |
+
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
240 |
+
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
241 |
+
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
module (torch.nn.Module): Any PyTorch module.
|
245 |
+
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
246 |
+
name (str): Full module name (prefix)
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
torch.nn.Module: Resulting module
|
250 |
+
|
251 |
+
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
252 |
+
"""
|
253 |
+
res = module
|
254 |
+
is_match = True
|
255 |
+
if module_match:
|
256 |
+
is_match = name in module_match
|
257 |
+
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
258 |
+
res = FrozenBatchNorm2d(module.num_features)
|
259 |
+
res.num_features = module.num_features
|
260 |
+
res.affine = module.affine
|
261 |
+
if module.affine:
|
262 |
+
res.weight.data = module.weight.data.clone().detach()
|
263 |
+
res.bias.data = module.bias.data.clone().detach()
|
264 |
+
res.running_mean.data = module.running_mean.data
|
265 |
+
res.running_var.data = module.running_var.data
|
266 |
+
res.eps = module.eps
|
267 |
+
else:
|
268 |
+
for child_name, child in module.named_children():
|
269 |
+
full_child_name = '.'.join([name, child_name]) if name else child_name
|
270 |
+
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
271 |
+
if new_child is not child:
|
272 |
+
res.add_module(child_name, new_child)
|
273 |
+
return res
|
274 |
+
|
275 |
+
|
276 |
+
# From PyTorch internals
|
277 |
+
def _ntuple(n):
|
278 |
+
def parse(x):
|
279 |
+
if isinstance(x, collections.abc.Iterable):
|
280 |
+
return x
|
281 |
+
return tuple(repeat(x, n))
|
282 |
+
return parse
|
283 |
+
|
284 |
+
|
285 |
+
to_1tuple = _ntuple(1)
|
286 |
+
to_2tuple = _ntuple(2)
|
287 |
+
to_3tuple = _ntuple(3)
|
288 |
+
to_4tuple = _ntuple(4)
|
289 |
+
to_ntuple = lambda n, x: _ntuple(n)(x)
|
290 |
+
|
291 |
+
|
292 |
+
def is_logging(args):
|
293 |
+
def is_global_master(args):
|
294 |
+
return args.rank == 0
|
295 |
+
|
296 |
+
def is_local_master(args):
|
297 |
+
return args.local_rank == 0
|
298 |
+
|
299 |
+
def is_master(args, local=False):
|
300 |
+
return is_local_master(args) if local else is_global_master(args)
|
301 |
+
return is_master
|
302 |
+
|
303 |
+
|
304 |
+
class AllGather(torch.autograd.Function):
|
305 |
+
"""An autograd function that performs allgather on a tensor.
|
306 |
+
Performs all_gather operation on the provided tensors.
|
307 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
308 |
+
"""
|
309 |
+
|
310 |
+
@staticmethod
|
311 |
+
def forward(ctx, tensor, rank, world_size):
|
312 |
+
tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
|
313 |
+
torch.distributed.all_gather(tensors_gather, tensor)
|
314 |
+
ctx.rank = rank
|
315 |
+
ctx.batch_size = tensor.shape[0]
|
316 |
+
return torch.cat(tensors_gather, 0)
|
317 |
+
|
318 |
+
@staticmethod
|
319 |
+
def backward(ctx, grad_output):
|
320 |
+
return (
|
321 |
+
grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
|
322 |
+
None,
|
323 |
+
None
|
324 |
+
)
|
325 |
+
|
326 |
+
allgather = AllGather.apply
|
example_inputs/hinton.jpeg
ADDED
example_inputs/lecun.jpg
ADDED
example_inputs/lifeifei.jpg
ADDED
example_inputs/liuyifei.png
ADDED
example_inputs/rihanna.webp
ADDED
example_inputs/zcy.webp
ADDED
flux/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from ._version import version as __version__ # type: ignore
|
3 |
+
from ._version import version_tuple
|
4 |
+
except ImportError:
|
5 |
+
__version__ = "unknown (no version information available)"
|
6 |
+
version_tuple = (0, 0, "unknown", "noinfo")
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
PACKAGE = __package__.replace("_", "-")
|
11 |
+
PACKAGE_ROOT = Path(__file__).parent
|
flux/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cli import app
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
app()
|
flux/api.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
API_ENDPOINT = "https://api.bfl.ml"
|
10 |
+
|
11 |
+
|
12 |
+
class ApiException(Exception):
|
13 |
+
def __init__(self, status_code: int, detail: str = None):
|
14 |
+
super().__init__()
|
15 |
+
self.detail = detail
|
16 |
+
self.status_code = status_code
|
17 |
+
|
18 |
+
def __str__(self) -> str:
|
19 |
+
return self.__repr__()
|
20 |
+
|
21 |
+
def __repr__(self) -> str:
|
22 |
+
if self.detail is None:
|
23 |
+
message = None
|
24 |
+
elif isinstance(self.detail, str):
|
25 |
+
message = self.detail
|
26 |
+
else:
|
27 |
+
message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
|
28 |
+
return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
|
29 |
+
|
30 |
+
|
31 |
+
class ImageRequest:
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
prompt: str,
|
35 |
+
width: int = 1024,
|
36 |
+
height: int = 1024,
|
37 |
+
name: str = "flux.1-pro",
|
38 |
+
num_steps: int = 50,
|
39 |
+
prompt_upsampling: bool = False,
|
40 |
+
seed: int = None,
|
41 |
+
validate: bool = True,
|
42 |
+
launch: bool = True,
|
43 |
+
api_key: str = None,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Manages an image generation request to the API.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
prompt: Prompt to sample
|
50 |
+
width: Width of the image in pixel
|
51 |
+
height: Height of the image in pixel
|
52 |
+
name: Name of the model
|
53 |
+
num_steps: Number of network evaluations
|
54 |
+
prompt_upsampling: Use prompt upsampling
|
55 |
+
seed: Fix the generation seed
|
56 |
+
validate: Run input validation
|
57 |
+
launch: Directly launches request
|
58 |
+
api_key: Your API key if not provided by the environment
|
59 |
+
|
60 |
+
Raises:
|
61 |
+
ValueError: For invalid input
|
62 |
+
ApiException: For errors raised from the API
|
63 |
+
"""
|
64 |
+
if validate:
|
65 |
+
if name not in ["flux.1-pro"]:
|
66 |
+
raise ValueError(f"Invalid model {name}")
|
67 |
+
elif width % 32 != 0:
|
68 |
+
raise ValueError(f"width must be divisible by 32, got {width}")
|
69 |
+
elif not (256 <= width <= 1440):
|
70 |
+
raise ValueError(f"width must be between 256 and 1440, got {width}")
|
71 |
+
elif height % 32 != 0:
|
72 |
+
raise ValueError(f"height must be divisible by 32, got {height}")
|
73 |
+
elif not (256 <= height <= 1440):
|
74 |
+
raise ValueError(f"height must be between 256 and 1440, got {height}")
|
75 |
+
elif not (1 <= num_steps <= 50):
|
76 |
+
raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
|
77 |
+
|
78 |
+
self.request_json = {
|
79 |
+
"prompt": prompt,
|
80 |
+
"width": width,
|
81 |
+
"height": height,
|
82 |
+
"variant": name,
|
83 |
+
"steps": num_steps,
|
84 |
+
"prompt_upsampling": prompt_upsampling,
|
85 |
+
}
|
86 |
+
if seed is not None:
|
87 |
+
self.request_json["seed"] = seed
|
88 |
+
|
89 |
+
self.request_id: str = None
|
90 |
+
self.result: dict = None
|
91 |
+
self._image_bytes: bytes = None
|
92 |
+
self._url: str = None
|
93 |
+
if api_key is None:
|
94 |
+
self.api_key = os.environ.get("BFL_API_KEY")
|
95 |
+
else:
|
96 |
+
self.api_key = api_key
|
97 |
+
|
98 |
+
if launch:
|
99 |
+
self.request()
|
100 |
+
|
101 |
+
def request(self):
|
102 |
+
"""
|
103 |
+
Request to generate the image.
|
104 |
+
"""
|
105 |
+
if self.request_id is not None:
|
106 |
+
return
|
107 |
+
response = requests.post(
|
108 |
+
f"{API_ENDPOINT}/v1/image",
|
109 |
+
headers={
|
110 |
+
"accept": "application/json",
|
111 |
+
"x-key": self.api_key,
|
112 |
+
"Content-Type": "application/json",
|
113 |
+
},
|
114 |
+
json=self.request_json,
|
115 |
+
)
|
116 |
+
result = response.json()
|
117 |
+
if response.status_code != 200:
|
118 |
+
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
|
119 |
+
self.request_id = response.json()["id"]
|
120 |
+
|
121 |
+
def retrieve(self) -> dict:
|
122 |
+
"""
|
123 |
+
Wait for the generation to finish and retrieve response.
|
124 |
+
"""
|
125 |
+
if self.request_id is None:
|
126 |
+
self.request()
|
127 |
+
while self.result is None:
|
128 |
+
response = requests.get(
|
129 |
+
f"{API_ENDPOINT}/v1/get_result",
|
130 |
+
headers={
|
131 |
+
"accept": "application/json",
|
132 |
+
"x-key": self.api_key,
|
133 |
+
},
|
134 |
+
params={
|
135 |
+
"id": self.request_id,
|
136 |
+
},
|
137 |
+
)
|
138 |
+
result = response.json()
|
139 |
+
if "status" not in result:
|
140 |
+
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
|
141 |
+
elif result["status"] == "Ready":
|
142 |
+
self.result = result["result"]
|
143 |
+
elif result["status"] == "Pending":
|
144 |
+
time.sleep(0.5)
|
145 |
+
else:
|
146 |
+
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
|
147 |
+
return self.result
|
148 |
+
|
149 |
+
@property
|
150 |
+
def bytes(self) -> bytes:
|
151 |
+
"""
|
152 |
+
Generated image as bytes.
|
153 |
+
"""
|
154 |
+
if self._image_bytes is None:
|
155 |
+
response = requests.get(self.url)
|
156 |
+
if response.status_code == 200:
|
157 |
+
self._image_bytes = response.content
|
158 |
+
else:
|
159 |
+
raise ApiException(status_code=response.status_code)
|
160 |
+
return self._image_bytes
|
161 |
+
|
162 |
+
@property
|
163 |
+
def url(self) -> str:
|
164 |
+
"""
|
165 |
+
Public url to retrieve the image from
|
166 |
+
"""
|
167 |
+
if self._url is None:
|
168 |
+
result = self.retrieve()
|
169 |
+
self._url = result["sample"]
|
170 |
+
return self._url
|
171 |
+
|
172 |
+
@property
|
173 |
+
def image(self) -> Image.Image:
|
174 |
+
"""
|
175 |
+
Load the image as a PIL Image
|
176 |
+
"""
|
177 |
+
return Image.open(io.BytesIO(self.bytes))
|
178 |
+
|
179 |
+
def save(self, path: str):
|
180 |
+
"""
|
181 |
+
Save the generated image to a local path
|
182 |
+
"""
|
183 |
+
suffix = Path(self.url).suffix
|
184 |
+
if not path.endswith(suffix):
|
185 |
+
path = path + suffix
|
186 |
+
Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
|
187 |
+
with open(path, "wb") as file:
|
188 |
+
file.write(self.bytes)
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
from fire import Fire
|
193 |
+
|
194 |
+
Fire(ImageRequest)
|
flux/cli.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from glob import iglob
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from einops import rearrange
|
9 |
+
from fire import Fire
|
10 |
+
from PIL import ExifTags, Image
|
11 |
+
from transformers import pipeline
|
12 |
+
|
13 |
+
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
14 |
+
from flux.util import (
|
15 |
+
configs,
|
16 |
+
embed_watermark,
|
17 |
+
load_ae,
|
18 |
+
load_clip,
|
19 |
+
load_flow_model,
|
20 |
+
load_t5,
|
21 |
+
)
|
22 |
+
|
23 |
+
NSFW_THRESHOLD = 0.85
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class SamplingOptions:
|
28 |
+
prompt: str
|
29 |
+
width: int
|
30 |
+
height: int
|
31 |
+
num_steps: int
|
32 |
+
guidance: float
|
33 |
+
seed: int
|
34 |
+
|
35 |
+
|
36 |
+
def parse_prompt(options: SamplingOptions) -> SamplingOptions:
|
37 |
+
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
38 |
+
usage = (
|
39 |
+
"Usage: Either write your prompt directly, leave this field empty "
|
40 |
+
"to repeat the prompt or write a command starting with a slash:\n"
|
41 |
+
"- '/w <width>' will set the width of the generated image\n"
|
42 |
+
"- '/h <height>' will set the height of the generated image\n"
|
43 |
+
"- '/s <seed>' sets the next seed\n"
|
44 |
+
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
45 |
+
"- '/n <steps>' sets the number of steps\n"
|
46 |
+
"- '/q' to quit"
|
47 |
+
)
|
48 |
+
|
49 |
+
while (prompt := input(user_question)).startswith("/"):
|
50 |
+
if prompt.startswith("/w"):
|
51 |
+
if prompt.count(" ") != 1:
|
52 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
53 |
+
continue
|
54 |
+
_, width = prompt.split()
|
55 |
+
options.width = 16 * (int(width) // 16)
|
56 |
+
print(
|
57 |
+
f"Setting resolution to {options.width} x {options.height} "
|
58 |
+
f"({options.height * options.width / 1e6:.2f}MP)"
|
59 |
+
)
|
60 |
+
elif prompt.startswith("/h"):
|
61 |
+
if prompt.count(" ") != 1:
|
62 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
63 |
+
continue
|
64 |
+
_, height = prompt.split()
|
65 |
+
options.height = 16 * (int(height) // 16)
|
66 |
+
print(
|
67 |
+
f"Setting resolution to {options.width} x {options.height} "
|
68 |
+
f"({options.height * options.width / 1e6:.2f}MP)"
|
69 |
+
)
|
70 |
+
elif prompt.startswith("/g"):
|
71 |
+
if prompt.count(" ") != 1:
|
72 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
73 |
+
continue
|
74 |
+
_, guidance = prompt.split()
|
75 |
+
options.guidance = float(guidance)
|
76 |
+
print(f"Setting guidance to {options.guidance}")
|
77 |
+
elif prompt.startswith("/s"):
|
78 |
+
if prompt.count(" ") != 1:
|
79 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
80 |
+
continue
|
81 |
+
_, seed = prompt.split()
|
82 |
+
options.seed = int(seed)
|
83 |
+
print(f"Setting seed to {options.seed}")
|
84 |
+
elif prompt.startswith("/n"):
|
85 |
+
if prompt.count(" ") != 1:
|
86 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
87 |
+
continue
|
88 |
+
_, steps = prompt.split()
|
89 |
+
options.num_steps = int(steps)
|
90 |
+
print(f"Setting seed to {options.num_steps}")
|
91 |
+
elif prompt.startswith("/q"):
|
92 |
+
print("Quitting")
|
93 |
+
return None
|
94 |
+
else:
|
95 |
+
if not prompt.startswith("/h"):
|
96 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
97 |
+
print(usage)
|
98 |
+
if prompt != "":
|
99 |
+
options.prompt = prompt
|
100 |
+
return options
|
101 |
+
|
102 |
+
|
103 |
+
@torch.inference_mode()
|
104 |
+
def main(
|
105 |
+
name: str = "flux-schnell",
|
106 |
+
width: int = 1360,
|
107 |
+
height: int = 768,
|
108 |
+
seed: int = None,
|
109 |
+
prompt: str = (
|
110 |
+
"a photo of a forest with mist swirling around the tree trunks. The word "
|
111 |
+
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
112 |
+
),
|
113 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
114 |
+
num_steps: int = None,
|
115 |
+
loop: bool = False,
|
116 |
+
guidance: float = 3.5,
|
117 |
+
offload: bool = False,
|
118 |
+
output_dir: str = "output",
|
119 |
+
add_sampling_metadata: bool = True,
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
Sample the flux model. Either interactively (set `--loop`) or run for a
|
123 |
+
single image.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
name: Name of the model to load
|
127 |
+
height: height of the sample in pixels (should be a multiple of 16)
|
128 |
+
width: width of the sample in pixels (should be a multiple of 16)
|
129 |
+
seed: Set a seed for sampling
|
130 |
+
output_name: where to save the output image, `{idx}` will be replaced
|
131 |
+
by the index of the sample
|
132 |
+
prompt: Prompt used for sampling
|
133 |
+
device: Pytorch device
|
134 |
+
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
135 |
+
loop: start an interactive session and sample multiple times
|
136 |
+
guidance: guidance value used for guidance distillation
|
137 |
+
add_sampling_metadata: Add the prompt to the image Exif metadata
|
138 |
+
"""
|
139 |
+
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
|
140 |
+
|
141 |
+
if name not in configs:
|
142 |
+
available = ", ".join(configs.keys())
|
143 |
+
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
144 |
+
|
145 |
+
torch_device = torch.device(device)
|
146 |
+
if num_steps is None:
|
147 |
+
num_steps = 4 if name == "flux-schnell" else 50
|
148 |
+
|
149 |
+
# allow for packing and conversion to latent space
|
150 |
+
height = 16 * (height // 16)
|
151 |
+
width = 16 * (width // 16)
|
152 |
+
|
153 |
+
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
154 |
+
if not os.path.exists(output_dir):
|
155 |
+
os.makedirs(output_dir)
|
156 |
+
idx = 0
|
157 |
+
else:
|
158 |
+
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)]
|
159 |
+
if len(fns) > 0:
|
160 |
+
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
161 |
+
else:
|
162 |
+
idx = 0
|
163 |
+
|
164 |
+
# init all components
|
165 |
+
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
|
166 |
+
clip = load_clip(torch_device)
|
167 |
+
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
168 |
+
ae = load_ae(name, device="cpu" if offload else torch_device)
|
169 |
+
|
170 |
+
rng = torch.Generator(device="cpu")
|
171 |
+
opts = SamplingOptions(
|
172 |
+
prompt=prompt,
|
173 |
+
width=width,
|
174 |
+
height=height,
|
175 |
+
num_steps=num_steps,
|
176 |
+
guidance=guidance,
|
177 |
+
seed=seed,
|
178 |
+
)
|
179 |
+
|
180 |
+
if loop:
|
181 |
+
opts = parse_prompt(opts)
|
182 |
+
|
183 |
+
while opts is not None:
|
184 |
+
if opts.seed is None:
|
185 |
+
opts.seed = rng.seed()
|
186 |
+
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
187 |
+
t0 = time.perf_counter()
|
188 |
+
|
189 |
+
# prepare input
|
190 |
+
x = get_noise(
|
191 |
+
1,
|
192 |
+
opts.height,
|
193 |
+
opts.width,
|
194 |
+
device=torch_device,
|
195 |
+
dtype=torch.bfloat16,
|
196 |
+
seed=opts.seed,
|
197 |
+
)
|
198 |
+
opts.seed = None
|
199 |
+
if offload:
|
200 |
+
ae = ae.cpu()
|
201 |
+
torch.cuda.empty_cache()
|
202 |
+
t5, clip = t5.to(torch_device), clip.to(torch_device)
|
203 |
+
inp = prepare(t5, clip, x, prompt=opts.prompt)
|
204 |
+
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
205 |
+
|
206 |
+
# offload TEs to CPU, load model to gpu
|
207 |
+
if offload:
|
208 |
+
t5, clip = t5.cpu(), clip.cpu()
|
209 |
+
torch.cuda.empty_cache()
|
210 |
+
model = model.to(torch_device)
|
211 |
+
|
212 |
+
# denoise initial noise
|
213 |
+
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
214 |
+
|
215 |
+
# offload model, load autoencoder to gpu
|
216 |
+
if offload:
|
217 |
+
model.cpu()
|
218 |
+
torch.cuda.empty_cache()
|
219 |
+
ae.decoder.to(x.device)
|
220 |
+
|
221 |
+
# decode latents to pixel space
|
222 |
+
x = unpack(x.float(), opts.height, opts.width)
|
223 |
+
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
224 |
+
x = ae.decode(x)
|
225 |
+
t1 = time.perf_counter()
|
226 |
+
|
227 |
+
fn = output_name.format(idx=idx)
|
228 |
+
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
|
229 |
+
# bring into PIL format and save
|
230 |
+
x = x.clamp(-1, 1)
|
231 |
+
x = embed_watermark(x.float())
|
232 |
+
x = rearrange(x[0], "c h w -> h w c")
|
233 |
+
|
234 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
235 |
+
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
|
236 |
+
|
237 |
+
if nsfw_score < NSFW_THRESHOLD:
|
238 |
+
exif_data = Image.Exif()
|
239 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
240 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
241 |
+
exif_data[ExifTags.Base.Model] = name
|
242 |
+
if add_sampling_metadata:
|
243 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
244 |
+
img.save(fn, exif=exif_data, quality=95, subsampling=0)
|
245 |
+
idx += 1
|
246 |
+
else:
|
247 |
+
print("Your generated image may contain NSFW content.")
|
248 |
+
|
249 |
+
if loop:
|
250 |
+
print("-" * 80)
|
251 |
+
opts = parse_prompt(opts)
|
252 |
+
else:
|
253 |
+
opts = None
|
254 |
+
|
255 |
+
|
256 |
+
def app():
|
257 |
+
Fire(main)
|
258 |
+
|
259 |
+
|
260 |
+
if __name__ == "__main__":
|
261 |
+
app()
|
flux/math.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
|
6 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
7 |
+
if pe is not None:
|
8 |
+
q, k = apply_rope(q, k, pe)
|
9 |
+
|
10 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
11 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
12 |
+
|
13 |
+
return x
|
14 |
+
|
15 |
+
|
16 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
17 |
+
assert dim % 2 == 0
|
18 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
19 |
+
omega = 1.0 / (theta**scale)
|
20 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
21 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
22 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
23 |
+
return out.float()
|
24 |
+
|
25 |
+
|
26 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
27 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
28 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
29 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
30 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
31 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
flux/model.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, nn
|
5 |
+
|
6 |
+
from flux.modules.layers import (
|
7 |
+
DoubleStreamBlock,
|
8 |
+
EmbedND,
|
9 |
+
LastLayer,
|
10 |
+
MLPEmbedder,
|
11 |
+
SingleStreamBlock,
|
12 |
+
timestep_embedding,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class FluxParams:
|
18 |
+
in_channels: int
|
19 |
+
vec_in_dim: int
|
20 |
+
context_in_dim: int
|
21 |
+
hidden_size: int
|
22 |
+
mlp_ratio: float
|
23 |
+
num_heads: int
|
24 |
+
depth: int
|
25 |
+
depth_single_blocks: int
|
26 |
+
axes_dim: list[int]
|
27 |
+
theta: int
|
28 |
+
qkv_bias: bool
|
29 |
+
guidance_embed: bool
|
30 |
+
|
31 |
+
|
32 |
+
class Flux(nn.Module):
|
33 |
+
"""
|
34 |
+
Transformer model for flow matching on sequences.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, params: FluxParams):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.params = params
|
41 |
+
self.in_channels = params.in_channels
|
42 |
+
self.out_channels = self.in_channels
|
43 |
+
if params.hidden_size % params.num_heads != 0:
|
44 |
+
raise ValueError(
|
45 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
46 |
+
)
|
47 |
+
pe_dim = params.hidden_size // params.num_heads
|
48 |
+
if sum(params.axes_dim) != pe_dim:
|
49 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
50 |
+
self.hidden_size = params.hidden_size
|
51 |
+
self.num_heads = params.num_heads
|
52 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
53 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
54 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
55 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
56 |
+
self.guidance_in = (
|
57 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
58 |
+
)
|
59 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
60 |
+
|
61 |
+
self.double_blocks = nn.ModuleList(
|
62 |
+
[
|
63 |
+
DoubleStreamBlock(
|
64 |
+
self.hidden_size,
|
65 |
+
self.num_heads,
|
66 |
+
mlp_ratio=params.mlp_ratio,
|
67 |
+
qkv_bias=params.qkv_bias,
|
68 |
+
)
|
69 |
+
for _ in range(params.depth)
|
70 |
+
]
|
71 |
+
)
|
72 |
+
|
73 |
+
self.single_blocks = nn.ModuleList(
|
74 |
+
[
|
75 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
76 |
+
for _ in range(params.depth_single_blocks)
|
77 |
+
]
|
78 |
+
)
|
79 |
+
|
80 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
81 |
+
|
82 |
+
self.pulid_ca = None
|
83 |
+
self.pulid_double_interval = 2
|
84 |
+
self.pulid_single_interval = 4
|
85 |
+
|
86 |
+
def forward(
|
87 |
+
self,
|
88 |
+
img: Tensor,
|
89 |
+
img_ids: Tensor,
|
90 |
+
txt: Tensor,
|
91 |
+
txt_ids: Tensor,
|
92 |
+
timesteps: Tensor,
|
93 |
+
y: Tensor,
|
94 |
+
guidance: Tensor = None,
|
95 |
+
id: Tensor = None,
|
96 |
+
id_weight: float = 1.0,
|
97 |
+
) -> Tensor:
|
98 |
+
if img.ndim != 3 or txt.ndim != 3:
|
99 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
100 |
+
|
101 |
+
# running on sequences img
|
102 |
+
img = self.img_in(img)
|
103 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
104 |
+
if self.params.guidance_embed:
|
105 |
+
if guidance is None:
|
106 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
107 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
108 |
+
vec = vec + self.vector_in(y)
|
109 |
+
txt = self.txt_in(txt)
|
110 |
+
|
111 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
112 |
+
pe = self.pe_embedder(ids)
|
113 |
+
|
114 |
+
ca_idx = 0
|
115 |
+
for i, block in enumerate(self.double_blocks):
|
116 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
117 |
+
|
118 |
+
if i % self.pulid_double_interval == 0 and id is not None:
|
119 |
+
img = img + id_weight * self.pulid_ca[ca_idx](id, img)
|
120 |
+
ca_idx += 1
|
121 |
+
|
122 |
+
img = torch.cat((txt, img), 1)
|
123 |
+
for i, block in enumerate(self.single_blocks):
|
124 |
+
x = block(img, vec=vec, pe=pe)
|
125 |
+
real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
|
126 |
+
|
127 |
+
if i % self.pulid_single_interval == 0 and id is not None:
|
128 |
+
real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
|
129 |
+
ca_idx += 1
|
130 |
+
|
131 |
+
img = torch.cat((txt, real_img), 1)
|
132 |
+
img = img[:, txt.shape[1] :, ...]
|
133 |
+
|
134 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
135 |
+
return img
|
flux/modules/__init__.py
ADDED
File without changes
|
flux/modules/autoencoder.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
from torch import Tensor, nn
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class AutoEncoderParams:
|
10 |
+
resolution: int
|
11 |
+
in_channels: int
|
12 |
+
ch: int
|
13 |
+
out_ch: int
|
14 |
+
ch_mult: list[int]
|
15 |
+
num_res_blocks: int
|
16 |
+
z_channels: int
|
17 |
+
scale_factor: float
|
18 |
+
shift_factor: float
|
19 |
+
|
20 |
+
|
21 |
+
def swish(x: Tensor) -> Tensor:
|
22 |
+
return x * torch.sigmoid(x)
|
23 |
+
|
24 |
+
|
25 |
+
class AttnBlock(nn.Module):
|
26 |
+
def __init__(self, in_channels: int):
|
27 |
+
super().__init__()
|
28 |
+
self.in_channels = in_channels
|
29 |
+
|
30 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
31 |
+
|
32 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
33 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
34 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
35 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
36 |
+
|
37 |
+
def attention(self, h_: Tensor) -> Tensor:
|
38 |
+
h_ = self.norm(h_)
|
39 |
+
q = self.q(h_)
|
40 |
+
k = self.k(h_)
|
41 |
+
v = self.v(h_)
|
42 |
+
|
43 |
+
b, c, h, w = q.shape
|
44 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
45 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
46 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
47 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
48 |
+
|
49 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
50 |
+
|
51 |
+
def forward(self, x: Tensor) -> Tensor:
|
52 |
+
return x + self.proj_out(self.attention(x))
|
53 |
+
|
54 |
+
|
55 |
+
class ResnetBlock(nn.Module):
|
56 |
+
def __init__(self, in_channels: int, out_channels: int):
|
57 |
+
super().__init__()
|
58 |
+
self.in_channels = in_channels
|
59 |
+
out_channels = in_channels if out_channels is None else out_channels
|
60 |
+
self.out_channels = out_channels
|
61 |
+
|
62 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
63 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
64 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
65 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
66 |
+
if self.in_channels != self.out_channels:
|
67 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
h = x
|
71 |
+
h = self.norm1(h)
|
72 |
+
h = swish(h)
|
73 |
+
h = self.conv1(h)
|
74 |
+
|
75 |
+
h = self.norm2(h)
|
76 |
+
h = swish(h)
|
77 |
+
h = self.conv2(h)
|
78 |
+
|
79 |
+
if self.in_channels != self.out_channels:
|
80 |
+
x = self.nin_shortcut(x)
|
81 |
+
|
82 |
+
return x + h
|
83 |
+
|
84 |
+
|
85 |
+
class Downsample(nn.Module):
|
86 |
+
def __init__(self, in_channels: int):
|
87 |
+
super().__init__()
|
88 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
89 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
90 |
+
|
91 |
+
def forward(self, x: Tensor):
|
92 |
+
pad = (0, 1, 0, 1)
|
93 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
94 |
+
x = self.conv(x)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class Upsample(nn.Module):
|
99 |
+
def __init__(self, in_channels: int):
|
100 |
+
super().__init__()
|
101 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
102 |
+
|
103 |
+
def forward(self, x: Tensor):
|
104 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
105 |
+
x = self.conv(x)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class Encoder(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
resolution: int,
|
113 |
+
in_channels: int,
|
114 |
+
ch: int,
|
115 |
+
ch_mult: list[int],
|
116 |
+
num_res_blocks: int,
|
117 |
+
z_channels: int,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.ch = ch
|
121 |
+
self.num_resolutions = len(ch_mult)
|
122 |
+
self.num_res_blocks = num_res_blocks
|
123 |
+
self.resolution = resolution
|
124 |
+
self.in_channels = in_channels
|
125 |
+
# downsampling
|
126 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
127 |
+
|
128 |
+
curr_res = resolution
|
129 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
130 |
+
self.in_ch_mult = in_ch_mult
|
131 |
+
self.down = nn.ModuleList()
|
132 |
+
block_in = self.ch
|
133 |
+
for i_level in range(self.num_resolutions):
|
134 |
+
block = nn.ModuleList()
|
135 |
+
attn = nn.ModuleList()
|
136 |
+
block_in = ch * in_ch_mult[i_level]
|
137 |
+
block_out = ch * ch_mult[i_level]
|
138 |
+
for _ in range(self.num_res_blocks):
|
139 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
140 |
+
block_in = block_out
|
141 |
+
down = nn.Module()
|
142 |
+
down.block = block
|
143 |
+
down.attn = attn
|
144 |
+
if i_level != self.num_resolutions - 1:
|
145 |
+
down.downsample = Downsample(block_in)
|
146 |
+
curr_res = curr_res // 2
|
147 |
+
self.down.append(down)
|
148 |
+
|
149 |
+
# middle
|
150 |
+
self.mid = nn.Module()
|
151 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
152 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
153 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
154 |
+
|
155 |
+
# end
|
156 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
157 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
158 |
+
|
159 |
+
def forward(self, x: Tensor) -> Tensor:
|
160 |
+
# downsampling
|
161 |
+
hs = [self.conv_in(x)]
|
162 |
+
for i_level in range(self.num_resolutions):
|
163 |
+
for i_block in range(self.num_res_blocks):
|
164 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
165 |
+
if len(self.down[i_level].attn) > 0:
|
166 |
+
h = self.down[i_level].attn[i_block](h)
|
167 |
+
hs.append(h)
|
168 |
+
if i_level != self.num_resolutions - 1:
|
169 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
170 |
+
|
171 |
+
# middle
|
172 |
+
h = hs[-1]
|
173 |
+
h = self.mid.block_1(h)
|
174 |
+
h = self.mid.attn_1(h)
|
175 |
+
h = self.mid.block_2(h)
|
176 |
+
# end
|
177 |
+
h = self.norm_out(h)
|
178 |
+
h = swish(h)
|
179 |
+
h = self.conv_out(h)
|
180 |
+
return h
|
181 |
+
|
182 |
+
|
183 |
+
class Decoder(nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
ch: int,
|
187 |
+
out_ch: int,
|
188 |
+
ch_mult: list[int],
|
189 |
+
num_res_blocks: int,
|
190 |
+
in_channels: int,
|
191 |
+
resolution: int,
|
192 |
+
z_channels: int,
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
self.ch = ch
|
196 |
+
self.num_resolutions = len(ch_mult)
|
197 |
+
self.num_res_blocks = num_res_blocks
|
198 |
+
self.resolution = resolution
|
199 |
+
self.in_channels = in_channels
|
200 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
201 |
+
|
202 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
203 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
204 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
205 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
206 |
+
|
207 |
+
# z to block_in
|
208 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
209 |
+
|
210 |
+
# middle
|
211 |
+
self.mid = nn.Module()
|
212 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
213 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
214 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
215 |
+
|
216 |
+
# upsampling
|
217 |
+
self.up = nn.ModuleList()
|
218 |
+
for i_level in reversed(range(self.num_resolutions)):
|
219 |
+
block = nn.ModuleList()
|
220 |
+
attn = nn.ModuleList()
|
221 |
+
block_out = ch * ch_mult[i_level]
|
222 |
+
for _ in range(self.num_res_blocks + 1):
|
223 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
224 |
+
block_in = block_out
|
225 |
+
up = nn.Module()
|
226 |
+
up.block = block
|
227 |
+
up.attn = attn
|
228 |
+
if i_level != 0:
|
229 |
+
up.upsample = Upsample(block_in)
|
230 |
+
curr_res = curr_res * 2
|
231 |
+
self.up.insert(0, up) # prepend to get consistent order
|
232 |
+
|
233 |
+
# end
|
234 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
235 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
236 |
+
|
237 |
+
def forward(self, z: Tensor) -> Tensor:
|
238 |
+
# z to block_in
|
239 |
+
h = self.conv_in(z)
|
240 |
+
|
241 |
+
# middle
|
242 |
+
h = self.mid.block_1(h)
|
243 |
+
h = self.mid.attn_1(h)
|
244 |
+
h = self.mid.block_2(h)
|
245 |
+
|
246 |
+
# upsampling
|
247 |
+
for i_level in reversed(range(self.num_resolutions)):
|
248 |
+
for i_block in range(self.num_res_blocks + 1):
|
249 |
+
h = self.up[i_level].block[i_block](h)
|
250 |
+
if len(self.up[i_level].attn) > 0:
|
251 |
+
h = self.up[i_level].attn[i_block](h)
|
252 |
+
if i_level != 0:
|
253 |
+
h = self.up[i_level].upsample(h)
|
254 |
+
|
255 |
+
# end
|
256 |
+
h = self.norm_out(h)
|
257 |
+
h = swish(h)
|
258 |
+
h = self.conv_out(h)
|
259 |
+
return h
|
260 |
+
|
261 |
+
|
262 |
+
class DiagonalGaussian(nn.Module):
|
263 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
264 |
+
super().__init__()
|
265 |
+
self.sample = sample
|
266 |
+
self.chunk_dim = chunk_dim
|
267 |
+
|
268 |
+
def forward(self, z: Tensor) -> Tensor:
|
269 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
270 |
+
if self.sample:
|
271 |
+
std = torch.exp(0.5 * logvar)
|
272 |
+
return mean + std * torch.randn_like(mean)
|
273 |
+
else:
|
274 |
+
return mean
|
275 |
+
|
276 |
+
|
277 |
+
class AutoEncoder(nn.Module):
|
278 |
+
def __init__(self, params: AutoEncoderParams):
|
279 |
+
super().__init__()
|
280 |
+
self.encoder = Encoder(
|
281 |
+
resolution=params.resolution,
|
282 |
+
in_channels=params.in_channels,
|
283 |
+
ch=params.ch,
|
284 |
+
ch_mult=params.ch_mult,
|
285 |
+
num_res_blocks=params.num_res_blocks,
|
286 |
+
z_channels=params.z_channels,
|
287 |
+
)
|
288 |
+
self.decoder = Decoder(
|
289 |
+
resolution=params.resolution,
|
290 |
+
in_channels=params.in_channels,
|
291 |
+
ch=params.ch,
|
292 |
+
out_ch=params.out_ch,
|
293 |
+
ch_mult=params.ch_mult,
|
294 |
+
num_res_blocks=params.num_res_blocks,
|
295 |
+
z_channels=params.z_channels,
|
296 |
+
)
|
297 |
+
self.reg = DiagonalGaussian()
|
298 |
+
|
299 |
+
self.scale_factor = params.scale_factor
|
300 |
+
self.shift_factor = params.shift_factor
|
301 |
+
|
302 |
+
def encode(self, x: Tensor) -> Tensor:
|
303 |
+
z = self.reg(self.encoder(x))
|
304 |
+
z = self.scale_factor * (z - self.shift_factor)
|
305 |
+
return z
|
306 |
+
|
307 |
+
def decode(self, z: Tensor) -> Tensor:
|
308 |
+
z = z / self.scale_factor + self.shift_factor
|
309 |
+
return self.decoder(z)
|
310 |
+
|
311 |
+
def forward(self, x: Tensor) -> Tensor:
|
312 |
+
return self.decode(self.encode(x))
|
flux/modules/conditioner.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor, nn
|
2 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
3 |
+
|
4 |
+
|
5 |
+
class HFEmbedder(nn.Module):
|
6 |
+
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
7 |
+
super().__init__()
|
8 |
+
self.is_clip = version.startswith("openai")
|
9 |
+
self.max_length = max_length
|
10 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
11 |
+
|
12 |
+
if self.is_clip:
|
13 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
14 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
15 |
+
else:
|
16 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
17 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
18 |
+
|
19 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
20 |
+
|
21 |
+
def forward(self, text: list[str]) -> Tensor:
|
22 |
+
batch_encoding = self.tokenizer(
|
23 |
+
text,
|
24 |
+
truncation=True,
|
25 |
+
max_length=self.max_length,
|
26 |
+
return_length=False,
|
27 |
+
return_overflowing_tokens=False,
|
28 |
+
padding="max_length",
|
29 |
+
return_tensors="pt",
|
30 |
+
)
|
31 |
+
|
32 |
+
outputs = self.hf_module(
|
33 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
34 |
+
attention_mask=None,
|
35 |
+
output_hidden_states=False,
|
36 |
+
)
|
37 |
+
return outputs[self.output_key]
|
flux/modules/layers.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from torch import Tensor, nn
|
7 |
+
|
8 |
+
from flux.math import attention, rope
|
9 |
+
|
10 |
+
|
11 |
+
class EmbedND(nn.Module):
|
12 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
13 |
+
super().__init__()
|
14 |
+
self.dim = dim
|
15 |
+
self.theta = theta
|
16 |
+
self.axes_dim = axes_dim
|
17 |
+
|
18 |
+
def forward(self, ids: Tensor) -> Tensor:
|
19 |
+
n_axes = ids.shape[-1]
|
20 |
+
emb = torch.cat(
|
21 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
22 |
+
dim=-3,
|
23 |
+
)
|
24 |
+
|
25 |
+
return emb.unsqueeze(1)
|
26 |
+
|
27 |
+
|
28 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
29 |
+
"""
|
30 |
+
Create sinusoidal timestep embeddings.
|
31 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
32 |
+
These may be fractional.
|
33 |
+
:param dim: the dimension of the output.
|
34 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
35 |
+
:return: an (N, D) Tensor of positional embeddings.
|
36 |
+
"""
|
37 |
+
t = time_factor * t
|
38 |
+
half = dim // 2
|
39 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
40 |
+
t.device
|
41 |
+
)
|
42 |
+
|
43 |
+
args = t[:, None].float() * freqs[None]
|
44 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
45 |
+
if dim % 2:
|
46 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
47 |
+
if torch.is_floating_point(t):
|
48 |
+
embedding = embedding.to(t)
|
49 |
+
return embedding
|
50 |
+
|
51 |
+
|
52 |
+
class MLPEmbedder(nn.Module):
|
53 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
54 |
+
super().__init__()
|
55 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
56 |
+
self.silu = nn.SiLU()
|
57 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
58 |
+
|
59 |
+
def forward(self, x: Tensor) -> Tensor:
|
60 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
61 |
+
|
62 |
+
|
63 |
+
class RMSNorm(torch.nn.Module):
|
64 |
+
def __init__(self, dim: int):
|
65 |
+
super().__init__()
|
66 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
67 |
+
|
68 |
+
def forward(self, x: Tensor):
|
69 |
+
x_dtype = x.dtype
|
70 |
+
x = x.float()
|
71 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
72 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
73 |
+
|
74 |
+
|
75 |
+
class QKNorm(torch.nn.Module):
|
76 |
+
def __init__(self, dim: int):
|
77 |
+
super().__init__()
|
78 |
+
self.query_norm = RMSNorm(dim)
|
79 |
+
self.key_norm = RMSNorm(dim)
|
80 |
+
|
81 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
82 |
+
q = self.query_norm(q)
|
83 |
+
k = self.key_norm(k)
|
84 |
+
return q.to(v), k.to(v)
|
85 |
+
|
86 |
+
|
87 |
+
class SelfAttention(nn.Module):
|
88 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
89 |
+
super().__init__()
|
90 |
+
self.num_heads = num_heads
|
91 |
+
head_dim = dim // num_heads
|
92 |
+
|
93 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
94 |
+
self.norm = QKNorm(head_dim)
|
95 |
+
self.proj = nn.Linear(dim, dim)
|
96 |
+
|
97 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
98 |
+
qkv = self.qkv(x)
|
99 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
100 |
+
q, k = self.norm(q, k, v)
|
101 |
+
x = attention(q, k, v, pe=pe)
|
102 |
+
x = self.proj(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class ModulationOut:
|
108 |
+
shift: Tensor
|
109 |
+
scale: Tensor
|
110 |
+
gate: Tensor
|
111 |
+
|
112 |
+
|
113 |
+
class Modulation(nn.Module):
|
114 |
+
def __init__(self, dim: int, double: bool):
|
115 |
+
super().__init__()
|
116 |
+
self.is_double = double
|
117 |
+
self.multiplier = 6 if double else 3
|
118 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
119 |
+
|
120 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]:
|
121 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
122 |
+
|
123 |
+
return (
|
124 |
+
ModulationOut(*out[:3]),
|
125 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
class DoubleStreamBlock(nn.Module):
|
130 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
131 |
+
super().__init__()
|
132 |
+
|
133 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
134 |
+
self.num_heads = num_heads
|
135 |
+
self.hidden_size = hidden_size
|
136 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
137 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
138 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
139 |
+
|
140 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
141 |
+
self.img_mlp = nn.Sequential(
|
142 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
143 |
+
nn.GELU(approximate="tanh"),
|
144 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
145 |
+
)
|
146 |
+
|
147 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
148 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
149 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
150 |
+
|
151 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
152 |
+
self.txt_mlp = nn.Sequential(
|
153 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
154 |
+
nn.GELU(approximate="tanh"),
|
155 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
156 |
+
)
|
157 |
+
|
158 |
+
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
159 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
160 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
161 |
+
|
162 |
+
# prepare image for attention
|
163 |
+
img_modulated = self.img_norm1(img)
|
164 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
165 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
166 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
167 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
168 |
+
|
169 |
+
# prepare txt for attention
|
170 |
+
txt_modulated = self.txt_norm1(txt)
|
171 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
172 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
173 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
174 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
175 |
+
|
176 |
+
# run actual attention
|
177 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
178 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
179 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
180 |
+
|
181 |
+
attn = attention(q, k, v, pe=pe)
|
182 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
183 |
+
|
184 |
+
# calculate the img bloks
|
185 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
186 |
+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
187 |
+
|
188 |
+
# calculate the txt bloks
|
189 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
190 |
+
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
191 |
+
return img, txt
|
192 |
+
|
193 |
+
|
194 |
+
class SingleStreamBlock(nn.Module):
|
195 |
+
"""
|
196 |
+
A DiT block with parallel linear layers as described in
|
197 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
198 |
+
"""
|
199 |
+
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
hidden_size: int,
|
203 |
+
num_heads: int,
|
204 |
+
mlp_ratio: float = 4.0,
|
205 |
+
qk_scale: float = None,
|
206 |
+
):
|
207 |
+
super().__init__()
|
208 |
+
self.hidden_dim = hidden_size
|
209 |
+
self.num_heads = num_heads
|
210 |
+
head_dim = hidden_size // num_heads
|
211 |
+
self.scale = qk_scale or head_dim**-0.5
|
212 |
+
|
213 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
214 |
+
# qkv and mlp_in
|
215 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
216 |
+
# proj and mlp_out
|
217 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
218 |
+
|
219 |
+
self.norm = QKNorm(head_dim)
|
220 |
+
|
221 |
+
self.hidden_size = hidden_size
|
222 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
223 |
+
|
224 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
225 |
+
self.modulation = Modulation(hidden_size, double=False)
|
226 |
+
|
227 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
228 |
+
mod, _ = self.modulation(vec)
|
229 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
230 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
231 |
+
|
232 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
233 |
+
q, k = self.norm(q, k, v)
|
234 |
+
|
235 |
+
# compute attention
|
236 |
+
attn = attention(q, k, v, pe=pe)
|
237 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
238 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
239 |
+
return x + mod.gate * output
|
240 |
+
|
241 |
+
|
242 |
+
class LastLayer(nn.Module):
|
243 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
244 |
+
super().__init__()
|
245 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
246 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
247 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
248 |
+
|
249 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
250 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
251 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
252 |
+
x = self.linear(x)
|
253 |
+
return x
|
flux/sampling.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from .model import Flux
|
9 |
+
from .modules.conditioner import HFEmbedder
|
10 |
+
|
11 |
+
|
12 |
+
def get_noise(
|
13 |
+
num_samples: int,
|
14 |
+
height: int,
|
15 |
+
width: int,
|
16 |
+
device: torch.device,
|
17 |
+
dtype: torch.dtype,
|
18 |
+
seed: int,
|
19 |
+
):
|
20 |
+
return torch.randn(
|
21 |
+
num_samples,
|
22 |
+
16,
|
23 |
+
# allow for packing
|
24 |
+
2 * math.ceil(height / 16),
|
25 |
+
2 * math.ceil(width / 16),
|
26 |
+
device=device,
|
27 |
+
dtype=dtype,
|
28 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str) -> dict[str, Tensor]:
|
33 |
+
bs, c, h, w = img.shape
|
34 |
+
if bs == 1 and not isinstance(prompt, str):
|
35 |
+
bs = len(prompt)
|
36 |
+
|
37 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
38 |
+
if img.shape[0] == 1 and bs > 1:
|
39 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
40 |
+
|
41 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
42 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
43 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
44 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
45 |
+
|
46 |
+
if isinstance(prompt, str):
|
47 |
+
prompt = [prompt]
|
48 |
+
txt = t5(prompt)
|
49 |
+
if txt.shape[0] == 1 and bs > 1:
|
50 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
51 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
52 |
+
|
53 |
+
vec = clip(prompt)
|
54 |
+
if vec.shape[0] == 1 and bs > 1:
|
55 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
56 |
+
|
57 |
+
return {
|
58 |
+
"img": img,
|
59 |
+
"img_ids": img_ids.to(img.device),
|
60 |
+
"txt": txt.to(img.device),
|
61 |
+
"txt_ids": txt_ids.to(img.device),
|
62 |
+
"vec": vec.to(img.device),
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
67 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
68 |
+
|
69 |
+
|
70 |
+
def get_lin_function(
|
71 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
72 |
+
) -> Callable[[float], float]:
|
73 |
+
m = (y2 - y1) / (x2 - x1)
|
74 |
+
b = y1 - m * x1
|
75 |
+
return lambda x: m * x + b
|
76 |
+
|
77 |
+
|
78 |
+
def get_schedule(
|
79 |
+
num_steps: int,
|
80 |
+
image_seq_len: int,
|
81 |
+
base_shift: float = 0.5,
|
82 |
+
max_shift: float = 1.15,
|
83 |
+
shift: bool = True,
|
84 |
+
) -> list[float]:
|
85 |
+
# extra step for zero
|
86 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
87 |
+
|
88 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
89 |
+
if shift:
|
90 |
+
# eastimate mu based on linear estimation between two points
|
91 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
92 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
93 |
+
|
94 |
+
return timesteps.tolist()
|
95 |
+
|
96 |
+
|
97 |
+
def denoise(
|
98 |
+
model: Flux,
|
99 |
+
# model input
|
100 |
+
img: Tensor,
|
101 |
+
img_ids: Tensor,
|
102 |
+
txt: Tensor,
|
103 |
+
txt_ids: Tensor,
|
104 |
+
vec: Tensor,
|
105 |
+
timesteps: list[float],
|
106 |
+
guidance: float = 4.0,
|
107 |
+
id_weight=1.0,
|
108 |
+
id=None,
|
109 |
+
start_step=0,
|
110 |
+
uncond_id=None,
|
111 |
+
true_cfg=1.0,
|
112 |
+
timestep_to_start_cfg=1,
|
113 |
+
neg_txt=None,
|
114 |
+
neg_txt_ids=None,
|
115 |
+
neg_vec=None,
|
116 |
+
):
|
117 |
+
# this is ignored for schnell
|
118 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
119 |
+
use_true_cfg = abs(true_cfg - 1.0) > 1e-2
|
120 |
+
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
|
121 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
122 |
+
pred = model(
|
123 |
+
img=img,
|
124 |
+
img_ids=img_ids,
|
125 |
+
txt=txt,
|
126 |
+
txt_ids=txt_ids,
|
127 |
+
y=vec,
|
128 |
+
timesteps=t_vec,
|
129 |
+
guidance=guidance_vec,
|
130 |
+
id=id if i >= start_step else None,
|
131 |
+
id_weight=id_weight,
|
132 |
+
)
|
133 |
+
|
134 |
+
if use_true_cfg and i >= timestep_to_start_cfg:
|
135 |
+
neg_pred = model(
|
136 |
+
img=img,
|
137 |
+
img_ids=img_ids,
|
138 |
+
txt=neg_txt,
|
139 |
+
txt_ids=neg_txt_ids,
|
140 |
+
y=neg_vec,
|
141 |
+
timesteps=t_vec,
|
142 |
+
guidance=guidance_vec,
|
143 |
+
id=uncond_id if i >= start_step else None,
|
144 |
+
id_weight=id_weight,
|
145 |
+
)
|
146 |
+
pred = neg_pred + true_cfg * (pred - neg_pred)
|
147 |
+
|
148 |
+
img = img + (t_prev - t_curr) * pred
|
149 |
+
|
150 |
+
return img
|
151 |
+
|
152 |
+
|
153 |
+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
154 |
+
return rearrange(
|
155 |
+
x,
|
156 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
157 |
+
h=math.ceil(height / 16),
|
158 |
+
w=math.ceil(width / 16),
|
159 |
+
ph=2,
|
160 |
+
pw=2,
|
161 |
+
)
|
flux/util.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from imwatermark import WatermarkEncoder
|
8 |
+
from safetensors.torch import load_file as load_sft
|
9 |
+
|
10 |
+
from flux.model import Flux, FluxParams
|
11 |
+
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
12 |
+
from flux.modules.conditioner import HFEmbedder
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class ModelSpec:
|
17 |
+
params: FluxParams
|
18 |
+
ae_params: AutoEncoderParams
|
19 |
+
ckpt_path: str
|
20 |
+
ae_path: str
|
21 |
+
repo_id: str
|
22 |
+
repo_flow: str
|
23 |
+
repo_ae: str
|
24 |
+
|
25 |
+
|
26 |
+
configs = {
|
27 |
+
"flux-dev": ModelSpec(
|
28 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
29 |
+
repo_flow="flux1-dev.safetensors",
|
30 |
+
repo_ae="ae.safetensors",
|
31 |
+
ckpt_path='models/flux1-dev.safetensors',
|
32 |
+
params=FluxParams(
|
33 |
+
in_channels=64,
|
34 |
+
vec_in_dim=768,
|
35 |
+
context_in_dim=4096,
|
36 |
+
hidden_size=3072,
|
37 |
+
mlp_ratio=4.0,
|
38 |
+
num_heads=24,
|
39 |
+
depth=19,
|
40 |
+
depth_single_blocks=38,
|
41 |
+
axes_dim=[16, 56, 56],
|
42 |
+
theta=10_000,
|
43 |
+
qkv_bias=True,
|
44 |
+
guidance_embed=True,
|
45 |
+
),
|
46 |
+
ae_path='models/ae.safetensors',
|
47 |
+
ae_params=AutoEncoderParams(
|
48 |
+
resolution=256,
|
49 |
+
in_channels=3,
|
50 |
+
ch=128,
|
51 |
+
out_ch=3,
|
52 |
+
ch_mult=[1, 2, 4, 4],
|
53 |
+
num_res_blocks=2,
|
54 |
+
z_channels=16,
|
55 |
+
scale_factor=0.3611,
|
56 |
+
shift_factor=0.1159,
|
57 |
+
),
|
58 |
+
),
|
59 |
+
"flux-schnell": ModelSpec(
|
60 |
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
61 |
+
repo_flow="flux1-schnell.safetensors",
|
62 |
+
repo_ae="ae.safetensors",
|
63 |
+
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
64 |
+
params=FluxParams(
|
65 |
+
in_channels=64,
|
66 |
+
vec_in_dim=768,
|
67 |
+
context_in_dim=4096,
|
68 |
+
hidden_size=3072,
|
69 |
+
mlp_ratio=4.0,
|
70 |
+
num_heads=24,
|
71 |
+
depth=19,
|
72 |
+
depth_single_blocks=38,
|
73 |
+
axes_dim=[16, 56, 56],
|
74 |
+
theta=10_000,
|
75 |
+
qkv_bias=True,
|
76 |
+
guidance_embed=False,
|
77 |
+
),
|
78 |
+
ae_path=os.getenv("AE"),
|
79 |
+
ae_params=AutoEncoderParams(
|
80 |
+
resolution=256,
|
81 |
+
in_channels=3,
|
82 |
+
ch=128,
|
83 |
+
out_ch=3,
|
84 |
+
ch_mult=[1, 2, 4, 4],
|
85 |
+
num_res_blocks=2,
|
86 |
+
z_channels=16,
|
87 |
+
scale_factor=0.3611,
|
88 |
+
shift_factor=0.1159,
|
89 |
+
),
|
90 |
+
),
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
95 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
96 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
97 |
+
print("\n" + "-" * 79 + "\n")
|
98 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
99 |
+
elif len(missing) > 0:
|
100 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
101 |
+
elif len(unexpected) > 0:
|
102 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
103 |
+
|
104 |
+
|
105 |
+
def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
|
106 |
+
# Loading Flux
|
107 |
+
print("Init model")
|
108 |
+
ckpt_path = configs[name].ckpt_path
|
109 |
+
if (
|
110 |
+
ckpt_path is None
|
111 |
+
and configs[name].repo_id is not None
|
112 |
+
and configs[name].repo_flow is not None
|
113 |
+
and hf_download
|
114 |
+
):
|
115 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
|
116 |
+
|
117 |
+
with torch.device(device):
|
118 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
119 |
+
|
120 |
+
if ckpt_path is not None:
|
121 |
+
print("Loading checkpoint")
|
122 |
+
# load_sft doesn't support torch.device
|
123 |
+
sd = load_sft(ckpt_path, device=str(device))
|
124 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
125 |
+
print_load_warning(missing, unexpected)
|
126 |
+
return model
|
127 |
+
|
128 |
+
|
129 |
+
def load_t5(device: str = "cuda", max_length: int = 512) -> HFEmbedder:
|
130 |
+
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
131 |
+
return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
132 |
+
|
133 |
+
|
134 |
+
def load_clip(device: str = "cuda") -> HFEmbedder:
|
135 |
+
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
|
136 |
+
|
137 |
+
|
138 |
+
def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEncoder:
|
139 |
+
ckpt_path = configs[name].ae_path
|
140 |
+
if (
|
141 |
+
ckpt_path is None
|
142 |
+
and configs[name].repo_id is not None
|
143 |
+
and configs[name].repo_ae is not None
|
144 |
+
and hf_download
|
145 |
+
):
|
146 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae, local_dir='models')
|
147 |
+
|
148 |
+
# Loading the autoencoder
|
149 |
+
print("Init AE")
|
150 |
+
with torch.device(device):
|
151 |
+
ae = AutoEncoder(configs[name].ae_params)
|
152 |
+
|
153 |
+
if ckpt_path is not None:
|
154 |
+
sd = load_sft(ckpt_path, device=str(device))
|
155 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False)
|
156 |
+
print_load_warning(missing, unexpected)
|
157 |
+
return ae
|
158 |
+
|
159 |
+
|
160 |
+
class WatermarkEmbedder:
|
161 |
+
def __init__(self, watermark):
|
162 |
+
self.watermark = watermark
|
163 |
+
self.num_bits = len(WATERMARK_BITS)
|
164 |
+
self.encoder = WatermarkEncoder()
|
165 |
+
self.encoder.set_watermark("bits", self.watermark)
|
166 |
+
|
167 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
168 |
+
"""
|
169 |
+
Adds a predefined watermark to the input image
|
170 |
+
|
171 |
+
Args:
|
172 |
+
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
same as input but watermarked
|
176 |
+
"""
|
177 |
+
image = 0.5 * image + 0.5
|
178 |
+
squeeze = len(image.shape) == 4
|
179 |
+
if squeeze:
|
180 |
+
image = image[None, ...]
|
181 |
+
n = image.shape[0]
|
182 |
+
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
183 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
184 |
+
# watermarking libary expects input as cv2 BGR format
|
185 |
+
for k in range(image_np.shape[0]):
|
186 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
187 |
+
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
188 |
+
image.device
|
189 |
+
)
|
190 |
+
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
191 |
+
if squeeze:
|
192 |
+
image = image[0]
|
193 |
+
image = 2 * image - 1
|
194 |
+
return image
|
195 |
+
|
196 |
+
|
197 |
+
# A fixed 48-bit message that was choosen at random
|
198 |
+
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
199 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
200 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
201 |
+
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
models/.gitkeep
ADDED
File without changes
|
pulid/attention_processor.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
NUM_ZERO = 0
|
7 |
+
ORTHO = False
|
8 |
+
ORTHO_v2 = False
|
9 |
+
|
10 |
+
|
11 |
+
class AttnProcessor(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
def __call__(
|
16 |
+
self,
|
17 |
+
attn,
|
18 |
+
hidden_states,
|
19 |
+
encoder_hidden_states=None,
|
20 |
+
attention_mask=None,
|
21 |
+
temb=None,
|
22 |
+
id_embedding=None,
|
23 |
+
id_scale=1.0,
|
24 |
+
):
|
25 |
+
residual = hidden_states
|
26 |
+
|
27 |
+
if attn.spatial_norm is not None:
|
28 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
29 |
+
|
30 |
+
input_ndim = hidden_states.ndim
|
31 |
+
|
32 |
+
if input_ndim == 4:
|
33 |
+
batch_size, channel, height, width = hidden_states.shape
|
34 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
35 |
+
|
36 |
+
batch_size, sequence_length, _ = (
|
37 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
38 |
+
)
|
39 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
40 |
+
|
41 |
+
if attn.group_norm is not None:
|
42 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
43 |
+
|
44 |
+
query = attn.to_q(hidden_states)
|
45 |
+
|
46 |
+
if encoder_hidden_states is None:
|
47 |
+
encoder_hidden_states = hidden_states
|
48 |
+
elif attn.norm_cross:
|
49 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
50 |
+
|
51 |
+
key = attn.to_k(encoder_hidden_states)
|
52 |
+
value = attn.to_v(encoder_hidden_states)
|
53 |
+
|
54 |
+
query = attn.head_to_batch_dim(query)
|
55 |
+
key = attn.head_to_batch_dim(key)
|
56 |
+
value = attn.head_to_batch_dim(value)
|
57 |
+
|
58 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
59 |
+
hidden_states = torch.bmm(attention_probs, value)
|
60 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
61 |
+
|
62 |
+
# linear proj
|
63 |
+
hidden_states = attn.to_out[0](hidden_states)
|
64 |
+
# dropout
|
65 |
+
hidden_states = attn.to_out[1](hidden_states)
|
66 |
+
|
67 |
+
if input_ndim == 4:
|
68 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
69 |
+
|
70 |
+
if attn.residual_connection:
|
71 |
+
hidden_states = hidden_states + residual
|
72 |
+
|
73 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
74 |
+
|
75 |
+
return hidden_states
|
76 |
+
|
77 |
+
|
78 |
+
class IDAttnProcessor(nn.Module):
|
79 |
+
r"""
|
80 |
+
Attention processor for ID-Adapater.
|
81 |
+
Args:
|
82 |
+
hidden_size (`int`):
|
83 |
+
The hidden size of the attention layer.
|
84 |
+
cross_attention_dim (`int`):
|
85 |
+
The number of channels in the `encoder_hidden_states`.
|
86 |
+
scale (`float`, defaults to 1.0):
|
87 |
+
the weight scale of image prompt.
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, hidden_size, cross_attention_dim=None):
|
91 |
+
super().__init__()
|
92 |
+
self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
93 |
+
self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
94 |
+
|
95 |
+
def __call__(
|
96 |
+
self,
|
97 |
+
attn,
|
98 |
+
hidden_states,
|
99 |
+
encoder_hidden_states=None,
|
100 |
+
attention_mask=None,
|
101 |
+
temb=None,
|
102 |
+
id_embedding=None,
|
103 |
+
id_scale=1.0,
|
104 |
+
):
|
105 |
+
residual = hidden_states
|
106 |
+
|
107 |
+
if attn.spatial_norm is not None:
|
108 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
109 |
+
|
110 |
+
input_ndim = hidden_states.ndim
|
111 |
+
|
112 |
+
if input_ndim == 4:
|
113 |
+
batch_size, channel, height, width = hidden_states.shape
|
114 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
115 |
+
|
116 |
+
batch_size, sequence_length, _ = (
|
117 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
118 |
+
)
|
119 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
120 |
+
|
121 |
+
if attn.group_norm is not None:
|
122 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
123 |
+
|
124 |
+
query = attn.to_q(hidden_states)
|
125 |
+
|
126 |
+
if encoder_hidden_states is None:
|
127 |
+
encoder_hidden_states = hidden_states
|
128 |
+
elif attn.norm_cross:
|
129 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
130 |
+
|
131 |
+
key = attn.to_k(encoder_hidden_states)
|
132 |
+
value = attn.to_v(encoder_hidden_states)
|
133 |
+
|
134 |
+
query = attn.head_to_batch_dim(query)
|
135 |
+
key = attn.head_to_batch_dim(key)
|
136 |
+
value = attn.head_to_batch_dim(value)
|
137 |
+
|
138 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
139 |
+
hidden_states = torch.bmm(attention_probs, value)
|
140 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
141 |
+
|
142 |
+
# for id-adapter
|
143 |
+
if id_embedding is not None:
|
144 |
+
if NUM_ZERO == 0:
|
145 |
+
id_key = self.id_to_k(id_embedding)
|
146 |
+
id_value = self.id_to_v(id_embedding)
|
147 |
+
else:
|
148 |
+
zero_tensor = torch.zeros(
|
149 |
+
(id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
|
150 |
+
dtype=id_embedding.dtype,
|
151 |
+
device=id_embedding.device,
|
152 |
+
)
|
153 |
+
id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1))
|
154 |
+
id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1))
|
155 |
+
|
156 |
+
id_key = attn.head_to_batch_dim(id_key).to(query.dtype)
|
157 |
+
id_value = attn.head_to_batch_dim(id_value).to(query.dtype)
|
158 |
+
|
159 |
+
id_attention_probs = attn.get_attention_scores(query, id_key, None)
|
160 |
+
id_hidden_states = torch.bmm(id_attention_probs, id_value)
|
161 |
+
id_hidden_states = attn.batch_to_head_dim(id_hidden_states)
|
162 |
+
|
163 |
+
if not ORTHO:
|
164 |
+
hidden_states = hidden_states + id_scale * id_hidden_states
|
165 |
+
else:
|
166 |
+
projection = (
|
167 |
+
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
168 |
+
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
169 |
+
* hidden_states
|
170 |
+
)
|
171 |
+
orthogonal = id_hidden_states - projection
|
172 |
+
hidden_states = hidden_states + id_scale * orthogonal
|
173 |
+
|
174 |
+
# linear proj
|
175 |
+
hidden_states = attn.to_out[0](hidden_states)
|
176 |
+
# dropout
|
177 |
+
hidden_states = attn.to_out[1](hidden_states)
|
178 |
+
|
179 |
+
if input_ndim == 4:
|
180 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
181 |
+
|
182 |
+
if attn.residual_connection:
|
183 |
+
hidden_states = hidden_states + residual
|
184 |
+
|
185 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
186 |
+
|
187 |
+
return hidden_states
|
188 |
+
|
189 |
+
|
190 |
+
class AttnProcessor2_0(nn.Module):
|
191 |
+
r"""
|
192 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self):
|
196 |
+
super().__init__()
|
197 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
198 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
199 |
+
|
200 |
+
def __call__(
|
201 |
+
self,
|
202 |
+
attn,
|
203 |
+
hidden_states,
|
204 |
+
encoder_hidden_states=None,
|
205 |
+
attention_mask=None,
|
206 |
+
temb=None,
|
207 |
+
id_embedding=None,
|
208 |
+
id_scale=1.0,
|
209 |
+
):
|
210 |
+
residual = hidden_states
|
211 |
+
|
212 |
+
if attn.spatial_norm is not None:
|
213 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
214 |
+
|
215 |
+
input_ndim = hidden_states.ndim
|
216 |
+
|
217 |
+
if input_ndim == 4:
|
218 |
+
batch_size, channel, height, width = hidden_states.shape
|
219 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
220 |
+
|
221 |
+
batch_size, sequence_length, _ = (
|
222 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
223 |
+
)
|
224 |
+
|
225 |
+
if attention_mask is not None:
|
226 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
227 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
228 |
+
# (batch, heads, source_length, target_length)
|
229 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
230 |
+
|
231 |
+
if attn.group_norm is not None:
|
232 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
233 |
+
|
234 |
+
query = attn.to_q(hidden_states)
|
235 |
+
|
236 |
+
if encoder_hidden_states is None:
|
237 |
+
encoder_hidden_states = hidden_states
|
238 |
+
elif attn.norm_cross:
|
239 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
240 |
+
|
241 |
+
key = attn.to_k(encoder_hidden_states)
|
242 |
+
value = attn.to_v(encoder_hidden_states)
|
243 |
+
|
244 |
+
inner_dim = key.shape[-1]
|
245 |
+
head_dim = inner_dim // attn.heads
|
246 |
+
|
247 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
248 |
+
|
249 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
250 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
251 |
+
|
252 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
253 |
+
hidden_states = F.scaled_dot_product_attention(
|
254 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
255 |
+
)
|
256 |
+
|
257 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
258 |
+
hidden_states = hidden_states.to(query.dtype)
|
259 |
+
|
260 |
+
# linear proj
|
261 |
+
hidden_states = attn.to_out[0](hidden_states)
|
262 |
+
# dropout
|
263 |
+
hidden_states = attn.to_out[1](hidden_states)
|
264 |
+
|
265 |
+
if input_ndim == 4:
|
266 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
267 |
+
|
268 |
+
if attn.residual_connection:
|
269 |
+
hidden_states = hidden_states + residual
|
270 |
+
|
271 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
272 |
+
|
273 |
+
return hidden_states
|
274 |
+
|
275 |
+
|
276 |
+
class IDAttnProcessor2_0(torch.nn.Module):
|
277 |
+
r"""
|
278 |
+
Attention processor for ID-Adapater for PyTorch 2.0.
|
279 |
+
Args:
|
280 |
+
hidden_size (`int`):
|
281 |
+
The hidden size of the attention layer.
|
282 |
+
cross_attention_dim (`int`):
|
283 |
+
The number of channels in the `encoder_hidden_states`.
|
284 |
+
"""
|
285 |
+
|
286 |
+
def __init__(self, hidden_size, cross_attention_dim=None):
|
287 |
+
super().__init__()
|
288 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
289 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
290 |
+
|
291 |
+
self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
292 |
+
self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
293 |
+
|
294 |
+
def __call__(
|
295 |
+
self,
|
296 |
+
attn,
|
297 |
+
hidden_states,
|
298 |
+
encoder_hidden_states=None,
|
299 |
+
attention_mask=None,
|
300 |
+
temb=None,
|
301 |
+
id_embedding=None,
|
302 |
+
id_scale=1.0,
|
303 |
+
):
|
304 |
+
residual = hidden_states
|
305 |
+
|
306 |
+
if attn.spatial_norm is not None:
|
307 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
308 |
+
|
309 |
+
input_ndim = hidden_states.ndim
|
310 |
+
|
311 |
+
if input_ndim == 4:
|
312 |
+
batch_size, channel, height, width = hidden_states.shape
|
313 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
314 |
+
|
315 |
+
batch_size, sequence_length, _ = (
|
316 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
317 |
+
)
|
318 |
+
|
319 |
+
if attention_mask is not None:
|
320 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
321 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
322 |
+
# (batch, heads, source_length, target_length)
|
323 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
324 |
+
|
325 |
+
if attn.group_norm is not None:
|
326 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
327 |
+
|
328 |
+
query = attn.to_q(hidden_states)
|
329 |
+
|
330 |
+
if encoder_hidden_states is None:
|
331 |
+
encoder_hidden_states = hidden_states
|
332 |
+
elif attn.norm_cross:
|
333 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
334 |
+
|
335 |
+
key = attn.to_k(encoder_hidden_states)
|
336 |
+
value = attn.to_v(encoder_hidden_states)
|
337 |
+
|
338 |
+
inner_dim = key.shape[-1]
|
339 |
+
head_dim = inner_dim // attn.heads
|
340 |
+
|
341 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
342 |
+
|
343 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
344 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
345 |
+
|
346 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
347 |
+
hidden_states = F.scaled_dot_product_attention(
|
348 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
349 |
+
)
|
350 |
+
|
351 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
352 |
+
hidden_states = hidden_states.to(query.dtype)
|
353 |
+
|
354 |
+
# for id embedding
|
355 |
+
if id_embedding is not None:
|
356 |
+
if NUM_ZERO == 0:
|
357 |
+
id_key = self.id_to_k(id_embedding).to(query.dtype)
|
358 |
+
id_value = self.id_to_v(id_embedding).to(query.dtype)
|
359 |
+
else:
|
360 |
+
zero_tensor = torch.zeros(
|
361 |
+
(id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
|
362 |
+
dtype=id_embedding.dtype,
|
363 |
+
device=id_embedding.device,
|
364 |
+
)
|
365 |
+
id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
|
366 |
+
id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
|
367 |
+
|
368 |
+
id_key = id_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
369 |
+
id_value = id_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
370 |
+
|
371 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
372 |
+
id_hidden_states = F.scaled_dot_product_attention(
|
373 |
+
query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
374 |
+
)
|
375 |
+
|
376 |
+
id_hidden_states = id_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
377 |
+
id_hidden_states = id_hidden_states.to(query.dtype)
|
378 |
+
|
379 |
+
if not ORTHO and not ORTHO_v2:
|
380 |
+
hidden_states = hidden_states + id_scale * id_hidden_states
|
381 |
+
elif ORTHO_v2:
|
382 |
+
orig_dtype = hidden_states.dtype
|
383 |
+
hidden_states = hidden_states.to(torch.float32)
|
384 |
+
id_hidden_states = id_hidden_states.to(torch.float32)
|
385 |
+
attn_map = query @ id_key.transpose(-2, -1)
|
386 |
+
attn_mean = attn_map.softmax(dim=-1).mean(dim=1)
|
387 |
+
attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True)
|
388 |
+
projection = (
|
389 |
+
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
390 |
+
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
391 |
+
* hidden_states
|
392 |
+
)
|
393 |
+
orthogonal = id_hidden_states + (attn_mean - 1) * projection
|
394 |
+
hidden_states = hidden_states + id_scale * orthogonal
|
395 |
+
hidden_states = hidden_states.to(orig_dtype)
|
396 |
+
else:
|
397 |
+
orig_dtype = hidden_states.dtype
|
398 |
+
hidden_states = hidden_states.to(torch.float32)
|
399 |
+
id_hidden_states = id_hidden_states.to(torch.float32)
|
400 |
+
projection = (
|
401 |
+
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
402 |
+
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
403 |
+
* hidden_states
|
404 |
+
)
|
405 |
+
orthogonal = id_hidden_states - projection
|
406 |
+
hidden_states = hidden_states + id_scale * orthogonal
|
407 |
+
hidden_states = hidden_states.to(orig_dtype)
|
408 |
+
|
409 |
+
# linear proj
|
410 |
+
hidden_states = attn.to_out[0](hidden_states)
|
411 |
+
# dropout
|
412 |
+
hidden_states = attn.to_out[1](hidden_states)
|
413 |
+
|
414 |
+
if input_ndim == 4:
|
415 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
416 |
+
|
417 |
+
if attn.residual_connection:
|
418 |
+
hidden_states = hidden_states + residual
|
419 |
+
|
420 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
421 |
+
|
422 |
+
return hidden_states
|
pulid/encoders.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class IDEncoder(nn.Module):
|
6 |
+
def __init__(self, width=1280, context_dim=2048, num_token=5):
|
7 |
+
super().__init__()
|
8 |
+
self.num_token = num_token
|
9 |
+
self.context_dim = context_dim
|
10 |
+
h1 = min((context_dim * num_token) // 4, 1024)
|
11 |
+
h2 = min((context_dim * num_token) // 2, 1024)
|
12 |
+
self.body = nn.Sequential(
|
13 |
+
nn.Linear(width, h1),
|
14 |
+
nn.LayerNorm(h1),
|
15 |
+
nn.LeakyReLU(),
|
16 |
+
nn.Linear(h1, h2),
|
17 |
+
nn.LayerNorm(h2),
|
18 |
+
nn.LeakyReLU(),
|
19 |
+
nn.Linear(h2, context_dim * num_token),
|
20 |
+
)
|
21 |
+
|
22 |
+
for i in range(5):
|
23 |
+
setattr(
|
24 |
+
self,
|
25 |
+
f'mapping_{i}',
|
26 |
+
nn.Sequential(
|
27 |
+
nn.Linear(1024, 1024),
|
28 |
+
nn.LayerNorm(1024),
|
29 |
+
nn.LeakyReLU(),
|
30 |
+
nn.Linear(1024, 1024),
|
31 |
+
nn.LayerNorm(1024),
|
32 |
+
nn.LeakyReLU(),
|
33 |
+
nn.Linear(1024, context_dim),
|
34 |
+
),
|
35 |
+
)
|
36 |
+
|
37 |
+
setattr(
|
38 |
+
self,
|
39 |
+
f'mapping_patch_{i}',
|
40 |
+
nn.Sequential(
|
41 |
+
nn.Linear(1024, 1024),
|
42 |
+
nn.LayerNorm(1024),
|
43 |
+
nn.LeakyReLU(),
|
44 |
+
nn.Linear(1024, 1024),
|
45 |
+
nn.LayerNorm(1024),
|
46 |
+
nn.LeakyReLU(),
|
47 |
+
nn.Linear(1024, context_dim),
|
48 |
+
),
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x, y):
|
52 |
+
# x shape [N, C]
|
53 |
+
x = self.body(x)
|
54 |
+
x = x.reshape(-1, self.num_token, self.context_dim)
|
55 |
+
|
56 |
+
hidden_states = ()
|
57 |
+
for i, emb in enumerate(y):
|
58 |
+
hidden_state = getattr(self, f'mapping_{i}')(emb[:, :1]) + getattr(self, f'mapping_patch_{i}')(
|
59 |
+
emb[:, 1:]
|
60 |
+
).mean(dim=1, keepdim=True)
|
61 |
+
hidden_states += (hidden_state,)
|
62 |
+
hidden_states = torch.cat(hidden_states, dim=1)
|
63 |
+
|
64 |
+
return torch.cat([x, hidden_states], dim=1)
|
pulid/encoders_flux.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
# FFN
|
8 |
+
def FeedForward(dim, mult=4):
|
9 |
+
inner_dim = int(dim * mult)
|
10 |
+
return nn.Sequential(
|
11 |
+
nn.LayerNorm(dim),
|
12 |
+
nn.Linear(dim, inner_dim, bias=False),
|
13 |
+
nn.GELU(),
|
14 |
+
nn.Linear(inner_dim, dim, bias=False),
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def reshape_tensor(x, heads):
|
19 |
+
bs, length, width = x.shape
|
20 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
21 |
+
x = x.view(bs, length, heads, -1)
|
22 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
23 |
+
x = x.transpose(1, 2)
|
24 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
25 |
+
x = x.reshape(bs, heads, length, -1)
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
class PerceiverAttentionCA(nn.Module):
|
30 |
+
def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
|
31 |
+
super().__init__()
|
32 |
+
self.scale = dim_head ** -0.5
|
33 |
+
self.dim_head = dim_head
|
34 |
+
self.heads = heads
|
35 |
+
inner_dim = dim_head * heads
|
36 |
+
|
37 |
+
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
|
38 |
+
self.norm2 = nn.LayerNorm(dim)
|
39 |
+
|
40 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
41 |
+
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
|
42 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
43 |
+
|
44 |
+
def forward(self, x, latents):
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
x (torch.Tensor): image features
|
48 |
+
shape (b, n1, D)
|
49 |
+
latent (torch.Tensor): latent features
|
50 |
+
shape (b, n2, D)
|
51 |
+
"""
|
52 |
+
x = self.norm1(x)
|
53 |
+
latents = self.norm2(latents)
|
54 |
+
|
55 |
+
b, seq_len, _ = latents.shape
|
56 |
+
|
57 |
+
q = self.to_q(latents)
|
58 |
+
k, v = self.to_kv(x).chunk(2, dim=-1)
|
59 |
+
|
60 |
+
q = reshape_tensor(q, self.heads)
|
61 |
+
k = reshape_tensor(k, self.heads)
|
62 |
+
v = reshape_tensor(v, self.heads)
|
63 |
+
|
64 |
+
# attention
|
65 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
66 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
67 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
68 |
+
out = weight @ v
|
69 |
+
|
70 |
+
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
|
71 |
+
|
72 |
+
return self.to_out(out)
|
73 |
+
|
74 |
+
|
75 |
+
class PerceiverAttention(nn.Module):
|
76 |
+
def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
|
77 |
+
super().__init__()
|
78 |
+
self.scale = dim_head ** -0.5
|
79 |
+
self.dim_head = dim_head
|
80 |
+
self.heads = heads
|
81 |
+
inner_dim = dim_head * heads
|
82 |
+
|
83 |
+
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
|
84 |
+
self.norm2 = nn.LayerNorm(dim)
|
85 |
+
|
86 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
87 |
+
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
|
88 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
89 |
+
|
90 |
+
def forward(self, x, latents):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
x (torch.Tensor): image features
|
94 |
+
shape (b, n1, D)
|
95 |
+
latent (torch.Tensor): latent features
|
96 |
+
shape (b, n2, D)
|
97 |
+
"""
|
98 |
+
x = self.norm1(x)
|
99 |
+
latents = self.norm2(latents)
|
100 |
+
|
101 |
+
b, seq_len, _ = latents.shape
|
102 |
+
|
103 |
+
q = self.to_q(latents)
|
104 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
105 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
106 |
+
|
107 |
+
q = reshape_tensor(q, self.heads)
|
108 |
+
k = reshape_tensor(k, self.heads)
|
109 |
+
v = reshape_tensor(v, self.heads)
|
110 |
+
|
111 |
+
# attention
|
112 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
113 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
114 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
115 |
+
out = weight @ v
|
116 |
+
|
117 |
+
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
|
118 |
+
|
119 |
+
return self.to_out(out)
|
120 |
+
|
121 |
+
|
122 |
+
class IDFormer(nn.Module):
|
123 |
+
"""
|
124 |
+
- perceiver resampler like arch (compared with previous MLP-like arch)
|
125 |
+
- we concat id embedding (generated by arcface) and query tokens as latents
|
126 |
+
- latents will attend each other and interact with vit features through cross-attention
|
127 |
+
- vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two
|
128 |
+
IDFormer layers
|
129 |
+
"""
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
dim=1024,
|
133 |
+
depth=10,
|
134 |
+
dim_head=64,
|
135 |
+
heads=16,
|
136 |
+
num_id_token=5,
|
137 |
+
num_queries=32,
|
138 |
+
output_dim=2048,
|
139 |
+
ff_mult=4,
|
140 |
+
):
|
141 |
+
super().__init__()
|
142 |
+
|
143 |
+
self.num_id_token = num_id_token
|
144 |
+
self.dim = dim
|
145 |
+
self.num_queries = num_queries
|
146 |
+
assert depth % 5 == 0
|
147 |
+
self.depth = depth // 5
|
148 |
+
scale = dim ** -0.5
|
149 |
+
|
150 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
|
151 |
+
self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
|
152 |
+
|
153 |
+
self.layers = nn.ModuleList([])
|
154 |
+
for _ in range(depth):
|
155 |
+
self.layers.append(
|
156 |
+
nn.ModuleList(
|
157 |
+
[
|
158 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
159 |
+
FeedForward(dim=dim, mult=ff_mult),
|
160 |
+
]
|
161 |
+
)
|
162 |
+
)
|
163 |
+
|
164 |
+
for i in range(5):
|
165 |
+
setattr(
|
166 |
+
self,
|
167 |
+
f'mapping_{i}',
|
168 |
+
nn.Sequential(
|
169 |
+
nn.Linear(1024, 1024),
|
170 |
+
nn.LayerNorm(1024),
|
171 |
+
nn.LeakyReLU(),
|
172 |
+
nn.Linear(1024, 1024),
|
173 |
+
nn.LayerNorm(1024),
|
174 |
+
nn.LeakyReLU(),
|
175 |
+
nn.Linear(1024, dim),
|
176 |
+
),
|
177 |
+
)
|
178 |
+
|
179 |
+
self.id_embedding_mapping = nn.Sequential(
|
180 |
+
nn.Linear(1280, 1024),
|
181 |
+
nn.LayerNorm(1024),
|
182 |
+
nn.LeakyReLU(),
|
183 |
+
nn.Linear(1024, 1024),
|
184 |
+
nn.LayerNorm(1024),
|
185 |
+
nn.LeakyReLU(),
|
186 |
+
nn.Linear(1024, dim * num_id_token),
|
187 |
+
)
|
188 |
+
|
189 |
+
def forward(self, x, y):
|
190 |
+
|
191 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
192 |
+
|
193 |
+
x = self.id_embedding_mapping(x)
|
194 |
+
x = x.reshape(-1, self.num_id_token, self.dim)
|
195 |
+
|
196 |
+
latents = torch.cat((latents, x), dim=1)
|
197 |
+
|
198 |
+
for i in range(5):
|
199 |
+
vit_feature = getattr(self, f'mapping_{i}')(y[i])
|
200 |
+
ctx_feature = torch.cat((x, vit_feature), dim=1)
|
201 |
+
for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]:
|
202 |
+
latents = attn(ctx_feature, latents) + latents
|
203 |
+
latents = ff(latents) + latents
|
204 |
+
|
205 |
+
latents = latents[:, :self.num_queries]
|
206 |
+
latents = latents @ self.proj_out
|
207 |
+
return latents
|