Spaces:
Sleeping
Sleeping
Upload 18 files
Browse files- Dockerfile +54 -0
- app.py +28 -0
- main.py +51 -0
- model_quantized_compressed.pkl.gz +3 -0
- requirements.txt +16 -0
- tmp/output/trash.txt +0 -0
- tsr/bake_texture.py +170 -0
- tsr/models/isosurface.py +52 -0
- tsr/models/nerf_renderer.py +180 -0
- tsr/models/network_utils.py +124 -0
- tsr/models/tokenizers/image.py +66 -0
- tsr/models/tokenizers/triplane.py +45 -0
- tsr/models/transformer/attention.py +653 -0
- tsr/models/transformer/basic_transformer_block.py +334 -0
- tsr/models/transformer/transformer_1d.py +219 -0
- tsr/system.py +205 -0
- tsr/utils.py +474 -0
- utils.py +115 -0
Dockerfile
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stage 1: Build the dependencies
|
2 |
+
FROM python:3.12-bullseye AS builder
|
3 |
+
|
4 |
+
# Install required system packages
|
5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
6 |
+
git \
|
7 |
+
build-essential \
|
8 |
+
cmake \
|
9 |
+
libopenblas-dev \
|
10 |
+
libomp-dev \
|
11 |
+
&& apt-get clean \
|
12 |
+
&& rm -rf /var/lib/apt/lists/*
|
13 |
+
|
14 |
+
# Set the working directory to /app
|
15 |
+
WORKDIR /app
|
16 |
+
|
17 |
+
# Copy requirements and install dependencies
|
18 |
+
COPY requirements.txt /app/
|
19 |
+
|
20 |
+
# Install Python dependencies and torchmcubes
|
21 |
+
RUN pip install --upgrade pip setuptools wheel \
|
22 |
+
&& pip install -r requirements.txt \
|
23 |
+
&& pip install git+https://github.com/tatsy/torchmcubes.git@3aef8afa5f21b113afc4f4ea148baee850cbd472 \
|
24 |
+
&& rm -rf ~/.cache/pip
|
25 |
+
|
26 |
+
# Copy the application files
|
27 |
+
COPY . /app
|
28 |
+
|
29 |
+
# Stage 2: Final image
|
30 |
+
FROM python:3.12-slim-bullseye
|
31 |
+
|
32 |
+
# Set up a new user named "user"
|
33 |
+
RUN useradd user
|
34 |
+
|
35 |
+
# Set the home environment variable and PATH
|
36 |
+
ENV HOME=/home/user \
|
37 |
+
PATH=/home/user/.local/bin:$PATH
|
38 |
+
|
39 |
+
# Set the working directory to the user's home directory
|
40 |
+
WORKDIR $HOME/app
|
41 |
+
|
42 |
+
# Copy the application files and installed packages from the builder stage
|
43 |
+
COPY --from=builder /app $HOME/app
|
44 |
+
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
45 |
+
COPY --from=builder /usr/local/bin /usr/local/bin
|
46 |
+
|
47 |
+
# Change ownership of the app directory to the user
|
48 |
+
RUN chown -R user:user $HOME/app
|
49 |
+
|
50 |
+
# Switch to the "user" user
|
51 |
+
USER user
|
52 |
+
|
53 |
+
# Set the entry point to run the FastAPI application
|
54 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7960"]
|
app.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile
|
2 |
+
from fastapi.responses import FileResponse
|
3 |
+
import os
|
4 |
+
|
5 |
+
from main import load_model, generate_mesh
|
6 |
+
|
7 |
+
## create a new FASTAPI app instance
|
8 |
+
app=FastAPI()
|
9 |
+
|
10 |
+
model = load_model()
|
11 |
+
|
12 |
+
@app.get("/")
|
13 |
+
def home():
|
14 |
+
return {"message":"Hello World"}
|
15 |
+
|
16 |
+
# Define a function to handle the GET request at `/generate`
|
17 |
+
@app.post("/generate")
|
18 |
+
async def generate(image: UploadFile = File(...)):
|
19 |
+
|
20 |
+
# Save the uploaded image to a temporary location
|
21 |
+
temp_image_path = f"tmp/output/{image.filename}"
|
22 |
+
with open(temp_image_path, "wb") as f:
|
23 |
+
f.write(await image.read())
|
24 |
+
|
25 |
+
output_file_path = generate_mesh(image_path=temp_image_path ,output_dir='tmp/output/' ,model=model)
|
26 |
+
|
27 |
+
## return the generate text in Json reposne
|
28 |
+
return FileResponse(output_file_path, media_type='application/octet-stream', filename="output.obj")
|
main.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from utils import process_image, run_model
|
3 |
+
from boto3 import Session
|
4 |
+
import torch
|
5 |
+
import pickle
|
6 |
+
import datetime
|
7 |
+
import gzip
|
8 |
+
|
9 |
+
# Retrieve credentials from environment variables
|
10 |
+
session = Session(
|
11 |
+
aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
|
12 |
+
aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
|
13 |
+
region_name=os.getenv('AWS_DEFAULT_REGION')
|
14 |
+
)
|
15 |
+
s3 = session.client('s3')
|
16 |
+
|
17 |
+
def load_model():
|
18 |
+
with gzip.open('model_quantized_compressed.pkl.gz', 'rb') as f_in:
|
19 |
+
model_data = f_in.read()
|
20 |
+
|
21 |
+
model = pickle.loads(model_data)
|
22 |
+
print("Model Loaded")
|
23 |
+
return model
|
24 |
+
|
25 |
+
def upload_to_s3(file_path, bucket_name, s3_key):
|
26 |
+
with open(file_path, 'rb') as f:
|
27 |
+
s3.upload_fileobj(f, bucket_name, s3_key)
|
28 |
+
s3_url = f's3://{bucket_name}/{s3_key}'
|
29 |
+
return s3_url
|
30 |
+
|
31 |
+
def generate_mesh(image_path, output_dir, model):
|
32 |
+
print('Process start')
|
33 |
+
# Process the image
|
34 |
+
image = process_image(image_path, output_dir)
|
35 |
+
print('Process end')
|
36 |
+
|
37 |
+
print('Run start')
|
38 |
+
output_file_path = run_model(model, image, output_dir)
|
39 |
+
print('Run end')
|
40 |
+
|
41 |
+
# Upload the input image and generated mesh file to S3
|
42 |
+
bucket_name = 'vasana-bkt1'
|
43 |
+
input_s3_key = f'input_images/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}-{os.path.basename(image_path)}'
|
44 |
+
output_s3_key = f'output_meshes/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}-{os.path.basename(output_file_path)}'
|
45 |
+
|
46 |
+
input_s3_url = upload_to_s3(image_path, bucket_name, input_s3_key)
|
47 |
+
output_s3_url = upload_to_s3(output_file_path, bucket_name, output_s3_key)
|
48 |
+
|
49 |
+
print(f'Files uploaded to S3:\nInput Image: {input_s3_url}\nOutput Mesh: {output_s3_url}')
|
50 |
+
|
51 |
+
return output_file_path
|
model_quantized_compressed.pkl.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc8f71c24a7db66a83bd712463be983e9bff078b05153fb60dba3e84d5781955
|
3 |
+
size 320691353
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omegaconf==2.3.0
|
2 |
+
Pillow==10.1.0
|
3 |
+
einops==0.7.0
|
4 |
+
transformers==4.35.0
|
5 |
+
trimesh==4.0.5
|
6 |
+
rembg
|
7 |
+
huggingface-hub
|
8 |
+
imageio[ffmpeg]
|
9 |
+
xatlas==0.0.9
|
10 |
+
moderngl==5.10.0
|
11 |
+
torch==2.3.1
|
12 |
+
numpy==1.26.4
|
13 |
+
boto3==1.34.161
|
14 |
+
uvicorn==0.30.6
|
15 |
+
fastapi==0.112.2
|
16 |
+
python-multipart==0.0.9
|
tmp/output/trash.txt
ADDED
File without changes
|
tsr/bake_texture.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import xatlas
|
4 |
+
import trimesh
|
5 |
+
import moderngl
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
def make_atlas(mesh, texture_resolution, texture_padding):
|
10 |
+
atlas = xatlas.Atlas()
|
11 |
+
atlas.add_mesh(mesh.vertices, mesh.faces)
|
12 |
+
options = xatlas.PackOptions()
|
13 |
+
options.resolution = texture_resolution
|
14 |
+
options.padding = texture_padding
|
15 |
+
options.bilinear = True
|
16 |
+
atlas.generate(pack_options=options)
|
17 |
+
vmapping, indices, uvs = atlas[0]
|
18 |
+
return {
|
19 |
+
"vmapping": vmapping,
|
20 |
+
"indices": indices,
|
21 |
+
"uvs": uvs,
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
def rasterize_position_atlas(
|
26 |
+
mesh, atlas_vmapping, atlas_indices, atlas_uvs, texture_resolution, texture_padding
|
27 |
+
):
|
28 |
+
ctx = moderngl.create_context(standalone=True)
|
29 |
+
basic_prog = ctx.program(
|
30 |
+
vertex_shader="""
|
31 |
+
#version 330
|
32 |
+
in vec2 in_uv;
|
33 |
+
in vec3 in_pos;
|
34 |
+
out vec3 v_pos;
|
35 |
+
void main() {
|
36 |
+
v_pos = in_pos;
|
37 |
+
gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
|
38 |
+
}
|
39 |
+
""",
|
40 |
+
fragment_shader="""
|
41 |
+
#version 330
|
42 |
+
in vec3 v_pos;
|
43 |
+
out vec4 o_col;
|
44 |
+
void main() {
|
45 |
+
o_col = vec4(v_pos, 1.0);
|
46 |
+
}
|
47 |
+
""",
|
48 |
+
)
|
49 |
+
gs_prog = ctx.program(
|
50 |
+
vertex_shader="""
|
51 |
+
#version 330
|
52 |
+
in vec2 in_uv;
|
53 |
+
in vec3 in_pos;
|
54 |
+
out vec3 vg_pos;
|
55 |
+
void main() {
|
56 |
+
vg_pos = in_pos;
|
57 |
+
gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
|
58 |
+
}
|
59 |
+
""",
|
60 |
+
geometry_shader="""
|
61 |
+
#version 330
|
62 |
+
uniform float u_resolution;
|
63 |
+
uniform float u_dilation;
|
64 |
+
layout (triangles) in;
|
65 |
+
layout (triangle_strip, max_vertices = 12) out;
|
66 |
+
in vec3 vg_pos[];
|
67 |
+
out vec3 vf_pos;
|
68 |
+
void lineSegment(int aidx, int bidx) {
|
69 |
+
vec2 a = gl_in[aidx].gl_Position.xy;
|
70 |
+
vec2 b = gl_in[bidx].gl_Position.xy;
|
71 |
+
vec3 aCol = vg_pos[aidx];
|
72 |
+
vec3 bCol = vg_pos[bidx];
|
73 |
+
|
74 |
+
vec2 dir = normalize((b - a) * u_resolution);
|
75 |
+
vec2 offset = vec2(-dir.y, dir.x) * u_dilation / u_resolution;
|
76 |
+
|
77 |
+
gl_Position = vec4(a + offset, 0.0, 1.0);
|
78 |
+
vf_pos = aCol;
|
79 |
+
EmitVertex();
|
80 |
+
gl_Position = vec4(a - offset, 0.0, 1.0);
|
81 |
+
vf_pos = aCol;
|
82 |
+
EmitVertex();
|
83 |
+
gl_Position = vec4(b + offset, 0.0, 1.0);
|
84 |
+
vf_pos = bCol;
|
85 |
+
EmitVertex();
|
86 |
+
gl_Position = vec4(b - offset, 0.0, 1.0);
|
87 |
+
vf_pos = bCol;
|
88 |
+
EmitVertex();
|
89 |
+
}
|
90 |
+
void main() {
|
91 |
+
lineSegment(0, 1);
|
92 |
+
lineSegment(1, 2);
|
93 |
+
lineSegment(2, 0);
|
94 |
+
EndPrimitive();
|
95 |
+
}
|
96 |
+
""",
|
97 |
+
fragment_shader="""
|
98 |
+
#version 330
|
99 |
+
in vec3 vf_pos;
|
100 |
+
out vec4 o_col;
|
101 |
+
void main() {
|
102 |
+
o_col = vec4(vf_pos, 1.0);
|
103 |
+
}
|
104 |
+
""",
|
105 |
+
)
|
106 |
+
uvs = atlas_uvs.flatten().astype("f4")
|
107 |
+
pos = mesh.vertices[atlas_vmapping].flatten().astype("f4")
|
108 |
+
indices = atlas_indices.flatten().astype("i4")
|
109 |
+
vbo_uvs = ctx.buffer(uvs)
|
110 |
+
vbo_pos = ctx.buffer(pos)
|
111 |
+
ibo = ctx.buffer(indices)
|
112 |
+
vao_content = [
|
113 |
+
vbo_uvs.bind("in_uv", layout="2f"),
|
114 |
+
vbo_pos.bind("in_pos", layout="3f"),
|
115 |
+
]
|
116 |
+
basic_vao = ctx.vertex_array(basic_prog, vao_content, ibo)
|
117 |
+
gs_vao = ctx.vertex_array(gs_prog, vao_content, ibo)
|
118 |
+
fbo = ctx.framebuffer(
|
119 |
+
color_attachments=[
|
120 |
+
ctx.texture((texture_resolution, texture_resolution), 4, dtype="f4")
|
121 |
+
]
|
122 |
+
)
|
123 |
+
fbo.use()
|
124 |
+
fbo.clear(0.0, 0.0, 0.0, 0.0)
|
125 |
+
gs_prog["u_resolution"].value = texture_resolution
|
126 |
+
gs_prog["u_dilation"].value = texture_padding
|
127 |
+
gs_vao.render()
|
128 |
+
basic_vao.render()
|
129 |
+
|
130 |
+
fbo_bytes = fbo.color_attachments[0].read()
|
131 |
+
fbo_np = np.frombuffer(fbo_bytes, dtype="f4").reshape(
|
132 |
+
texture_resolution, texture_resolution, 4
|
133 |
+
)
|
134 |
+
return fbo_np
|
135 |
+
|
136 |
+
|
137 |
+
def positions_to_colors(model, scene_code, positions_texture, texture_resolution):
|
138 |
+
positions = torch.tensor(positions_texture.reshape(-1, 4)[:, :-1])
|
139 |
+
with torch.no_grad():
|
140 |
+
queried_grid = model.renderer.query_triplane(
|
141 |
+
model.decoder,
|
142 |
+
positions,
|
143 |
+
scene_code,
|
144 |
+
)
|
145 |
+
rgb_f = queried_grid["color"].numpy().reshape(-1, 3)
|
146 |
+
rgba_f = np.insert(rgb_f, 3, positions_texture.reshape(-1, 4)[:, -1], axis=1)
|
147 |
+
rgba_f[rgba_f[:, -1] == 0.0] = [0, 0, 0, 0]
|
148 |
+
return rgba_f.reshape(texture_resolution, texture_resolution, 4)
|
149 |
+
|
150 |
+
|
151 |
+
def bake_texture(mesh, model, scene_code, texture_resolution):
|
152 |
+
texture_padding = round(max(2, texture_resolution / 256))
|
153 |
+
atlas = make_atlas(mesh, texture_resolution, texture_padding)
|
154 |
+
positions_texture = rasterize_position_atlas(
|
155 |
+
mesh,
|
156 |
+
atlas["vmapping"],
|
157 |
+
atlas["indices"],
|
158 |
+
atlas["uvs"],
|
159 |
+
texture_resolution,
|
160 |
+
texture_padding,
|
161 |
+
)
|
162 |
+
colors_texture = positions_to_colors(
|
163 |
+
model, scene_code, positions_texture, texture_resolution
|
164 |
+
)
|
165 |
+
return {
|
166 |
+
"vmapping": atlas["vmapping"],
|
167 |
+
"indices": atlas["indices"],
|
168 |
+
"uvs": atlas["uvs"],
|
169 |
+
"colors": colors_texture,
|
170 |
+
}
|
tsr/models/isosurface.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchmcubes import marching_cubes
|
7 |
+
|
8 |
+
|
9 |
+
class IsosurfaceHelper(nn.Module):
|
10 |
+
points_range: Tuple[float, float] = (0, 1)
|
11 |
+
|
12 |
+
@property
|
13 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
14 |
+
raise NotImplementedError
|
15 |
+
|
16 |
+
|
17 |
+
class MarchingCubeHelper(IsosurfaceHelper):
|
18 |
+
def __init__(self, resolution: int) -> None:
|
19 |
+
super().__init__()
|
20 |
+
self.resolution = resolution
|
21 |
+
self.mc_func: Callable = marching_cubes
|
22 |
+
self._grid_vertices: Optional[torch.FloatTensor] = None
|
23 |
+
|
24 |
+
@property
|
25 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
26 |
+
if self._grid_vertices is None:
|
27 |
+
# keep the vertices on CPU so that we can support very large resolution
|
28 |
+
x, y, z = (
|
29 |
+
torch.linspace(*self.points_range, self.resolution),
|
30 |
+
torch.linspace(*self.points_range, self.resolution),
|
31 |
+
torch.linspace(*self.points_range, self.resolution),
|
32 |
+
)
|
33 |
+
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
|
34 |
+
verts = torch.cat(
|
35 |
+
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
|
36 |
+
).reshape(-1, 3)
|
37 |
+
self._grid_vertices = verts
|
38 |
+
return self._grid_vertices
|
39 |
+
|
40 |
+
def forward(
|
41 |
+
self,
|
42 |
+
level: torch.FloatTensor,
|
43 |
+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
44 |
+
level = -level.view(self.resolution, self.resolution, self.resolution)
|
45 |
+
try:
|
46 |
+
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
|
47 |
+
except AttributeError:
|
48 |
+
print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
|
49 |
+
v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
|
50 |
+
v_pos = v_pos[..., [2, 1, 0]]
|
51 |
+
v_pos = v_pos / (self.resolution - 1.0)
|
52 |
+
return v_pos.to(level.device), t_pos_idx.to(level.device)
|
tsr/models/nerf_renderer.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange, reduce
|
7 |
+
|
8 |
+
from ..utils import (
|
9 |
+
BaseModule,
|
10 |
+
chunk_batch,
|
11 |
+
get_activation,
|
12 |
+
rays_intersect_bbox,
|
13 |
+
scale_tensor,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class TriplaneNeRFRenderer(BaseModule):
|
18 |
+
@dataclass
|
19 |
+
class Config(BaseModule.Config):
|
20 |
+
radius: float
|
21 |
+
|
22 |
+
feature_reduction: str = "concat"
|
23 |
+
density_activation: str = "trunc_exp"
|
24 |
+
density_bias: float = -1.0
|
25 |
+
color_activation: str = "sigmoid"
|
26 |
+
num_samples_per_ray: int = 128
|
27 |
+
randomized: bool = False
|
28 |
+
|
29 |
+
cfg: Config
|
30 |
+
|
31 |
+
def configure(self) -> None:
|
32 |
+
assert self.cfg.feature_reduction in ["concat", "mean"]
|
33 |
+
self.chunk_size = 0
|
34 |
+
|
35 |
+
def set_chunk_size(self, chunk_size: int):
|
36 |
+
assert (
|
37 |
+
chunk_size >= 0
|
38 |
+
), "chunk_size must be a non-negative integer (0 for no chunking)."
|
39 |
+
self.chunk_size = chunk_size
|
40 |
+
|
41 |
+
def query_triplane(
|
42 |
+
self,
|
43 |
+
decoder: torch.nn.Module,
|
44 |
+
positions: torch.Tensor,
|
45 |
+
triplane: torch.Tensor,
|
46 |
+
) -> Dict[str, torch.Tensor]:
|
47 |
+
input_shape = positions.shape[:-1]
|
48 |
+
positions = positions.view(-1, 3)
|
49 |
+
|
50 |
+
# positions in (-radius, radius)
|
51 |
+
# normalized to (-1, 1) for grid sample
|
52 |
+
positions = scale_tensor(
|
53 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
54 |
+
)
|
55 |
+
|
56 |
+
def _query_chunk(x):
|
57 |
+
indices2D: torch.Tensor = torch.stack(
|
58 |
+
(x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
|
59 |
+
dim=-3,
|
60 |
+
)
|
61 |
+
out: torch.Tensor = F.grid_sample(
|
62 |
+
rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
|
63 |
+
rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
|
64 |
+
align_corners=False,
|
65 |
+
mode="bilinear",
|
66 |
+
)
|
67 |
+
if self.cfg.feature_reduction == "concat":
|
68 |
+
out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
|
69 |
+
elif self.cfg.feature_reduction == "mean":
|
70 |
+
out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
|
71 |
+
else:
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
net_out: Dict[str, torch.Tensor] = decoder(out)
|
75 |
+
return net_out
|
76 |
+
|
77 |
+
if self.chunk_size > 0:
|
78 |
+
net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
|
79 |
+
else:
|
80 |
+
net_out = _query_chunk(positions)
|
81 |
+
|
82 |
+
net_out["density_act"] = get_activation(self.cfg.density_activation)(
|
83 |
+
net_out["density"] + self.cfg.density_bias
|
84 |
+
)
|
85 |
+
net_out["color"] = get_activation(self.cfg.color_activation)(
|
86 |
+
net_out["features"]
|
87 |
+
)
|
88 |
+
|
89 |
+
net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
|
90 |
+
|
91 |
+
return net_out
|
92 |
+
|
93 |
+
def _forward(
|
94 |
+
self,
|
95 |
+
decoder: torch.nn.Module,
|
96 |
+
triplane: torch.Tensor,
|
97 |
+
rays_o: torch.Tensor,
|
98 |
+
rays_d: torch.Tensor,
|
99 |
+
**kwargs,
|
100 |
+
):
|
101 |
+
rays_shape = rays_o.shape[:-1]
|
102 |
+
rays_o = rays_o.view(-1, 3)
|
103 |
+
rays_d = rays_d.view(-1, 3)
|
104 |
+
n_rays = rays_o.shape[0]
|
105 |
+
|
106 |
+
t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
|
107 |
+
t_near, t_far = t_near[rays_valid], t_far[rays_valid]
|
108 |
+
|
109 |
+
t_vals = torch.linspace(
|
110 |
+
0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
|
111 |
+
)
|
112 |
+
t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
|
113 |
+
z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
|
114 |
+
|
115 |
+
xyz = (
|
116 |
+
rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
|
117 |
+
) # (N_rays, N_sample, 3)
|
118 |
+
|
119 |
+
mlp_out = self.query_triplane(
|
120 |
+
decoder=decoder,
|
121 |
+
positions=xyz,
|
122 |
+
triplane=triplane,
|
123 |
+
)
|
124 |
+
|
125 |
+
eps = 1e-10
|
126 |
+
# deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
|
127 |
+
deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
|
128 |
+
alpha = 1 - torch.exp(
|
129 |
+
-deltas * mlp_out["density_act"][..., 0]
|
130 |
+
) # (N_rays, N_samples)
|
131 |
+
accum_prod = torch.cat(
|
132 |
+
[
|
133 |
+
torch.ones_like(alpha[:, :1]),
|
134 |
+
torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
|
135 |
+
],
|
136 |
+
dim=-1,
|
137 |
+
)
|
138 |
+
weights = alpha * accum_prod # (N_rays, N_samples)
|
139 |
+
comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
|
140 |
+
opacity_ = weights.sum(dim=-1) # (N_rays)
|
141 |
+
|
142 |
+
comp_rgb = torch.zeros(
|
143 |
+
n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
|
144 |
+
)
|
145 |
+
opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
|
146 |
+
comp_rgb[rays_valid] = comp_rgb_
|
147 |
+
opacity[rays_valid] = opacity_
|
148 |
+
|
149 |
+
comp_rgb += 1 - opacity[..., None]
|
150 |
+
comp_rgb = comp_rgb.view(*rays_shape, 3)
|
151 |
+
|
152 |
+
return comp_rgb
|
153 |
+
|
154 |
+
def forward(
|
155 |
+
self,
|
156 |
+
decoder: torch.nn.Module,
|
157 |
+
triplane: torch.Tensor,
|
158 |
+
rays_o: torch.Tensor,
|
159 |
+
rays_d: torch.Tensor,
|
160 |
+
) -> Dict[str, torch.Tensor]:
|
161 |
+
if triplane.ndim == 4:
|
162 |
+
comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
|
163 |
+
else:
|
164 |
+
comp_rgb = torch.stack(
|
165 |
+
[
|
166 |
+
self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
|
167 |
+
for i in range(triplane.shape[0])
|
168 |
+
],
|
169 |
+
dim=0,
|
170 |
+
)
|
171 |
+
|
172 |
+
return comp_rgb
|
173 |
+
|
174 |
+
def train(self, mode=True):
|
175 |
+
self.randomized = mode and self.cfg.randomized
|
176 |
+
return super().train(mode=mode)
|
177 |
+
|
178 |
+
def eval(self):
|
179 |
+
self.randomized = False
|
180 |
+
return super().eval()
|
tsr/models/network_utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from ..utils import BaseModule
|
9 |
+
|
10 |
+
|
11 |
+
class TriplaneUpsampleNetwork(BaseModule):
|
12 |
+
@dataclass
|
13 |
+
class Config(BaseModule.Config):
|
14 |
+
in_channels: int
|
15 |
+
out_channels: int
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.upsample = nn.ConvTranspose2d(
|
21 |
+
self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
|
25 |
+
triplanes_up = rearrange(
|
26 |
+
self.upsample(
|
27 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
28 |
+
),
|
29 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
30 |
+
Np=3,
|
31 |
+
)
|
32 |
+
return triplanes_up
|
33 |
+
|
34 |
+
|
35 |
+
class NeRFMLP(BaseModule):
|
36 |
+
@dataclass
|
37 |
+
class Config(BaseModule.Config):
|
38 |
+
in_channels: int
|
39 |
+
n_neurons: int
|
40 |
+
n_hidden_layers: int
|
41 |
+
activation: str = "relu"
|
42 |
+
bias: bool = True
|
43 |
+
weight_init: Optional[str] = "kaiming_uniform"
|
44 |
+
bias_init: Optional[str] = None
|
45 |
+
|
46 |
+
cfg: Config
|
47 |
+
|
48 |
+
def configure(self) -> None:
|
49 |
+
layers = [
|
50 |
+
self.make_linear(
|
51 |
+
self.cfg.in_channels,
|
52 |
+
self.cfg.n_neurons,
|
53 |
+
bias=self.cfg.bias,
|
54 |
+
weight_init=self.cfg.weight_init,
|
55 |
+
bias_init=self.cfg.bias_init,
|
56 |
+
),
|
57 |
+
self.make_activation(self.cfg.activation),
|
58 |
+
]
|
59 |
+
for i in range(self.cfg.n_hidden_layers - 1):
|
60 |
+
layers += [
|
61 |
+
self.make_linear(
|
62 |
+
self.cfg.n_neurons,
|
63 |
+
self.cfg.n_neurons,
|
64 |
+
bias=self.cfg.bias,
|
65 |
+
weight_init=self.cfg.weight_init,
|
66 |
+
bias_init=self.cfg.bias_init,
|
67 |
+
),
|
68 |
+
self.make_activation(self.cfg.activation),
|
69 |
+
]
|
70 |
+
layers += [
|
71 |
+
self.make_linear(
|
72 |
+
self.cfg.n_neurons,
|
73 |
+
4, # density 1 + features 3
|
74 |
+
bias=self.cfg.bias,
|
75 |
+
weight_init=self.cfg.weight_init,
|
76 |
+
bias_init=self.cfg.bias_init,
|
77 |
+
)
|
78 |
+
]
|
79 |
+
self.layers = nn.Sequential(*layers)
|
80 |
+
|
81 |
+
def make_linear(
|
82 |
+
self,
|
83 |
+
dim_in,
|
84 |
+
dim_out,
|
85 |
+
bias=True,
|
86 |
+
weight_init=None,
|
87 |
+
bias_init=None,
|
88 |
+
):
|
89 |
+
layer = nn.Linear(dim_in, dim_out, bias=bias)
|
90 |
+
|
91 |
+
if weight_init is None:
|
92 |
+
pass
|
93 |
+
elif weight_init == "kaiming_uniform":
|
94 |
+
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
if bias:
|
99 |
+
if bias_init is None:
|
100 |
+
pass
|
101 |
+
elif bias_init == "zero":
|
102 |
+
torch.nn.init.zeros_(layer.bias)
|
103 |
+
else:
|
104 |
+
raise NotImplementedError
|
105 |
+
|
106 |
+
return layer
|
107 |
+
|
108 |
+
def make_activation(self, activation):
|
109 |
+
if activation == "relu":
|
110 |
+
return nn.ReLU(inplace=True)
|
111 |
+
elif activation == "silu":
|
112 |
+
return nn.SiLU(inplace=True)
|
113 |
+
else:
|
114 |
+
raise NotImplementedError
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
inp_shape = x.shape[:-1]
|
118 |
+
x = x.reshape(-1, x.shape[-1])
|
119 |
+
|
120 |
+
features = self.layers(x)
|
121 |
+
features = features.reshape(*inp_shape, -1)
|
122 |
+
out = {"density": features[..., 0:1], "features": features[..., 1:4]}
|
123 |
+
|
124 |
+
return out
|
tsr/models/tokenizers/image.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from transformers.models.vit.modeling_vit import ViTModel
|
8 |
+
|
9 |
+
from ...utils import BaseModule
|
10 |
+
|
11 |
+
|
12 |
+
class DINOSingleImageTokenizer(BaseModule):
|
13 |
+
@dataclass
|
14 |
+
class Config(BaseModule.Config):
|
15 |
+
pretrained_model_name_or_path: str = "facebook/dino-vitb16"
|
16 |
+
enable_gradient_checkpointing: bool = False
|
17 |
+
|
18 |
+
cfg: Config
|
19 |
+
|
20 |
+
def configure(self) -> None:
|
21 |
+
self.model: ViTModel = ViTModel(
|
22 |
+
ViTModel.config_class.from_pretrained(
|
23 |
+
hf_hub_download(
|
24 |
+
repo_id=self.cfg.pretrained_model_name_or_path,
|
25 |
+
filename="config.json",
|
26 |
+
)
|
27 |
+
)
|
28 |
+
)
|
29 |
+
|
30 |
+
if self.cfg.enable_gradient_checkpointing:
|
31 |
+
self.model.encoder.gradient_checkpointing = True
|
32 |
+
|
33 |
+
self.register_buffer(
|
34 |
+
"image_mean",
|
35 |
+
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
|
36 |
+
persistent=False,
|
37 |
+
)
|
38 |
+
self.register_buffer(
|
39 |
+
"image_std",
|
40 |
+
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
|
41 |
+
persistent=False,
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
45 |
+
packed = False
|
46 |
+
if images.ndim == 4:
|
47 |
+
packed = True
|
48 |
+
images = images.unsqueeze(1)
|
49 |
+
|
50 |
+
batch_size, n_input_views = images.shape[:2]
|
51 |
+
images = (images - self.image_mean) / self.image_std
|
52 |
+
out = self.model(
|
53 |
+
rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
|
54 |
+
)
|
55 |
+
local_features, global_features = out.last_hidden_state, out.pooler_output
|
56 |
+
local_features = local_features.permute(0, 2, 1)
|
57 |
+
local_features = rearrange(
|
58 |
+
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
|
59 |
+
)
|
60 |
+
if packed:
|
61 |
+
local_features = local_features.squeeze(1)
|
62 |
+
|
63 |
+
return local_features
|
64 |
+
|
65 |
+
def detokenize(self, *args, **kwargs):
|
66 |
+
raise NotImplementedError
|
tsr/models/tokenizers/triplane.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
from ...utils import BaseModule
|
9 |
+
|
10 |
+
|
11 |
+
class Triplane1DTokenizer(BaseModule):
|
12 |
+
@dataclass
|
13 |
+
class Config(BaseModule.Config):
|
14 |
+
plane_size: int
|
15 |
+
num_channels: int
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.embeddings = nn.Parameter(
|
21 |
+
torch.randn(
|
22 |
+
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
|
23 |
+
dtype=torch.float32,
|
24 |
+
)
|
25 |
+
* 1
|
26 |
+
/ math.sqrt(self.cfg.num_channels)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, batch_size: int) -> torch.Tensor:
|
30 |
+
return rearrange(
|
31 |
+
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
|
32 |
+
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
|
33 |
+
)
|
34 |
+
|
35 |
+
def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
|
36 |
+
batch_size, Ct, Nt = tokens.shape
|
37 |
+
assert Nt == self.cfg.plane_size**2 * 3
|
38 |
+
assert Ct == self.cfg.num_channels
|
39 |
+
return rearrange(
|
40 |
+
tokens,
|
41 |
+
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
|
42 |
+
Np=3,
|
43 |
+
Hp=self.cfg.plane_size,
|
44 |
+
Wp=self.cfg.plane_size,
|
45 |
+
)
|
tsr/models/transformer/attention.py
ADDED
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# --------
|
16 |
+
#
|
17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
18 |
+
#
|
19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
|
39 |
+
from typing import Optional
|
40 |
+
|
41 |
+
import torch
|
42 |
+
import torch.nn.functional as F
|
43 |
+
from torch import nn
|
44 |
+
|
45 |
+
|
46 |
+
class Attention(nn.Module):
|
47 |
+
r"""
|
48 |
+
A cross attention layer.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
query_dim (`int`):
|
52 |
+
The number of channels in the query.
|
53 |
+
cross_attention_dim (`int`, *optional*):
|
54 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
55 |
+
heads (`int`, *optional*, defaults to 8):
|
56 |
+
The number of heads to use for multi-head attention.
|
57 |
+
dim_head (`int`, *optional*, defaults to 64):
|
58 |
+
The number of channels in each head.
|
59 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
60 |
+
The dropout probability to use.
|
61 |
+
bias (`bool`, *optional*, defaults to False):
|
62 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
63 |
+
upcast_attention (`bool`, *optional*, defaults to False):
|
64 |
+
Set to `True` to upcast the attention computation to `float32`.
|
65 |
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
66 |
+
Set to `True` to upcast the softmax computation to `float32`.
|
67 |
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
68 |
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
69 |
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
70 |
+
The number of groups to use for the group norm in the cross attention.
|
71 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
72 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
73 |
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
74 |
+
The number of groups to use for the group norm in the attention.
|
75 |
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
76 |
+
The number of channels to use for the spatial normalization.
|
77 |
+
out_bias (`bool`, *optional*, defaults to `True`):
|
78 |
+
Set to `True` to use a bias in the output linear layer.
|
79 |
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
80 |
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
81 |
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
82 |
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
83 |
+
`added_kv_proj_dim` is not `None`.
|
84 |
+
eps (`float`, *optional*, defaults to 1e-5):
|
85 |
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
86 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
87 |
+
A factor to rescale the output by dividing it with this value.
|
88 |
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
89 |
+
Set to `True` to add the residual connection to the output.
|
90 |
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
91 |
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
92 |
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
93 |
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
94 |
+
`AttnProcessor` otherwise.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
query_dim: int,
|
100 |
+
cross_attention_dim: Optional[int] = None,
|
101 |
+
heads: int = 8,
|
102 |
+
dim_head: int = 64,
|
103 |
+
dropout: float = 0.0,
|
104 |
+
bias: bool = False,
|
105 |
+
upcast_attention: bool = False,
|
106 |
+
upcast_softmax: bool = False,
|
107 |
+
cross_attention_norm: Optional[str] = None,
|
108 |
+
cross_attention_norm_num_groups: int = 32,
|
109 |
+
added_kv_proj_dim: Optional[int] = None,
|
110 |
+
norm_num_groups: Optional[int] = None,
|
111 |
+
out_bias: bool = True,
|
112 |
+
scale_qk: bool = True,
|
113 |
+
only_cross_attention: bool = False,
|
114 |
+
eps: float = 1e-5,
|
115 |
+
rescale_output_factor: float = 1.0,
|
116 |
+
residual_connection: bool = False,
|
117 |
+
_from_deprecated_attn_block: bool = False,
|
118 |
+
processor: Optional["AttnProcessor"] = None,
|
119 |
+
out_dim: int = None,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
123 |
+
self.query_dim = query_dim
|
124 |
+
self.cross_attention_dim = (
|
125 |
+
cross_attention_dim if cross_attention_dim is not None else query_dim
|
126 |
+
)
|
127 |
+
self.upcast_attention = upcast_attention
|
128 |
+
self.upcast_softmax = upcast_softmax
|
129 |
+
self.rescale_output_factor = rescale_output_factor
|
130 |
+
self.residual_connection = residual_connection
|
131 |
+
self.dropout = dropout
|
132 |
+
self.fused_projections = False
|
133 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
134 |
+
|
135 |
+
# we make use of this private variable to know whether this class is loaded
|
136 |
+
# with an deprecated state dict so that we can convert it on the fly
|
137 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
138 |
+
|
139 |
+
self.scale_qk = scale_qk
|
140 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
141 |
+
|
142 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
143 |
+
# for slice_size > 0 the attention score computation
|
144 |
+
# is split across the batch axis to save memory
|
145 |
+
# You can set slice_size with `set_attention_slice`
|
146 |
+
self.sliceable_head_dim = heads
|
147 |
+
|
148 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
149 |
+
self.only_cross_attention = only_cross_attention
|
150 |
+
|
151 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
152 |
+
raise ValueError(
|
153 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
154 |
+
)
|
155 |
+
|
156 |
+
if norm_num_groups is not None:
|
157 |
+
self.group_norm = nn.GroupNorm(
|
158 |
+
num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
self.group_norm = None
|
162 |
+
|
163 |
+
self.spatial_norm = None
|
164 |
+
|
165 |
+
if cross_attention_norm is None:
|
166 |
+
self.norm_cross = None
|
167 |
+
elif cross_attention_norm == "layer_norm":
|
168 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
169 |
+
elif cross_attention_norm == "group_norm":
|
170 |
+
if self.added_kv_proj_dim is not None:
|
171 |
+
# The given `encoder_hidden_states` are initially of shape
|
172 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
173 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
174 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
175 |
+
# the number of channels for the group norm.
|
176 |
+
norm_cross_num_channels = added_kv_proj_dim
|
177 |
+
else:
|
178 |
+
norm_cross_num_channels = self.cross_attention_dim
|
179 |
+
|
180 |
+
self.norm_cross = nn.GroupNorm(
|
181 |
+
num_channels=norm_cross_num_channels,
|
182 |
+
num_groups=cross_attention_norm_num_groups,
|
183 |
+
eps=1e-5,
|
184 |
+
affine=True,
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
raise ValueError(
|
188 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
189 |
+
)
|
190 |
+
|
191 |
+
linear_cls = nn.Linear
|
192 |
+
|
193 |
+
self.linear_cls = linear_cls
|
194 |
+
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
195 |
+
|
196 |
+
if not self.only_cross_attention:
|
197 |
+
# only relevant for the `AddedKVProcessor` classes
|
198 |
+
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
199 |
+
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
200 |
+
else:
|
201 |
+
self.to_k = None
|
202 |
+
self.to_v = None
|
203 |
+
|
204 |
+
if self.added_kv_proj_dim is not None:
|
205 |
+
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
206 |
+
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
207 |
+
|
208 |
+
self.to_out = nn.ModuleList([])
|
209 |
+
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
|
210 |
+
self.to_out.append(nn.Dropout(dropout))
|
211 |
+
|
212 |
+
# set attention processor
|
213 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
214 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
215 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
216 |
+
if processor is None:
|
217 |
+
processor = (
|
218 |
+
AttnProcessor2_0()
|
219 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
220 |
+
else AttnProcessor()
|
221 |
+
)
|
222 |
+
self.set_processor(processor)
|
223 |
+
|
224 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
225 |
+
self.processor = processor
|
226 |
+
|
227 |
+
def forward(
|
228 |
+
self,
|
229 |
+
hidden_states: torch.FloatTensor,
|
230 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
231 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
232 |
+
**cross_attention_kwargs,
|
233 |
+
) -> torch.Tensor:
|
234 |
+
r"""
|
235 |
+
The forward method of the `Attention` class.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
hidden_states (`torch.Tensor`):
|
239 |
+
The hidden states of the query.
|
240 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
241 |
+
The hidden states of the encoder.
|
242 |
+
attention_mask (`torch.Tensor`, *optional*):
|
243 |
+
The attention mask to use. If `None`, no mask is applied.
|
244 |
+
**cross_attention_kwargs:
|
245 |
+
Additional keyword arguments to pass along to the cross attention.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
`torch.Tensor`: The output of the attention layer.
|
249 |
+
"""
|
250 |
+
# The `Attention` class can call different attention processors / attention functions
|
251 |
+
# here we simply pass along all tensors to the selected processor class
|
252 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
253 |
+
return self.processor(
|
254 |
+
self,
|
255 |
+
hidden_states,
|
256 |
+
encoder_hidden_states=encoder_hidden_states,
|
257 |
+
attention_mask=attention_mask,
|
258 |
+
**cross_attention_kwargs,
|
259 |
+
)
|
260 |
+
|
261 |
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
262 |
+
r"""
|
263 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
264 |
+
is the number of heads initialized while constructing the `Attention` class.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
`torch.Tensor`: The reshaped tensor.
|
271 |
+
"""
|
272 |
+
head_size = self.heads
|
273 |
+
batch_size, seq_len, dim = tensor.shape
|
274 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
275 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
276 |
+
batch_size // head_size, seq_len, dim * head_size
|
277 |
+
)
|
278 |
+
return tensor
|
279 |
+
|
280 |
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
281 |
+
r"""
|
282 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
283 |
+
the number of heads initialized while constructing the `Attention` class.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
287 |
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
288 |
+
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
`torch.Tensor`: The reshaped tensor.
|
292 |
+
"""
|
293 |
+
head_size = self.heads
|
294 |
+
batch_size, seq_len, dim = tensor.shape
|
295 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
296 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
297 |
+
|
298 |
+
if out_dim == 3:
|
299 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
300 |
+
|
301 |
+
return tensor
|
302 |
+
|
303 |
+
def get_attention_scores(
|
304 |
+
self,
|
305 |
+
query: torch.Tensor,
|
306 |
+
key: torch.Tensor,
|
307 |
+
attention_mask: torch.Tensor = None,
|
308 |
+
) -> torch.Tensor:
|
309 |
+
r"""
|
310 |
+
Compute the attention scores.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
query (`torch.Tensor`): The query tensor.
|
314 |
+
key (`torch.Tensor`): The key tensor.
|
315 |
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
`torch.Tensor`: The attention probabilities/scores.
|
319 |
+
"""
|
320 |
+
dtype = query.dtype
|
321 |
+
if self.upcast_attention:
|
322 |
+
query = query.float()
|
323 |
+
key = key.float()
|
324 |
+
|
325 |
+
if attention_mask is None:
|
326 |
+
baddbmm_input = torch.empty(
|
327 |
+
query.shape[0],
|
328 |
+
query.shape[1],
|
329 |
+
key.shape[1],
|
330 |
+
dtype=query.dtype,
|
331 |
+
device=query.device,
|
332 |
+
)
|
333 |
+
beta = 0
|
334 |
+
else:
|
335 |
+
baddbmm_input = attention_mask
|
336 |
+
beta = 1
|
337 |
+
|
338 |
+
attention_scores = torch.baddbmm(
|
339 |
+
baddbmm_input,
|
340 |
+
query,
|
341 |
+
key.transpose(-1, -2),
|
342 |
+
beta=beta,
|
343 |
+
alpha=self.scale,
|
344 |
+
)
|
345 |
+
del baddbmm_input
|
346 |
+
|
347 |
+
if self.upcast_softmax:
|
348 |
+
attention_scores = attention_scores.float()
|
349 |
+
|
350 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
351 |
+
del attention_scores
|
352 |
+
|
353 |
+
attention_probs = attention_probs.to(dtype)
|
354 |
+
|
355 |
+
return attention_probs
|
356 |
+
|
357 |
+
def prepare_attention_mask(
|
358 |
+
self,
|
359 |
+
attention_mask: torch.Tensor,
|
360 |
+
target_length: int,
|
361 |
+
batch_size: int,
|
362 |
+
out_dim: int = 3,
|
363 |
+
) -> torch.Tensor:
|
364 |
+
r"""
|
365 |
+
Prepare the attention mask for the attention computation.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
attention_mask (`torch.Tensor`):
|
369 |
+
The attention mask to prepare.
|
370 |
+
target_length (`int`):
|
371 |
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
372 |
+
batch_size (`int`):
|
373 |
+
The batch size, which is used to repeat the attention mask.
|
374 |
+
out_dim (`int`, *optional*, defaults to `3`):
|
375 |
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
`torch.Tensor`: The prepared attention mask.
|
379 |
+
"""
|
380 |
+
head_size = self.heads
|
381 |
+
if attention_mask is None:
|
382 |
+
return attention_mask
|
383 |
+
|
384 |
+
current_length: int = attention_mask.shape[-1]
|
385 |
+
if current_length != target_length:
|
386 |
+
if attention_mask.device.type == "mps":
|
387 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
388 |
+
# Instead, we can manually construct the padding tensor.
|
389 |
+
padding_shape = (
|
390 |
+
attention_mask.shape[0],
|
391 |
+
attention_mask.shape[1],
|
392 |
+
target_length,
|
393 |
+
)
|
394 |
+
padding = torch.zeros(
|
395 |
+
padding_shape,
|
396 |
+
dtype=attention_mask.dtype,
|
397 |
+
device=attention_mask.device,
|
398 |
+
)
|
399 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
400 |
+
else:
|
401 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
402 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
403 |
+
# remaining_length: int = target_length - current_length
|
404 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
405 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
406 |
+
|
407 |
+
if out_dim == 3:
|
408 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
409 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
410 |
+
elif out_dim == 4:
|
411 |
+
attention_mask = attention_mask.unsqueeze(1)
|
412 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
413 |
+
|
414 |
+
return attention_mask
|
415 |
+
|
416 |
+
def norm_encoder_hidden_states(
|
417 |
+
self, encoder_hidden_states: torch.Tensor
|
418 |
+
) -> torch.Tensor:
|
419 |
+
r"""
|
420 |
+
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
421 |
+
`Attention` class.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
425 |
+
|
426 |
+
Returns:
|
427 |
+
`torch.Tensor`: The normalized encoder hidden states.
|
428 |
+
"""
|
429 |
+
assert (
|
430 |
+
self.norm_cross is not None
|
431 |
+
), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
432 |
+
|
433 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
434 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
435 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
436 |
+
# Group norm norms along the channels dimension and expects
|
437 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
438 |
+
# to norm along the hidden dimension, so we need to move
|
439 |
+
# (batch_size, sequence_length, hidden_size) ->
|
440 |
+
# (batch_size, hidden_size, sequence_length)
|
441 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
442 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
443 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
444 |
+
else:
|
445 |
+
assert False
|
446 |
+
|
447 |
+
return encoder_hidden_states
|
448 |
+
|
449 |
+
@torch.no_grad()
|
450 |
+
def fuse_projections(self, fuse=True):
|
451 |
+
is_cross_attention = self.cross_attention_dim != self.query_dim
|
452 |
+
device = self.to_q.weight.data.device
|
453 |
+
dtype = self.to_q.weight.data.dtype
|
454 |
+
|
455 |
+
if not is_cross_attention:
|
456 |
+
# fetch weight matrices.
|
457 |
+
concatenated_weights = torch.cat(
|
458 |
+
[self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
|
459 |
+
)
|
460 |
+
in_features = concatenated_weights.shape[1]
|
461 |
+
out_features = concatenated_weights.shape[0]
|
462 |
+
|
463 |
+
# create a new single projection layer and copy over the weights.
|
464 |
+
self.to_qkv = self.linear_cls(
|
465 |
+
in_features, out_features, bias=False, device=device, dtype=dtype
|
466 |
+
)
|
467 |
+
self.to_qkv.weight.copy_(concatenated_weights)
|
468 |
+
|
469 |
+
else:
|
470 |
+
concatenated_weights = torch.cat(
|
471 |
+
[self.to_k.weight.data, self.to_v.weight.data]
|
472 |
+
)
|
473 |
+
in_features = concatenated_weights.shape[1]
|
474 |
+
out_features = concatenated_weights.shape[0]
|
475 |
+
|
476 |
+
self.to_kv = self.linear_cls(
|
477 |
+
in_features, out_features, bias=False, device=device, dtype=dtype
|
478 |
+
)
|
479 |
+
self.to_kv.weight.copy_(concatenated_weights)
|
480 |
+
|
481 |
+
self.fused_projections = fuse
|
482 |
+
|
483 |
+
|
484 |
+
class AttnProcessor:
|
485 |
+
r"""
|
486 |
+
Default processor for performing attention-related computations.
|
487 |
+
"""
|
488 |
+
|
489 |
+
def __call__(
|
490 |
+
self,
|
491 |
+
attn: Attention,
|
492 |
+
hidden_states: torch.FloatTensor,
|
493 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
494 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
495 |
+
) -> torch.Tensor:
|
496 |
+
residual = hidden_states
|
497 |
+
|
498 |
+
input_ndim = hidden_states.ndim
|
499 |
+
|
500 |
+
if input_ndim == 4:
|
501 |
+
batch_size, channel, height, width = hidden_states.shape
|
502 |
+
hidden_states = hidden_states.view(
|
503 |
+
batch_size, channel, height * width
|
504 |
+
).transpose(1, 2)
|
505 |
+
|
506 |
+
batch_size, sequence_length, _ = (
|
507 |
+
hidden_states.shape
|
508 |
+
if encoder_hidden_states is None
|
509 |
+
else encoder_hidden_states.shape
|
510 |
+
)
|
511 |
+
attention_mask = attn.prepare_attention_mask(
|
512 |
+
attention_mask, sequence_length, batch_size
|
513 |
+
)
|
514 |
+
|
515 |
+
if attn.group_norm is not None:
|
516 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
517 |
+
1, 2
|
518 |
+
)
|
519 |
+
|
520 |
+
query = attn.to_q(hidden_states)
|
521 |
+
|
522 |
+
if encoder_hidden_states is None:
|
523 |
+
encoder_hidden_states = hidden_states
|
524 |
+
elif attn.norm_cross:
|
525 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
526 |
+
encoder_hidden_states
|
527 |
+
)
|
528 |
+
|
529 |
+
key = attn.to_k(encoder_hidden_states)
|
530 |
+
value = attn.to_v(encoder_hidden_states)
|
531 |
+
|
532 |
+
query = attn.head_to_batch_dim(query)
|
533 |
+
key = attn.head_to_batch_dim(key)
|
534 |
+
value = attn.head_to_batch_dim(value)
|
535 |
+
|
536 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
537 |
+
hidden_states = torch.bmm(attention_probs, value)
|
538 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
539 |
+
|
540 |
+
# linear proj
|
541 |
+
hidden_states = attn.to_out[0](hidden_states)
|
542 |
+
# dropout
|
543 |
+
hidden_states = attn.to_out[1](hidden_states)
|
544 |
+
|
545 |
+
if input_ndim == 4:
|
546 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
547 |
+
batch_size, channel, height, width
|
548 |
+
)
|
549 |
+
|
550 |
+
if attn.residual_connection:
|
551 |
+
hidden_states = hidden_states + residual
|
552 |
+
|
553 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
554 |
+
|
555 |
+
return hidden_states
|
556 |
+
|
557 |
+
|
558 |
+
class AttnProcessor2_0:
|
559 |
+
r"""
|
560 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
561 |
+
"""
|
562 |
+
|
563 |
+
def __init__(self):
|
564 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
565 |
+
raise ImportError(
|
566 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
567 |
+
)
|
568 |
+
|
569 |
+
def __call__(
|
570 |
+
self,
|
571 |
+
attn: Attention,
|
572 |
+
hidden_states: torch.FloatTensor,
|
573 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
574 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
575 |
+
) -> torch.FloatTensor:
|
576 |
+
residual = hidden_states
|
577 |
+
|
578 |
+
input_ndim = hidden_states.ndim
|
579 |
+
|
580 |
+
if input_ndim == 4:
|
581 |
+
batch_size, channel, height, width = hidden_states.shape
|
582 |
+
hidden_states = hidden_states.view(
|
583 |
+
batch_size, channel, height * width
|
584 |
+
).transpose(1, 2)
|
585 |
+
|
586 |
+
batch_size, sequence_length, _ = (
|
587 |
+
hidden_states.shape
|
588 |
+
if encoder_hidden_states is None
|
589 |
+
else encoder_hidden_states.shape
|
590 |
+
)
|
591 |
+
|
592 |
+
if attention_mask is not None:
|
593 |
+
attention_mask = attn.prepare_attention_mask(
|
594 |
+
attention_mask, sequence_length, batch_size
|
595 |
+
)
|
596 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
597 |
+
# (batch, heads, source_length, target_length)
|
598 |
+
attention_mask = attention_mask.view(
|
599 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
600 |
+
)
|
601 |
+
|
602 |
+
if attn.group_norm is not None:
|
603 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
604 |
+
1, 2
|
605 |
+
)
|
606 |
+
|
607 |
+
query = attn.to_q(hidden_states)
|
608 |
+
|
609 |
+
if encoder_hidden_states is None:
|
610 |
+
encoder_hidden_states = hidden_states
|
611 |
+
elif attn.norm_cross:
|
612 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
613 |
+
encoder_hidden_states
|
614 |
+
)
|
615 |
+
|
616 |
+
key = attn.to_k(encoder_hidden_states)
|
617 |
+
value = attn.to_v(encoder_hidden_states)
|
618 |
+
|
619 |
+
inner_dim = key.shape[-1]
|
620 |
+
head_dim = inner_dim // attn.heads
|
621 |
+
|
622 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
623 |
+
|
624 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
625 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
626 |
+
|
627 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
628 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
629 |
+
hidden_states = F.scaled_dot_product_attention(
|
630 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
631 |
+
)
|
632 |
+
|
633 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
634 |
+
batch_size, -1, attn.heads * head_dim
|
635 |
+
)
|
636 |
+
hidden_states = hidden_states.to(query.dtype)
|
637 |
+
|
638 |
+
# linear proj
|
639 |
+
hidden_states = attn.to_out[0](hidden_states)
|
640 |
+
# dropout
|
641 |
+
hidden_states = attn.to_out[1](hidden_states)
|
642 |
+
|
643 |
+
if input_ndim == 4:
|
644 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
645 |
+
batch_size, channel, height, width
|
646 |
+
)
|
647 |
+
|
648 |
+
if attn.residual_connection:
|
649 |
+
hidden_states = hidden_states + residual
|
650 |
+
|
651 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
652 |
+
|
653 |
+
return hidden_states
|
tsr/models/transformer/basic_transformer_block.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# --------
|
16 |
+
#
|
17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
18 |
+
#
|
19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
|
39 |
+
from typing import Optional
|
40 |
+
|
41 |
+
import torch
|
42 |
+
import torch.nn.functional as F
|
43 |
+
from torch import nn
|
44 |
+
|
45 |
+
from .attention import Attention
|
46 |
+
|
47 |
+
|
48 |
+
class BasicTransformerBlock(nn.Module):
|
49 |
+
r"""
|
50 |
+
A basic Transformer block.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
dim (`int`): The number of channels in the input and output.
|
54 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
55 |
+
attention_head_dim (`int`): The number of channels in each head.
|
56 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
57 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
58 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
59 |
+
attention_bias (:
|
60 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
61 |
+
only_cross_attention (`bool`, *optional*):
|
62 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
63 |
+
double_self_attention (`bool`, *optional*):
|
64 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
65 |
+
upcast_attention (`bool`, *optional*):
|
66 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
67 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
68 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
69 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
70 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
71 |
+
final_dropout (`bool` *optional*, defaults to False):
|
72 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
dim: int,
|
78 |
+
num_attention_heads: int,
|
79 |
+
attention_head_dim: int,
|
80 |
+
dropout=0.0,
|
81 |
+
cross_attention_dim: Optional[int] = None,
|
82 |
+
activation_fn: str = "geglu",
|
83 |
+
attention_bias: bool = False,
|
84 |
+
only_cross_attention: bool = False,
|
85 |
+
double_self_attention: bool = False,
|
86 |
+
upcast_attention: bool = False,
|
87 |
+
norm_elementwise_affine: bool = True,
|
88 |
+
norm_type: str = "layer_norm",
|
89 |
+
final_dropout: bool = False,
|
90 |
+
):
|
91 |
+
super().__init__()
|
92 |
+
self.only_cross_attention = only_cross_attention
|
93 |
+
|
94 |
+
assert norm_type == "layer_norm"
|
95 |
+
|
96 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
97 |
+
# 1. Self-Attn
|
98 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
99 |
+
self.attn1 = Attention(
|
100 |
+
query_dim=dim,
|
101 |
+
heads=num_attention_heads,
|
102 |
+
dim_head=attention_head_dim,
|
103 |
+
dropout=dropout,
|
104 |
+
bias=attention_bias,
|
105 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
106 |
+
upcast_attention=upcast_attention,
|
107 |
+
)
|
108 |
+
|
109 |
+
# 2. Cross-Attn
|
110 |
+
if cross_attention_dim is not None or double_self_attention:
|
111 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
112 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
113 |
+
# the second cross attention block.
|
114 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
115 |
+
|
116 |
+
self.attn2 = Attention(
|
117 |
+
query_dim=dim,
|
118 |
+
cross_attention_dim=(
|
119 |
+
cross_attention_dim if not double_self_attention else None
|
120 |
+
),
|
121 |
+
heads=num_attention_heads,
|
122 |
+
dim_head=attention_head_dim,
|
123 |
+
dropout=dropout,
|
124 |
+
bias=attention_bias,
|
125 |
+
upcast_attention=upcast_attention,
|
126 |
+
) # is self-attn if encoder_hidden_states is none
|
127 |
+
else:
|
128 |
+
self.norm2 = None
|
129 |
+
self.attn2 = None
|
130 |
+
|
131 |
+
# 3. Feed-forward
|
132 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
133 |
+
self.ff = FeedForward(
|
134 |
+
dim,
|
135 |
+
dropout=dropout,
|
136 |
+
activation_fn=activation_fn,
|
137 |
+
final_dropout=final_dropout,
|
138 |
+
)
|
139 |
+
|
140 |
+
# let chunk size default to None
|
141 |
+
self._chunk_size = None
|
142 |
+
self._chunk_dim = 0
|
143 |
+
|
144 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
145 |
+
# Sets chunk feed-forward
|
146 |
+
self._chunk_size = chunk_size
|
147 |
+
self._chunk_dim = dim
|
148 |
+
|
149 |
+
def forward(
|
150 |
+
self,
|
151 |
+
hidden_states: torch.FloatTensor,
|
152 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
153 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
154 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
155 |
+
) -> torch.FloatTensor:
|
156 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
157 |
+
# 0. Self-Attention
|
158 |
+
norm_hidden_states = self.norm1(hidden_states)
|
159 |
+
|
160 |
+
attn_output = self.attn1(
|
161 |
+
norm_hidden_states,
|
162 |
+
encoder_hidden_states=(
|
163 |
+
encoder_hidden_states if self.only_cross_attention else None
|
164 |
+
),
|
165 |
+
attention_mask=attention_mask,
|
166 |
+
)
|
167 |
+
|
168 |
+
hidden_states = attn_output + hidden_states
|
169 |
+
|
170 |
+
# 3. Cross-Attention
|
171 |
+
if self.attn2 is not None:
|
172 |
+
norm_hidden_states = self.norm2(hidden_states)
|
173 |
+
|
174 |
+
attn_output = self.attn2(
|
175 |
+
norm_hidden_states,
|
176 |
+
encoder_hidden_states=encoder_hidden_states,
|
177 |
+
attention_mask=encoder_attention_mask,
|
178 |
+
)
|
179 |
+
hidden_states = attn_output + hidden_states
|
180 |
+
|
181 |
+
# 4. Feed-forward
|
182 |
+
norm_hidden_states = self.norm3(hidden_states)
|
183 |
+
|
184 |
+
if self._chunk_size is not None:
|
185 |
+
# "feed_forward_chunk_size" can be used to save memory
|
186 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
187 |
+
raise ValueError(
|
188 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
189 |
+
)
|
190 |
+
|
191 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
192 |
+
ff_output = torch.cat(
|
193 |
+
[
|
194 |
+
self.ff(hid_slice)
|
195 |
+
for hid_slice in norm_hidden_states.chunk(
|
196 |
+
num_chunks, dim=self._chunk_dim
|
197 |
+
)
|
198 |
+
],
|
199 |
+
dim=self._chunk_dim,
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
ff_output = self.ff(norm_hidden_states)
|
203 |
+
|
204 |
+
hidden_states = ff_output + hidden_states
|
205 |
+
|
206 |
+
return hidden_states
|
207 |
+
|
208 |
+
|
209 |
+
class FeedForward(nn.Module):
|
210 |
+
r"""
|
211 |
+
A feed-forward layer.
|
212 |
+
|
213 |
+
Parameters:
|
214 |
+
dim (`int`): The number of channels in the input.
|
215 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
216 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
217 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
218 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
219 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
dim: int,
|
225 |
+
dim_out: Optional[int] = None,
|
226 |
+
mult: int = 4,
|
227 |
+
dropout: float = 0.0,
|
228 |
+
activation_fn: str = "geglu",
|
229 |
+
final_dropout: bool = False,
|
230 |
+
):
|
231 |
+
super().__init__()
|
232 |
+
inner_dim = int(dim * mult)
|
233 |
+
dim_out = dim_out if dim_out is not None else dim
|
234 |
+
linear_cls = nn.Linear
|
235 |
+
|
236 |
+
if activation_fn == "gelu":
|
237 |
+
act_fn = GELU(dim, inner_dim)
|
238 |
+
if activation_fn == "gelu-approximate":
|
239 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
240 |
+
elif activation_fn == "geglu":
|
241 |
+
act_fn = GEGLU(dim, inner_dim)
|
242 |
+
elif activation_fn == "geglu-approximate":
|
243 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
244 |
+
|
245 |
+
self.net = nn.ModuleList([])
|
246 |
+
# project in
|
247 |
+
self.net.append(act_fn)
|
248 |
+
# project dropout
|
249 |
+
self.net.append(nn.Dropout(dropout))
|
250 |
+
# project out
|
251 |
+
self.net.append(linear_cls(inner_dim, dim_out))
|
252 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
253 |
+
if final_dropout:
|
254 |
+
self.net.append(nn.Dropout(dropout))
|
255 |
+
|
256 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
257 |
+
for module in self.net:
|
258 |
+
hidden_states = module(hidden_states)
|
259 |
+
return hidden_states
|
260 |
+
|
261 |
+
|
262 |
+
class GELU(nn.Module):
|
263 |
+
r"""
|
264 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
265 |
+
|
266 |
+
Parameters:
|
267 |
+
dim_in (`int`): The number of channels in the input.
|
268 |
+
dim_out (`int`): The number of channels in the output.
|
269 |
+
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
270 |
+
"""
|
271 |
+
|
272 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
273 |
+
super().__init__()
|
274 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
275 |
+
self.approximate = approximate
|
276 |
+
|
277 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
278 |
+
if gate.device.type != "mps":
|
279 |
+
return F.gelu(gate, approximate=self.approximate)
|
280 |
+
# mps: gelu is not implemented for float16
|
281 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
|
282 |
+
dtype=gate.dtype
|
283 |
+
)
|
284 |
+
|
285 |
+
def forward(self, hidden_states):
|
286 |
+
hidden_states = self.proj(hidden_states)
|
287 |
+
hidden_states = self.gelu(hidden_states)
|
288 |
+
return hidden_states
|
289 |
+
|
290 |
+
|
291 |
+
class GEGLU(nn.Module):
|
292 |
+
r"""
|
293 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
294 |
+
|
295 |
+
Parameters:
|
296 |
+
dim_in (`int`): The number of channels in the input.
|
297 |
+
dim_out (`int`): The number of channels in the output.
|
298 |
+
"""
|
299 |
+
|
300 |
+
def __init__(self, dim_in: int, dim_out: int):
|
301 |
+
super().__init__()
|
302 |
+
linear_cls = nn.Linear
|
303 |
+
|
304 |
+
self.proj = linear_cls(dim_in, dim_out * 2)
|
305 |
+
|
306 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
307 |
+
if gate.device.type != "mps":
|
308 |
+
return F.gelu(gate)
|
309 |
+
# mps: gelu is not implemented for float16
|
310 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
311 |
+
|
312 |
+
def forward(self, hidden_states, scale: float = 1.0):
|
313 |
+
args = ()
|
314 |
+
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
315 |
+
return hidden_states * self.gelu(gate)
|
316 |
+
|
317 |
+
|
318 |
+
class ApproximateGELU(nn.Module):
|
319 |
+
r"""
|
320 |
+
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
|
321 |
+
https://arxiv.org/abs/1606.08415.
|
322 |
+
|
323 |
+
Parameters:
|
324 |
+
dim_in (`int`): The number of channels in the input.
|
325 |
+
dim_out (`int`): The number of channels in the output.
|
326 |
+
"""
|
327 |
+
|
328 |
+
def __init__(self, dim_in: int, dim_out: int):
|
329 |
+
super().__init__()
|
330 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
331 |
+
|
332 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
333 |
+
x = self.proj(x)
|
334 |
+
return x * torch.sigmoid(1.702 * x)
|
tsr/models/transformer/transformer_1d.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# --------
|
16 |
+
#
|
17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
18 |
+
#
|
19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
|
39 |
+
from dataclasses import dataclass
|
40 |
+
from typing import Optional
|
41 |
+
|
42 |
+
import torch
|
43 |
+
import torch.nn.functional as F
|
44 |
+
from torch import nn
|
45 |
+
|
46 |
+
from ...utils import BaseModule
|
47 |
+
from .basic_transformer_block import BasicTransformerBlock
|
48 |
+
|
49 |
+
|
50 |
+
class Transformer1D(BaseModule):
|
51 |
+
@dataclass
|
52 |
+
class Config(BaseModule.Config):
|
53 |
+
num_attention_heads: int = 16
|
54 |
+
attention_head_dim: int = 88
|
55 |
+
in_channels: Optional[int] = None
|
56 |
+
out_channels: Optional[int] = None
|
57 |
+
num_layers: int = 1
|
58 |
+
dropout: float = 0.0
|
59 |
+
norm_num_groups: int = 32
|
60 |
+
cross_attention_dim: Optional[int] = None
|
61 |
+
attention_bias: bool = False
|
62 |
+
activation_fn: str = "geglu"
|
63 |
+
only_cross_attention: bool = False
|
64 |
+
double_self_attention: bool = False
|
65 |
+
upcast_attention: bool = False
|
66 |
+
norm_type: str = "layer_norm"
|
67 |
+
norm_elementwise_affine: bool = True
|
68 |
+
gradient_checkpointing: bool = False
|
69 |
+
|
70 |
+
cfg: Config
|
71 |
+
|
72 |
+
def configure(self) -> None:
|
73 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
74 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
75 |
+
inner_dim = self.num_attention_heads * self.attention_head_dim
|
76 |
+
|
77 |
+
linear_cls = nn.Linear
|
78 |
+
|
79 |
+
# 2. Define input layers
|
80 |
+
self.in_channels = self.cfg.in_channels
|
81 |
+
|
82 |
+
self.norm = torch.nn.GroupNorm(
|
83 |
+
num_groups=self.cfg.norm_num_groups,
|
84 |
+
num_channels=self.cfg.in_channels,
|
85 |
+
eps=1e-6,
|
86 |
+
affine=True,
|
87 |
+
)
|
88 |
+
self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
|
89 |
+
|
90 |
+
# 3. Define transformers blocks
|
91 |
+
self.transformer_blocks = nn.ModuleList(
|
92 |
+
[
|
93 |
+
BasicTransformerBlock(
|
94 |
+
inner_dim,
|
95 |
+
self.num_attention_heads,
|
96 |
+
self.attention_head_dim,
|
97 |
+
dropout=self.cfg.dropout,
|
98 |
+
cross_attention_dim=self.cfg.cross_attention_dim,
|
99 |
+
activation_fn=self.cfg.activation_fn,
|
100 |
+
attention_bias=self.cfg.attention_bias,
|
101 |
+
only_cross_attention=self.cfg.only_cross_attention,
|
102 |
+
double_self_attention=self.cfg.double_self_attention,
|
103 |
+
upcast_attention=self.cfg.upcast_attention,
|
104 |
+
norm_type=self.cfg.norm_type,
|
105 |
+
norm_elementwise_affine=self.cfg.norm_elementwise_affine,
|
106 |
+
)
|
107 |
+
for d in range(self.cfg.num_layers)
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
# 4. Define output layers
|
112 |
+
self.out_channels = (
|
113 |
+
self.cfg.in_channels
|
114 |
+
if self.cfg.out_channels is None
|
115 |
+
else self.cfg.out_channels
|
116 |
+
)
|
117 |
+
|
118 |
+
self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
|
119 |
+
|
120 |
+
self.gradient_checkpointing = self.cfg.gradient_checkpointing
|
121 |
+
|
122 |
+
def forward(
|
123 |
+
self,
|
124 |
+
hidden_states: torch.Tensor,
|
125 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
126 |
+
attention_mask: Optional[torch.Tensor] = None,
|
127 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
The [`Transformer1DModel`] forward method.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
134 |
+
Input `hidden_states`.
|
135 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
136 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
137 |
+
self-attention.
|
138 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
139 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
140 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
141 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
142 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
143 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
144 |
+
|
145 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
146 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
147 |
+
|
148 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
149 |
+
above. This bias will be added to the cross-attention scores.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
torch.FloatTensor
|
153 |
+
"""
|
154 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
155 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
156 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
157 |
+
# expects mask of shape:
|
158 |
+
# [batch, key_tokens]
|
159 |
+
# adds singleton query_tokens dimension:
|
160 |
+
# [batch, 1, key_tokens]
|
161 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
162 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
163 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
164 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
165 |
+
# assume that mask is expressed as:
|
166 |
+
# (1 = keep, 0 = discard)
|
167 |
+
# convert mask into a bias that can be added to attention scores:
|
168 |
+
# (keep = +0, discard = -10000.0)
|
169 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
170 |
+
attention_mask = attention_mask.unsqueeze(1)
|
171 |
+
|
172 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
173 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
174 |
+
encoder_attention_mask = (
|
175 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
176 |
+
) * -10000.0
|
177 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
178 |
+
|
179 |
+
# 1. Input
|
180 |
+
batch, _, seq_len = hidden_states.shape
|
181 |
+
residual = hidden_states
|
182 |
+
|
183 |
+
hidden_states = self.norm(hidden_states)
|
184 |
+
inner_dim = hidden_states.shape[1]
|
185 |
+
hidden_states = hidden_states.permute(0, 2, 1).reshape(
|
186 |
+
batch, seq_len, inner_dim
|
187 |
+
)
|
188 |
+
hidden_states = self.proj_in(hidden_states)
|
189 |
+
|
190 |
+
# 2. Blocks
|
191 |
+
for block in self.transformer_blocks:
|
192 |
+
if self.training and self.gradient_checkpointing:
|
193 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
194 |
+
block,
|
195 |
+
hidden_states,
|
196 |
+
attention_mask,
|
197 |
+
encoder_hidden_states,
|
198 |
+
encoder_attention_mask,
|
199 |
+
use_reentrant=False,
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
hidden_states = block(
|
203 |
+
hidden_states,
|
204 |
+
attention_mask=attention_mask,
|
205 |
+
encoder_hidden_states=encoder_hidden_states,
|
206 |
+
encoder_attention_mask=encoder_attention_mask,
|
207 |
+
)
|
208 |
+
|
209 |
+
# 3. Output
|
210 |
+
hidden_states = self.proj_out(hidden_states)
|
211 |
+
hidden_states = (
|
212 |
+
hidden_states.reshape(batch, seq_len, inner_dim)
|
213 |
+
.permute(0, 2, 1)
|
214 |
+
.contiguous()
|
215 |
+
)
|
216 |
+
|
217 |
+
output = hidden_states + residual
|
218 |
+
|
219 |
+
return output
|
tsr/system.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import List, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import trimesh
|
11 |
+
from einops import rearrange
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from .models.isosurface import MarchingCubeHelper
|
17 |
+
from .utils import (
|
18 |
+
BaseModule,
|
19 |
+
ImagePreprocessor,
|
20 |
+
find_class,
|
21 |
+
get_spherical_cameras,
|
22 |
+
scale_tensor,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class TSR(BaseModule):
|
27 |
+
@dataclass
|
28 |
+
class Config(BaseModule.Config):
|
29 |
+
cond_image_size: int
|
30 |
+
|
31 |
+
image_tokenizer_cls: str
|
32 |
+
image_tokenizer: dict
|
33 |
+
|
34 |
+
tokenizer_cls: str
|
35 |
+
tokenizer: dict
|
36 |
+
|
37 |
+
backbone_cls: str
|
38 |
+
backbone: dict
|
39 |
+
|
40 |
+
post_processor_cls: str
|
41 |
+
post_processor: dict
|
42 |
+
|
43 |
+
decoder_cls: str
|
44 |
+
decoder: dict
|
45 |
+
|
46 |
+
renderer_cls: str
|
47 |
+
renderer: dict
|
48 |
+
|
49 |
+
cfg: Config
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def from_pretrained(
|
53 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
54 |
+
):
|
55 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
56 |
+
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
57 |
+
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
58 |
+
else:
|
59 |
+
config_path = hf_hub_download(
|
60 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
61 |
+
)
|
62 |
+
weight_path = hf_hub_download(
|
63 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
64 |
+
)
|
65 |
+
|
66 |
+
cfg = OmegaConf.load(config_path)
|
67 |
+
OmegaConf.resolve(cfg)
|
68 |
+
model = cls(cfg)
|
69 |
+
ckpt = torch.load(weight_path, map_location="cpu")
|
70 |
+
model.load_state_dict(ckpt)
|
71 |
+
return model
|
72 |
+
|
73 |
+
def configure(self):
|
74 |
+
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
|
75 |
+
self.cfg.image_tokenizer
|
76 |
+
)
|
77 |
+
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
78 |
+
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
|
79 |
+
self.post_processor = find_class(self.cfg.post_processor_cls)(
|
80 |
+
self.cfg.post_processor
|
81 |
+
)
|
82 |
+
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
|
83 |
+
self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
|
84 |
+
self.image_processor = ImagePreprocessor()
|
85 |
+
self.isosurface_helper = None
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
image: Union[
|
90 |
+
PIL.Image.Image,
|
91 |
+
np.ndarray,
|
92 |
+
torch.FloatTensor,
|
93 |
+
List[PIL.Image.Image],
|
94 |
+
List[np.ndarray],
|
95 |
+
List[torch.FloatTensor],
|
96 |
+
],
|
97 |
+
device: str,
|
98 |
+
) -> torch.FloatTensor:
|
99 |
+
rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
|
100 |
+
device
|
101 |
+
)
|
102 |
+
batch_size = rgb_cond.shape[0]
|
103 |
+
|
104 |
+
input_image_tokens: torch.Tensor = self.image_tokenizer(
|
105 |
+
rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
|
106 |
+
)
|
107 |
+
|
108 |
+
input_image_tokens = rearrange(
|
109 |
+
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
|
110 |
+
)
|
111 |
+
|
112 |
+
tokens: torch.Tensor = self.tokenizer(batch_size)
|
113 |
+
|
114 |
+
tokens = self.backbone(
|
115 |
+
tokens,
|
116 |
+
encoder_hidden_states=input_image_tokens,
|
117 |
+
)
|
118 |
+
|
119 |
+
scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
|
120 |
+
return scene_codes
|
121 |
+
|
122 |
+
def render(
|
123 |
+
self,
|
124 |
+
scene_codes,
|
125 |
+
n_views: int,
|
126 |
+
elevation_deg: float = 0.0,
|
127 |
+
camera_distance: float = 1.9,
|
128 |
+
fovy_deg: float = 40.0,
|
129 |
+
height: int = 256,
|
130 |
+
width: int = 256,
|
131 |
+
return_type: str = "pil",
|
132 |
+
):
|
133 |
+
rays_o, rays_d = get_spherical_cameras(
|
134 |
+
n_views, elevation_deg, camera_distance, fovy_deg, height, width
|
135 |
+
)
|
136 |
+
rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
|
137 |
+
|
138 |
+
def process_output(image: torch.FloatTensor):
|
139 |
+
if return_type == "pt":
|
140 |
+
return image
|
141 |
+
elif return_type == "np":
|
142 |
+
return image.detach().cpu().numpy()
|
143 |
+
elif return_type == "pil":
|
144 |
+
return Image.fromarray(
|
145 |
+
(image.detach().cpu().numpy() * 255.0).astype(np.uint8)
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError
|
149 |
+
|
150 |
+
images = []
|
151 |
+
for scene_code in scene_codes:
|
152 |
+
images_ = []
|
153 |
+
for i in range(n_views):
|
154 |
+
with torch.no_grad():
|
155 |
+
image = self.renderer(
|
156 |
+
self.decoder, scene_code, rays_o[i], rays_d[i]
|
157 |
+
)
|
158 |
+
images_.append(process_output(image))
|
159 |
+
images.append(images_)
|
160 |
+
|
161 |
+
return images
|
162 |
+
|
163 |
+
def set_marching_cubes_resolution(self, resolution: int):
|
164 |
+
if (
|
165 |
+
self.isosurface_helper is not None
|
166 |
+
and self.isosurface_helper.resolution == resolution
|
167 |
+
):
|
168 |
+
return
|
169 |
+
self.isosurface_helper = MarchingCubeHelper(resolution)
|
170 |
+
|
171 |
+
def extract_mesh(self, scene_codes, has_vertex_color, resolution: int = 256, threshold: float = 25.0):
|
172 |
+
self.set_marching_cubes_resolution(resolution)
|
173 |
+
meshes = []
|
174 |
+
for scene_code in scene_codes:
|
175 |
+
with torch.no_grad():
|
176 |
+
density = self.renderer.query_triplane(
|
177 |
+
self.decoder,
|
178 |
+
scale_tensor(
|
179 |
+
self.isosurface_helper.grid_vertices.to(scene_codes.device),
|
180 |
+
self.isosurface_helper.points_range,
|
181 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
182 |
+
),
|
183 |
+
scene_code,
|
184 |
+
)["density_act"]
|
185 |
+
v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
|
186 |
+
v_pos = scale_tensor(
|
187 |
+
v_pos,
|
188 |
+
self.isosurface_helper.points_range,
|
189 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
190 |
+
)
|
191 |
+
color = None
|
192 |
+
if has_vertex_color:
|
193 |
+
with torch.no_grad():
|
194 |
+
color = self.renderer.query_triplane(
|
195 |
+
self.decoder,
|
196 |
+
v_pos,
|
197 |
+
scene_code,
|
198 |
+
)["color"]
|
199 |
+
mesh = trimesh.Trimesh(
|
200 |
+
vertices=v_pos.cpu().numpy(),
|
201 |
+
faces=t_pos_idx.cpu().numpy(),
|
202 |
+
vertex_colors=color.cpu().numpy() if has_vertex_color else None,
|
203 |
+
)
|
204 |
+
meshes.append(mesh)
|
205 |
+
return meshes
|
tsr/utils.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import math
|
3 |
+
from collections import defaultdict
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import imageio
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
# import rembg
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import trimesh
|
15 |
+
from omegaconf import DictConfig, OmegaConf
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
20 |
+
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
|
21 |
+
return scfg
|
22 |
+
|
23 |
+
|
24 |
+
def find_class(cls_string):
|
25 |
+
module_string = ".".join(cls_string.split(".")[:-1])
|
26 |
+
cls_name = cls_string.split(".")[-1]
|
27 |
+
module = importlib.import_module(module_string, package=None)
|
28 |
+
cls = getattr(module, cls_name)
|
29 |
+
return cls
|
30 |
+
|
31 |
+
|
32 |
+
def get_intrinsic_from_fov(fov, H, W, bs=-1):
|
33 |
+
focal_length = 0.5 * H / np.tan(0.5 * fov)
|
34 |
+
intrinsic = np.identity(3, dtype=np.float32)
|
35 |
+
intrinsic[0, 0] = focal_length
|
36 |
+
intrinsic[1, 1] = focal_length
|
37 |
+
intrinsic[0, 2] = W / 2.0
|
38 |
+
intrinsic[1, 2] = H / 2.0
|
39 |
+
|
40 |
+
if bs > 0:
|
41 |
+
intrinsic = intrinsic[None].repeat(bs, axis=0)
|
42 |
+
|
43 |
+
return torch.from_numpy(intrinsic)
|
44 |
+
|
45 |
+
|
46 |
+
class BaseModule(nn.Module):
|
47 |
+
@dataclass
|
48 |
+
class Config:
|
49 |
+
pass
|
50 |
+
|
51 |
+
cfg: Config # add this to every subclass of BaseModule to enable static type checking
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
|
55 |
+
) -> None:
|
56 |
+
super().__init__()
|
57 |
+
self.cfg = parse_structured(self.Config, cfg)
|
58 |
+
self.configure(*args, **kwargs)
|
59 |
+
|
60 |
+
def configure(self, *args, **kwargs) -> None:
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
|
64 |
+
class ImagePreprocessor:
|
65 |
+
def convert_and_resize(
|
66 |
+
self,
|
67 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
68 |
+
size: int,
|
69 |
+
):
|
70 |
+
if isinstance(image, PIL.Image.Image):
|
71 |
+
image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
72 |
+
elif isinstance(image, np.ndarray):
|
73 |
+
if image.dtype == np.uint8:
|
74 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0)
|
75 |
+
else:
|
76 |
+
image = torch.from_numpy(image)
|
77 |
+
elif isinstance(image, torch.Tensor):
|
78 |
+
pass
|
79 |
+
|
80 |
+
batched = image.ndim == 4
|
81 |
+
|
82 |
+
if not batched:
|
83 |
+
image = image[None, ...]
|
84 |
+
image = F.interpolate(
|
85 |
+
image.permute(0, 3, 1, 2),
|
86 |
+
(size, size),
|
87 |
+
mode="bilinear",
|
88 |
+
align_corners=False,
|
89 |
+
antialias=True,
|
90 |
+
).permute(0, 2, 3, 1)
|
91 |
+
if not batched:
|
92 |
+
image = image[0]
|
93 |
+
return image
|
94 |
+
|
95 |
+
def __call__(
|
96 |
+
self,
|
97 |
+
image: Union[
|
98 |
+
PIL.Image.Image,
|
99 |
+
np.ndarray,
|
100 |
+
torch.FloatTensor,
|
101 |
+
List[PIL.Image.Image],
|
102 |
+
List[np.ndarray],
|
103 |
+
List[torch.FloatTensor],
|
104 |
+
],
|
105 |
+
size: int,
|
106 |
+
) -> Any:
|
107 |
+
if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
|
108 |
+
image = self.convert_and_resize(image, size)
|
109 |
+
else:
|
110 |
+
if not isinstance(image, list):
|
111 |
+
image = [image]
|
112 |
+
image = [self.convert_and_resize(im, size) for im in image]
|
113 |
+
image = torch.stack(image, dim=0)
|
114 |
+
return image
|
115 |
+
|
116 |
+
|
117 |
+
def rays_intersect_bbox(
|
118 |
+
rays_o: torch.Tensor,
|
119 |
+
rays_d: torch.Tensor,
|
120 |
+
radius: float,
|
121 |
+
near: float = 0.0,
|
122 |
+
valid_thresh: float = 0.01,
|
123 |
+
):
|
124 |
+
input_shape = rays_o.shape[:-1]
|
125 |
+
rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
|
126 |
+
rays_d_valid = torch.where(
|
127 |
+
rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
|
128 |
+
)
|
129 |
+
if type(radius) in [int, float]:
|
130 |
+
radius = torch.FloatTensor(
|
131 |
+
[[-radius, radius], [-radius, radius], [-radius, radius]]
|
132 |
+
).to(rays_o.device)
|
133 |
+
radius = (
|
134 |
+
1.0 - 1.0e-3
|
135 |
+
) * radius # tighten the radius to make sure the intersection point lies in the bounding box
|
136 |
+
interx0 = (radius[..., 1] - rays_o) / rays_d_valid
|
137 |
+
interx1 = (radius[..., 0] - rays_o) / rays_d_valid
|
138 |
+
t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
|
139 |
+
t_far = torch.maximum(interx0, interx1).amin(dim=-1)
|
140 |
+
|
141 |
+
# check wheter a ray intersects the bbox or not
|
142 |
+
rays_valid = t_far - t_near > valid_thresh
|
143 |
+
|
144 |
+
t_near[torch.where(~rays_valid)] = 0.0
|
145 |
+
t_far[torch.where(~rays_valid)] = 0.0
|
146 |
+
|
147 |
+
t_near = t_near.view(*input_shape, 1)
|
148 |
+
t_far = t_far.view(*input_shape, 1)
|
149 |
+
rays_valid = rays_valid.view(*input_shape)
|
150 |
+
|
151 |
+
return t_near, t_far, rays_valid
|
152 |
+
|
153 |
+
|
154 |
+
def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
|
155 |
+
if chunk_size <= 0:
|
156 |
+
return func(*args, **kwargs)
|
157 |
+
B = None
|
158 |
+
for arg in list(args) + list(kwargs.values()):
|
159 |
+
if isinstance(arg, torch.Tensor):
|
160 |
+
B = arg.shape[0]
|
161 |
+
break
|
162 |
+
assert (
|
163 |
+
B is not None
|
164 |
+
), "No tensor found in args or kwargs, cannot determine batch size."
|
165 |
+
out = defaultdict(list)
|
166 |
+
out_type = None
|
167 |
+
# max(1, B) to support B == 0
|
168 |
+
for i in range(0, max(1, B), chunk_size):
|
169 |
+
out_chunk = func(
|
170 |
+
*[
|
171 |
+
arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
172 |
+
for arg in args
|
173 |
+
],
|
174 |
+
**{
|
175 |
+
k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
176 |
+
for k, arg in kwargs.items()
|
177 |
+
},
|
178 |
+
)
|
179 |
+
if out_chunk is None:
|
180 |
+
continue
|
181 |
+
out_type = type(out_chunk)
|
182 |
+
if isinstance(out_chunk, torch.Tensor):
|
183 |
+
out_chunk = {0: out_chunk}
|
184 |
+
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
|
185 |
+
chunk_length = len(out_chunk)
|
186 |
+
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
|
187 |
+
elif isinstance(out_chunk, dict):
|
188 |
+
pass
|
189 |
+
else:
|
190 |
+
print(
|
191 |
+
f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
|
192 |
+
)
|
193 |
+
exit(1)
|
194 |
+
for k, v in out_chunk.items():
|
195 |
+
v = v if torch.is_grad_enabled() else v.detach()
|
196 |
+
out[k].append(v)
|
197 |
+
|
198 |
+
if out_type is None:
|
199 |
+
return None
|
200 |
+
|
201 |
+
out_merged: Dict[Any, Optional[torch.Tensor]] = {}
|
202 |
+
for k, v in out.items():
|
203 |
+
if all([vv is None for vv in v]):
|
204 |
+
# allow None in return value
|
205 |
+
out_merged[k] = None
|
206 |
+
elif all([isinstance(vv, torch.Tensor) for vv in v]):
|
207 |
+
out_merged[k] = torch.cat(v, dim=0)
|
208 |
+
else:
|
209 |
+
raise TypeError(
|
210 |
+
f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
|
211 |
+
)
|
212 |
+
|
213 |
+
if out_type is torch.Tensor:
|
214 |
+
return out_merged[0]
|
215 |
+
elif out_type in [tuple, list]:
|
216 |
+
return out_type([out_merged[i] for i in range(chunk_length)])
|
217 |
+
elif out_type is dict:
|
218 |
+
return out_merged
|
219 |
+
|
220 |
+
|
221 |
+
ValidScale = Union[Tuple[float, float], torch.FloatTensor]
|
222 |
+
|
223 |
+
|
224 |
+
def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
|
225 |
+
if inp_scale is None:
|
226 |
+
inp_scale = (0, 1)
|
227 |
+
if tgt_scale is None:
|
228 |
+
tgt_scale = (0, 1)
|
229 |
+
if isinstance(tgt_scale, torch.FloatTensor):
|
230 |
+
assert dat.shape[-1] == tgt_scale.shape[-1]
|
231 |
+
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
232 |
+
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
233 |
+
return dat
|
234 |
+
|
235 |
+
|
236 |
+
def get_activation(name) -> Callable:
|
237 |
+
if name is None:
|
238 |
+
return lambda x: x
|
239 |
+
name = name.lower()
|
240 |
+
if name == "none":
|
241 |
+
return lambda x: x
|
242 |
+
elif name == "exp":
|
243 |
+
return lambda x: torch.exp(x)
|
244 |
+
elif name == "sigmoid":
|
245 |
+
return lambda x: torch.sigmoid(x)
|
246 |
+
elif name == "tanh":
|
247 |
+
return lambda x: torch.tanh(x)
|
248 |
+
elif name == "softplus":
|
249 |
+
return lambda x: F.softplus(x)
|
250 |
+
else:
|
251 |
+
try:
|
252 |
+
return getattr(F, name)
|
253 |
+
except AttributeError:
|
254 |
+
raise ValueError(f"Unknown activation function: {name}")
|
255 |
+
|
256 |
+
|
257 |
+
def get_ray_directions(
|
258 |
+
H: int,
|
259 |
+
W: int,
|
260 |
+
focal: Union[float, Tuple[float, float]],
|
261 |
+
principal: Optional[Tuple[float, float]] = None,
|
262 |
+
use_pixel_centers: bool = True,
|
263 |
+
normalize: bool = True,
|
264 |
+
) -> torch.FloatTensor:
|
265 |
+
"""
|
266 |
+
Get ray directions for all pixels in camera coordinate.
|
267 |
+
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
|
268 |
+
ray-tracing-generating-camera-rays/standard-coordinate-systems
|
269 |
+
|
270 |
+
Inputs:
|
271 |
+
H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
|
272 |
+
Outputs:
|
273 |
+
directions: (H, W, 3), the direction of the rays in camera coordinate
|
274 |
+
"""
|
275 |
+
pixel_center = 0.5 if use_pixel_centers else 0
|
276 |
+
|
277 |
+
if isinstance(focal, float):
|
278 |
+
fx, fy = focal, focal
|
279 |
+
cx, cy = W / 2, H / 2
|
280 |
+
else:
|
281 |
+
fx, fy = focal
|
282 |
+
assert principal is not None
|
283 |
+
cx, cy = principal
|
284 |
+
|
285 |
+
i, j = torch.meshgrid(
|
286 |
+
torch.arange(W, dtype=torch.float32) + pixel_center,
|
287 |
+
torch.arange(H, dtype=torch.float32) + pixel_center,
|
288 |
+
indexing="xy",
|
289 |
+
)
|
290 |
+
|
291 |
+
directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
|
292 |
+
|
293 |
+
if normalize:
|
294 |
+
directions = F.normalize(directions, dim=-1)
|
295 |
+
|
296 |
+
return directions
|
297 |
+
|
298 |
+
|
299 |
+
def get_rays(
|
300 |
+
directions,
|
301 |
+
c2w,
|
302 |
+
keepdim=False,
|
303 |
+
normalize=False,
|
304 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
305 |
+
# Rotate ray directions from camera coordinate to the world coordinate
|
306 |
+
assert directions.shape[-1] == 3
|
307 |
+
|
308 |
+
if directions.ndim == 2: # (N_rays, 3)
|
309 |
+
if c2w.ndim == 2: # (4, 4)
|
310 |
+
c2w = c2w[None, :, :]
|
311 |
+
assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
|
312 |
+
rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
|
313 |
+
rays_o = c2w[:, :3, 3].expand(rays_d.shape)
|
314 |
+
elif directions.ndim == 3: # (H, W, 3)
|
315 |
+
assert c2w.ndim in [2, 3]
|
316 |
+
if c2w.ndim == 2: # (4, 4)
|
317 |
+
rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
|
318 |
+
-1
|
319 |
+
) # (H, W, 3)
|
320 |
+
rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
|
321 |
+
elif c2w.ndim == 3: # (B, 4, 4)
|
322 |
+
rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
323 |
+
-1
|
324 |
+
) # (B, H, W, 3)
|
325 |
+
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
326 |
+
elif directions.ndim == 4: # (B, H, W, 3)
|
327 |
+
assert c2w.ndim == 3 # (B, 4, 4)
|
328 |
+
rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
329 |
+
-1
|
330 |
+
) # (B, H, W, 3)
|
331 |
+
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
332 |
+
|
333 |
+
if normalize:
|
334 |
+
rays_d = F.normalize(rays_d, dim=-1)
|
335 |
+
if not keepdim:
|
336 |
+
rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
|
337 |
+
|
338 |
+
return rays_o, rays_d
|
339 |
+
|
340 |
+
|
341 |
+
def get_spherical_cameras(
|
342 |
+
n_views: int,
|
343 |
+
elevation_deg: float,
|
344 |
+
camera_distance: float,
|
345 |
+
fovy_deg: float,
|
346 |
+
height: int,
|
347 |
+
width: int,
|
348 |
+
):
|
349 |
+
azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
|
350 |
+
elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
|
351 |
+
camera_distances = torch.full_like(elevation_deg, camera_distance)
|
352 |
+
|
353 |
+
elevation = elevation_deg * math.pi / 180
|
354 |
+
azimuth = azimuth_deg * math.pi / 180
|
355 |
+
|
356 |
+
# convert spherical coordinates to cartesian coordinates
|
357 |
+
# right hand coordinate system, x back, y right, z up
|
358 |
+
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
|
359 |
+
camera_positions = torch.stack(
|
360 |
+
[
|
361 |
+
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
362 |
+
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
363 |
+
camera_distances * torch.sin(elevation),
|
364 |
+
],
|
365 |
+
dim=-1,
|
366 |
+
)
|
367 |
+
|
368 |
+
# default scene center at origin
|
369 |
+
center = torch.zeros_like(camera_positions)
|
370 |
+
# default camera up direction as +z
|
371 |
+
up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
|
372 |
+
|
373 |
+
fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
|
374 |
+
|
375 |
+
lookat = F.normalize(center - camera_positions, dim=-1)
|
376 |
+
right = F.normalize(torch.cross(lookat, up), dim=-1)
|
377 |
+
up = F.normalize(torch.cross(right, lookat), dim=-1)
|
378 |
+
c2w3x4 = torch.cat(
|
379 |
+
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
380 |
+
dim=-1,
|
381 |
+
)
|
382 |
+
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
|
383 |
+
c2w[:, 3, 3] = 1.0
|
384 |
+
|
385 |
+
# get directions by dividing directions_unit_focal by focal length
|
386 |
+
focal_length = 0.5 * height / torch.tan(0.5 * fovy)
|
387 |
+
directions_unit_focal = get_ray_directions(
|
388 |
+
H=height,
|
389 |
+
W=width,
|
390 |
+
focal=1.0,
|
391 |
+
)
|
392 |
+
directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
|
393 |
+
directions[:, :, :, :2] = (
|
394 |
+
directions[:, :, :, :2] / focal_length[:, None, None, None]
|
395 |
+
)
|
396 |
+
# must use normalize=True to normalize directions here
|
397 |
+
rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
|
398 |
+
|
399 |
+
return rays_o, rays_d
|
400 |
+
|
401 |
+
|
402 |
+
def remove_background(
|
403 |
+
image: PIL.Image.Image,
|
404 |
+
rembg_session: Any = None,
|
405 |
+
force: bool = False,
|
406 |
+
**rembg_kwargs,
|
407 |
+
) -> PIL.Image.Image:
|
408 |
+
do_remove = True
|
409 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
410 |
+
do_remove = False
|
411 |
+
do_remove = do_remove or force
|
412 |
+
if do_remove:
|
413 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
414 |
+
return image
|
415 |
+
|
416 |
+
|
417 |
+
def resize_foreground(
|
418 |
+
image: PIL.Image.Image,
|
419 |
+
ratio: float,
|
420 |
+
) -> PIL.Image.Image:
|
421 |
+
image = np.array(image)
|
422 |
+
assert image.shape[-1] == 4
|
423 |
+
alpha = np.where(image[..., 3] > 0)
|
424 |
+
y1, y2, x1, x2 = (
|
425 |
+
alpha[0].min(),
|
426 |
+
alpha[0].max(),
|
427 |
+
alpha[1].min(),
|
428 |
+
alpha[1].max(),
|
429 |
+
)
|
430 |
+
# crop the foreground
|
431 |
+
fg = image[y1:y2, x1:x2]
|
432 |
+
# pad to square
|
433 |
+
size = max(fg.shape[0], fg.shape[1])
|
434 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
435 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
436 |
+
new_image = np.pad(
|
437 |
+
fg,
|
438 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
439 |
+
mode="constant",
|
440 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
441 |
+
)
|
442 |
+
|
443 |
+
# compute padding according to the ratio
|
444 |
+
new_size = int(new_image.shape[0] / ratio)
|
445 |
+
# pad to size, double side
|
446 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
447 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
448 |
+
new_image = np.pad(
|
449 |
+
new_image,
|
450 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
451 |
+
mode="constant",
|
452 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
453 |
+
)
|
454 |
+
new_image = PIL.Image.fromarray(new_image)
|
455 |
+
return new_image
|
456 |
+
|
457 |
+
|
458 |
+
def save_video(
|
459 |
+
frames: List[PIL.Image.Image],
|
460 |
+
output_path: str,
|
461 |
+
fps: int = 30,
|
462 |
+
):
|
463 |
+
# use imageio to save video
|
464 |
+
frames = [np.array(frame) for frame in frames]
|
465 |
+
writer = imageio.get_writer(output_path, fps=fps)
|
466 |
+
for frame in frames:
|
467 |
+
writer.append_data(frame)
|
468 |
+
writer.close()
|
469 |
+
|
470 |
+
|
471 |
+
def to_gradio_3d_orientation(mesh):
|
472 |
+
mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
|
473 |
+
mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
|
474 |
+
return mesh
|
utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
# from PIL import Image, ImageOps
|
7 |
+
# import numpy as np
|
8 |
+
import torch
|
9 |
+
import xatlas
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from tsr.system import TSR
|
13 |
+
from tsr.utils import save_video
|
14 |
+
from tsr.bake_texture import bake_texture
|
15 |
+
|
16 |
+
|
17 |
+
class Timer:
|
18 |
+
def __init__(self):
|
19 |
+
self.items = {}
|
20 |
+
self.time_scale = 1000.0 # ms
|
21 |
+
self.time_unit = "ms"
|
22 |
+
|
23 |
+
def start(self, name: str) -> None:
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
torch.cuda.synchronize()
|
26 |
+
self.items[name] = time.time()
|
27 |
+
logging.info(f"{name} ...")
|
28 |
+
|
29 |
+
def end(self, name: str) -> float:
|
30 |
+
if name not in self.items:
|
31 |
+
return
|
32 |
+
if torch.cuda.is_available():
|
33 |
+
torch.cuda.synchronize()
|
34 |
+
start_time = self.items.pop(name)
|
35 |
+
delta = time.time() - start_time
|
36 |
+
t = delta * self.time_scale
|
37 |
+
logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
|
38 |
+
|
39 |
+
|
40 |
+
def initialize_model(pretrained_model_name_or_path="stabilityai/TripoSR",
|
41 |
+
chunk_size=8192,
|
42 |
+
device="cuda:0" if torch.cuda.is_available() else "cpu"):
|
43 |
+
timer.start("Initializing model")
|
44 |
+
model = TSR.from_pretrained(
|
45 |
+
pretrained_model_name_or_path,
|
46 |
+
config_name="config.yaml",
|
47 |
+
weight_name="model.ckpt",
|
48 |
+
)
|
49 |
+
model.renderer.set_chunk_size(chunk_size)
|
50 |
+
model.to(device)
|
51 |
+
timer.end("Initializing model")
|
52 |
+
return model
|
53 |
+
|
54 |
+
|
55 |
+
def process_image(image_path, output_dir, no_remove_bg=True, foreground_ratio=0.85):
|
56 |
+
timer.start("Processing image")
|
57 |
+
|
58 |
+
if no_remove_bg:
|
59 |
+
rembg_session = None
|
60 |
+
image = np.array(Image.open(image_path).convert("RGB"))
|
61 |
+
else:
|
62 |
+
image = remove_background(image_path)
|
63 |
+
|
64 |
+
# Save the processed image
|
65 |
+
os.makedirs(output_dir, exist_ok=True)
|
66 |
+
image.save(os.path.join(output_dir, "processed_input.png"))
|
67 |
+
|
68 |
+
timer.end("Processing image")
|
69 |
+
return image
|
70 |
+
|
71 |
+
|
72 |
+
def run_model(model, image, output_dir, device="cuda:0" if torch.cuda.is_available() else "cpu", render=False, mc_resolution=256, model_save_format='obj', bake_texture_flag=False, texture_resolution=2048):
|
73 |
+
logging.info("Running model...")
|
74 |
+
|
75 |
+
timer.start("Running model")
|
76 |
+
with torch.no_grad():
|
77 |
+
scene_codes = model([image], device=device)
|
78 |
+
timer.end("Running model")
|
79 |
+
|
80 |
+
if render:
|
81 |
+
timer.start("Rendering")
|
82 |
+
render_images = model.render(scene_codes, n_views=30, return_type="pil")
|
83 |
+
for ri, render_image in enumerate(render_images[0]):
|
84 |
+
render_image.save(os.path.join(output_dir, f"render_{ri:03d}.png"))
|
85 |
+
save_video(
|
86 |
+
render_images[0], os.path.join(output_dir, "render.mp4"), fps=30
|
87 |
+
)
|
88 |
+
timer.end("Rendering")
|
89 |
+
|
90 |
+
timer.start("Extracting mesh")
|
91 |
+
meshes = model.extract_mesh(scene_codes, not bake_texture_flag, resolution=mc_resolution)
|
92 |
+
timer.end("Extracting mesh")
|
93 |
+
|
94 |
+
out_mesh_path = os.path.join(output_dir, f"mesh.{model_save_format}")
|
95 |
+
if bake_texture_flag:
|
96 |
+
out_texture_path = os.path.join(output_dir, "texture.png")
|
97 |
+
|
98 |
+
timer.start("Baking texture")
|
99 |
+
bake_output = bake_texture(meshes[0], model, scene_codes[0], texture_resolution)
|
100 |
+
timer.end("Baking texture")
|
101 |
+
|
102 |
+
timer.start("Exporting mesh and texture")
|
103 |
+
xatlas.export(out_mesh_path, meshes[0].vertices[bake_output["vmapping"]], bake_output["indices"], bake_output["uvs"], meshes[0].vertex_normals[bake_output["vmapping"]])
|
104 |
+
Image.fromarray((bake_output["colors"] * 255.0).astype(np.uint8)).transpose(Image.FLIP_TOP_BOTTOM).save(out_texture_path)
|
105 |
+
timer.end("Exporting mesh and texture")
|
106 |
+
else:
|
107 |
+
timer.start("Exporting mesh")
|
108 |
+
meshes[0].export(out_mesh_path)
|
109 |
+
timer.end("Exporting mesh")
|
110 |
+
|
111 |
+
return out_mesh_path
|
112 |
+
|
113 |
+
|
114 |
+
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
|
115 |
+
timer = Timer()
|