|
import os |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from torchvision.transforms import Resize |
|
import imageio |
|
from einops import rearrange |
|
import cv2 |
|
from PIL import Image |
|
from annotator.util import resize_image, HWC3 |
|
from annotator.canny import CannyDetector |
|
from annotator.openpose import OpenposeDetector |
|
import decord |
|
decord.bridge.set_bridge('torch') |
|
|
|
apply_canny = CannyDetector() |
|
apply_openpose = OpenposeDetector() |
|
|
|
|
|
def add_watermark(image, im_size_h, im_size_w, watermark_path="__assets__/picsart_watermark.jpg", |
|
wmsize=16, bbuf=5, opacity=0.9): |
|
''' |
|
Creates a watermark on the saved inference image. |
|
We request that you do not remove this to properly assign credit to |
|
Shi-Lab's work. |
|
''' |
|
watermark = Image.open(watermark_path).resize((wmsize, wmsize)) |
|
loc_h = im_size_h - wmsize - bbuf |
|
loc_w = im_size_w - wmsize - bbuf |
|
image[loc_h:-bbuf, loc_w:-bbuf, :] = watermark |
|
return image |
|
|
|
|
|
def pre_process_canny(input_video, low_threshold=100, high_threshold=200): |
|
detected_maps = [] |
|
for frame in input_video: |
|
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8) |
|
detected_map = apply_canny(img, low_threshold, high_threshold) |
|
detected_map = HWC3(detected_map) |
|
detected_maps.append(detected_map[None]) |
|
detected_maps = np.concatenate(detected_maps) |
|
control = torch.from_numpy(detected_maps.copy()).float() / 255.0 |
|
return rearrange(control, 'f h w c -> f c h w') |
|
|
|
|
|
def pre_process_pose(input_video, apply_pose_detect: bool = True): |
|
detected_maps = [] |
|
for frame in input_video: |
|
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8) |
|
img = HWC3(img) |
|
if apply_pose_detect: |
|
detected_map, _ = apply_openpose(img) |
|
else: |
|
detected_map = img |
|
detected_map = HWC3(detected_map) |
|
H, W, C = img.shape |
|
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) |
|
detected_maps.append(detected_map[None]) |
|
detected_maps = np.concatenate(detected_maps) |
|
control = torch.from_numpy(detected_maps.copy()).float() / 255.0 |
|
return rearrange(control, 'f h w c -> f c h w') |
|
|
|
|
|
def create_video(frames, fps, rescale=False, path=None): |
|
if path is None: |
|
dir = "temporal" |
|
os.makedirs(dir, exist_ok=True) |
|
path = os.path.join(dir, 'movie.mp4') |
|
|
|
outputs = [] |
|
for i, x in enumerate(frames): |
|
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) |
|
if rescale: |
|
x = (x + 1.0) / 2.0 |
|
x = (x * 255).numpy().astype(np.uint8) |
|
|
|
h_, w_, _ = x.shape |
|
x = add_watermark(x, im_size_h=h_, im_size_w=w_) |
|
outputs.append(x) |
|
|
|
|
|
imageio.mimsave(path, outputs, fps=fps) |
|
return path |
|
|
|
def create_gif(frames, fps, rescale=False): |
|
dir = "temporal" |
|
os.makedirs(dir, exist_ok=True) |
|
path = os.path.join(dir, 'canny_db.gif') |
|
|
|
outputs = [] |
|
for i, x in enumerate(frames): |
|
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) |
|
if rescale: |
|
x = (x + 1.0) / 2.0 |
|
x = (x * 255).numpy().astype(np.uint8) |
|
h_, w_, _ = x.shape |
|
x = add_watermark(x, im_size_h=h_, im_size_w=w_) |
|
outputs.append(x) |
|
|
|
|
|
imageio.mimsave(path, outputs, fps=fps) |
|
return path |
|
|
|
def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1): |
|
vr = decord.VideoReader(video_path) |
|
video = vr.get_batch(range(0, len(vr))).asnumpy() |
|
initial_fps = vr.get_avg_fps() |
|
if output_fps == -1: |
|
output_fps = int(initial_fps) |
|
if end_t == -1: |
|
end_t = len(vr) / initial_fps |
|
else: |
|
end_t = min(len(vr) / initial_fps, end_t) |
|
assert 0 <= start_t < end_t |
|
assert output_fps > 0 |
|
f, h, w, c = video.shape |
|
start_f_ind = int(start_t * initial_fps) |
|
end_f_ind = int(end_t * initial_fps) |
|
num_f = int((end_t - start_t) * output_fps) |
|
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int) |
|
video = video[sample_idx] |
|
video = rearrange(video, "f h w c -> f c h w") |
|
video = torch.Tensor(video).to(device).to(dtype) |
|
if h > w: |
|
w = int(w * resolution / h) |
|
w = w - w % 8 |
|
h = resolution - resolution % 8 |
|
video = Resize((h, w))(video) |
|
else: |
|
h = int(h * resolution / w) |
|
h = h - h % 8 |
|
w = resolution - resolution % 8 |
|
video = Resize((h, w))(video) |
|
if normalize: |
|
video = video / 127.5 - 1.0 |
|
return video, output_fps |
|
|
|
|
|
def post_process_gif(list_of_results, image_resolution): |
|
output_file = "/tmp/ddxk.gif" |
|
imageio.mimsave(output_file, list_of_results, fps=4) |
|
return output_file |
|
|
|
|
|
class CrossFrameAttnProcessor: |
|
def __init__(self, unet_chunk_size=2): |
|
self.unet_chunk_size = unet_chunk_size |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None): |
|
batch_size, sequence_length, _ = hidden_states.shape |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
query = attn.to_q(hidden_states) |
|
|
|
is_cross_attention = encoder_hidden_states is not None |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.cross_attention_norm: |
|
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
if not is_cross_attention: |
|
video_length = key.size()[0] // self.unet_chunk_size |
|
|
|
|
|
former_frame_index = [0] * video_length |
|
key = rearrange(key, "(b f) d c -> b f d c", f=video_length) |
|
key = key[:, former_frame_index] |
|
key = rearrange(key, "b f d c -> (b f) d c") |
|
value = rearrange(value, "(b f) d c -> b f d c", f=video_length) |
|
value = value[:, former_frame_index] |
|
value = rearrange(value, "b f d c -> (b f) d c") |
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
return hidden_states |
|
|