Spaces:
Runtime error
Runtime error
Linoy Tsaban
commited on
Commit
·
76c83ed
1
Parent(s):
5eb8981
Update preprocess_utils.py
Browse files- preprocess_utils.py +73 -2
preprocess_utils.py
CHANGED
@@ -219,7 +219,7 @@ class Preprocess(nn.Module):
|
|
219 |
|
220 |
return_inverted_latents = self.config["frames"] is not None
|
221 |
for i, t in enumerate(tqdm(timesteps)):
|
222 |
-
for b in range(0, latent_frames.shape[0], batch_size):
|
223 |
x_batch = latent_frames[b:b + batch_size]
|
224 |
model_input = x_batch
|
225 |
cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
|
@@ -320,5 +320,76 @@ class Preprocess(nn.Module):
|
|
320 |
return self.frames, self.latents, self.total_inverted_latents, None
|
321 |
|
322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
return_inverted_latents = self.config["frames"] is not None
|
221 |
for i, t in enumerate(tqdm(timesteps)):
|
222 |
+
for b in range(0, latent_frames.shape[0], int(batch_size)):
|
223 |
x_batch = latent_frames[b:b + batch_size]
|
224 |
model_input = x_batch
|
225 |
cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
|
|
|
320 |
return self.frames, self.latents, self.total_inverted_latents, None
|
321 |
|
322 |
|
323 |
+
def prep(opt):
|
324 |
+
# timesteps to save
|
325 |
+
if opt["sd_version"] == '2.1':
|
326 |
+
model_key = "stabilityai/stable-diffusion-2-1-base"
|
327 |
+
elif opt["sd_version"] == '2.0':
|
328 |
+
model_key = "stabilityai/stable-diffusion-2-base"
|
329 |
+
elif opt["sd_version"] == '1.5' or opt["sd_version"] == 'ControlNet':
|
330 |
+
model_key = "runwayml/stable-diffusion-v1-5"
|
331 |
+
elif opt["sd_version"] == 'depth':
|
332 |
+
model_key = "stabilityai/stable-diffusion-2-depth"
|
333 |
+
toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
334 |
+
toy_scheduler.set_timesteps(opt["save_steps"])
|
335 |
+
timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=opt["save_steps"],
|
336 |
+
strength=1.0,
|
337 |
+
device=device)
|
338 |
+
|
339 |
+
seed_everything(opt["seed"])
|
340 |
+
if not opt["frames"]: # original non demo setting
|
341 |
+
save_path = os.path.join(opt["save_dir"],
|
342 |
+
f'sd_{opt["sd_version"]}',
|
343 |
+
Path(opt["data_path"]).stem,
|
344 |
+
f'steps_{opt["steps"]}',
|
345 |
+
f'nframes_{opt["n_frames"]}')
|
346 |
+
os.makedirs(os.path.join(save_path, f'latents'), exist_ok=True)
|
347 |
+
add_dict_to_yaml_file(os.path.join(opt["save_dir"], 'inversion_prompts.yaml'), Path(opt["data_path"]).stem, opt["inversion_prompt"])
|
348 |
+
# save inversion prompt in a txt file
|
349 |
+
with open(os.path.join(save_path, 'inversion_prompt.txt'), 'w') as f:
|
350 |
+
f.write(opt["inversion_prompt"])
|
351 |
+
else:
|
352 |
+
save_path = None
|
353 |
+
|
354 |
+
model = Preprocess(device, opt)
|
355 |
+
|
356 |
+
frames, latents, total_inverted_latents, rgb_reconstruction = model.extract_latents(
|
357 |
+
num_steps=model.config["steps"],
|
358 |
+
save_path=save_path,
|
359 |
+
batch_size=model.config["batch_size"],
|
360 |
+
timesteps_to_save=timesteps_to_save,
|
361 |
+
inversion_prompt=model.config["inversion_prompt"],
|
362 |
+
)
|
363 |
|
364 |
+
|
365 |
+
return frames, latents, total_inverted_latents, rgb_reconstruction
|
366 |
+
# if not os.path.isdir(os.path.join(save_path, f'frames')):
|
367 |
+
# os.mkdir(os.path.join(save_path, f'frames'))
|
368 |
+
# for i, frame in enumerate(recon_frames):
|
369 |
+
# T.ToPILImage()(frame).save(os.path.join(save_path, f'frames', f'{i:05d}.png'))
|
370 |
+
# frames = (recon_frames * 255).to(torch.uint8).cpu().permute(0, 2, 3, 1)
|
371 |
+
# write_video(os.path.join(save_path, f'inverted.mp4'), frames, fps=10)
|
372 |
+
|
373 |
+
|
374 |
+
# if __name__ == "__main__":
|
375 |
+
# device = 'cuda'
|
376 |
+
# parser = argparse.ArgumentParser()
|
377 |
+
# parser.add_argument('--data_path', type=str,
|
378 |
+
# default='data/woman-running.mp4')
|
379 |
+
# parser.add_argument('--H', type=int, default=512,
|
380 |
+
# help='for non-square videos, we recommand using 672 x 384 or 384 x 672, aspect ratio 1.75')
|
381 |
+
# parser.add_argument('--W', type=int, default=512,
|
382 |
+
# help='for non-square videos, we recommand using 672 x 384 or 384 x 672, aspect ratio 1.75')
|
383 |
+
# parser.add_argument('--save_dir', type=str, default='latents')
|
384 |
+
# parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1', 'ControlNet', 'depth'],
|
385 |
+
# help="stable diffusion version")
|
386 |
+
# parser.add_argument('--steps', type=int, default=500)
|
387 |
+
# parser.add_argument('--batch_size', type=int, default=40)
|
388 |
+
# parser.add_argument('--save_steps', type=int, default=50)
|
389 |
+
# parser.add_argument('--n_frames', type=int, default=40)
|
390 |
+
# parser.add_argument('--inversion_prompt', type=str, default='a woman running')
|
391 |
+
# opt = parser.parse_args()
|
392 |
+
# video_path = opt.data_path
|
393 |
+
# save_video_frames(video_path, img_size=(opt.H, opt.W))
|
394 |
+
# opt.data_path = os.path.join('data', Path(video_path).stem)
|
395 |
+
# prep(opt)
|