Spaces:
Running
on
Zero
Running
on
Zero
hanshu.yan
commited on
Commit
·
2ec72fb
1
Parent(s):
b83e3cf
add app.py
Browse files- LICENSE +21 -0
- README.md +80 -12
- app.py +182 -0
- gradio_app.py +188 -0
- output/.DS_Store +0 -0
- output/0/input.png +0 -0
- output/0/mesh.obj +0 -0
- requirements.txt +17 -0
- requirements2.txt +9 -0
- run.py +162 -0
- src/__pycache__/__init__.cpython-38.pyc +0 -0
- src/__pycache__/scheduler_perflow.cpython-310.pyc +0 -0
- src/__pycache__/scheduler_perflow.cpython-38.pyc +0 -0
- src/__pycache__/utils_perflow.cpython-38.pyc +0 -0
- src/laion_bytenas.py +257 -0
- src/pfode_solver.py +120 -0
- src/scheduler_perflow.py +343 -0
- src/utils_perflow.py +77 -0
- test.yaml +10 -0
- tsr/__pycache__/system.cpython-310.pyc +0 -0
- tsr/__pycache__/system.cpython-38.pyc +0 -0
- tsr/__pycache__/utils.cpython-310.pyc +0 -0
- tsr/__pycache__/utils.cpython-38.pyc +0 -0
- tsr/models/__pycache__/isosurface.cpython-310.pyc +0 -0
- tsr/models/__pycache__/isosurface.cpython-38.pyc +0 -0
- tsr/models/__pycache__/nerf_renderer.cpython-310.pyc +0 -0
- tsr/models/__pycache__/nerf_renderer.cpython-38.pyc +0 -0
- tsr/models/__pycache__/network_utils.cpython-310.pyc +0 -0
- tsr/models/__pycache__/network_utils.cpython-38.pyc +0 -0
- tsr/models/isosurface.py +52 -0
- tsr/models/nerf_renderer.py +180 -0
- tsr/models/network_utils.py +124 -0
- tsr/models/tokenizers/__pycache__/image.cpython-310.pyc +0 -0
- tsr/models/tokenizers/__pycache__/image.cpython-38.pyc +0 -0
- tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc +0 -0
- tsr/models/tokenizers/__pycache__/triplane.cpython-38.pyc +0 -0
- tsr/models/tokenizers/image.py +66 -0
- tsr/models/tokenizers/triplane.py +45 -0
- tsr/models/transformer/__pycache__/attention.cpython-310.pyc +0 -0
- tsr/models/transformer/__pycache__/attention.cpython-38.pyc +0 -0
- tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc +0 -0
- tsr/models/transformer/__pycache__/basic_transformer_block.cpython-38.pyc +0 -0
- tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc +0 -0
- tsr/models/transformer/__pycache__/transformer_1d.cpython-38.pyc +0 -0
- tsr/models/transformer/attention.py +653 -0
- tsr/models/transformer/basic_transformer_block.py +334 -0
- tsr/models/transformer/transformer_1d.py +219 -0
- tsr/system.py +203 -0
- tsr/utils.py +474 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Tripo AI & Stability AI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,80 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TripoSR <a href="https://huggingface.co/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a> <a href="https://huggingface.co/spaces/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a> <a href="https://arxiv.org/abs/2403.02151"><img src="https://img.shields.io/badge/Arxiv-2403.02151-B31B1B.svg"></a>
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
<img src="figures/teaser800.gif" alt="Teaser Video">
|
5 |
+
</div>
|
6 |
+
|
7 |
+
This is the official codebase for **TripoSR**, a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
|
8 |
+
<br><br>
|
9 |
+
Leveraging the principles of the [Large Reconstruction Model (LRM)](https://yiconghong.me/LRM/), TripoSR brings to the table key advancements that significantly boost both the speed and quality of 3D reconstruction. Our model is distinguished by its ability to rapidly process inputs, generating high-quality 3D models in less than 0.5 seconds on an NVIDIA A100 GPU. TripoSR has exhibited superior performance in both qualitative and quantitative evaluations, outperforming other open-source alternatives across multiple public datasets. The figures below illustrate visual comparisons and metrics showcasing TripoSR's performance relative to other leading models. Details about the model architecture, training process, and comparisons can be found in this [technical report](https://arxiv.org/abs/2403.02151).
|
10 |
+
|
11 |
+
<!--
|
12 |
+
<div align="center">
|
13 |
+
<img src="figures/comparison800.gif" alt="Teaser Video">
|
14 |
+
</div>
|
15 |
+
-->
|
16 |
+
<p align="center">
|
17 |
+
<img width="800" src="figures/visual_comparisons.jpg"/>
|
18 |
+
</p>
|
19 |
+
|
20 |
+
<p align="center">
|
21 |
+
<img width="450" src="figures/scatter-comparison.png"/>
|
22 |
+
</p>
|
23 |
+
|
24 |
+
|
25 |
+
The model is released under the MIT license, which includes the source code, pretrained models, and an interactive online demo. Our goal is to empower researchers, developers, and creatives to push the boundaries of what's possible in 3D generative AI and 3D content creation.
|
26 |
+
|
27 |
+
## Getting Started
|
28 |
+
### Installation
|
29 |
+
- Python >= 3.8
|
30 |
+
- Install CUDA if available
|
31 |
+
- Install PyTorch according to your platform: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) **[Please make sure that the locally-installed CUDA major version matches the PyTorch-shipped CUDA major version. For example if you have CUDA 11.x installed, make sure to install PyTorch compiled with CUDA 11.x.]**
|
32 |
+
- Update setuptools by `pip install --upgrade setuptools`
|
33 |
+
- Install other dependencies by `pip install -r requirements.txt`
|
34 |
+
|
35 |
+
### Manual Inference
|
36 |
+
```sh
|
37 |
+
python run.py examples/chair.png --output-dir output/
|
38 |
+
```
|
39 |
+
This will save the reconstructed 3D model to `output/`. You can also specify more than one image path separated by spaces. The default options takes about **6GB VRAM** for a single image input.
|
40 |
+
|
41 |
+
For detailed usage of this script, use `python run.py --help`.
|
42 |
+
|
43 |
+
### Local Gradio App
|
44 |
+
Install Gradio:
|
45 |
+
```sh
|
46 |
+
pip install gradio
|
47 |
+
```
|
48 |
+
Start the Gradio App:
|
49 |
+
```sh
|
50 |
+
python gradio_app.py
|
51 |
+
```
|
52 |
+
|
53 |
+
## Troubleshooting
|
54 |
+
> AttributeError: module 'torchmcubes_module' has no attribute 'mcubes_cuda'
|
55 |
+
|
56 |
+
or
|
57 |
+
|
58 |
+
> torchmcubes was not compiled with CUDA support, use CPU version instead.
|
59 |
+
|
60 |
+
This is because `torchmcubes` is compiled without CUDA support. Please make sure that
|
61 |
+
|
62 |
+
- The locally-installed CUDA major version matches the PyTorch-shipped CUDA major version. For example if you have CUDA 11.x installed, make sure to install PyTorch compiled with CUDA 11.x.
|
63 |
+
- `setuptools>=49.6.0`. If not, upgrade by `pip install --upgrade setuptools`.
|
64 |
+
|
65 |
+
Then re-install `torchmcubes` by:
|
66 |
+
|
67 |
+
```sh
|
68 |
+
pip uninstall torchmcubes
|
69 |
+
pip install git+https://github.com/tatsy/torchmcubes.git
|
70 |
+
```
|
71 |
+
|
72 |
+
## Citation
|
73 |
+
```BibTeX
|
74 |
+
@article{TripoSR2024,
|
75 |
+
title={TripoSR: Fast 3D Object Reconstruction from a Single Image},
|
76 |
+
author={Tochilkin, Dmitry and Pankratz, David and Liu, Zexiang and Huang, Zixuan and and Letts, Adam and Li, Yangguang and Liang, Ding and Laforte, Christian and Jampani, Varun and Cao, Yan-Pei},
|
77 |
+
journal={arXiv preprint arXiv:2403.02151},
|
78 |
+
year={2024}
|
79 |
+
}
|
80 |
+
```
|
app.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, logging, time, argparse, random, tempfile, rembg
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from functools import partial
|
7 |
+
from tsr.system import TSR
|
8 |
+
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
9 |
+
|
10 |
+
from src.scheduler_perflow import PeRFlowScheduler
|
11 |
+
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
12 |
+
|
13 |
+
def merge_delta_weights_into_unet(pipe, delta_weights, org_alpha = 1.0):
|
14 |
+
unet_weights = pipe.unet.state_dict()
|
15 |
+
for key in delta_weights.keys():
|
16 |
+
dtype = unet_weights[key].dtype
|
17 |
+
try:
|
18 |
+
unet_weights[key] = org_alpha * unet_weights[key].to(dtype=delta_weights[key].dtype) + delta_weights[key].to(device=unet_weights[key].device)
|
19 |
+
except:
|
20 |
+
unet_weights[key] = unet_weights[key].to(dtype=delta_weights[key].dtype)
|
21 |
+
unet_weights[key] = unet_weights[key].to(dtype)
|
22 |
+
pipe.unet.load_state_dict(unet_weights, strict=True)
|
23 |
+
return pipe
|
24 |
+
|
25 |
+
def setup_seed(seed):
|
26 |
+
random.seed(seed)
|
27 |
+
np.random.seed(seed)
|
28 |
+
torch.manual_seed(seed)
|
29 |
+
torch.cuda.manual_seed_all(seed)
|
30 |
+
torch.backends.cudnn.deterministic = True
|
31 |
+
|
32 |
+
if torch.cuda.is_available():
|
33 |
+
device = "cuda:0"
|
34 |
+
else:
|
35 |
+
device = "cpu"
|
36 |
+
|
37 |
+
### TripoSR
|
38 |
+
model = TSR.from_pretrained(
|
39 |
+
"stabilityai/TripoSR",
|
40 |
+
config_name="config.yaml",
|
41 |
+
weight_name="model.ckpt",
|
42 |
+
)
|
43 |
+
# adjust the chunk size to balance between speed and memory usage
|
44 |
+
model.renderer.set_chunk_size(8192)
|
45 |
+
model.to(device)
|
46 |
+
|
47 |
+
|
48 |
+
### PeRFlow-T2I
|
49 |
+
# pipe_t2i = StableDiffusionPipeline.from_pretrained("Lykon/dreamshaper-8", torch_dtype=torch.float16, safety_checker=None)
|
50 |
+
pipe_t2i = StableDiffusionPipeline.from_pretrained("stablediffusionapi/disney-pixar-cartoon", torch_dtype=torch.float16, safety_checker=None)
|
51 |
+
delta_weights = UNet2DConditionModel.from_pretrained("hansyan/piecewise-rectified-flow-delta-weights", torch_dtype=torch.float16, variant="v0-1",).state_dict()
|
52 |
+
pipe_t2i = merge_delta_weights_into_unet(pipe_t2i, delta_weights)
|
53 |
+
pipe_t2i.scheduler = PeRFlowScheduler.from_config(pipe_t2i.scheduler.config, prediction_type="epsilon", num_time_windows=4)
|
54 |
+
pipe_t2i.to('cuda:0', torch.float16)
|
55 |
+
|
56 |
+
|
57 |
+
### gradio
|
58 |
+
rembg_session = rembg.new_session()
|
59 |
+
|
60 |
+
def generate(text, seed):
|
61 |
+
def fill_background(image):
|
62 |
+
image = np.array(image).astype(np.float32) / 255.0
|
63 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
64 |
+
image = Image.fromarray((image * 255.0).astype(np.uint8))
|
65 |
+
return image
|
66 |
+
|
67 |
+
setup_seed(int(seed))
|
68 |
+
# text = text
|
69 |
+
samples = pipe_t2i(
|
70 |
+
prompt = [text],
|
71 |
+
negative_prompt = ["distorted, blur, low-quality, haze, out of focus"],
|
72 |
+
height = 512,
|
73 |
+
width = 512,
|
74 |
+
# num_inference_steps = 4,
|
75 |
+
# guidance_scale = 4.5,
|
76 |
+
num_inference_steps = 6,
|
77 |
+
guidance_scale = 7,
|
78 |
+
output_type = 'pt',
|
79 |
+
).images
|
80 |
+
samples = torch.nn.functional.interpolate(samples, size=768, mode='bilinear')
|
81 |
+
samples = samples.squeeze(0).permute(1, 2, 0).cpu().numpy()*255.
|
82 |
+
samples = samples.astype(np.uint8)
|
83 |
+
samples = Image.fromarray(samples[:, :, :3])
|
84 |
+
|
85 |
+
image = remove_background(samples, rembg_session)
|
86 |
+
image = resize_foreground(image, 0.85)
|
87 |
+
image = fill_background(image)
|
88 |
+
return image
|
89 |
+
|
90 |
+
def render(image, mc_resolution=256, formats=["obj"]):
|
91 |
+
scene_codes = model(image, device=device)
|
92 |
+
mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
|
93 |
+
mesh = to_gradio_3d_orientation(mesh)
|
94 |
+
rv = []
|
95 |
+
for format in formats:
|
96 |
+
mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
97 |
+
mesh.export(mesh_path.name)
|
98 |
+
rv.append(mesh_path.name)
|
99 |
+
return rv[0]
|
100 |
+
|
101 |
+
# warm up
|
102 |
+
_ = generate("a bird", 42)
|
103 |
+
|
104 |
+
# layout
|
105 |
+
css = """
|
106 |
+
h1 {
|
107 |
+
text-align: center;
|
108 |
+
display:block;
|
109 |
+
}
|
110 |
+
h2 {
|
111 |
+
text-align: center;
|
112 |
+
display:block;
|
113 |
+
}
|
114 |
+
h3 {
|
115 |
+
text-align: center;
|
116 |
+
display:block;
|
117 |
+
}
|
118 |
+
"""
|
119 |
+
with gr.Blocks(title="TripoSR", css=css) as interface:
|
120 |
+
gr.Markdown(
|
121 |
+
"""
|
122 |
+
# Instant Text-to-3D Mesh Demo
|
123 |
+
|
124 |
+
### [PeRFlow](https://github.com/magic-research/piecewise-rectified-flow)-T2I + [TripoSR](https://github.com/VAST-AI-Research/TripoSR)
|
125 |
+
|
126 |
+
Two-stage synthesis: 1) generating images by PeRFlow-T2I with 6-step inference; 2) rendering 3D assests.
|
127 |
+
"""
|
128 |
+
)
|
129 |
+
|
130 |
+
with gr.Column():
|
131 |
+
with gr.Row():
|
132 |
+
output_image = gr.Image(label='Generated Image', height=384, width=384)
|
133 |
+
|
134 |
+
output_model_obj = gr.Model3D(
|
135 |
+
label="Output 3D Model (OBJ Format)",
|
136 |
+
interactive=False,
|
137 |
+
height=384, width=384,
|
138 |
+
)
|
139 |
+
|
140 |
+
with gr.Row():
|
141 |
+
textbox = gr.Textbox(label="Input Prompt", value="a colorful bird")
|
142 |
+
seed = gr.Textbox(label="Random Seed", value=42)
|
143 |
+
|
144 |
+
# activate
|
145 |
+
textbox.submit(
|
146 |
+
fn=generate,
|
147 |
+
inputs=[textbox, seed],
|
148 |
+
outputs=[output_image],
|
149 |
+
).success(
|
150 |
+
fn=render,
|
151 |
+
inputs=[output_image],
|
152 |
+
outputs=[output_model_obj],
|
153 |
+
)
|
154 |
+
|
155 |
+
seed.submit(
|
156 |
+
fn=generate,
|
157 |
+
inputs=[textbox, seed],
|
158 |
+
outputs=[output_image],
|
159 |
+
).success(
|
160 |
+
fn=render,
|
161 |
+
inputs=[output_image],
|
162 |
+
outputs=[output_model_obj],
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == '__main__':
|
168 |
+
parser = argparse.ArgumentParser()
|
169 |
+
parser.add_argument('--username', type=str, default=None, help='Username for authentication')
|
170 |
+
parser.add_argument('--password', type=str, default=None, help='Password for authentication')
|
171 |
+
parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
|
172 |
+
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
173 |
+
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
174 |
+
parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
|
175 |
+
args = parser.parse_args()
|
176 |
+
interface.queue(max_size=args.queuesize)
|
177 |
+
interface.launch(
|
178 |
+
auth=(args.username, args.password) if (args.username and args.password) else None,
|
179 |
+
share=args.share,
|
180 |
+
server_name="0.0.0.0" if args.listen else None,
|
181 |
+
server_port=args.port
|
182 |
+
)
|
gradio_app.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
import time
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import rembg
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
from tsr.system import TSR
|
14 |
+
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
|
19 |
+
if torch.cuda.is_available():
|
20 |
+
device = "cuda:0"
|
21 |
+
else:
|
22 |
+
device = "cpu"
|
23 |
+
|
24 |
+
model = TSR.from_pretrained(
|
25 |
+
"stabilityai/TripoSR",
|
26 |
+
config_name="config.yaml",
|
27 |
+
weight_name="model.ckpt",
|
28 |
+
)
|
29 |
+
|
30 |
+
# adjust the chunk size to balance between speed and memory usage
|
31 |
+
model.renderer.set_chunk_size(8192)
|
32 |
+
model.to(device)
|
33 |
+
|
34 |
+
rembg_session = rembg.new_session()
|
35 |
+
|
36 |
+
|
37 |
+
def check_input_image(input_image):
|
38 |
+
if input_image is None:
|
39 |
+
raise gr.Error("No image uploaded!")
|
40 |
+
|
41 |
+
|
42 |
+
def preprocess(input_image, do_remove_background, foreground_ratio):
|
43 |
+
def fill_background(image):
|
44 |
+
image = np.array(image).astype(np.float32) / 255.0
|
45 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
46 |
+
image = Image.fromarray((image * 255.0).astype(np.uint8))
|
47 |
+
return image
|
48 |
+
|
49 |
+
if do_remove_background:
|
50 |
+
image = input_image.convert("RGB")
|
51 |
+
image = remove_background(image, rembg_session)
|
52 |
+
image = resize_foreground(image, foreground_ratio)
|
53 |
+
image = fill_background(image)
|
54 |
+
else:
|
55 |
+
image = input_image
|
56 |
+
if image.mode == "RGBA":
|
57 |
+
image = fill_background(image)
|
58 |
+
return image
|
59 |
+
|
60 |
+
|
61 |
+
def generate(image, mc_resolution, formats=["obj", "glb"]):
|
62 |
+
print(image.shape, image.min(), image.max())
|
63 |
+
scene_codes = model(image, device=device)
|
64 |
+
mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
|
65 |
+
mesh = to_gradio_3d_orientation(mesh)
|
66 |
+
rv = []
|
67 |
+
for format in formats:
|
68 |
+
mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
69 |
+
mesh.export(mesh_path.name)
|
70 |
+
rv.append(mesh_path.name)
|
71 |
+
return rv
|
72 |
+
|
73 |
+
|
74 |
+
def run_example(image_pil):
|
75 |
+
preprocessed = preprocess(image_pil, False, 0.9)
|
76 |
+
mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
|
77 |
+
return preprocessed, mesh_name_obj, mesh_name_glb
|
78 |
+
|
79 |
+
|
80 |
+
with gr.Blocks(title="TripoSR") as interface:
|
81 |
+
gr.Markdown(
|
82 |
+
"""
|
83 |
+
# TripoSR Demo
|
84 |
+
[TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
|
85 |
+
|
86 |
+
**Tips:**
|
87 |
+
1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
|
88 |
+
2. You can disable "Remove Background" for the provided examples since they have been already preprocessed.
|
89 |
+
3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
|
90 |
+
"""
|
91 |
+
)
|
92 |
+
with gr.Row(variant="panel"):
|
93 |
+
with gr.Column():
|
94 |
+
with gr.Row():
|
95 |
+
input_image = gr.Image(
|
96 |
+
label="Input Image",
|
97 |
+
image_mode="RGBA",
|
98 |
+
sources="upload",
|
99 |
+
type="pil",
|
100 |
+
elem_id="content_image",
|
101 |
+
)
|
102 |
+
processed_image = gr.Image(label="Processed Image", interactive=False)
|
103 |
+
with gr.Row():
|
104 |
+
with gr.Group():
|
105 |
+
do_remove_background = gr.Checkbox(
|
106 |
+
label="Remove Background", value=True
|
107 |
+
)
|
108 |
+
foreground_ratio = gr.Slider(
|
109 |
+
label="Foreground Ratio",
|
110 |
+
minimum=0.5,
|
111 |
+
maximum=1.0,
|
112 |
+
value=0.85,
|
113 |
+
step=0.05,
|
114 |
+
)
|
115 |
+
mc_resolution = gr.Slider(
|
116 |
+
label="Marching Cubes Resolution",
|
117 |
+
minimum=32,
|
118 |
+
maximum=320,
|
119 |
+
value=256,
|
120 |
+
step=32
|
121 |
+
)
|
122 |
+
with gr.Row():
|
123 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
124 |
+
with gr.Column():
|
125 |
+
with gr.Tab("OBJ"):
|
126 |
+
output_model_obj = gr.Model3D(
|
127 |
+
label="Output Model (OBJ Format)",
|
128 |
+
interactive=False,
|
129 |
+
)
|
130 |
+
gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
|
131 |
+
with gr.Tab("GLB"):
|
132 |
+
output_model_glb = gr.Model3D(
|
133 |
+
label="Output Model (GLB Format)",
|
134 |
+
interactive=False,
|
135 |
+
)
|
136 |
+
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
137 |
+
with gr.Row(variant="panel"):
|
138 |
+
gr.Examples(
|
139 |
+
examples=[
|
140 |
+
"examples/hamburger.png",
|
141 |
+
"examples/poly_fox.png",
|
142 |
+
"examples/robot.png",
|
143 |
+
"examples/teapot.png",
|
144 |
+
"examples/tiger_girl.png",
|
145 |
+
"examples/horse.png",
|
146 |
+
"examples/flamingo.png",
|
147 |
+
"examples/unicorn.png",
|
148 |
+
"examples/chair.png",
|
149 |
+
"examples/iso_house.png",
|
150 |
+
"examples/marble.png",
|
151 |
+
"examples/police_woman.png",
|
152 |
+
"examples/captured_p.png",
|
153 |
+
],
|
154 |
+
inputs=[input_image],
|
155 |
+
outputs=[processed_image, output_model_obj, output_model_glb],
|
156 |
+
cache_examples=False,
|
157 |
+
fn=partial(run_example),
|
158 |
+
label="Examples",
|
159 |
+
examples_per_page=20,
|
160 |
+
)
|
161 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
162 |
+
fn=preprocess,
|
163 |
+
inputs=[input_image, do_remove_background, foreground_ratio],
|
164 |
+
outputs=[processed_image],
|
165 |
+
).success(
|
166 |
+
fn=generate,
|
167 |
+
inputs=[processed_image, mc_resolution],
|
168 |
+
outputs=[output_model_obj, output_model_glb],
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
if __name__ == '__main__':
|
174 |
+
parser = argparse.ArgumentParser()
|
175 |
+
parser.add_argument('--username', type=str, default=None, help='Username for authentication')
|
176 |
+
parser.add_argument('--password', type=str, default=None, help='Password for authentication')
|
177 |
+
parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
|
178 |
+
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
179 |
+
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
180 |
+
parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
|
181 |
+
args = parser.parse_args()
|
182 |
+
interface.queue(max_size=args.queuesize)
|
183 |
+
interface.launch(
|
184 |
+
auth=(args.username, args.password) if (args.username and args.password) else None,
|
185 |
+
share=args.share,
|
186 |
+
server_name="0.0.0.0" if args.listen else None,
|
187 |
+
server_port=args.port
|
188 |
+
)
|
output/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
output/0/input.png
ADDED
output/0/mesh.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.24.0
|
2 |
+
einops==0.7.0
|
3 |
+
gradio==4.20.1
|
4 |
+
huggingface_hub==0.21.4
|
5 |
+
imageio==2.27.0
|
6 |
+
numpy==1.24.3
|
7 |
+
omegaconf==2.3.0
|
8 |
+
packaging==23.2
|
9 |
+
Pillow==10.1.0
|
10 |
+
rembg==2.0.55
|
11 |
+
safetensors==0.3.2
|
12 |
+
torch==2.0.0
|
13 |
+
torchvision==0.15.1
|
14 |
+
tqdm==4.64.1
|
15 |
+
transformers==4.27.0
|
16 |
+
trimesh==4.0.5
|
17 |
+
git+https://github.com/tatsy/torchmcubes.git
|
requirements2.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omegaconf==2.3.0
|
2 |
+
Pillow==10.1.0
|
3 |
+
einops==0.7.0
|
4 |
+
git+https://github.com/tatsy/torchmcubes.git
|
5 |
+
transformers==4.35.0
|
6 |
+
trimesh==4.0.5
|
7 |
+
rembg
|
8 |
+
huggingface-hub
|
9 |
+
imageio[ffmpeg]
|
run.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import rembg
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from tsr.system import TSR
|
12 |
+
from tsr.utils import remove_background, resize_foreground, save_video
|
13 |
+
|
14 |
+
|
15 |
+
class Timer:
|
16 |
+
def __init__(self):
|
17 |
+
self.items = {}
|
18 |
+
self.time_scale = 1000.0 # ms
|
19 |
+
self.time_unit = "ms"
|
20 |
+
|
21 |
+
def start(self, name: str) -> None:
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
torch.cuda.synchronize()
|
24 |
+
self.items[name] = time.time()
|
25 |
+
logging.info(f"{name} ...")
|
26 |
+
|
27 |
+
def end(self, name: str) -> float:
|
28 |
+
if name not in self.items:
|
29 |
+
return
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
torch.cuda.synchronize()
|
32 |
+
start_time = self.items.pop(name)
|
33 |
+
delta = time.time() - start_time
|
34 |
+
t = delta * self.time_scale
|
35 |
+
logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
|
36 |
+
|
37 |
+
|
38 |
+
timer = Timer()
|
39 |
+
|
40 |
+
|
41 |
+
logging.basicConfig(
|
42 |
+
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
43 |
+
)
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
parser.add_argument("image", type=str, nargs="+", help="Path to input image(s).")
|
46 |
+
parser.add_argument(
|
47 |
+
"--device",
|
48 |
+
default="cuda:0",
|
49 |
+
type=str,
|
50 |
+
help="Device to use. If no CUDA-compatible device is found, will fallback to 'cpu'. Default: 'cuda:0'",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--pretrained-model-name-or-path",
|
54 |
+
default="stabilityai/TripoSR",
|
55 |
+
type=str,
|
56 |
+
help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/TripoSR'",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--chunk-size",
|
60 |
+
default=8192,
|
61 |
+
type=int,
|
62 |
+
help="Evaluation chunk size for surface extraction and rendering. Smaller chunk size reduces VRAM usage but increases computation time. 0 for no chunking. Default: 8192",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--mc-resolution",
|
66 |
+
default=256,
|
67 |
+
type=int,
|
68 |
+
help="Marching cubes grid resolution. Default: 256"
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--no-remove-bg",
|
72 |
+
action="store_true",
|
73 |
+
help="If specified, the background will NOT be automatically removed from the input image, and the input image should be an RGB image with gray background and properly-sized foreground. Default: false",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--foreground-ratio",
|
77 |
+
default=0.85,
|
78 |
+
type=float,
|
79 |
+
help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--output-dir",
|
83 |
+
default="output/",
|
84 |
+
type=str,
|
85 |
+
help="Output directory to save the results. Default: 'output/'",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--model-save-format",
|
89 |
+
default="obj",
|
90 |
+
type=str,
|
91 |
+
choices=["obj", "glb"],
|
92 |
+
help="Format to save the extracted mesh. Default: 'obj'",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--render",
|
96 |
+
action="store_true",
|
97 |
+
help="If specified, save a NeRF-rendered video. Default: false",
|
98 |
+
)
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
output_dir = args.output_dir
|
102 |
+
os.makedirs(output_dir, exist_ok=True)
|
103 |
+
|
104 |
+
device = args.device
|
105 |
+
if not torch.cuda.is_available():
|
106 |
+
device = "cpu"
|
107 |
+
|
108 |
+
timer.start("Initializing model")
|
109 |
+
model = TSR.from_pretrained(
|
110 |
+
args.pretrained_model_name_or_path,
|
111 |
+
config_name="config.yaml",
|
112 |
+
weight_name="model.ckpt",
|
113 |
+
)
|
114 |
+
model.renderer.set_chunk_size(args.chunk_size)
|
115 |
+
model.to(device)
|
116 |
+
timer.end("Initializing model")
|
117 |
+
|
118 |
+
timer.start("Processing images")
|
119 |
+
images = []
|
120 |
+
|
121 |
+
if args.no_remove_bg:
|
122 |
+
rembg_session = None
|
123 |
+
else:
|
124 |
+
rembg_session = rembg.new_session()
|
125 |
+
|
126 |
+
for i, image_path in enumerate(args.image):
|
127 |
+
if args.no_remove_bg:
|
128 |
+
image = np.array(Image.open(image_path).convert("RGB"))
|
129 |
+
else:
|
130 |
+
image = remove_background(Image.open(image_path), rembg_session)
|
131 |
+
image = resize_foreground(image, args.foreground_ratio)
|
132 |
+
image = np.array(image).astype(np.float32) / 255.0
|
133 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
134 |
+
image = Image.fromarray((image * 255.0).astype(np.uint8))
|
135 |
+
if not os.path.exists(os.path.join(output_dir, str(i))):
|
136 |
+
os.makedirs(os.path.join(output_dir, str(i)))
|
137 |
+
image.save(os.path.join(output_dir, str(i), f"input.png"))
|
138 |
+
images.append(image)
|
139 |
+
timer.end("Processing images")
|
140 |
+
|
141 |
+
for i, image in enumerate(images):
|
142 |
+
logging.info(f"Running image {i + 1}/{len(images)} ...")
|
143 |
+
|
144 |
+
timer.start("Running model")
|
145 |
+
with torch.no_grad():
|
146 |
+
scene_codes = model([image], device=device)
|
147 |
+
timer.end("Running model")
|
148 |
+
|
149 |
+
if args.render:
|
150 |
+
timer.start("Rendering")
|
151 |
+
render_images = model.render(scene_codes, n_views=30, return_type="pil")
|
152 |
+
for ri, render_image in enumerate(render_images[0]):
|
153 |
+
render_image.save(os.path.join(output_dir, str(i), f"render_{ri:03d}.png"))
|
154 |
+
save_video(
|
155 |
+
render_images[0], os.path.join(output_dir, str(i), f"render.mp4"), fps=30
|
156 |
+
)
|
157 |
+
timer.end("Rendering")
|
158 |
+
|
159 |
+
timer.start("Exporting mesh")
|
160 |
+
meshes = model.extract_mesh(scene_codes, resolution=args.mc_resolution)
|
161 |
+
meshes[0].export(os.path.join(output_dir, str(i), f"mesh.{args.model_save_format}"))
|
162 |
+
timer.end("Exporting mesh")
|
src/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (147 Bytes). View file
|
|
src/__pycache__/scheduler_perflow.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
src/__pycache__/scheduler_perflow.cpython-38.pyc
ADDED
Binary file (12.1 kB). View file
|
|
src/__pycache__/utils_perflow.cpython-38.pyc
ADDED
Binary file (2.64 kB). View file
|
|
src/laion_bytenas.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageStat
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset, get_worker_info
|
9 |
+
from torchvision import transforms as T
|
10 |
+
|
11 |
+
|
12 |
+
### >>>>>>>> >>>>>>>> text related >>>>>>>> >>>>>>>> ###
|
13 |
+
|
14 |
+
class TokenizerWrapper():
|
15 |
+
def __init__(self, tokenizer, is_train, proportion_empty_prompts, use_generic_prompts=False):
|
16 |
+
self.tokenizer = tokenizer
|
17 |
+
self.is_train = is_train
|
18 |
+
self.proportion_empty_prompts = proportion_empty_prompts
|
19 |
+
self.use_generic_prompts = use_generic_prompts
|
20 |
+
|
21 |
+
def __call__(self, prompts):
|
22 |
+
if isinstance(prompts, str):
|
23 |
+
prompts = [prompts]
|
24 |
+
captions = []
|
25 |
+
for caption in prompts:
|
26 |
+
if random.random() < self.proportion_empty_prompts:
|
27 |
+
captions.append("")
|
28 |
+
else:
|
29 |
+
if self.use_generic_prompts:
|
30 |
+
captions.append("best quality, high quality")
|
31 |
+
elif isinstance(caption, str):
|
32 |
+
captions.append(caption)
|
33 |
+
elif isinstance(caption, (list, np.ndarray)):
|
34 |
+
# take a random caption if there are multiple
|
35 |
+
captions.append(random.choice(caption) if self.is_train else caption[0])
|
36 |
+
else:
|
37 |
+
raise ValueError(
|
38 |
+
f"Caption column should contain either strings or lists of strings."
|
39 |
+
)
|
40 |
+
inputs = self.tokenizer(
|
41 |
+
captions, max_length=self.tokenizer.model_max_length, padding="max_length",
|
42 |
+
truncation=True, return_tensors="pt"
|
43 |
+
)
|
44 |
+
return inputs.input_ids
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
### >>>>>>>> >>>>>>>> image related >>>>>>>> >>>>>>>> ###
|
49 |
+
|
50 |
+
MONOCHROMATIC_MAX_VARIANCE = 0.3
|
51 |
+
|
52 |
+
def is_monochromatic_image(pil_img):
|
53 |
+
v = ImageStat.Stat(pil_img.convert('RGB')).var
|
54 |
+
return sum(v)<MONOCHROMATIC_MAX_VARIANCE
|
55 |
+
|
56 |
+
def isnumeric(text):
|
57 |
+
return (''.join(filter(str.isalnum, text))).isnumeric()
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
class TextPromptDataset(IterableDataset):
|
62 |
+
'''
|
63 |
+
The dataset for (text embedding, noise, generated latent) triplets.
|
64 |
+
'''
|
65 |
+
def __init__(self,
|
66 |
+
data_root,
|
67 |
+
tokenizer = None,
|
68 |
+
transform = None,
|
69 |
+
rank = 0,
|
70 |
+
world_size = 1,
|
71 |
+
shuffle = True,
|
72 |
+
):
|
73 |
+
self.tokenizer = tokenizer
|
74 |
+
self.transform = transform
|
75 |
+
|
76 |
+
self.img_root = os.path.join(data_root, 'JPEGImages')
|
77 |
+
self.data_list = []
|
78 |
+
|
79 |
+
print("#### Loading filename list...")
|
80 |
+
json_root = os.path.join(data_root, 'list')
|
81 |
+
json_list = [p for p in os.listdir(json_root) if p.startswith("shard") and p.endswith('.json')]
|
82 |
+
|
83 |
+
# duplicate several shards to make sure each process has the same number of shards
|
84 |
+
assert len(json_list) > world_size
|
85 |
+
duplicate = world_size - len(json_list)%world_size if len(json_list)%world_size>0 else 0
|
86 |
+
json_list = json_list + json_list[:duplicate]
|
87 |
+
json_list = json_list[rank::world_size]
|
88 |
+
|
89 |
+
for json_file in tqdm(json_list):
|
90 |
+
shard_name = os.path.basename(json_file).split('.')[0]
|
91 |
+
with open(os.path.join(json_root, json_file)) as f:
|
92 |
+
key_text_pairs = json.load(f)
|
93 |
+
|
94 |
+
for pair in key_text_pairs:
|
95 |
+
self.data_list.append( [shard_name] + pair )
|
96 |
+
|
97 |
+
print("#### All filename loaded...")
|
98 |
+
|
99 |
+
self.shuffle = shuffle
|
100 |
+
|
101 |
+
def __len__(self):
|
102 |
+
return len(self.data_list)
|
103 |
+
|
104 |
+
|
105 |
+
def __iter__(self):
|
106 |
+
worker_info = get_worker_info()
|
107 |
+
|
108 |
+
if worker_info is None: # single-process data loading, return the full iterator
|
109 |
+
data_list = self.data_list
|
110 |
+
else:
|
111 |
+
len_data = len(self.data_list) - len(self.data_list) % worker_info.num_workers
|
112 |
+
data_list = self.data_list[:len_data][worker_info.id :: worker_info.num_workers]
|
113 |
+
# print(worker_info.num_workers, worker_info.id, len(data_list)/len(self.data_list))
|
114 |
+
|
115 |
+
if self.shuffle:
|
116 |
+
random.shuffle(data_list)
|
117 |
+
|
118 |
+
while True:
|
119 |
+
for idx in range(len(data_list)):
|
120 |
+
# try:
|
121 |
+
shard_name = data_list[idx][0]
|
122 |
+
data = {}
|
123 |
+
|
124 |
+
img_file = data_list[idx][1]
|
125 |
+
img = Image.open(os.path.join(self.img_root, shard_name, img_file+'.jpg')).convert("RGB")
|
126 |
+
|
127 |
+
if is_monochromatic_image(img):
|
128 |
+
continue
|
129 |
+
|
130 |
+
if self.transform is not None:
|
131 |
+
img = self.transform(img)
|
132 |
+
|
133 |
+
data['pixel_values'] = img
|
134 |
+
|
135 |
+
text = data_list[idx][2]
|
136 |
+
if self.tokenizer is not None:
|
137 |
+
if isinstance(self.tokenizer, list):
|
138 |
+
assert len(self.tokenizer)==2
|
139 |
+
data['input_ids'] = self.tokenizer[0](text)[0]
|
140 |
+
data['input_ids_2'] = self.tokenizer[1](text)[0]
|
141 |
+
else:
|
142 |
+
data['input_ids'] = self.tokenizer(text)[0]
|
143 |
+
else:
|
144 |
+
data['input_ids'] = text
|
145 |
+
|
146 |
+
yield data
|
147 |
+
|
148 |
+
# except Exception as e:
|
149 |
+
# raise(e)
|
150 |
+
|
151 |
+
def collate_fn(self, examples):
|
152 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
153 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
154 |
+
|
155 |
+
if self.tokenizer is not None:
|
156 |
+
if isinstance(self.tokenizer, list):
|
157 |
+
assert len(self.tokenizer)==2
|
158 |
+
input_ids = torch.stack([example["input_ids"] for example in examples])
|
159 |
+
input_ids_2 = torch.stack([example["input_ids_2"] for example in examples])
|
160 |
+
return {"pixel_values": pixel_values, "input_ids": input_ids, "input_ids_2": input_ids_2,}
|
161 |
+
else:
|
162 |
+
input_ids = torch.stack([example["input_ids"] for example in examples])
|
163 |
+
return {"pixel_values": pixel_values, "input_ids": input_ids,}
|
164 |
+
else:
|
165 |
+
input_ids = [example["input_ids"] for example in examples]
|
166 |
+
return {"pixel_values": pixel_values, "input_ids": input_ids,}
|
167 |
+
|
168 |
+
|
169 |
+
def make_train_dataset(
|
170 |
+
train_data_path,
|
171 |
+
size = 512,
|
172 |
+
tokenizer=None,
|
173 |
+
cfg_drop_ratio=0,
|
174 |
+
rank=0,
|
175 |
+
world_size=1,
|
176 |
+
shuffle=True,
|
177 |
+
):
|
178 |
+
|
179 |
+
_image_transform = T.Compose([
|
180 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
181 |
+
T.Resize(size),
|
182 |
+
T.CenterCrop((size,size)),
|
183 |
+
T.ToTensor(),
|
184 |
+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
185 |
+
])
|
186 |
+
|
187 |
+
if tokenizer is not None:
|
188 |
+
if isinstance(tokenizer, list):
|
189 |
+
assert len(tokenizer)==2
|
190 |
+
tokenizer_1 = TokenizerWrapper(
|
191 |
+
tokenizer[0],
|
192 |
+
is_train=True,
|
193 |
+
proportion_empty_prompts=cfg_drop_ratio,
|
194 |
+
use_generic_prompts=False,
|
195 |
+
)
|
196 |
+
tokenizer_2 = TokenizerWrapper(
|
197 |
+
tokenizer[1],
|
198 |
+
is_train=True,
|
199 |
+
proportion_empty_prompts=cfg_drop_ratio,
|
200 |
+
use_generic_prompts=False,
|
201 |
+
)
|
202 |
+
tokenizer = [tokenizer_1, tokenizer_2]
|
203 |
+
|
204 |
+
else:
|
205 |
+
tokenizer = TokenizerWrapper(
|
206 |
+
tokenizer,
|
207 |
+
is_train=True,
|
208 |
+
proportion_empty_prompts=cfg_drop_ratio,
|
209 |
+
use_generic_prompts=False,
|
210 |
+
)
|
211 |
+
|
212 |
+
|
213 |
+
train_dataset = TextPromptDataset(
|
214 |
+
data_root=train_data_path,
|
215 |
+
transform=_image_transform,
|
216 |
+
rank=rank,
|
217 |
+
world_size=world_size,
|
218 |
+
tokenizer=tokenizer,
|
219 |
+
shuffle=shuffle,
|
220 |
+
)
|
221 |
+
return train_dataset
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
### >>>>>>>> >>>>>>>> Test >>>>>>>> >>>>>>>> ###
|
233 |
+
if __name__ == "__main__":
|
234 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
235 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
236 |
+
"/mnt/bn/ic-research-aigc-editing/fast-diffusion-models/assets/public_models/StableDiffusion/stable-diffusion-v1-5",
|
237 |
+
subfolder="tokenizer"
|
238 |
+
)
|
239 |
+
train_dataset = make_train_dataset(tokenizer=tokenizer, rank=0, world_size=10)
|
240 |
+
|
241 |
+
loader = torch.utils.data.DataLoader(
|
242 |
+
train_dataset, batch_size=64, num_workers=0,
|
243 |
+
collate_fn=train_dataset.collect_fn if hasattr(train_dataset, 'collect_fn') else None,
|
244 |
+
)
|
245 |
+
for batch in loader:
|
246 |
+
pixel_values = batch["pixel_values"]
|
247 |
+
prompt_ids = batch['input_ids']
|
248 |
+
from einops import rearrange
|
249 |
+
pixel_values = rearrange(pixel_values, 'b c h w -> b h w c')
|
250 |
+
|
251 |
+
for i in range(pixel_values.shape[0]):
|
252 |
+
import pdb; pdb.set_trace()
|
253 |
+
Image.fromarray(((pixel_values[i] + 1 )/2 * 255 ).numpy().astype(np.uint8)).save('tmp.png')
|
254 |
+
input_id = prompt_ids[i]
|
255 |
+
text = tokenizer.decode(input_id).split('<|startoftext|>')[-1].split('<|endoftext|>')[0]
|
256 |
+
print(text)
|
257 |
+
pass
|
src/pfode_solver.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, math, random, argparse, logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional, Union, List, Callable
|
4 |
+
from collections import OrderedDict
|
5 |
+
from packaging import version
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.utils.checkpoint
|
13 |
+
import torchvision
|
14 |
+
|
15 |
+
|
16 |
+
class PFODESolver():
|
17 |
+
def __init__(self, scheduler, t_initial=1, t_terminal=0,) -> None:
|
18 |
+
self.t_initial = t_initial
|
19 |
+
self.t_terminal = t_terminal
|
20 |
+
self.scheduler = scheduler
|
21 |
+
|
22 |
+
train_step_terminal = 0
|
23 |
+
train_step_initial = train_step_terminal + self.scheduler.config.num_train_timesteps # 0+1000
|
24 |
+
self.stepsize = (t_terminal-t_initial) / (train_step_terminal - train_step_initial) #1/1000
|
25 |
+
|
26 |
+
def get_timesteps(self, t_start, t_end, num_steps):
|
27 |
+
# (b,) -> (b,1)
|
28 |
+
t_start = t_start[:, None]
|
29 |
+
t_end = t_end[:, None]
|
30 |
+
assert t_start.dim() == 2
|
31 |
+
|
32 |
+
timepoints = torch.arange(0, num_steps, 1).expand(t_start.shape[0], num_steps).to(device=t_start.device)
|
33 |
+
interval = (t_end - t_start) / (torch.ones([1], device=t_start.device) * num_steps)
|
34 |
+
timepoints = t_start + interval * timepoints
|
35 |
+
|
36 |
+
timesteps = (self.scheduler.num_train_timesteps - 1) + (timepoints - self.t_initial) / self.stepsize # correspondint to StableDiffusion indexing system, from 999 (t_init) -> 0 (dt)
|
37 |
+
return timesteps.round().long()
|
38 |
+
|
39 |
+
def solve(self,
|
40 |
+
latents,
|
41 |
+
unet,
|
42 |
+
t_start,
|
43 |
+
t_end,
|
44 |
+
prompt_embeds,
|
45 |
+
negative_prompt_embeds,
|
46 |
+
guidance_scale=1.0,
|
47 |
+
num_steps = 2,
|
48 |
+
num_windows = 1,
|
49 |
+
):
|
50 |
+
assert t_start.dim() == 1
|
51 |
+
assert guidance_scale >= 1 and torch.all(torch.gt(t_start, t_end))
|
52 |
+
|
53 |
+
do_classifier_free_guidance = True if guidance_scale > 1 else False
|
54 |
+
bsz = latents.shape[0]
|
55 |
+
|
56 |
+
if do_classifier_free_guidance:
|
57 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
58 |
+
|
59 |
+
timestep_cond = None
|
60 |
+
if unet.config.time_cond_proj_dim is not None:
|
61 |
+
guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(bsz)
|
62 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
63 |
+
guidance_scale_tensor, embedding_dim=unet.config.time_cond_proj_dim
|
64 |
+
).to(device=latents.device, dtype=latents.dtype)
|
65 |
+
|
66 |
+
|
67 |
+
timesteps = self.get_timesteps(t_start, t_end, num_steps).to(device=latents.device)
|
68 |
+
timestep_interval = self.scheduler.config.num_train_timesteps // (num_windows * num_steps)
|
69 |
+
|
70 |
+
# Denoising loop
|
71 |
+
with torch.no_grad():
|
72 |
+
for i in range(num_steps):
|
73 |
+
t = torch.cat([timesteps[:, i]]*2) if do_classifier_free_guidance else timesteps[:, i]
|
74 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
75 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
76 |
+
|
77 |
+
noise_pred = unet(
|
78 |
+
latent_model_input,
|
79 |
+
t,
|
80 |
+
encoder_hidden_states=prompt_embeds,
|
81 |
+
timestep_cond=timestep_cond,
|
82 |
+
return_dict=False,
|
83 |
+
)[0]
|
84 |
+
|
85 |
+
if do_classifier_free_guidance:
|
86 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
87 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
88 |
+
|
89 |
+
##### STEP: compute the previous noisy sample x_t -> x_t-1
|
90 |
+
batch_timesteps = timesteps[:, i].cpu()
|
91 |
+
prev_timestep = batch_timesteps - timestep_interval
|
92 |
+
|
93 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[batch_timesteps]
|
94 |
+
alpha_prod_t_prev = torch.zeros_like(alpha_prod_t)
|
95 |
+
for ib in range(prev_timestep.shape[0]):
|
96 |
+
alpha_prod_t_prev[ib] = self.scheduler.alphas_cumprod[prev_timestep[ib]] if prev_timestep[ib] >= 0 else self.scheduler.final_alpha_cumprod
|
97 |
+
beta_prod_t = 1 - alpha_prod_t
|
98 |
+
|
99 |
+
alpha_prod_t = alpha_prod_t.to(device=latents.device, dtype=latents.dtype)
|
100 |
+
alpha_prod_t_prev = alpha_prod_t_prev.to(device=latents.device, dtype=latents.dtype)
|
101 |
+
beta_prod_t = beta_prod_t.to(device=latents.device, dtype=latents.dtype)
|
102 |
+
|
103 |
+
if self.scheduler.config.prediction_type == "epsilon":
|
104 |
+
pred_original_sample = (latents - beta_prod_t[:,None,None,None] ** (0.5) * noise_pred) / alpha_prod_t[:, None,None,None] ** (0.5)
|
105 |
+
pred_epsilon = noise_pred
|
106 |
+
elif self.scheduler.config.prediction_type == "v_prediction":
|
107 |
+
pred_original_sample = (alpha_prod_t[:,None,None,None]**0.5) * latents - (beta_prod_t[:,None,None,None]**0.5) * noise_pred
|
108 |
+
pred_epsilon = (alpha_prod_t[:,None,None,None]**0.5) * noise_pred + (beta_prod_t[:,None,None,None]**0.5) * latents
|
109 |
+
else:
|
110 |
+
raise NotImplementedError
|
111 |
+
|
112 |
+
pred_sample_direction = (1 - alpha_prod_t_prev[:,None,None,None]) ** (0.5) * pred_epsilon
|
113 |
+
latents = alpha_prod_t_prev[:,None,None,None] ** (0.5) * pred_original_sample + pred_sample_direction
|
114 |
+
|
115 |
+
|
116 |
+
return latents
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
src/scheduler_perflow.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.utils import BaseOutput
|
25 |
+
from diffusers.utils.torch_utils import randn_tensor
|
26 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
27 |
+
|
28 |
+
|
29 |
+
class Time_Windows():
|
30 |
+
def __init__(self, t_initial=1, t_terminal=0, num_windows=4, precision=1./1000) -> None:
|
31 |
+
assert t_terminal < t_initial
|
32 |
+
time_windows = [ 1.*i/num_windows for i in range(1, num_windows+1)][::-1]
|
33 |
+
|
34 |
+
self.window_starts = time_windows # [1.0, 0.75, 0.5, 0.25]
|
35 |
+
self.window_ends = time_windows[1:] + [t_terminal] # [0.75, 0.5, 0.25, 0]
|
36 |
+
self.precision = precision
|
37 |
+
|
38 |
+
def get_window(self, tp):
|
39 |
+
idx = 0
|
40 |
+
# robust to numerical error; e.g, (0.6+1/10000) belongs to [0.6, 0.3)
|
41 |
+
while (tp-0.1*self.precision) <= self.window_ends[idx]:
|
42 |
+
idx += 1
|
43 |
+
return self.window_starts[idx], self.window_ends[idx]
|
44 |
+
|
45 |
+
def lookup_window(self, timepoint):
|
46 |
+
if timepoint.dim() == 0:
|
47 |
+
t_start, t_end = self.get_window(timepoint)
|
48 |
+
t_start = torch.ones_like(timepoint) * t_start
|
49 |
+
t_end = torch.ones_like(timepoint) * t_end
|
50 |
+
else:
|
51 |
+
t_start = torch.zeros_like(timepoint)
|
52 |
+
t_end = torch.zeros_like(timepoint)
|
53 |
+
bsz = timepoint.shape[0]
|
54 |
+
for i in range(bsz):
|
55 |
+
tp = timepoint[i]
|
56 |
+
ts, te = self.get_window(tp)
|
57 |
+
t_start[i] = ts
|
58 |
+
t_end[i] = te
|
59 |
+
return t_start, t_end
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
@dataclass
|
64 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
65 |
+
class PeRFlowSchedulerOutput(BaseOutput):
|
66 |
+
"""
|
67 |
+
Output class for the scheduler's `step` function output.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
71 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
72 |
+
denoising loop.
|
73 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
74 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
75 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
76 |
+
"""
|
77 |
+
|
78 |
+
prev_sample: torch.FloatTensor
|
79 |
+
pred_original_sample: Optional[torch.FloatTensor] = None
|
80 |
+
|
81 |
+
|
82 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
83 |
+
def betas_for_alpha_bar(
|
84 |
+
num_diffusion_timesteps,
|
85 |
+
max_beta=0.999,
|
86 |
+
alpha_transform_type="cosine",
|
87 |
+
):
|
88 |
+
"""
|
89 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
90 |
+
(1-beta) over time from t = [0,1].
|
91 |
+
|
92 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
93 |
+
to that part of the diffusion process.
|
94 |
+
|
95 |
+
|
96 |
+
Args:
|
97 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
98 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
99 |
+
prevent singularities.
|
100 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
101 |
+
Choose from `cosine` or `exp`
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
105 |
+
"""
|
106 |
+
if alpha_transform_type == "cosine":
|
107 |
+
|
108 |
+
def alpha_bar_fn(t):
|
109 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
110 |
+
|
111 |
+
elif alpha_transform_type == "exp":
|
112 |
+
|
113 |
+
def alpha_bar_fn(t):
|
114 |
+
return math.exp(t * -12.0)
|
115 |
+
|
116 |
+
else:
|
117 |
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
118 |
+
|
119 |
+
betas = []
|
120 |
+
for i in range(num_diffusion_timesteps):
|
121 |
+
t1 = i / num_diffusion_timesteps
|
122 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
123 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
124 |
+
return torch.tensor(betas, dtype=torch.float32)
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
class PeRFlowScheduler(SchedulerMixin, ConfigMixin):
|
129 |
+
"""
|
130 |
+
`ReFlowScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
131 |
+
non-Markovian guidance.
|
132 |
+
|
133 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
134 |
+
methods the library implements for all schedulers such as loading and saving.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
num_train_timesteps (`int`, defaults to 1000):
|
138 |
+
The number of diffusion steps to train the model.
|
139 |
+
beta_start (`float`, defaults to 0.0001):
|
140 |
+
The starting `beta` value of inference.
|
141 |
+
beta_end (`float`, defaults to 0.02):
|
142 |
+
The final `beta` value.
|
143 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
144 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
145 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
146 |
+
trained_betas (`np.ndarray`, *optional*):
|
147 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
148 |
+
set_alpha_to_one (`bool`, defaults to `True`):
|
149 |
+
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
150 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
151 |
+
otherwise it uses the alpha value at step 0.
|
152 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*)
|
153 |
+
"""
|
154 |
+
|
155 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
156 |
+
order = 1
|
157 |
+
|
158 |
+
@register_to_config
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
num_train_timesteps: int = 1000,
|
162 |
+
beta_start: float = 0.00085,
|
163 |
+
beta_end: float = 0.012,
|
164 |
+
beta_schedule: str = "scaled_linear",
|
165 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
166 |
+
set_alpha_to_one: bool = False,
|
167 |
+
prediction_type: str = "epsilon",
|
168 |
+
t_noise: float = 1,
|
169 |
+
t_clean: float = 0,
|
170 |
+
num_time_windows = 4,
|
171 |
+
):
|
172 |
+
if trained_betas is not None:
|
173 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
174 |
+
elif beta_schedule == "linear":
|
175 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
176 |
+
elif beta_schedule == "scaled_linear":
|
177 |
+
# this schedule is very specific to the latent diffusion model.
|
178 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
179 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
180 |
+
# Glide cosine schedule
|
181 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
182 |
+
else:
|
183 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
184 |
+
|
185 |
+
self.alphas = 1.0 - self.betas
|
186 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
187 |
+
|
188 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
189 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
190 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
191 |
+
# whether we use the final alpha of the "non-previous" one.
|
192 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
193 |
+
|
194 |
+
# # standard deviation of the initial noise distribution
|
195 |
+
self.init_noise_sigma = 1.0
|
196 |
+
|
197 |
+
self.time_windows = Time_Windows(t_initial=t_noise, t_terminal=t_clean,
|
198 |
+
num_windows=num_time_windows,
|
199 |
+
precision=1./num_train_timesteps)
|
200 |
+
|
201 |
+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
202 |
+
"""
|
203 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
204 |
+
current timestep.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
sample (`torch.FloatTensor`):
|
208 |
+
The input sample.
|
209 |
+
timestep (`int`, *optional*):
|
210 |
+
The current timestep in the diffusion chain.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
`torch.FloatTensor`:
|
214 |
+
A scaled input sample.
|
215 |
+
"""
|
216 |
+
return sample
|
217 |
+
|
218 |
+
|
219 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
220 |
+
"""
|
221 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
222 |
+
|
223 |
+
Args:
|
224 |
+
num_inference_steps (`int`):
|
225 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
226 |
+
"""
|
227 |
+
if num_inference_steps < self.config.num_time_windows:
|
228 |
+
num_inference_steps = self.config.num_time_windows
|
229 |
+
print(f"### We recommend a num_inference_steps not less than num_time_windows. It's set as {self.config.num_time_windows}.")
|
230 |
+
|
231 |
+
timesteps = []
|
232 |
+
for i in range(self.config.num_time_windows):
|
233 |
+
if i < num_inference_steps%self.config.num_time_windows:
|
234 |
+
num_steps_cur_win = num_inference_steps//self.config.num_time_windows+1
|
235 |
+
else:
|
236 |
+
num_steps_cur_win = num_inference_steps//self.config.num_time_windows
|
237 |
+
|
238 |
+
t_s = self.time_windows.window_starts[i]
|
239 |
+
t_e = self.time_windows.window_ends[i]
|
240 |
+
timesteps_cur_win = np.linspace(t_s, t_e, num=num_steps_cur_win, endpoint=False)
|
241 |
+
timesteps.append(timesteps_cur_win)
|
242 |
+
|
243 |
+
timesteps = np.concatenate(timesteps)
|
244 |
+
|
245 |
+
self.timesteps = torch.from_numpy(
|
246 |
+
(timesteps*self.config.num_train_timesteps).astype(np.int64)
|
247 |
+
).to(device)
|
248 |
+
|
249 |
+
def get_window_alpha(self, timestep):
|
250 |
+
time_windows = self.time_windows
|
251 |
+
num_train_timesteps = self.config.num_train_timesteps
|
252 |
+
|
253 |
+
t_win_start, t_win_end = time_windows.lookup_window(timestep / num_train_timesteps)
|
254 |
+
t_win_len = t_win_end - t_win_start
|
255 |
+
t_interval = timestep / num_train_timesteps - t_win_start # NOTE: negative value
|
256 |
+
|
257 |
+
idx_start = (t_win_start*num_train_timesteps - 1 ).long()
|
258 |
+
idx_end = torch.clamp( (t_win_end*num_train_timesteps - 1 ).long(), min=0)
|
259 |
+
alpha_cumprod_s_e = self.alphas_cumprod[idx_start] / self.alphas_cumprod[idx_end]
|
260 |
+
gamma_s_e = alpha_cumprod_s_e ** 0.5
|
261 |
+
|
262 |
+
return t_win_start, t_win_end, t_win_len, t_interval, gamma_s_e
|
263 |
+
|
264 |
+
def step(
|
265 |
+
self,
|
266 |
+
model_output: torch.FloatTensor,
|
267 |
+
timestep: int,
|
268 |
+
sample: torch.FloatTensor,
|
269 |
+
return_dict: bool = True,
|
270 |
+
) -> Union[PeRFlowSchedulerOutput, Tuple]:
|
271 |
+
"""
|
272 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
273 |
+
process from the learned model outputs (most often the predicted noise).
|
274 |
+
|
275 |
+
Args:
|
276 |
+
model_output (`torch.FloatTensor`):
|
277 |
+
The direct output from learned diffusion model.
|
278 |
+
timestep (`float`):
|
279 |
+
The current discrete timestep in the diffusion chain.
|
280 |
+
sample (`torch.FloatTensor`):
|
281 |
+
A current instance of a sample created by the diffusion process.
|
282 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
283 |
+
Whether or not to return a [`~schedulers.scheduling_ddim.PeRFlowSchedulerOutput`] or `tuple`.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
[`~schedulers.scheduling_utils.PeRFlowSchedulerOutput`] or `tuple`:
|
287 |
+
If return_dict is `True`, [`~schedulers.scheduling_ddim.PeRFlowSchedulerOutput`] is returned, otherwise a
|
288 |
+
tuple is returned where the first element is the sample tensor.
|
289 |
+
"""
|
290 |
+
|
291 |
+
if self.config.prediction_type == "epsilon":
|
292 |
+
pred_epsilon = model_output
|
293 |
+
t_win_start, t_win_end, t_win_len, t_interval, gamma_s_e = self.get_window_alpha(timestep)
|
294 |
+
pred_sample_end = ( sample - (1-t_interval/t_win_len) * ((1-gamma_s_e**2)**0.5) * pred_epsilon ) \
|
295 |
+
/ ( gamma_s_e + t_interval / t_win_len * (1-gamma_s_e) )
|
296 |
+
pred_velocity = (pred_sample_end - sample) / (t_win_end - (t_win_start + t_interval))
|
297 |
+
|
298 |
+
elif self.config.prediction_type == "velocity":
|
299 |
+
pred_velocity = model_output
|
300 |
+
else:
|
301 |
+
raise ValueError(
|
302 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `velocity`."
|
303 |
+
)
|
304 |
+
|
305 |
+
# get dt
|
306 |
+
idx = torch.argwhere(torch.where(self.timesteps==timestep, 1,0))
|
307 |
+
prev_step = self.timesteps[idx+1] if (idx+1)<len(self.timesteps) else 0
|
308 |
+
dt = (prev_step - timestep) / self.config.num_train_timesteps
|
309 |
+
dt = dt.to(sample.device, sample.dtype)
|
310 |
+
|
311 |
+
prev_sample = sample + dt * pred_velocity
|
312 |
+
|
313 |
+
if not return_dict:
|
314 |
+
return (prev_sample,)
|
315 |
+
return PeRFlowSchedulerOutput(prev_sample=prev_sample, pred_original_sample=None)
|
316 |
+
|
317 |
+
|
318 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
319 |
+
def add_noise(
|
320 |
+
self,
|
321 |
+
original_samples: torch.FloatTensor,
|
322 |
+
noise: torch.FloatTensor,
|
323 |
+
timesteps: torch.IntTensor,
|
324 |
+
) -> torch.FloatTensor:
|
325 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
326 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
327 |
+
timesteps = timesteps.to(original_samples.device) - 1 # indexing from 0
|
328 |
+
|
329 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
330 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
331 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
332 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
333 |
+
|
334 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
335 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
336 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
337 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
338 |
+
|
339 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
340 |
+
return noisy_samples
|
341 |
+
|
342 |
+
def __len__(self):
|
343 |
+
return self.config.num_train_timesteps
|
src/utils_perflow.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import OrderedDict
|
3 |
+
import torch
|
4 |
+
from safetensors import safe_open
|
5 |
+
from safetensors.torch import save_file
|
6 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
7 |
+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_ldm_clip_checkpoint
|
8 |
+
|
9 |
+
|
10 |
+
def merge_delta_weights_into_unet(pipe, delta_weights):
|
11 |
+
unet_weights = pipe.unet.state_dict()
|
12 |
+
assert unet_weights.keys() == delta_weights.keys()
|
13 |
+
for key in delta_weights.keys():
|
14 |
+
dtype = unet_weights[key].dtype
|
15 |
+
unet_weights[key] = unet_weights[key].to(dtype=delta_weights[key].dtype) + delta_weights[key].to(device=unet_weights[key].device)
|
16 |
+
unet_weights[key] = unet_weights[key].to(dtype)
|
17 |
+
pipe.unet.load_state_dict(unet_weights, strict=True)
|
18 |
+
return pipe
|
19 |
+
|
20 |
+
|
21 |
+
def load_delta_weights_into_unet(
|
22 |
+
pipe,
|
23 |
+
model_path = "hsyan/piecewise-rectified-flow-v0-1",
|
24 |
+
base_path = "runwayml/stable-diffusion-v1-5",
|
25 |
+
):
|
26 |
+
## load delta_weights
|
27 |
+
if os.path.exists(os.path.join(model_path, "delta_weights.safetensors")):
|
28 |
+
print("### delta_weights exists, loading...")
|
29 |
+
delta_weights = OrderedDict()
|
30 |
+
with safe_open(os.path.join(model_path, "delta_weights.safetensors"), framework="pt", device="cpu") as f:
|
31 |
+
for key in f.keys():
|
32 |
+
delta_weights[key] = f.get_tensor(key)
|
33 |
+
|
34 |
+
elif os.path.exists(os.path.join(model_path, "diffusion_pytorch_model.safetensors")):
|
35 |
+
print("### merged_weights exists, loading...")
|
36 |
+
merged_weights = OrderedDict()
|
37 |
+
with safe_open(os.path.join(model_path, "diffusion_pytorch_model.safetensors"), framework="pt", device="cpu") as f:
|
38 |
+
for key in f.keys():
|
39 |
+
merged_weights[key] = f.get_tensor(key)
|
40 |
+
|
41 |
+
base_weights = StableDiffusionPipeline.from_pretrained(
|
42 |
+
base_path, torch_dtype=torch.float16, safety_checker=None).unet.state_dict()
|
43 |
+
assert base_weights.keys() == merged_weights.keys()
|
44 |
+
|
45 |
+
delta_weights = OrderedDict()
|
46 |
+
for key in merged_weights.keys():
|
47 |
+
delta_weights[key] = merged_weights[key] - base_weights[key].to(device=merged_weights[key].device, dtype=merged_weights[key].dtype)
|
48 |
+
|
49 |
+
print("### saving delta_weights...")
|
50 |
+
save_file(delta_weights, os.path.join(model_path, "delta_weights.safetensors"))
|
51 |
+
|
52 |
+
else:
|
53 |
+
raise ValueError(f"{model_path} does not contain delta weights or merged weights")
|
54 |
+
|
55 |
+
## merge delta_weights to the target pipeline
|
56 |
+
pipe = merge_delta_weights_into_unet(pipe, delta_weights)
|
57 |
+
return pipe
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def load_dreambooth_into_pipeline(pipe, sd_dreambooth):
|
63 |
+
assert sd_dreambooth.endswith(".safetensors")
|
64 |
+
state_dict = {}
|
65 |
+
with safe_open(sd_dreambooth, framework="pt", device="cpu") as f:
|
66 |
+
for key in f.keys():
|
67 |
+
state_dict[key] = f.get_tensor(key)
|
68 |
+
|
69 |
+
unet_config = {} # unet, line 449 in convert_ldm_unet_checkpoint
|
70 |
+
for key in pipe.unet.config.keys():
|
71 |
+
if key != 'num_class_embeds':
|
72 |
+
unet_config[key] = pipe.unet.config[key]
|
73 |
+
|
74 |
+
pipe.unet.load_state_dict(convert_ldm_unet_checkpoint(state_dict, unet_config), strict=False)
|
75 |
+
pipe.vae.load_state_dict(convert_ldm_vae_checkpoint(state_dict, pipe.vae.config))
|
76 |
+
pipe.text_encoder = convert_ldm_clip_checkpoint(state_dict, text_encoder=pipe.text_encoder)
|
77 |
+
return pipe
|
test.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: test
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
- conda-forge
|
7 |
+
dependencies:
|
8 |
+
- python=3.10.12
|
9 |
+
- pip=23.2.1
|
10 |
+
- cudatoolkit=11.7
|
tsr/__pycache__/system.cpython-310.pyc
ADDED
Binary file (5.19 kB). View file
|
|
tsr/__pycache__/system.cpython-38.pyc
ADDED
Binary file (5.07 kB). View file
|
|
tsr/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (13.6 kB). View file
|
|
tsr/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (13.5 kB). View file
|
|
tsr/models/__pycache__/isosurface.cpython-310.pyc
ADDED
Binary file (2.27 kB). View file
|
|
tsr/models/__pycache__/isosurface.cpython-38.pyc
ADDED
Binary file (2.23 kB). View file
|
|
tsr/models/__pycache__/nerf_renderer.cpython-310.pyc
ADDED
Binary file (5.32 kB). View file
|
|
tsr/models/__pycache__/nerf_renderer.cpython-38.pyc
ADDED
Binary file (5.31 kB). View file
|
|
tsr/models/__pycache__/network_utils.cpython-310.pyc
ADDED
Binary file (3.44 kB). View file
|
|
tsr/models/__pycache__/network_utils.cpython-38.pyc
ADDED
Binary file (3.39 kB). View file
|
|
tsr/models/isosurface.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchmcubes import marching_cubes
|
7 |
+
|
8 |
+
|
9 |
+
class IsosurfaceHelper(nn.Module):
|
10 |
+
points_range: Tuple[float, float] = (0, 1)
|
11 |
+
|
12 |
+
@property
|
13 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
14 |
+
raise NotImplementedError
|
15 |
+
|
16 |
+
|
17 |
+
class MarchingCubeHelper(IsosurfaceHelper):
|
18 |
+
def __init__(self, resolution: int) -> None:
|
19 |
+
super().__init__()
|
20 |
+
self.resolution = resolution
|
21 |
+
self.mc_func: Callable = marching_cubes
|
22 |
+
self._grid_vertices: Optional[torch.FloatTensor] = None
|
23 |
+
|
24 |
+
@property
|
25 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
26 |
+
if self._grid_vertices is None:
|
27 |
+
# keep the vertices on CPU so that we can support very large resolution
|
28 |
+
x, y, z = (
|
29 |
+
torch.linspace(*self.points_range, self.resolution),
|
30 |
+
torch.linspace(*self.points_range, self.resolution),
|
31 |
+
torch.linspace(*self.points_range, self.resolution),
|
32 |
+
)
|
33 |
+
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
|
34 |
+
verts = torch.cat(
|
35 |
+
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
|
36 |
+
).reshape(-1, 3)
|
37 |
+
self._grid_vertices = verts
|
38 |
+
return self._grid_vertices
|
39 |
+
|
40 |
+
def forward(
|
41 |
+
self,
|
42 |
+
level: torch.FloatTensor,
|
43 |
+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
44 |
+
level = -level.view(self.resolution, self.resolution, self.resolution)
|
45 |
+
try:
|
46 |
+
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
|
47 |
+
except AttributeError:
|
48 |
+
print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
|
49 |
+
v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
|
50 |
+
v_pos = v_pos[..., [2, 1, 0]]
|
51 |
+
v_pos = v_pos / (self.resolution - 1.0)
|
52 |
+
return v_pos.to(level.device), t_pos_idx.to(level.device)
|
tsr/models/nerf_renderer.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange, reduce
|
7 |
+
|
8 |
+
from ..utils import (
|
9 |
+
BaseModule,
|
10 |
+
chunk_batch,
|
11 |
+
get_activation,
|
12 |
+
rays_intersect_bbox,
|
13 |
+
scale_tensor,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class TriplaneNeRFRenderer(BaseModule):
|
18 |
+
@dataclass
|
19 |
+
class Config(BaseModule.Config):
|
20 |
+
radius: float
|
21 |
+
|
22 |
+
feature_reduction: str = "concat"
|
23 |
+
density_activation: str = "trunc_exp"
|
24 |
+
density_bias: float = -1.0
|
25 |
+
color_activation: str = "sigmoid"
|
26 |
+
num_samples_per_ray: int = 128
|
27 |
+
randomized: bool = False
|
28 |
+
|
29 |
+
cfg: Config
|
30 |
+
|
31 |
+
def configure(self) -> None:
|
32 |
+
assert self.cfg.feature_reduction in ["concat", "mean"]
|
33 |
+
self.chunk_size = 0
|
34 |
+
|
35 |
+
def set_chunk_size(self, chunk_size: int):
|
36 |
+
assert (
|
37 |
+
chunk_size >= 0
|
38 |
+
), "chunk_size must be a non-negative integer (0 for no chunking)."
|
39 |
+
self.chunk_size = chunk_size
|
40 |
+
|
41 |
+
def query_triplane(
|
42 |
+
self,
|
43 |
+
decoder: torch.nn.Module,
|
44 |
+
positions: torch.Tensor,
|
45 |
+
triplane: torch.Tensor,
|
46 |
+
) -> Dict[str, torch.Tensor]:
|
47 |
+
input_shape = positions.shape[:-1]
|
48 |
+
positions = positions.view(-1, 3)
|
49 |
+
|
50 |
+
# positions in (-radius, radius)
|
51 |
+
# normalized to (-1, 1) for grid sample
|
52 |
+
positions = scale_tensor(
|
53 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
54 |
+
)
|
55 |
+
|
56 |
+
def _query_chunk(x):
|
57 |
+
indices2D: torch.Tensor = torch.stack(
|
58 |
+
(x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
|
59 |
+
dim=-3,
|
60 |
+
)
|
61 |
+
out: torch.Tensor = F.grid_sample(
|
62 |
+
rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
|
63 |
+
rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
|
64 |
+
align_corners=False,
|
65 |
+
mode="bilinear",
|
66 |
+
)
|
67 |
+
if self.cfg.feature_reduction == "concat":
|
68 |
+
out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
|
69 |
+
elif self.cfg.feature_reduction == "mean":
|
70 |
+
out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
|
71 |
+
else:
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
net_out: Dict[str, torch.Tensor] = decoder(out)
|
75 |
+
return net_out
|
76 |
+
|
77 |
+
if self.chunk_size > 0:
|
78 |
+
net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
|
79 |
+
else:
|
80 |
+
net_out = _query_chunk(positions)
|
81 |
+
|
82 |
+
net_out["density_act"] = get_activation(self.cfg.density_activation)(
|
83 |
+
net_out["density"] + self.cfg.density_bias
|
84 |
+
)
|
85 |
+
net_out["color"] = get_activation(self.cfg.color_activation)(
|
86 |
+
net_out["features"]
|
87 |
+
)
|
88 |
+
|
89 |
+
net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
|
90 |
+
|
91 |
+
return net_out
|
92 |
+
|
93 |
+
def _forward(
|
94 |
+
self,
|
95 |
+
decoder: torch.nn.Module,
|
96 |
+
triplane: torch.Tensor,
|
97 |
+
rays_o: torch.Tensor,
|
98 |
+
rays_d: torch.Tensor,
|
99 |
+
**kwargs,
|
100 |
+
):
|
101 |
+
rays_shape = rays_o.shape[:-1]
|
102 |
+
rays_o = rays_o.view(-1, 3)
|
103 |
+
rays_d = rays_d.view(-1, 3)
|
104 |
+
n_rays = rays_o.shape[0]
|
105 |
+
|
106 |
+
t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
|
107 |
+
t_near, t_far = t_near[rays_valid], t_far[rays_valid]
|
108 |
+
|
109 |
+
t_vals = torch.linspace(
|
110 |
+
0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
|
111 |
+
)
|
112 |
+
t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
|
113 |
+
z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
|
114 |
+
|
115 |
+
xyz = (
|
116 |
+
rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
|
117 |
+
) # (N_rays, N_sample, 3)
|
118 |
+
|
119 |
+
mlp_out = self.query_triplane(
|
120 |
+
decoder=decoder,
|
121 |
+
positions=xyz,
|
122 |
+
triplane=triplane,
|
123 |
+
)
|
124 |
+
|
125 |
+
eps = 1e-10
|
126 |
+
# deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
|
127 |
+
deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
|
128 |
+
alpha = 1 - torch.exp(
|
129 |
+
-deltas * mlp_out["density_act"][..., 0]
|
130 |
+
) # (N_rays, N_samples)
|
131 |
+
accum_prod = torch.cat(
|
132 |
+
[
|
133 |
+
torch.ones_like(alpha[:, :1]),
|
134 |
+
torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
|
135 |
+
],
|
136 |
+
dim=-1,
|
137 |
+
)
|
138 |
+
weights = alpha * accum_prod # (N_rays, N_samples)
|
139 |
+
comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
|
140 |
+
opacity_ = weights.sum(dim=-1) # (N_rays)
|
141 |
+
|
142 |
+
comp_rgb = torch.zeros(
|
143 |
+
n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
|
144 |
+
)
|
145 |
+
opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
|
146 |
+
comp_rgb[rays_valid] = comp_rgb_
|
147 |
+
opacity[rays_valid] = opacity_
|
148 |
+
|
149 |
+
comp_rgb += 1 - opacity[..., None]
|
150 |
+
comp_rgb = comp_rgb.view(*rays_shape, 3)
|
151 |
+
|
152 |
+
return comp_rgb
|
153 |
+
|
154 |
+
def forward(
|
155 |
+
self,
|
156 |
+
decoder: torch.nn.Module,
|
157 |
+
triplane: torch.Tensor,
|
158 |
+
rays_o: torch.Tensor,
|
159 |
+
rays_d: torch.Tensor,
|
160 |
+
) -> Dict[str, torch.Tensor]:
|
161 |
+
if triplane.ndim == 4:
|
162 |
+
comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
|
163 |
+
else:
|
164 |
+
comp_rgb = torch.stack(
|
165 |
+
[
|
166 |
+
self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
|
167 |
+
for i in range(triplane.shape[0])
|
168 |
+
],
|
169 |
+
dim=0,
|
170 |
+
)
|
171 |
+
|
172 |
+
return comp_rgb
|
173 |
+
|
174 |
+
def train(self, mode=True):
|
175 |
+
self.randomized = mode and self.cfg.randomized
|
176 |
+
return super().train(mode=mode)
|
177 |
+
|
178 |
+
def eval(self):
|
179 |
+
self.randomized = False
|
180 |
+
return super().eval()
|
tsr/models/network_utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from ..utils import BaseModule
|
9 |
+
|
10 |
+
|
11 |
+
class TriplaneUpsampleNetwork(BaseModule):
|
12 |
+
@dataclass
|
13 |
+
class Config(BaseModule.Config):
|
14 |
+
in_channels: int
|
15 |
+
out_channels: int
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.upsample = nn.ConvTranspose2d(
|
21 |
+
self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
|
25 |
+
triplanes_up = rearrange(
|
26 |
+
self.upsample(
|
27 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
28 |
+
),
|
29 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
30 |
+
Np=3,
|
31 |
+
)
|
32 |
+
return triplanes_up
|
33 |
+
|
34 |
+
|
35 |
+
class NeRFMLP(BaseModule):
|
36 |
+
@dataclass
|
37 |
+
class Config(BaseModule.Config):
|
38 |
+
in_channels: int
|
39 |
+
n_neurons: int
|
40 |
+
n_hidden_layers: int
|
41 |
+
activation: str = "relu"
|
42 |
+
bias: bool = True
|
43 |
+
weight_init: Optional[str] = "kaiming_uniform"
|
44 |
+
bias_init: Optional[str] = None
|
45 |
+
|
46 |
+
cfg: Config
|
47 |
+
|
48 |
+
def configure(self) -> None:
|
49 |
+
layers = [
|
50 |
+
self.make_linear(
|
51 |
+
self.cfg.in_channels,
|
52 |
+
self.cfg.n_neurons,
|
53 |
+
bias=self.cfg.bias,
|
54 |
+
weight_init=self.cfg.weight_init,
|
55 |
+
bias_init=self.cfg.bias_init,
|
56 |
+
),
|
57 |
+
self.make_activation(self.cfg.activation),
|
58 |
+
]
|
59 |
+
for i in range(self.cfg.n_hidden_layers - 1):
|
60 |
+
layers += [
|
61 |
+
self.make_linear(
|
62 |
+
self.cfg.n_neurons,
|
63 |
+
self.cfg.n_neurons,
|
64 |
+
bias=self.cfg.bias,
|
65 |
+
weight_init=self.cfg.weight_init,
|
66 |
+
bias_init=self.cfg.bias_init,
|
67 |
+
),
|
68 |
+
self.make_activation(self.cfg.activation),
|
69 |
+
]
|
70 |
+
layers += [
|
71 |
+
self.make_linear(
|
72 |
+
self.cfg.n_neurons,
|
73 |
+
4, # density 1 + features 3
|
74 |
+
bias=self.cfg.bias,
|
75 |
+
weight_init=self.cfg.weight_init,
|
76 |
+
bias_init=self.cfg.bias_init,
|
77 |
+
)
|
78 |
+
]
|
79 |
+
self.layers = nn.Sequential(*layers)
|
80 |
+
|
81 |
+
def make_linear(
|
82 |
+
self,
|
83 |
+
dim_in,
|
84 |
+
dim_out,
|
85 |
+
bias=True,
|
86 |
+
weight_init=None,
|
87 |
+
bias_init=None,
|
88 |
+
):
|
89 |
+
layer = nn.Linear(dim_in, dim_out, bias=bias)
|
90 |
+
|
91 |
+
if weight_init is None:
|
92 |
+
pass
|
93 |
+
elif weight_init == "kaiming_uniform":
|
94 |
+
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
if bias:
|
99 |
+
if bias_init is None:
|
100 |
+
pass
|
101 |
+
elif bias_init == "zero":
|
102 |
+
torch.nn.init.zeros_(layer.bias)
|
103 |
+
else:
|
104 |
+
raise NotImplementedError
|
105 |
+
|
106 |
+
return layer
|
107 |
+
|
108 |
+
def make_activation(self, activation):
|
109 |
+
if activation == "relu":
|
110 |
+
return nn.ReLU(inplace=True)
|
111 |
+
elif activation == "silu":
|
112 |
+
return nn.SiLU(inplace=True)
|
113 |
+
else:
|
114 |
+
raise NotImplementedError
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
inp_shape = x.shape[:-1]
|
118 |
+
x = x.reshape(-1, x.shape[-1])
|
119 |
+
|
120 |
+
features = self.layers(x)
|
121 |
+
features = features.reshape(*inp_shape, -1)
|
122 |
+
out = {"density": features[..., 0:1], "features": features[..., 1:4]}
|
123 |
+
|
124 |
+
return out
|
tsr/models/tokenizers/__pycache__/image.cpython-310.pyc
ADDED
Binary file (2.42 kB). View file
|
|
tsr/models/tokenizers/__pycache__/image.cpython-38.pyc
ADDED
Binary file (2.39 kB). View file
|
|
tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc
ADDED
Binary file (1.79 kB). View file
|
|
tsr/models/tokenizers/__pycache__/triplane.cpython-38.pyc
ADDED
Binary file (1.77 kB). View file
|
|
tsr/models/tokenizers/image.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from transformers.models.vit.modeling_vit import ViTModel
|
8 |
+
|
9 |
+
from ...utils import BaseModule
|
10 |
+
|
11 |
+
|
12 |
+
class DINOSingleImageTokenizer(BaseModule):
|
13 |
+
@dataclass
|
14 |
+
class Config(BaseModule.Config):
|
15 |
+
pretrained_model_name_or_path: str = "facebook/dino-vitb16"
|
16 |
+
enable_gradient_checkpointing: bool = False
|
17 |
+
|
18 |
+
cfg: Config
|
19 |
+
|
20 |
+
def configure(self) -> None:
|
21 |
+
self.model: ViTModel = ViTModel(
|
22 |
+
ViTModel.config_class.from_pretrained(
|
23 |
+
hf_hub_download(
|
24 |
+
repo_id=self.cfg.pretrained_model_name_or_path,
|
25 |
+
filename="config.json",
|
26 |
+
)
|
27 |
+
)
|
28 |
+
)
|
29 |
+
|
30 |
+
if self.cfg.enable_gradient_checkpointing:
|
31 |
+
self.model.encoder.gradient_checkpointing = True
|
32 |
+
|
33 |
+
self.register_buffer(
|
34 |
+
"image_mean",
|
35 |
+
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
|
36 |
+
persistent=False,
|
37 |
+
)
|
38 |
+
self.register_buffer(
|
39 |
+
"image_std",
|
40 |
+
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
|
41 |
+
persistent=False,
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
45 |
+
packed = False
|
46 |
+
if images.ndim == 4:
|
47 |
+
packed = True
|
48 |
+
images = images.unsqueeze(1)
|
49 |
+
|
50 |
+
batch_size, n_input_views = images.shape[:2]
|
51 |
+
images = (images - self.image_mean) / self.image_std
|
52 |
+
out = self.model(
|
53 |
+
rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
|
54 |
+
)
|
55 |
+
local_features, global_features = out.last_hidden_state, out.pooler_output
|
56 |
+
local_features = local_features.permute(0, 2, 1)
|
57 |
+
local_features = rearrange(
|
58 |
+
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
|
59 |
+
)
|
60 |
+
if packed:
|
61 |
+
local_features = local_features.squeeze(1)
|
62 |
+
|
63 |
+
return local_features
|
64 |
+
|
65 |
+
def detokenize(self, *args, **kwargs):
|
66 |
+
raise NotImplementedError
|
tsr/models/tokenizers/triplane.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
from ...utils import BaseModule
|
9 |
+
|
10 |
+
|
11 |
+
class Triplane1DTokenizer(BaseModule):
|
12 |
+
@dataclass
|
13 |
+
class Config(BaseModule.Config):
|
14 |
+
plane_size: int
|
15 |
+
num_channels: int
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.embeddings = nn.Parameter(
|
21 |
+
torch.randn(
|
22 |
+
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
|
23 |
+
dtype=torch.float32,
|
24 |
+
)
|
25 |
+
* 1
|
26 |
+
/ math.sqrt(self.cfg.num_channels)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, batch_size: int) -> torch.Tensor:
|
30 |
+
return rearrange(
|
31 |
+
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
|
32 |
+
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
|
33 |
+
)
|
34 |
+
|
35 |
+
def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
|
36 |
+
batch_size, Ct, Nt = tokens.shape
|
37 |
+
assert Nt == self.cfg.plane_size**2 * 3
|
38 |
+
assert Ct == self.cfg.num_channels
|
39 |
+
return rearrange(
|
40 |
+
tokens,
|
41 |
+
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
|
42 |
+
Np=3,
|
43 |
+
Hp=self.cfg.plane_size,
|
44 |
+
Wp=self.cfg.plane_size,
|
45 |
+
)
|
tsr/models/transformer/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (15.3 kB). View file
|
|
tsr/models/transformer/__pycache__/attention.cpython-38.pyc
ADDED
Binary file (15.2 kB). View file
|
|
tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc
ADDED
Binary file (9.65 kB). View file
|
|
tsr/models/transformer/__pycache__/basic_transformer_block.cpython-38.pyc
ADDED
Binary file (9.49 kB). View file
|
|
tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc
ADDED
Binary file (4.91 kB). View file
|
|
tsr/models/transformer/__pycache__/transformer_1d.cpython-38.pyc
ADDED
Binary file (4.85 kB). View file
|
|
tsr/models/transformer/attention.py
ADDED
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# --------
|
16 |
+
#
|
17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
18 |
+
#
|
19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
|
39 |
+
from typing import Optional
|
40 |
+
|
41 |
+
import torch
|
42 |
+
import torch.nn.functional as F
|
43 |
+
from torch import nn
|
44 |
+
|
45 |
+
|
46 |
+
class Attention(nn.Module):
|
47 |
+
r"""
|
48 |
+
A cross attention layer.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
query_dim (`int`):
|
52 |
+
The number of channels in the query.
|
53 |
+
cross_attention_dim (`int`, *optional*):
|
54 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
55 |
+
heads (`int`, *optional*, defaults to 8):
|
56 |
+
The number of heads to use for multi-head attention.
|
57 |
+
dim_head (`int`, *optional*, defaults to 64):
|
58 |
+
The number of channels in each head.
|
59 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
60 |
+
The dropout probability to use.
|
61 |
+
bias (`bool`, *optional*, defaults to False):
|
62 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
63 |
+
upcast_attention (`bool`, *optional*, defaults to False):
|
64 |
+
Set to `True` to upcast the attention computation to `float32`.
|
65 |
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
66 |
+
Set to `True` to upcast the softmax computation to `float32`.
|
67 |
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
68 |
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
69 |
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
70 |
+
The number of groups to use for the group norm in the cross attention.
|
71 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
72 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
73 |
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
74 |
+
The number of groups to use for the group norm in the attention.
|
75 |
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
76 |
+
The number of channels to use for the spatial normalization.
|
77 |
+
out_bias (`bool`, *optional*, defaults to `True`):
|
78 |
+
Set to `True` to use a bias in the output linear layer.
|
79 |
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
80 |
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
81 |
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
82 |
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
83 |
+
`added_kv_proj_dim` is not `None`.
|
84 |
+
eps (`float`, *optional*, defaults to 1e-5):
|
85 |
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
86 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
87 |
+
A factor to rescale the output by dividing it with this value.
|
88 |
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
89 |
+
Set to `True` to add the residual connection to the output.
|
90 |
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
91 |
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
92 |
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
93 |
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
94 |
+
`AttnProcessor` otherwise.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
query_dim: int,
|
100 |
+
cross_attention_dim: Optional[int] = None,
|
101 |
+
heads: int = 8,
|
102 |
+
dim_head: int = 64,
|
103 |
+
dropout: float = 0.0,
|
104 |
+
bias: bool = False,
|
105 |
+
upcast_attention: bool = False,
|
106 |
+
upcast_softmax: bool = False,
|
107 |
+
cross_attention_norm: Optional[str] = None,
|
108 |
+
cross_attention_norm_num_groups: int = 32,
|
109 |
+
added_kv_proj_dim: Optional[int] = None,
|
110 |
+
norm_num_groups: Optional[int] = None,
|
111 |
+
out_bias: bool = True,
|
112 |
+
scale_qk: bool = True,
|
113 |
+
only_cross_attention: bool = False,
|
114 |
+
eps: float = 1e-5,
|
115 |
+
rescale_output_factor: float = 1.0,
|
116 |
+
residual_connection: bool = False,
|
117 |
+
_from_deprecated_attn_block: bool = False,
|
118 |
+
processor: Optional["AttnProcessor"] = None,
|
119 |
+
out_dim: int = None,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
123 |
+
self.query_dim = query_dim
|
124 |
+
self.cross_attention_dim = (
|
125 |
+
cross_attention_dim if cross_attention_dim is not None else query_dim
|
126 |
+
)
|
127 |
+
self.upcast_attention = upcast_attention
|
128 |
+
self.upcast_softmax = upcast_softmax
|
129 |
+
self.rescale_output_factor = rescale_output_factor
|
130 |
+
self.residual_connection = residual_connection
|
131 |
+
self.dropout = dropout
|
132 |
+
self.fused_projections = False
|
133 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
134 |
+
|
135 |
+
# we make use of this private variable to know whether this class is loaded
|
136 |
+
# with an deprecated state dict so that we can convert it on the fly
|
137 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
138 |
+
|
139 |
+
self.scale_qk = scale_qk
|
140 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
141 |
+
|
142 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
143 |
+
# for slice_size > 0 the attention score computation
|
144 |
+
# is split across the batch axis to save memory
|
145 |
+
# You can set slice_size with `set_attention_slice`
|
146 |
+
self.sliceable_head_dim = heads
|
147 |
+
|
148 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
149 |
+
self.only_cross_attention = only_cross_attention
|
150 |
+
|
151 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
152 |
+
raise ValueError(
|
153 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
154 |
+
)
|
155 |
+
|
156 |
+
if norm_num_groups is not None:
|
157 |
+
self.group_norm = nn.GroupNorm(
|
158 |
+
num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
self.group_norm = None
|
162 |
+
|
163 |
+
self.spatial_norm = None
|
164 |
+
|
165 |
+
if cross_attention_norm is None:
|
166 |
+
self.norm_cross = None
|
167 |
+
elif cross_attention_norm == "layer_norm":
|
168 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
169 |
+
elif cross_attention_norm == "group_norm":
|
170 |
+
if self.added_kv_proj_dim is not None:
|
171 |
+
# The given `encoder_hidden_states` are initially of shape
|
172 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
173 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
174 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
175 |
+
# the number of channels for the group norm.
|
176 |
+
norm_cross_num_channels = added_kv_proj_dim
|
177 |
+
else:
|
178 |
+
norm_cross_num_channels = self.cross_attention_dim
|
179 |
+
|
180 |
+
self.norm_cross = nn.GroupNorm(
|
181 |
+
num_channels=norm_cross_num_channels,
|
182 |
+
num_groups=cross_attention_norm_num_groups,
|
183 |
+
eps=1e-5,
|
184 |
+
affine=True,
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
raise ValueError(
|
188 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
189 |
+
)
|
190 |
+
|
191 |
+
linear_cls = nn.Linear
|
192 |
+
|
193 |
+
self.linear_cls = linear_cls
|
194 |
+
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
195 |
+
|
196 |
+
if not self.only_cross_attention:
|
197 |
+
# only relevant for the `AddedKVProcessor` classes
|
198 |
+
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
199 |
+
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
200 |
+
else:
|
201 |
+
self.to_k = None
|
202 |
+
self.to_v = None
|
203 |
+
|
204 |
+
if self.added_kv_proj_dim is not None:
|
205 |
+
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
206 |
+
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
207 |
+
|
208 |
+
self.to_out = nn.ModuleList([])
|
209 |
+
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
|
210 |
+
self.to_out.append(nn.Dropout(dropout))
|
211 |
+
|
212 |
+
# set attention processor
|
213 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
214 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
215 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
216 |
+
if processor is None:
|
217 |
+
processor = (
|
218 |
+
AttnProcessor2_0()
|
219 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
220 |
+
else AttnProcessor()
|
221 |
+
)
|
222 |
+
self.set_processor(processor)
|
223 |
+
|
224 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
225 |
+
self.processor = processor
|
226 |
+
|
227 |
+
def forward(
|
228 |
+
self,
|
229 |
+
hidden_states: torch.FloatTensor,
|
230 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
231 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
232 |
+
**cross_attention_kwargs,
|
233 |
+
) -> torch.Tensor:
|
234 |
+
r"""
|
235 |
+
The forward method of the `Attention` class.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
hidden_states (`torch.Tensor`):
|
239 |
+
The hidden states of the query.
|
240 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
241 |
+
The hidden states of the encoder.
|
242 |
+
attention_mask (`torch.Tensor`, *optional*):
|
243 |
+
The attention mask to use. If `None`, no mask is applied.
|
244 |
+
**cross_attention_kwargs:
|
245 |
+
Additional keyword arguments to pass along to the cross attention.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
`torch.Tensor`: The output of the attention layer.
|
249 |
+
"""
|
250 |
+
# The `Attention` class can call different attention processors / attention functions
|
251 |
+
# here we simply pass along all tensors to the selected processor class
|
252 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
253 |
+
return self.processor(
|
254 |
+
self,
|
255 |
+
hidden_states,
|
256 |
+
encoder_hidden_states=encoder_hidden_states,
|
257 |
+
attention_mask=attention_mask,
|
258 |
+
**cross_attention_kwargs,
|
259 |
+
)
|
260 |
+
|
261 |
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
262 |
+
r"""
|
263 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
264 |
+
is the number of heads initialized while constructing the `Attention` class.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
`torch.Tensor`: The reshaped tensor.
|
271 |
+
"""
|
272 |
+
head_size = self.heads
|
273 |
+
batch_size, seq_len, dim = tensor.shape
|
274 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
275 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
276 |
+
batch_size // head_size, seq_len, dim * head_size
|
277 |
+
)
|
278 |
+
return tensor
|
279 |
+
|
280 |
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
281 |
+
r"""
|
282 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
283 |
+
the number of heads initialized while constructing the `Attention` class.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
287 |
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
288 |
+
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
`torch.Tensor`: The reshaped tensor.
|
292 |
+
"""
|
293 |
+
head_size = self.heads
|
294 |
+
batch_size, seq_len, dim = tensor.shape
|
295 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
296 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
297 |
+
|
298 |
+
if out_dim == 3:
|
299 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
300 |
+
|
301 |
+
return tensor
|
302 |
+
|
303 |
+
def get_attention_scores(
|
304 |
+
self,
|
305 |
+
query: torch.Tensor,
|
306 |
+
key: torch.Tensor,
|
307 |
+
attention_mask: torch.Tensor = None,
|
308 |
+
) -> torch.Tensor:
|
309 |
+
r"""
|
310 |
+
Compute the attention scores.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
query (`torch.Tensor`): The query tensor.
|
314 |
+
key (`torch.Tensor`): The key tensor.
|
315 |
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
`torch.Tensor`: The attention probabilities/scores.
|
319 |
+
"""
|
320 |
+
dtype = query.dtype
|
321 |
+
if self.upcast_attention:
|
322 |
+
query = query.float()
|
323 |
+
key = key.float()
|
324 |
+
|
325 |
+
if attention_mask is None:
|
326 |
+
baddbmm_input = torch.empty(
|
327 |
+
query.shape[0],
|
328 |
+
query.shape[1],
|
329 |
+
key.shape[1],
|
330 |
+
dtype=query.dtype,
|
331 |
+
device=query.device,
|
332 |
+
)
|
333 |
+
beta = 0
|
334 |
+
else:
|
335 |
+
baddbmm_input = attention_mask
|
336 |
+
beta = 1
|
337 |
+
|
338 |
+
attention_scores = torch.baddbmm(
|
339 |
+
baddbmm_input,
|
340 |
+
query,
|
341 |
+
key.transpose(-1, -2),
|
342 |
+
beta=beta,
|
343 |
+
alpha=self.scale,
|
344 |
+
)
|
345 |
+
del baddbmm_input
|
346 |
+
|
347 |
+
if self.upcast_softmax:
|
348 |
+
attention_scores = attention_scores.float()
|
349 |
+
|
350 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
351 |
+
del attention_scores
|
352 |
+
|
353 |
+
attention_probs = attention_probs.to(dtype)
|
354 |
+
|
355 |
+
return attention_probs
|
356 |
+
|
357 |
+
def prepare_attention_mask(
|
358 |
+
self,
|
359 |
+
attention_mask: torch.Tensor,
|
360 |
+
target_length: int,
|
361 |
+
batch_size: int,
|
362 |
+
out_dim: int = 3,
|
363 |
+
) -> torch.Tensor:
|
364 |
+
r"""
|
365 |
+
Prepare the attention mask for the attention computation.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
attention_mask (`torch.Tensor`):
|
369 |
+
The attention mask to prepare.
|
370 |
+
target_length (`int`):
|
371 |
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
372 |
+
batch_size (`int`):
|
373 |
+
The batch size, which is used to repeat the attention mask.
|
374 |
+
out_dim (`int`, *optional*, defaults to `3`):
|
375 |
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
`torch.Tensor`: The prepared attention mask.
|
379 |
+
"""
|
380 |
+
head_size = self.heads
|
381 |
+
if attention_mask is None:
|
382 |
+
return attention_mask
|
383 |
+
|
384 |
+
current_length: int = attention_mask.shape[-1]
|
385 |
+
if current_length != target_length:
|
386 |
+
if attention_mask.device.type == "mps":
|
387 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
388 |
+
# Instead, we can manually construct the padding tensor.
|
389 |
+
padding_shape = (
|
390 |
+
attention_mask.shape[0],
|
391 |
+
attention_mask.shape[1],
|
392 |
+
target_length,
|
393 |
+
)
|
394 |
+
padding = torch.zeros(
|
395 |
+
padding_shape,
|
396 |
+
dtype=attention_mask.dtype,
|
397 |
+
device=attention_mask.device,
|
398 |
+
)
|
399 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
400 |
+
else:
|
401 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
402 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
403 |
+
# remaining_length: int = target_length - current_length
|
404 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
405 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
406 |
+
|
407 |
+
if out_dim == 3:
|
408 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
409 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
410 |
+
elif out_dim == 4:
|
411 |
+
attention_mask = attention_mask.unsqueeze(1)
|
412 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
413 |
+
|
414 |
+
return attention_mask
|
415 |
+
|
416 |
+
def norm_encoder_hidden_states(
|
417 |
+
self, encoder_hidden_states: torch.Tensor
|
418 |
+
) -> torch.Tensor:
|
419 |
+
r"""
|
420 |
+
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
421 |
+
`Attention` class.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
425 |
+
|
426 |
+
Returns:
|
427 |
+
`torch.Tensor`: The normalized encoder hidden states.
|
428 |
+
"""
|
429 |
+
assert (
|
430 |
+
self.norm_cross is not None
|
431 |
+
), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
432 |
+
|
433 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
434 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
435 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
436 |
+
# Group norm norms along the channels dimension and expects
|
437 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
438 |
+
# to norm along the hidden dimension, so we need to move
|
439 |
+
# (batch_size, sequence_length, hidden_size) ->
|
440 |
+
# (batch_size, hidden_size, sequence_length)
|
441 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
442 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
443 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
444 |
+
else:
|
445 |
+
assert False
|
446 |
+
|
447 |
+
return encoder_hidden_states
|
448 |
+
|
449 |
+
@torch.no_grad()
|
450 |
+
def fuse_projections(self, fuse=True):
|
451 |
+
is_cross_attention = self.cross_attention_dim != self.query_dim
|
452 |
+
device = self.to_q.weight.data.device
|
453 |
+
dtype = self.to_q.weight.data.dtype
|
454 |
+
|
455 |
+
if not is_cross_attention:
|
456 |
+
# fetch weight matrices.
|
457 |
+
concatenated_weights = torch.cat(
|
458 |
+
[self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
|
459 |
+
)
|
460 |
+
in_features = concatenated_weights.shape[1]
|
461 |
+
out_features = concatenated_weights.shape[0]
|
462 |
+
|
463 |
+
# create a new single projection layer and copy over the weights.
|
464 |
+
self.to_qkv = self.linear_cls(
|
465 |
+
in_features, out_features, bias=False, device=device, dtype=dtype
|
466 |
+
)
|
467 |
+
self.to_qkv.weight.copy_(concatenated_weights)
|
468 |
+
|
469 |
+
else:
|
470 |
+
concatenated_weights = torch.cat(
|
471 |
+
[self.to_k.weight.data, self.to_v.weight.data]
|
472 |
+
)
|
473 |
+
in_features = concatenated_weights.shape[1]
|
474 |
+
out_features = concatenated_weights.shape[0]
|
475 |
+
|
476 |
+
self.to_kv = self.linear_cls(
|
477 |
+
in_features, out_features, bias=False, device=device, dtype=dtype
|
478 |
+
)
|
479 |
+
self.to_kv.weight.copy_(concatenated_weights)
|
480 |
+
|
481 |
+
self.fused_projections = fuse
|
482 |
+
|
483 |
+
|
484 |
+
class AttnProcessor:
|
485 |
+
r"""
|
486 |
+
Default processor for performing attention-related computations.
|
487 |
+
"""
|
488 |
+
|
489 |
+
def __call__(
|
490 |
+
self,
|
491 |
+
attn: Attention,
|
492 |
+
hidden_states: torch.FloatTensor,
|
493 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
494 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
495 |
+
) -> torch.Tensor:
|
496 |
+
residual = hidden_states
|
497 |
+
|
498 |
+
input_ndim = hidden_states.ndim
|
499 |
+
|
500 |
+
if input_ndim == 4:
|
501 |
+
batch_size, channel, height, width = hidden_states.shape
|
502 |
+
hidden_states = hidden_states.view(
|
503 |
+
batch_size, channel, height * width
|
504 |
+
).transpose(1, 2)
|
505 |
+
|
506 |
+
batch_size, sequence_length, _ = (
|
507 |
+
hidden_states.shape
|
508 |
+
if encoder_hidden_states is None
|
509 |
+
else encoder_hidden_states.shape
|
510 |
+
)
|
511 |
+
attention_mask = attn.prepare_attention_mask(
|
512 |
+
attention_mask, sequence_length, batch_size
|
513 |
+
)
|
514 |
+
|
515 |
+
if attn.group_norm is not None:
|
516 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
517 |
+
1, 2
|
518 |
+
)
|
519 |
+
|
520 |
+
query = attn.to_q(hidden_states)
|
521 |
+
|
522 |
+
if encoder_hidden_states is None:
|
523 |
+
encoder_hidden_states = hidden_states
|
524 |
+
elif attn.norm_cross:
|
525 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
526 |
+
encoder_hidden_states
|
527 |
+
)
|
528 |
+
|
529 |
+
key = attn.to_k(encoder_hidden_states)
|
530 |
+
value = attn.to_v(encoder_hidden_states)
|
531 |
+
|
532 |
+
query = attn.head_to_batch_dim(query)
|
533 |
+
key = attn.head_to_batch_dim(key)
|
534 |
+
value = attn.head_to_batch_dim(value)
|
535 |
+
|
536 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
537 |
+
hidden_states = torch.bmm(attention_probs, value)
|
538 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
539 |
+
|
540 |
+
# linear proj
|
541 |
+
hidden_states = attn.to_out[0](hidden_states)
|
542 |
+
# dropout
|
543 |
+
hidden_states = attn.to_out[1](hidden_states)
|
544 |
+
|
545 |
+
if input_ndim == 4:
|
546 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
547 |
+
batch_size, channel, height, width
|
548 |
+
)
|
549 |
+
|
550 |
+
if attn.residual_connection:
|
551 |
+
hidden_states = hidden_states + residual
|
552 |
+
|
553 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
554 |
+
|
555 |
+
return hidden_states
|
556 |
+
|
557 |
+
|
558 |
+
class AttnProcessor2_0:
|
559 |
+
r"""
|
560 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
561 |
+
"""
|
562 |
+
|
563 |
+
def __init__(self):
|
564 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
565 |
+
raise ImportError(
|
566 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
567 |
+
)
|
568 |
+
|
569 |
+
def __call__(
|
570 |
+
self,
|
571 |
+
attn: Attention,
|
572 |
+
hidden_states: torch.FloatTensor,
|
573 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
574 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
575 |
+
) -> torch.FloatTensor:
|
576 |
+
residual = hidden_states
|
577 |
+
|
578 |
+
input_ndim = hidden_states.ndim
|
579 |
+
|
580 |
+
if input_ndim == 4:
|
581 |
+
batch_size, channel, height, width = hidden_states.shape
|
582 |
+
hidden_states = hidden_states.view(
|
583 |
+
batch_size, channel, height * width
|
584 |
+
).transpose(1, 2)
|
585 |
+
|
586 |
+
batch_size, sequence_length, _ = (
|
587 |
+
hidden_states.shape
|
588 |
+
if encoder_hidden_states is None
|
589 |
+
else encoder_hidden_states.shape
|
590 |
+
)
|
591 |
+
|
592 |
+
if attention_mask is not None:
|
593 |
+
attention_mask = attn.prepare_attention_mask(
|
594 |
+
attention_mask, sequence_length, batch_size
|
595 |
+
)
|
596 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
597 |
+
# (batch, heads, source_length, target_length)
|
598 |
+
attention_mask = attention_mask.view(
|
599 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
600 |
+
)
|
601 |
+
|
602 |
+
if attn.group_norm is not None:
|
603 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
604 |
+
1, 2
|
605 |
+
)
|
606 |
+
|
607 |
+
query = attn.to_q(hidden_states)
|
608 |
+
|
609 |
+
if encoder_hidden_states is None:
|
610 |
+
encoder_hidden_states = hidden_states
|
611 |
+
elif attn.norm_cross:
|
612 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
613 |
+
encoder_hidden_states
|
614 |
+
)
|
615 |
+
|
616 |
+
key = attn.to_k(encoder_hidden_states)
|
617 |
+
value = attn.to_v(encoder_hidden_states)
|
618 |
+
|
619 |
+
inner_dim = key.shape[-1]
|
620 |
+
head_dim = inner_dim // attn.heads
|
621 |
+
|
622 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
623 |
+
|
624 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
625 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
626 |
+
|
627 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
628 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
629 |
+
hidden_states = F.scaled_dot_product_attention(
|
630 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
631 |
+
)
|
632 |
+
|
633 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
634 |
+
batch_size, -1, attn.heads * head_dim
|
635 |
+
)
|
636 |
+
hidden_states = hidden_states.to(query.dtype)
|
637 |
+
|
638 |
+
# linear proj
|
639 |
+
hidden_states = attn.to_out[0](hidden_states)
|
640 |
+
# dropout
|
641 |
+
hidden_states = attn.to_out[1](hidden_states)
|
642 |
+
|
643 |
+
if input_ndim == 4:
|
644 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
645 |
+
batch_size, channel, height, width
|
646 |
+
)
|
647 |
+
|
648 |
+
if attn.residual_connection:
|
649 |
+
hidden_states = hidden_states + residual
|
650 |
+
|
651 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
652 |
+
|
653 |
+
return hidden_states
|
tsr/models/transformer/basic_transformer_block.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# --------
|
16 |
+
#
|
17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
18 |
+
#
|
19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
|
39 |
+
from typing import Optional
|
40 |
+
|
41 |
+
import torch
|
42 |
+
import torch.nn.functional as F
|
43 |
+
from torch import nn
|
44 |
+
|
45 |
+
from .attention import Attention
|
46 |
+
|
47 |
+
|
48 |
+
class BasicTransformerBlock(nn.Module):
|
49 |
+
r"""
|
50 |
+
A basic Transformer block.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
dim (`int`): The number of channels in the input and output.
|
54 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
55 |
+
attention_head_dim (`int`): The number of channels in each head.
|
56 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
57 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
58 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
59 |
+
attention_bias (:
|
60 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
61 |
+
only_cross_attention (`bool`, *optional*):
|
62 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
63 |
+
double_self_attention (`bool`, *optional*):
|
64 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
65 |
+
upcast_attention (`bool`, *optional*):
|
66 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
67 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
68 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
69 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
70 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
71 |
+
final_dropout (`bool` *optional*, defaults to False):
|
72 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
dim: int,
|
78 |
+
num_attention_heads: int,
|
79 |
+
attention_head_dim: int,
|
80 |
+
dropout=0.0,
|
81 |
+
cross_attention_dim: Optional[int] = None,
|
82 |
+
activation_fn: str = "geglu",
|
83 |
+
attention_bias: bool = False,
|
84 |
+
only_cross_attention: bool = False,
|
85 |
+
double_self_attention: bool = False,
|
86 |
+
upcast_attention: bool = False,
|
87 |
+
norm_elementwise_affine: bool = True,
|
88 |
+
norm_type: str = "layer_norm",
|
89 |
+
final_dropout: bool = False,
|
90 |
+
):
|
91 |
+
super().__init__()
|
92 |
+
self.only_cross_attention = only_cross_attention
|
93 |
+
|
94 |
+
assert norm_type == "layer_norm"
|
95 |
+
|
96 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
97 |
+
# 1. Self-Attn
|
98 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
99 |
+
self.attn1 = Attention(
|
100 |
+
query_dim=dim,
|
101 |
+
heads=num_attention_heads,
|
102 |
+
dim_head=attention_head_dim,
|
103 |
+
dropout=dropout,
|
104 |
+
bias=attention_bias,
|
105 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
106 |
+
upcast_attention=upcast_attention,
|
107 |
+
)
|
108 |
+
|
109 |
+
# 2. Cross-Attn
|
110 |
+
if cross_attention_dim is not None or double_self_attention:
|
111 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
112 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
113 |
+
# the second cross attention block.
|
114 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
115 |
+
|
116 |
+
self.attn2 = Attention(
|
117 |
+
query_dim=dim,
|
118 |
+
cross_attention_dim=(
|
119 |
+
cross_attention_dim if not double_self_attention else None
|
120 |
+
),
|
121 |
+
heads=num_attention_heads,
|
122 |
+
dim_head=attention_head_dim,
|
123 |
+
dropout=dropout,
|
124 |
+
bias=attention_bias,
|
125 |
+
upcast_attention=upcast_attention,
|
126 |
+
) # is self-attn if encoder_hidden_states is none
|
127 |
+
else:
|
128 |
+
self.norm2 = None
|
129 |
+
self.attn2 = None
|
130 |
+
|
131 |
+
# 3. Feed-forward
|
132 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
133 |
+
self.ff = FeedForward(
|
134 |
+
dim,
|
135 |
+
dropout=dropout,
|
136 |
+
activation_fn=activation_fn,
|
137 |
+
final_dropout=final_dropout,
|
138 |
+
)
|
139 |
+
|
140 |
+
# let chunk size default to None
|
141 |
+
self._chunk_size = None
|
142 |
+
self._chunk_dim = 0
|
143 |
+
|
144 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
145 |
+
# Sets chunk feed-forward
|
146 |
+
self._chunk_size = chunk_size
|
147 |
+
self._chunk_dim = dim
|
148 |
+
|
149 |
+
def forward(
|
150 |
+
self,
|
151 |
+
hidden_states: torch.FloatTensor,
|
152 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
153 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
154 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
155 |
+
) -> torch.FloatTensor:
|
156 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
157 |
+
# 0. Self-Attention
|
158 |
+
norm_hidden_states = self.norm1(hidden_states)
|
159 |
+
|
160 |
+
attn_output = self.attn1(
|
161 |
+
norm_hidden_states,
|
162 |
+
encoder_hidden_states=(
|
163 |
+
encoder_hidden_states if self.only_cross_attention else None
|
164 |
+
),
|
165 |
+
attention_mask=attention_mask,
|
166 |
+
)
|
167 |
+
|
168 |
+
hidden_states = attn_output + hidden_states
|
169 |
+
|
170 |
+
# 3. Cross-Attention
|
171 |
+
if self.attn2 is not None:
|
172 |
+
norm_hidden_states = self.norm2(hidden_states)
|
173 |
+
|
174 |
+
attn_output = self.attn2(
|
175 |
+
norm_hidden_states,
|
176 |
+
encoder_hidden_states=encoder_hidden_states,
|
177 |
+
attention_mask=encoder_attention_mask,
|
178 |
+
)
|
179 |
+
hidden_states = attn_output + hidden_states
|
180 |
+
|
181 |
+
# 4. Feed-forward
|
182 |
+
norm_hidden_states = self.norm3(hidden_states)
|
183 |
+
|
184 |
+
if self._chunk_size is not None:
|
185 |
+
# "feed_forward_chunk_size" can be used to save memory
|
186 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
187 |
+
raise ValueError(
|
188 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
189 |
+
)
|
190 |
+
|
191 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
192 |
+
ff_output = torch.cat(
|
193 |
+
[
|
194 |
+
self.ff(hid_slice)
|
195 |
+
for hid_slice in norm_hidden_states.chunk(
|
196 |
+
num_chunks, dim=self._chunk_dim
|
197 |
+
)
|
198 |
+
],
|
199 |
+
dim=self._chunk_dim,
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
ff_output = self.ff(norm_hidden_states)
|
203 |
+
|
204 |
+
hidden_states = ff_output + hidden_states
|
205 |
+
|
206 |
+
return hidden_states
|
207 |
+
|
208 |
+
|
209 |
+
class FeedForward(nn.Module):
|
210 |
+
r"""
|
211 |
+
A feed-forward layer.
|
212 |
+
|
213 |
+
Parameters:
|
214 |
+
dim (`int`): The number of channels in the input.
|
215 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
216 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
217 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
218 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
219 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
dim: int,
|
225 |
+
dim_out: Optional[int] = None,
|
226 |
+
mult: int = 4,
|
227 |
+
dropout: float = 0.0,
|
228 |
+
activation_fn: str = "geglu",
|
229 |
+
final_dropout: bool = False,
|
230 |
+
):
|
231 |
+
super().__init__()
|
232 |
+
inner_dim = int(dim * mult)
|
233 |
+
dim_out = dim_out if dim_out is not None else dim
|
234 |
+
linear_cls = nn.Linear
|
235 |
+
|
236 |
+
if activation_fn == "gelu":
|
237 |
+
act_fn = GELU(dim, inner_dim)
|
238 |
+
if activation_fn == "gelu-approximate":
|
239 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
240 |
+
elif activation_fn == "geglu":
|
241 |
+
act_fn = GEGLU(dim, inner_dim)
|
242 |
+
elif activation_fn == "geglu-approximate":
|
243 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
244 |
+
|
245 |
+
self.net = nn.ModuleList([])
|
246 |
+
# project in
|
247 |
+
self.net.append(act_fn)
|
248 |
+
# project dropout
|
249 |
+
self.net.append(nn.Dropout(dropout))
|
250 |
+
# project out
|
251 |
+
self.net.append(linear_cls(inner_dim, dim_out))
|
252 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
253 |
+
if final_dropout:
|
254 |
+
self.net.append(nn.Dropout(dropout))
|
255 |
+
|
256 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
257 |
+
for module in self.net:
|
258 |
+
hidden_states = module(hidden_states)
|
259 |
+
return hidden_states
|
260 |
+
|
261 |
+
|
262 |
+
class GELU(nn.Module):
|
263 |
+
r"""
|
264 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
265 |
+
|
266 |
+
Parameters:
|
267 |
+
dim_in (`int`): The number of channels in the input.
|
268 |
+
dim_out (`int`): The number of channels in the output.
|
269 |
+
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
270 |
+
"""
|
271 |
+
|
272 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
273 |
+
super().__init__()
|
274 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
275 |
+
self.approximate = approximate
|
276 |
+
|
277 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
278 |
+
if gate.device.type != "mps":
|
279 |
+
return F.gelu(gate, approximate=self.approximate)
|
280 |
+
# mps: gelu is not implemented for float16
|
281 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
|
282 |
+
dtype=gate.dtype
|
283 |
+
)
|
284 |
+
|
285 |
+
def forward(self, hidden_states):
|
286 |
+
hidden_states = self.proj(hidden_states)
|
287 |
+
hidden_states = self.gelu(hidden_states)
|
288 |
+
return hidden_states
|
289 |
+
|
290 |
+
|
291 |
+
class GEGLU(nn.Module):
|
292 |
+
r"""
|
293 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
294 |
+
|
295 |
+
Parameters:
|
296 |
+
dim_in (`int`): The number of channels in the input.
|
297 |
+
dim_out (`int`): The number of channels in the output.
|
298 |
+
"""
|
299 |
+
|
300 |
+
def __init__(self, dim_in: int, dim_out: int):
|
301 |
+
super().__init__()
|
302 |
+
linear_cls = nn.Linear
|
303 |
+
|
304 |
+
self.proj = linear_cls(dim_in, dim_out * 2)
|
305 |
+
|
306 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
307 |
+
if gate.device.type != "mps":
|
308 |
+
return F.gelu(gate)
|
309 |
+
# mps: gelu is not implemented for float16
|
310 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
311 |
+
|
312 |
+
def forward(self, hidden_states, scale: float = 1.0):
|
313 |
+
args = ()
|
314 |
+
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
315 |
+
return hidden_states * self.gelu(gate)
|
316 |
+
|
317 |
+
|
318 |
+
class ApproximateGELU(nn.Module):
|
319 |
+
r"""
|
320 |
+
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
|
321 |
+
https://arxiv.org/abs/1606.08415.
|
322 |
+
|
323 |
+
Parameters:
|
324 |
+
dim_in (`int`): The number of channels in the input.
|
325 |
+
dim_out (`int`): The number of channels in the output.
|
326 |
+
"""
|
327 |
+
|
328 |
+
def __init__(self, dim_in: int, dim_out: int):
|
329 |
+
super().__init__()
|
330 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
331 |
+
|
332 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
333 |
+
x = self.proj(x)
|
334 |
+
return x * torch.sigmoid(1.702 * x)
|
tsr/models/transformer/transformer_1d.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# --------
|
16 |
+
#
|
17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
18 |
+
#
|
19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
|
39 |
+
from dataclasses import dataclass
|
40 |
+
from typing import Optional
|
41 |
+
|
42 |
+
import torch
|
43 |
+
import torch.nn.functional as F
|
44 |
+
from torch import nn
|
45 |
+
|
46 |
+
from ...utils import BaseModule
|
47 |
+
from .basic_transformer_block import BasicTransformerBlock
|
48 |
+
|
49 |
+
|
50 |
+
class Transformer1D(BaseModule):
|
51 |
+
@dataclass
|
52 |
+
class Config(BaseModule.Config):
|
53 |
+
num_attention_heads: int = 16
|
54 |
+
attention_head_dim: int = 88
|
55 |
+
in_channels: Optional[int] = None
|
56 |
+
out_channels: Optional[int] = None
|
57 |
+
num_layers: int = 1
|
58 |
+
dropout: float = 0.0
|
59 |
+
norm_num_groups: int = 32
|
60 |
+
cross_attention_dim: Optional[int] = None
|
61 |
+
attention_bias: bool = False
|
62 |
+
activation_fn: str = "geglu"
|
63 |
+
only_cross_attention: bool = False
|
64 |
+
double_self_attention: bool = False
|
65 |
+
upcast_attention: bool = False
|
66 |
+
norm_type: str = "layer_norm"
|
67 |
+
norm_elementwise_affine: bool = True
|
68 |
+
gradient_checkpointing: bool = False
|
69 |
+
|
70 |
+
cfg: Config
|
71 |
+
|
72 |
+
def configure(self) -> None:
|
73 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
74 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
75 |
+
inner_dim = self.num_attention_heads * self.attention_head_dim
|
76 |
+
|
77 |
+
linear_cls = nn.Linear
|
78 |
+
|
79 |
+
# 2. Define input layers
|
80 |
+
self.in_channels = self.cfg.in_channels
|
81 |
+
|
82 |
+
self.norm = torch.nn.GroupNorm(
|
83 |
+
num_groups=self.cfg.norm_num_groups,
|
84 |
+
num_channels=self.cfg.in_channels,
|
85 |
+
eps=1e-6,
|
86 |
+
affine=True,
|
87 |
+
)
|
88 |
+
self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
|
89 |
+
|
90 |
+
# 3. Define transformers blocks
|
91 |
+
self.transformer_blocks = nn.ModuleList(
|
92 |
+
[
|
93 |
+
BasicTransformerBlock(
|
94 |
+
inner_dim,
|
95 |
+
self.num_attention_heads,
|
96 |
+
self.attention_head_dim,
|
97 |
+
dropout=self.cfg.dropout,
|
98 |
+
cross_attention_dim=self.cfg.cross_attention_dim,
|
99 |
+
activation_fn=self.cfg.activation_fn,
|
100 |
+
attention_bias=self.cfg.attention_bias,
|
101 |
+
only_cross_attention=self.cfg.only_cross_attention,
|
102 |
+
double_self_attention=self.cfg.double_self_attention,
|
103 |
+
upcast_attention=self.cfg.upcast_attention,
|
104 |
+
norm_type=self.cfg.norm_type,
|
105 |
+
norm_elementwise_affine=self.cfg.norm_elementwise_affine,
|
106 |
+
)
|
107 |
+
for d in range(self.cfg.num_layers)
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
# 4. Define output layers
|
112 |
+
self.out_channels = (
|
113 |
+
self.cfg.in_channels
|
114 |
+
if self.cfg.out_channels is None
|
115 |
+
else self.cfg.out_channels
|
116 |
+
)
|
117 |
+
|
118 |
+
self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
|
119 |
+
|
120 |
+
self.gradient_checkpointing = self.cfg.gradient_checkpointing
|
121 |
+
|
122 |
+
def forward(
|
123 |
+
self,
|
124 |
+
hidden_states: torch.Tensor,
|
125 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
126 |
+
attention_mask: Optional[torch.Tensor] = None,
|
127 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
The [`Transformer1DModel`] forward method.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
134 |
+
Input `hidden_states`.
|
135 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
136 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
137 |
+
self-attention.
|
138 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
139 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
140 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
141 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
142 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
143 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
144 |
+
|
145 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
146 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
147 |
+
|
148 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
149 |
+
above. This bias will be added to the cross-attention scores.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
torch.FloatTensor
|
153 |
+
"""
|
154 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
155 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
156 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
157 |
+
# expects mask of shape:
|
158 |
+
# [batch, key_tokens]
|
159 |
+
# adds singleton query_tokens dimension:
|
160 |
+
# [batch, 1, key_tokens]
|
161 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
162 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
163 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
164 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
165 |
+
# assume that mask is expressed as:
|
166 |
+
# (1 = keep, 0 = discard)
|
167 |
+
# convert mask into a bias that can be added to attention scores:
|
168 |
+
# (keep = +0, discard = -10000.0)
|
169 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
170 |
+
attention_mask = attention_mask.unsqueeze(1)
|
171 |
+
|
172 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
173 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
174 |
+
encoder_attention_mask = (
|
175 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
176 |
+
) * -10000.0
|
177 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
178 |
+
|
179 |
+
# 1. Input
|
180 |
+
batch, _, seq_len = hidden_states.shape
|
181 |
+
residual = hidden_states
|
182 |
+
|
183 |
+
hidden_states = self.norm(hidden_states)
|
184 |
+
inner_dim = hidden_states.shape[1]
|
185 |
+
hidden_states = hidden_states.permute(0, 2, 1).reshape(
|
186 |
+
batch, seq_len, inner_dim
|
187 |
+
)
|
188 |
+
hidden_states = self.proj_in(hidden_states)
|
189 |
+
|
190 |
+
# 2. Blocks
|
191 |
+
for block in self.transformer_blocks:
|
192 |
+
if self.training and self.gradient_checkpointing:
|
193 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
194 |
+
block,
|
195 |
+
hidden_states,
|
196 |
+
attention_mask,
|
197 |
+
encoder_hidden_states,
|
198 |
+
encoder_attention_mask,
|
199 |
+
use_reentrant=False,
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
hidden_states = block(
|
203 |
+
hidden_states,
|
204 |
+
attention_mask=attention_mask,
|
205 |
+
encoder_hidden_states=encoder_hidden_states,
|
206 |
+
encoder_attention_mask=encoder_attention_mask,
|
207 |
+
)
|
208 |
+
|
209 |
+
# 3. Output
|
210 |
+
hidden_states = self.proj_out(hidden_states)
|
211 |
+
hidden_states = (
|
212 |
+
hidden_states.reshape(batch, seq_len, inner_dim)
|
213 |
+
.permute(0, 2, 1)
|
214 |
+
.contiguous()
|
215 |
+
)
|
216 |
+
|
217 |
+
output = hidden_states + residual
|
218 |
+
|
219 |
+
return output
|
tsr/system.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import List, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import trimesh
|
11 |
+
from einops import rearrange
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from .models.isosurface import MarchingCubeHelper
|
17 |
+
from .utils import (
|
18 |
+
BaseModule,
|
19 |
+
ImagePreprocessor,
|
20 |
+
find_class,
|
21 |
+
get_spherical_cameras,
|
22 |
+
scale_tensor,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class TSR(BaseModule):
|
27 |
+
@dataclass
|
28 |
+
class Config(BaseModule.Config):
|
29 |
+
cond_image_size: int
|
30 |
+
|
31 |
+
image_tokenizer_cls: str
|
32 |
+
image_tokenizer: dict
|
33 |
+
|
34 |
+
tokenizer_cls: str
|
35 |
+
tokenizer: dict
|
36 |
+
|
37 |
+
backbone_cls: str
|
38 |
+
backbone: dict
|
39 |
+
|
40 |
+
post_processor_cls: str
|
41 |
+
post_processor: dict
|
42 |
+
|
43 |
+
decoder_cls: str
|
44 |
+
decoder: dict
|
45 |
+
|
46 |
+
renderer_cls: str
|
47 |
+
renderer: dict
|
48 |
+
|
49 |
+
cfg: Config
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def from_pretrained(
|
53 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
54 |
+
):
|
55 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
56 |
+
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
57 |
+
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
58 |
+
else:
|
59 |
+
config_path = hf_hub_download(
|
60 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
61 |
+
)
|
62 |
+
weight_path = hf_hub_download(
|
63 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
64 |
+
)
|
65 |
+
|
66 |
+
cfg = OmegaConf.load(config_path)
|
67 |
+
OmegaConf.resolve(cfg)
|
68 |
+
model = cls(cfg)
|
69 |
+
ckpt = torch.load(weight_path, map_location="cpu")
|
70 |
+
model.load_state_dict(ckpt)
|
71 |
+
return model
|
72 |
+
|
73 |
+
def configure(self):
|
74 |
+
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
|
75 |
+
self.cfg.image_tokenizer
|
76 |
+
)
|
77 |
+
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
78 |
+
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
|
79 |
+
self.post_processor = find_class(self.cfg.post_processor_cls)(
|
80 |
+
self.cfg.post_processor
|
81 |
+
)
|
82 |
+
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
|
83 |
+
self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
|
84 |
+
self.image_processor = ImagePreprocessor()
|
85 |
+
self.isosurface_helper = None
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
image: Union[
|
90 |
+
PIL.Image.Image,
|
91 |
+
np.ndarray,
|
92 |
+
torch.FloatTensor,
|
93 |
+
List[PIL.Image.Image],
|
94 |
+
List[np.ndarray],
|
95 |
+
List[torch.FloatTensor],
|
96 |
+
],
|
97 |
+
device: str,
|
98 |
+
) -> torch.FloatTensor:
|
99 |
+
rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
|
100 |
+
device
|
101 |
+
)
|
102 |
+
batch_size = rgb_cond.shape[0]
|
103 |
+
|
104 |
+
input_image_tokens: torch.Tensor = self.image_tokenizer(
|
105 |
+
rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
|
106 |
+
)
|
107 |
+
|
108 |
+
input_image_tokens = rearrange(
|
109 |
+
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
|
110 |
+
)
|
111 |
+
|
112 |
+
tokens: torch.Tensor = self.tokenizer(batch_size)
|
113 |
+
|
114 |
+
tokens = self.backbone(
|
115 |
+
tokens,
|
116 |
+
encoder_hidden_states=input_image_tokens,
|
117 |
+
)
|
118 |
+
|
119 |
+
scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
|
120 |
+
return scene_codes
|
121 |
+
|
122 |
+
def render(
|
123 |
+
self,
|
124 |
+
scene_codes,
|
125 |
+
n_views: int,
|
126 |
+
elevation_deg: float = 0.0,
|
127 |
+
camera_distance: float = 1.9,
|
128 |
+
fovy_deg: float = 40.0,
|
129 |
+
height: int = 256,
|
130 |
+
width: int = 256,
|
131 |
+
return_type: str = "pil",
|
132 |
+
):
|
133 |
+
rays_o, rays_d = get_spherical_cameras(
|
134 |
+
n_views, elevation_deg, camera_distance, fovy_deg, height, width
|
135 |
+
)
|
136 |
+
rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
|
137 |
+
|
138 |
+
def process_output(image: torch.FloatTensor):
|
139 |
+
if return_type == "pt":
|
140 |
+
return image
|
141 |
+
elif return_type == "np":
|
142 |
+
return image.detach().cpu().numpy()
|
143 |
+
elif return_type == "pil":
|
144 |
+
return Image.fromarray(
|
145 |
+
(image.detach().cpu().numpy() * 255.0).astype(np.uint8)
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError
|
149 |
+
|
150 |
+
images = []
|
151 |
+
for scene_code in scene_codes:
|
152 |
+
images_ = []
|
153 |
+
for i in range(n_views):
|
154 |
+
with torch.no_grad():
|
155 |
+
image = self.renderer(
|
156 |
+
self.decoder, scene_code, rays_o[i], rays_d[i]
|
157 |
+
)
|
158 |
+
images_.append(process_output(image))
|
159 |
+
images.append(images_)
|
160 |
+
|
161 |
+
return images
|
162 |
+
|
163 |
+
def set_marching_cubes_resolution(self, resolution: int):
|
164 |
+
if (
|
165 |
+
self.isosurface_helper is not None
|
166 |
+
and self.isosurface_helper.resolution == resolution
|
167 |
+
):
|
168 |
+
return
|
169 |
+
self.isosurface_helper = MarchingCubeHelper(resolution)
|
170 |
+
|
171 |
+
def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
|
172 |
+
self.set_marching_cubes_resolution(resolution)
|
173 |
+
meshes = []
|
174 |
+
for scene_code in scene_codes:
|
175 |
+
with torch.no_grad():
|
176 |
+
density = self.renderer.query_triplane(
|
177 |
+
self.decoder,
|
178 |
+
scale_tensor(
|
179 |
+
self.isosurface_helper.grid_vertices.to(scene_codes.device),
|
180 |
+
self.isosurface_helper.points_range,
|
181 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
182 |
+
),
|
183 |
+
scene_code,
|
184 |
+
)["density_act"]
|
185 |
+
v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
|
186 |
+
v_pos = scale_tensor(
|
187 |
+
v_pos,
|
188 |
+
self.isosurface_helper.points_range,
|
189 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
190 |
+
)
|
191 |
+
with torch.no_grad():
|
192 |
+
color = self.renderer.query_triplane(
|
193 |
+
self.decoder,
|
194 |
+
v_pos,
|
195 |
+
scene_code,
|
196 |
+
)["color"]
|
197 |
+
mesh = trimesh.Trimesh(
|
198 |
+
vertices=v_pos.cpu().numpy(),
|
199 |
+
faces=t_pos_idx.cpu().numpy(),
|
200 |
+
vertex_colors=color.cpu().numpy(),
|
201 |
+
)
|
202 |
+
meshes.append(mesh)
|
203 |
+
return meshes
|
tsr/utils.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import math
|
3 |
+
from collections import defaultdict
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import imageio
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import rembg
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import trimesh
|
15 |
+
from omegaconf import DictConfig, OmegaConf
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
20 |
+
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
|
21 |
+
return scfg
|
22 |
+
|
23 |
+
|
24 |
+
def find_class(cls_string):
|
25 |
+
module_string = ".".join(cls_string.split(".")[:-1])
|
26 |
+
cls_name = cls_string.split(".")[-1]
|
27 |
+
module = importlib.import_module(module_string, package=None)
|
28 |
+
cls = getattr(module, cls_name)
|
29 |
+
return cls
|
30 |
+
|
31 |
+
|
32 |
+
def get_intrinsic_from_fov(fov, H, W, bs=-1):
|
33 |
+
focal_length = 0.5 * H / np.tan(0.5 * fov)
|
34 |
+
intrinsic = np.identity(3, dtype=np.float32)
|
35 |
+
intrinsic[0, 0] = focal_length
|
36 |
+
intrinsic[1, 1] = focal_length
|
37 |
+
intrinsic[0, 2] = W / 2.0
|
38 |
+
intrinsic[1, 2] = H / 2.0
|
39 |
+
|
40 |
+
if bs > 0:
|
41 |
+
intrinsic = intrinsic[None].repeat(bs, axis=0)
|
42 |
+
|
43 |
+
return torch.from_numpy(intrinsic)
|
44 |
+
|
45 |
+
|
46 |
+
class BaseModule(nn.Module):
|
47 |
+
@dataclass
|
48 |
+
class Config:
|
49 |
+
pass
|
50 |
+
|
51 |
+
cfg: Config # add this to every subclass of BaseModule to enable static type checking
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
|
55 |
+
) -> None:
|
56 |
+
super().__init__()
|
57 |
+
self.cfg = parse_structured(self.Config, cfg)
|
58 |
+
self.configure(*args, **kwargs)
|
59 |
+
|
60 |
+
def configure(self, *args, **kwargs) -> None:
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
|
64 |
+
class ImagePreprocessor:
|
65 |
+
def convert_and_resize(
|
66 |
+
self,
|
67 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
68 |
+
size: int,
|
69 |
+
):
|
70 |
+
if isinstance(image, PIL.Image.Image):
|
71 |
+
image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
72 |
+
elif isinstance(image, np.ndarray):
|
73 |
+
if image.dtype == np.uint8:
|
74 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0)
|
75 |
+
else:
|
76 |
+
image = torch.from_numpy(image)
|
77 |
+
elif isinstance(image, torch.Tensor):
|
78 |
+
pass
|
79 |
+
|
80 |
+
batched = image.ndim == 4
|
81 |
+
|
82 |
+
if not batched:
|
83 |
+
image = image[None, ...]
|
84 |
+
image = F.interpolate(
|
85 |
+
image.permute(0, 3, 1, 2),
|
86 |
+
(size, size),
|
87 |
+
mode="bilinear",
|
88 |
+
align_corners=False,
|
89 |
+
antialias=True,
|
90 |
+
).permute(0, 2, 3, 1)
|
91 |
+
if not batched:
|
92 |
+
image = image[0]
|
93 |
+
return image
|
94 |
+
|
95 |
+
def __call__(
|
96 |
+
self,
|
97 |
+
image: Union[
|
98 |
+
PIL.Image.Image,
|
99 |
+
np.ndarray,
|
100 |
+
torch.FloatTensor,
|
101 |
+
List[PIL.Image.Image],
|
102 |
+
List[np.ndarray],
|
103 |
+
List[torch.FloatTensor],
|
104 |
+
],
|
105 |
+
size: int,
|
106 |
+
) -> Any:
|
107 |
+
if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
|
108 |
+
image = self.convert_and_resize(image, size)
|
109 |
+
else:
|
110 |
+
if not isinstance(image, list):
|
111 |
+
image = [image]
|
112 |
+
image = [self.convert_and_resize(im, size) for im in image]
|
113 |
+
image = torch.stack(image, dim=0)
|
114 |
+
return image
|
115 |
+
|
116 |
+
|
117 |
+
def rays_intersect_bbox(
|
118 |
+
rays_o: torch.Tensor,
|
119 |
+
rays_d: torch.Tensor,
|
120 |
+
radius: float,
|
121 |
+
near: float = 0.0,
|
122 |
+
valid_thresh: float = 0.01,
|
123 |
+
):
|
124 |
+
input_shape = rays_o.shape[:-1]
|
125 |
+
rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
|
126 |
+
rays_d_valid = torch.where(
|
127 |
+
rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
|
128 |
+
)
|
129 |
+
if type(radius) in [int, float]:
|
130 |
+
radius = torch.FloatTensor(
|
131 |
+
[[-radius, radius], [-radius, radius], [-radius, radius]]
|
132 |
+
).to(rays_o.device)
|
133 |
+
radius = (
|
134 |
+
1.0 - 1.0e-3
|
135 |
+
) * radius # tighten the radius to make sure the intersection point lies in the bounding box
|
136 |
+
interx0 = (radius[..., 1] - rays_o) / rays_d_valid
|
137 |
+
interx1 = (radius[..., 0] - rays_o) / rays_d_valid
|
138 |
+
t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
|
139 |
+
t_far = torch.maximum(interx0, interx1).amin(dim=-1)
|
140 |
+
|
141 |
+
# check wheter a ray intersects the bbox or not
|
142 |
+
rays_valid = t_far - t_near > valid_thresh
|
143 |
+
|
144 |
+
t_near[torch.where(~rays_valid)] = 0.0
|
145 |
+
t_far[torch.where(~rays_valid)] = 0.0
|
146 |
+
|
147 |
+
t_near = t_near.view(*input_shape, 1)
|
148 |
+
t_far = t_far.view(*input_shape, 1)
|
149 |
+
rays_valid = rays_valid.view(*input_shape)
|
150 |
+
|
151 |
+
return t_near, t_far, rays_valid
|
152 |
+
|
153 |
+
|
154 |
+
def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
|
155 |
+
if chunk_size <= 0:
|
156 |
+
return func(*args, **kwargs)
|
157 |
+
B = None
|
158 |
+
for arg in list(args) + list(kwargs.values()):
|
159 |
+
if isinstance(arg, torch.Tensor):
|
160 |
+
B = arg.shape[0]
|
161 |
+
break
|
162 |
+
assert (
|
163 |
+
B is not None
|
164 |
+
), "No tensor found in args or kwargs, cannot determine batch size."
|
165 |
+
out = defaultdict(list)
|
166 |
+
out_type = None
|
167 |
+
# max(1, B) to support B == 0
|
168 |
+
for i in range(0, max(1, B), chunk_size):
|
169 |
+
out_chunk = func(
|
170 |
+
*[
|
171 |
+
arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
172 |
+
for arg in args
|
173 |
+
],
|
174 |
+
**{
|
175 |
+
k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
176 |
+
for k, arg in kwargs.items()
|
177 |
+
},
|
178 |
+
)
|
179 |
+
if out_chunk is None:
|
180 |
+
continue
|
181 |
+
out_type = type(out_chunk)
|
182 |
+
if isinstance(out_chunk, torch.Tensor):
|
183 |
+
out_chunk = {0: out_chunk}
|
184 |
+
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
|
185 |
+
chunk_length = len(out_chunk)
|
186 |
+
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
|
187 |
+
elif isinstance(out_chunk, dict):
|
188 |
+
pass
|
189 |
+
else:
|
190 |
+
print(
|
191 |
+
f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
|
192 |
+
)
|
193 |
+
exit(1)
|
194 |
+
for k, v in out_chunk.items():
|
195 |
+
v = v if torch.is_grad_enabled() else v.detach()
|
196 |
+
out[k].append(v)
|
197 |
+
|
198 |
+
if out_type is None:
|
199 |
+
return None
|
200 |
+
|
201 |
+
out_merged: Dict[Any, Optional[torch.Tensor]] = {}
|
202 |
+
for k, v in out.items():
|
203 |
+
if all([vv is None for vv in v]):
|
204 |
+
# allow None in return value
|
205 |
+
out_merged[k] = None
|
206 |
+
elif all([isinstance(vv, torch.Tensor) for vv in v]):
|
207 |
+
out_merged[k] = torch.cat(v, dim=0)
|
208 |
+
else:
|
209 |
+
raise TypeError(
|
210 |
+
f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
|
211 |
+
)
|
212 |
+
|
213 |
+
if out_type is torch.Tensor:
|
214 |
+
return out_merged[0]
|
215 |
+
elif out_type in [tuple, list]:
|
216 |
+
return out_type([out_merged[i] for i in range(chunk_length)])
|
217 |
+
elif out_type is dict:
|
218 |
+
return out_merged
|
219 |
+
|
220 |
+
|
221 |
+
ValidScale = Union[Tuple[float, float], torch.FloatTensor]
|
222 |
+
|
223 |
+
|
224 |
+
def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
|
225 |
+
if inp_scale is None:
|
226 |
+
inp_scale = (0, 1)
|
227 |
+
if tgt_scale is None:
|
228 |
+
tgt_scale = (0, 1)
|
229 |
+
if isinstance(tgt_scale, torch.FloatTensor):
|
230 |
+
assert dat.shape[-1] == tgt_scale.shape[-1]
|
231 |
+
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
232 |
+
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
233 |
+
return dat
|
234 |
+
|
235 |
+
|
236 |
+
def get_activation(name) -> Callable:
|
237 |
+
if name is None:
|
238 |
+
return lambda x: x
|
239 |
+
name = name.lower()
|
240 |
+
if name == "none":
|
241 |
+
return lambda x: x
|
242 |
+
elif name == "exp":
|
243 |
+
return lambda x: torch.exp(x)
|
244 |
+
elif name == "sigmoid":
|
245 |
+
return lambda x: torch.sigmoid(x)
|
246 |
+
elif name == "tanh":
|
247 |
+
return lambda x: torch.tanh(x)
|
248 |
+
elif name == "softplus":
|
249 |
+
return lambda x: F.softplus(x)
|
250 |
+
else:
|
251 |
+
try:
|
252 |
+
return getattr(F, name)
|
253 |
+
except AttributeError:
|
254 |
+
raise ValueError(f"Unknown activation function: {name}")
|
255 |
+
|
256 |
+
|
257 |
+
def get_ray_directions(
|
258 |
+
H: int,
|
259 |
+
W: int,
|
260 |
+
focal: Union[float, Tuple[float, float]],
|
261 |
+
principal: Optional[Tuple[float, float]] = None,
|
262 |
+
use_pixel_centers: bool = True,
|
263 |
+
normalize: bool = True,
|
264 |
+
) -> torch.FloatTensor:
|
265 |
+
"""
|
266 |
+
Get ray directions for all pixels in camera coordinate.
|
267 |
+
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
|
268 |
+
ray-tracing-generating-camera-rays/standard-coordinate-systems
|
269 |
+
|
270 |
+
Inputs:
|
271 |
+
H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
|
272 |
+
Outputs:
|
273 |
+
directions: (H, W, 3), the direction of the rays in camera coordinate
|
274 |
+
"""
|
275 |
+
pixel_center = 0.5 if use_pixel_centers else 0
|
276 |
+
|
277 |
+
if isinstance(focal, float):
|
278 |
+
fx, fy = focal, focal
|
279 |
+
cx, cy = W / 2, H / 2
|
280 |
+
else:
|
281 |
+
fx, fy = focal
|
282 |
+
assert principal is not None
|
283 |
+
cx, cy = principal
|
284 |
+
|
285 |
+
i, j = torch.meshgrid(
|
286 |
+
torch.arange(W, dtype=torch.float32) + pixel_center,
|
287 |
+
torch.arange(H, dtype=torch.float32) + pixel_center,
|
288 |
+
indexing="xy",
|
289 |
+
)
|
290 |
+
|
291 |
+
directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
|
292 |
+
|
293 |
+
if normalize:
|
294 |
+
directions = F.normalize(directions, dim=-1)
|
295 |
+
|
296 |
+
return directions
|
297 |
+
|
298 |
+
|
299 |
+
def get_rays(
|
300 |
+
directions,
|
301 |
+
c2w,
|
302 |
+
keepdim=False,
|
303 |
+
normalize=False,
|
304 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
305 |
+
# Rotate ray directions from camera coordinate to the world coordinate
|
306 |
+
assert directions.shape[-1] == 3
|
307 |
+
|
308 |
+
if directions.ndim == 2: # (N_rays, 3)
|
309 |
+
if c2w.ndim == 2: # (4, 4)
|
310 |
+
c2w = c2w[None, :, :]
|
311 |
+
assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
|
312 |
+
rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
|
313 |
+
rays_o = c2w[:, :3, 3].expand(rays_d.shape)
|
314 |
+
elif directions.ndim == 3: # (H, W, 3)
|
315 |
+
assert c2w.ndim in [2, 3]
|
316 |
+
if c2w.ndim == 2: # (4, 4)
|
317 |
+
rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
|
318 |
+
-1
|
319 |
+
) # (H, W, 3)
|
320 |
+
rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
|
321 |
+
elif c2w.ndim == 3: # (B, 4, 4)
|
322 |
+
rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
323 |
+
-1
|
324 |
+
) # (B, H, W, 3)
|
325 |
+
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
326 |
+
elif directions.ndim == 4: # (B, H, W, 3)
|
327 |
+
assert c2w.ndim == 3 # (B, 4, 4)
|
328 |
+
rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
329 |
+
-1
|
330 |
+
) # (B, H, W, 3)
|
331 |
+
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
332 |
+
|
333 |
+
if normalize:
|
334 |
+
rays_d = F.normalize(rays_d, dim=-1)
|
335 |
+
if not keepdim:
|
336 |
+
rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
|
337 |
+
|
338 |
+
return rays_o, rays_d
|
339 |
+
|
340 |
+
|
341 |
+
def get_spherical_cameras(
|
342 |
+
n_views: int,
|
343 |
+
elevation_deg: float,
|
344 |
+
camera_distance: float,
|
345 |
+
fovy_deg: float,
|
346 |
+
height: int,
|
347 |
+
width: int,
|
348 |
+
):
|
349 |
+
azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
|
350 |
+
elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
|
351 |
+
camera_distances = torch.full_like(elevation_deg, camera_distance)
|
352 |
+
|
353 |
+
elevation = elevation_deg * math.pi / 180
|
354 |
+
azimuth = azimuth_deg * math.pi / 180
|
355 |
+
|
356 |
+
# convert spherical coordinates to cartesian coordinates
|
357 |
+
# right hand coordinate system, x back, y right, z up
|
358 |
+
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
|
359 |
+
camera_positions = torch.stack(
|
360 |
+
[
|
361 |
+
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
362 |
+
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
363 |
+
camera_distances * torch.sin(elevation),
|
364 |
+
],
|
365 |
+
dim=-1,
|
366 |
+
)
|
367 |
+
|
368 |
+
# default scene center at origin
|
369 |
+
center = torch.zeros_like(camera_positions)
|
370 |
+
# default camera up direction as +z
|
371 |
+
up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
|
372 |
+
|
373 |
+
fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
|
374 |
+
|
375 |
+
lookat = F.normalize(center - camera_positions, dim=-1)
|
376 |
+
right = F.normalize(torch.cross(lookat, up), dim=-1)
|
377 |
+
up = F.normalize(torch.cross(right, lookat), dim=-1)
|
378 |
+
c2w3x4 = torch.cat(
|
379 |
+
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
380 |
+
dim=-1,
|
381 |
+
)
|
382 |
+
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
|
383 |
+
c2w[:, 3, 3] = 1.0
|
384 |
+
|
385 |
+
# get directions by dividing directions_unit_focal by focal length
|
386 |
+
focal_length = 0.5 * height / torch.tan(0.5 * fovy)
|
387 |
+
directions_unit_focal = get_ray_directions(
|
388 |
+
H=height,
|
389 |
+
W=width,
|
390 |
+
focal=1.0,
|
391 |
+
)
|
392 |
+
directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
|
393 |
+
directions[:, :, :, :2] = (
|
394 |
+
directions[:, :, :, :2] / focal_length[:, None, None, None]
|
395 |
+
)
|
396 |
+
# must use normalize=True to normalize directions here
|
397 |
+
rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
|
398 |
+
|
399 |
+
return rays_o, rays_d
|
400 |
+
|
401 |
+
|
402 |
+
def remove_background(
|
403 |
+
image: PIL.Image.Image,
|
404 |
+
rembg_session: Any = None,
|
405 |
+
force: bool = False,
|
406 |
+
**rembg_kwargs,
|
407 |
+
) -> PIL.Image.Image:
|
408 |
+
do_remove = True
|
409 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
410 |
+
do_remove = False
|
411 |
+
do_remove = do_remove or force
|
412 |
+
if do_remove:
|
413 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
414 |
+
return image
|
415 |
+
|
416 |
+
|
417 |
+
def resize_foreground(
|
418 |
+
image: PIL.Image.Image,
|
419 |
+
ratio: float,
|
420 |
+
) -> PIL.Image.Image:
|
421 |
+
image = np.array(image)
|
422 |
+
assert image.shape[-1] == 4
|
423 |
+
alpha = np.where(image[..., 3] > 0)
|
424 |
+
y1, y2, x1, x2 = (
|
425 |
+
alpha[0].min(),
|
426 |
+
alpha[0].max(),
|
427 |
+
alpha[1].min(),
|
428 |
+
alpha[1].max(),
|
429 |
+
)
|
430 |
+
# crop the foreground
|
431 |
+
fg = image[y1:y2, x1:x2]
|
432 |
+
# pad to square
|
433 |
+
size = max(fg.shape[0], fg.shape[1])
|
434 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
435 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
436 |
+
new_image = np.pad(
|
437 |
+
fg,
|
438 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
439 |
+
mode="constant",
|
440 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
441 |
+
)
|
442 |
+
|
443 |
+
# compute padding according to the ratio
|
444 |
+
new_size = int(new_image.shape[0] / ratio)
|
445 |
+
# pad to size, double side
|
446 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
447 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
448 |
+
new_image = np.pad(
|
449 |
+
new_image,
|
450 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
451 |
+
mode="constant",
|
452 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
453 |
+
)
|
454 |
+
new_image = PIL.Image.fromarray(new_image)
|
455 |
+
return new_image
|
456 |
+
|
457 |
+
|
458 |
+
def save_video(
|
459 |
+
frames: List[PIL.Image.Image],
|
460 |
+
output_path: str,
|
461 |
+
fps: int = 30,
|
462 |
+
):
|
463 |
+
# use imageio to save video
|
464 |
+
frames = [np.array(frame) for frame in frames]
|
465 |
+
writer = imageio.get_writer(output_path, fps=fps)
|
466 |
+
for frame in frames:
|
467 |
+
writer.append_data(frame)
|
468 |
+
writer.close()
|
469 |
+
|
470 |
+
|
471 |
+
def to_gradio_3d_orientation(mesh):
|
472 |
+
mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
|
473 |
+
mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
|
474 |
+
return mesh
|