wondervictor commited on
Commit
2d1e0bb
·
verified ·
1 Parent(s): e829283

Update model_new.py

Browse files
Files changed (1) hide show
  1. model_new.py +257 -0
model_new.py CHANGED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import spaces
3
+ from safetensors.torch import load_file
4
+ from autoregressive.models.gpt_t2i import GPT_models
5
+ from tokenizer.tokenizer_image.vq_model import VQ_models
6
+ from language.t5 import T5Embedder
7
+ import torch
8
+ import numpy as np
9
+ import PIL
10
+ from PIL import Image
11
+ from condition.canny import CannyDetector
12
+ import time
13
+ from autoregressive.models.generate import generate
14
+ from condition.midas.depth import MidasDetector
15
+
16
+
17
+ models = {
18
+ "canny": "checkpoints/canny_MR.safetensors",
19
+ "depth": "checkpoints/depth_MR.safetensors",
20
+ }
21
+
22
+
23
+ def resize_image_to_16_multiple(image, condition_type='canny'):
24
+ if isinstance(image, np.ndarray):
25
+ image = Image.fromarray(image)
26
+ # image = Image.open(image_path)
27
+ width, height = image.size
28
+
29
+ if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
30
+ new_width = (width + 31) // 32 * 32
31
+ new_height = (height + 31) // 32 * 32
32
+ else:
33
+ new_width = (width + 15) // 16 * 16
34
+ new_height = (height + 15) // 16 * 16
35
+
36
+ resized_image = image.resize((new_width, new_height))
37
+ return resized_image
38
+
39
+
40
+ class Model:
41
+
42
+ def __init__(self):
43
+ self.device = torch.device(
44
+ "cuda")
45
+ self.base_model_id = ""
46
+ self.task_name = ""
47
+ self.vq_model = self.load_vq()
48
+ self.t5_model = self.load_t5()
49
+ self.gpt_model_canny = self.load_gpt(condition_type='canny')
50
+ # self.gpt_model_depth = self.load_gpt(condition_type='depth')
51
+
52
+ def to(self, device):
53
+ self.gpt_model_canny.to('cuda')
54
+
55
+ def load_vq(self):
56
+ vq_model = VQ_models["VQ-16"](codebook_size=16384,
57
+ codebook_embed_dim=8)
58
+ vq_model.eval()
59
+ checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
60
+ map_location="cpu")
61
+ vq_model.load_state_dict(checkpoint["model"])
62
+ del checkpoint
63
+ print("image tokenizer is loaded")
64
+ return vq_model
65
+
66
+ def load_gpt(self, condition_type='canny'):
67
+ gpt_ckpt = models[condition_type]
68
+ # precision = torch.bfloat16
69
+ precision = torch.float32
70
+ latent_size = 768 // 16
71
+ gpt_model = GPT_models["GPT-XL"](
72
+ block_size=latent_size**2,
73
+ cls_token_num=120,
74
+ model_type='t2i',
75
+ condition_type=condition_type,
76
+ ).to(device='cpu', dtype=precision)
77
+
78
+ model_weight = load_file(gpt_ckpt)
79
+ print("prev:", model_weight['adapter.model.embeddings.patch_embeddings.projection.weight'])
80
+ gpt_model.load_state_dict(model_weight, strict=True)
81
+ gpt_model.eval()
82
+ print("loaded:", gpt_model.adapter.model.embeddings.patch_embeddings.projection.weight)
83
+ print("gpt model is loaded")
84
+ return gpt_model
85
+
86
+ def load_t5(self):
87
+ # precision = torch.bfloat16
88
+ precision = torch.float32
89
+ t5_model = T5Embedder(
90
+ device=self.device,
91
+ local_cache=True,
92
+ cache_dir='checkpoints/flan-t5-xl',
93
+ dir_or_name='flan-t5-xl',
94
+ torch_dtype=precision,
95
+ model_max_length=120,
96
+ )
97
+ return t5_model
98
+
99
+ @torch.no_grad()
100
+ @spaces.GPU(enable_queue=True)
101
+ def process_canny(
102
+ self,
103
+ image: np.ndarray,
104
+ prompt: str,
105
+ cfg_scale: float,
106
+ temperature: float,
107
+ top_k: int,
108
+ top_p: int,
109
+ seed: int,
110
+ low_threshold: int,
111
+ high_threshold: int,
112
+ ) -> list[PIL.Image.Image]:
113
+ print(image)
114
+ image = resize_image_to_16_multiple(image, 'canny')
115
+ W, H = image.size
116
+ print(W, H)
117
+ self.t5_model.model.to('cuda').to(torch.bfloat16)
118
+ self.gpt_model_canny.to('cuda').to(torch.bfloat16)
119
+ self.vq_model.to('cuda')
120
+
121
+ condition_img = self.get_control_canny(np.array(image), low_threshold,
122
+ high_threshold)
123
+ condition_img = torch.from_numpy(condition_img[None, None,
124
+ ...]).repeat(
125
+ 2, 3, 1, 1)
126
+ condition_img = condition_img.to(self.device)
127
+ condition_img = 2 * (condition_img / 255 - 0.5)
128
+ prompts = [prompt] * 2
129
+ caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
130
+
131
+ print(f"processing left-padding...")
132
+ new_emb_masks = torch.flip(emb_masks, dims=[-1])
133
+ new_caption_embs = []
134
+ for idx, (caption_emb,
135
+ emb_mask) in enumerate(zip(caption_embs, emb_masks)):
136
+ valid_num = int(emb_mask.sum().item())
137
+ print(f' prompt {idx} token len: {valid_num}')
138
+ new_caption_emb = torch.cat(
139
+ [caption_emb[valid_num:], caption_emb[:valid_num]])
140
+ new_caption_embs.append(new_caption_emb)
141
+ new_caption_embs = torch.stack(new_caption_embs)
142
+ c_indices = new_caption_embs * new_emb_masks[:, :, None]
143
+ c_emb_masks = new_emb_masks
144
+ qzshape = [len(c_indices), 8, H // 16, W // 16]
145
+ t1 = time.time()
146
+ print(caption_embs.device)
147
+ index_sample = generate(
148
+ self.gpt_model_canny,
149
+ c_indices,
150
+ (H // 16) * (W // 16),
151
+ c_emb_masks,
152
+ condition=condition_img,
153
+ cfg_scale=cfg_scale,
154
+ temperature=temperature,
155
+ top_k=top_k,
156
+ top_p=top_p,
157
+ sample_logits=True,
158
+ )
159
+ sampling_time = time.time() - t1
160
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
161
+
162
+ t2 = time.time()
163
+ print(index_sample.shape)
164
+ samples = self.vq_model.decode_code(
165
+ index_sample, qzshape) # output value is between [-1, 1]
166
+ decoder_time = time.time() - t2
167
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
168
+
169
+ samples = torch.cat((condition_img[0:1], samples), dim=0)
170
+ samples = 255 * (samples * 0.5 + 0.5)
171
+ samples = [image] + [
172
+ Image.fromarray(
173
+ sample.permute(1, 2, 0).cpu().detach().numpy().clip(
174
+ 0, 255).astype(np.uint8)) for sample in samples
175
+ ]
176
+ del condition_img
177
+ torch.cuda.empty_cache()
178
+ return samples
179
+
180
+ @torch.no_grad()
181
+ @spaces.GPU(enable_queue=True)
182
+ def process_depth(
183
+ self,
184
+ image: np.ndarray,
185
+ prompt: str,
186
+ cfg_scale: float,
187
+ temperature: float,
188
+ top_k: int,
189
+ top_p: int,
190
+ seed: int,
191
+ ) -> list[PIL.Image.Image]:
192
+ image = resize_image_to_16_multiple(image, 'depth')
193
+ W, H = image.size
194
+ print(W, H)
195
+ self.gpt_model_canny.to('cpu')
196
+ self.t5_model.model.to(self.device)
197
+ self.gpt_model_depth.to(self.device)
198
+ self.get_control_depth.model.to(self.device)
199
+ self.vq_model.to(self.device)
200
+ image_tensor = torch.from_numpy(np.array(image)).to(self.device)
201
+
202
+ condition_img = 2 * (image_tensor / 255 - 0.5)
203
+ print(condition_img.shape)
204
+ condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2, 1, 1, 1)
205
+
206
+ prompts = [prompt] * 2
207
+ caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
208
+
209
+ print(f"processing left-padding...")
210
+ new_emb_masks = torch.flip(emb_masks, dims=[-1])
211
+ new_caption_embs = []
212
+ for idx, (caption_emb,
213
+ emb_mask) in enumerate(zip(caption_embs, emb_masks)):
214
+ valid_num = int(emb_mask.sum().item())
215
+ print(f' prompt {idx} token len: {valid_num}')
216
+ new_caption_emb = torch.cat(
217
+ [caption_emb[valid_num:], caption_emb[:valid_num]])
218
+ new_caption_embs.append(new_caption_emb)
219
+ new_caption_embs = torch.stack(new_caption_embs)
220
+
221
+ c_indices = new_caption_embs * new_emb_masks[:, :, None]
222
+ c_emb_masks = new_emb_masks
223
+ qzshape = [len(c_indices), 8, H // 16, W // 16]
224
+ t1 = time.time()
225
+ index_sample = generate(
226
+ self.gpt_model_depth,
227
+ c_indices,
228
+ (H // 16) * (W // 16),
229
+ c_emb_masks,
230
+ condition=condition_img,
231
+ cfg_scale=cfg_scale,
232
+ temperature=temperature,
233
+ top_k=top_k,
234
+ top_p=top_p,
235
+ sample_logits=True,
236
+ )
237
+ sampling_time = time.time() - t1
238
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
239
+
240
+ t2 = time.time()
241
+ print(index_sample.shape)
242
+ samples = self.vq_model.decode_code(index_sample, qzshape)
243
+ decoder_time = time.time() - t2
244
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
245
+ condition_img = condition_img.cpu()
246
+ samples = samples.cpu()
247
+ samples = torch.cat((condition_img[0:1], samples), dim=0)
248
+ samples = 255 * (samples * 0.5 + 0.5)
249
+ samples = [image] + [
250
+ Image.fromarray(
251
+ sample.permute(1, 2, 0).cpu().detach().numpy().clip(0, 255).astype(np.uint8))
252
+ for sample in samples
253
+ ]
254
+ del image_tensor
255
+ del condition_img
256
+ torch.cuda.empty_cache()
257
+ return samples