Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- app_canny_db.py +103 -0
- app_text_to_video.py +97 -0
- config.py +1 -0
- gradio_utils.py +98 -0
- hf_utils.py +39 -0
- style (1).css +3 -0
- utils (1).py +207 -0
app_canny_db.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from model import Model
|
3 |
+
import gradio_utils
|
4 |
+
import os
|
5 |
+
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
|
6 |
+
|
7 |
+
|
8 |
+
examples = [
|
9 |
+
['Anime DB', "woman1", "Portrait of detailed 1girl, feminine, soldier cinematic shot on canon 5d ultra realistic skin intricate clothes accurate hands Rory Lewis Artgerm WLOP Jeremy Lipking Jane Ansell studio lighting"],
|
10 |
+
['Arcane DB', "woman1", "Oil painting of a beautiful girl arcane style, masterpiece, a high-quality, detailed, and professional photo"],
|
11 |
+
['GTA-5 DB', "man1", "gtav style"],
|
12 |
+
['GTA-5 DB', "woman3", "gtav style"],
|
13 |
+
['Avatar DB', "woman2", "oil painting of a beautiful girl avatar style"],
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
def load_db_model(evt: gr.SelectData):
|
18 |
+
db_name = gradio_utils.get_db_name_from_id(evt.index)
|
19 |
+
return db_name
|
20 |
+
|
21 |
+
|
22 |
+
def canny_select(evt: gr.SelectData):
|
23 |
+
canny_name = gradio_utils.get_canny_name_from_id(evt.index)
|
24 |
+
return canny_name
|
25 |
+
|
26 |
+
|
27 |
+
def create_demo(model: Model):
|
28 |
+
|
29 |
+
with gr.Blocks() as demo:
|
30 |
+
with gr.Row():
|
31 |
+
gr.Markdown(
|
32 |
+
'## Text, Canny-Edge and DreamBooth Conditional Video Generation')
|
33 |
+
with gr.Row():
|
34 |
+
gr.HTML(
|
35 |
+
"""
|
36 |
+
<div style="text-align: left; auto;">
|
37 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
38 |
+
Description: Our current release supports only four predefined DreamBooth models and four "motion edges". So you must choose one DreamBooth model and one "motion edges" shown below, or use the examples. The keywords <b>1girl</b>, <b>arcane style</b>, <b>gtav</b>, and <b>avatar style</b> correspond to the models from left to right.
|
39 |
+
</h3>
|
40 |
+
</div>
|
41 |
+
""")
|
42 |
+
with gr.Row():
|
43 |
+
with gr.Column():
|
44 |
+
# input_video_path = gr.Video(source='upload', format="mp4", visible=False)
|
45 |
+
gr.Markdown("## Selection")
|
46 |
+
db_text_field = gr.Markdown('DB Model: **Anime DB** ')
|
47 |
+
canny_text_field = gr.Markdown('Motion: **woman1**')
|
48 |
+
prompt = gr.Textbox(label='Prompt')
|
49 |
+
run_button = gr.Button(label='Run')
|
50 |
+
with gr.Accordion('Advanced options', open=False):
|
51 |
+
watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
|
52 |
+
"None"], label="Watermark", value='Picsart AI Research')
|
53 |
+
chunk_size = gr.Slider(
|
54 |
+
label="Chunk size", minimum=2, maximum=16, value=12 if on_huggingspace else 8, step=1, visible=not on_huggingspace)
|
55 |
+
with gr.Column():
|
56 |
+
result = gr.Image(label="Generated Video").style(height=400)
|
57 |
+
|
58 |
+
with gr.Row():
|
59 |
+
gallery_db = gr.Gallery(label="Db models", value=[('__assets__/db_files/anime.jpg', "anime"), ('__assets__/db_files/arcane.jpg', "Arcane"), (
|
60 |
+
'__assets__/db_files/gta.jpg', "GTA-5 (Man)"), ('__assets__/db_files/avatar.jpg', "Avatar DB")]).style(grid=[4], height=50)
|
61 |
+
with gr.Row():
|
62 |
+
gallery_canny = gr.Gallery(label="Motions", value=[('__assets__/db_files/woman1.gif', "woman1"), ('__assets__/db_files/woman2.gif', "woman2"), (
|
63 |
+
'__assets__/db_files/man1.gif', "man1"), ('__assets__/db_files/woman3.gif', "woman3")]).style(grid=[4], height=50)
|
64 |
+
|
65 |
+
db_selection = gr.Textbox(label="DB Model", visible=False)
|
66 |
+
canny_selection = gr.Textbox(
|
67 |
+
label="One of the above defined motions", visible=False)
|
68 |
+
|
69 |
+
gallery_db.select(load_db_model, None, db_selection)
|
70 |
+
gallery_canny.select(canny_select, None, canny_selection)
|
71 |
+
|
72 |
+
db_selection.change(on_db_selection_update, None, db_text_field)
|
73 |
+
canny_selection.change(on_canny_selection_update,
|
74 |
+
None, canny_text_field)
|
75 |
+
|
76 |
+
inputs = [
|
77 |
+
db_selection,
|
78 |
+
canny_selection,
|
79 |
+
prompt,
|
80 |
+
chunk_size,
|
81 |
+
watermark,
|
82 |
+
]
|
83 |
+
|
84 |
+
gr.Examples(examples=examples,
|
85 |
+
inputs=inputs,
|
86 |
+
outputs=result,
|
87 |
+
fn=model.process_controlnet_canny_db,
|
88 |
+
cache_examples=on_huggingspace,
|
89 |
+
)
|
90 |
+
|
91 |
+
run_button.click(fn=model.process_controlnet_canny_db,
|
92 |
+
inputs=inputs,
|
93 |
+
outputs=result,)
|
94 |
+
return demo
|
95 |
+
|
96 |
+
|
97 |
+
def on_db_selection_update(evt: gr.EventData):
|
98 |
+
|
99 |
+
return f"DB model: **{evt._data}**"
|
100 |
+
|
101 |
+
|
102 |
+
def on_canny_selection_update(evt: gr.EventData):
|
103 |
+
return f"Motion: **{evt._data}**"
|
app_text_to_video.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from model import Model
|
3 |
+
import os
|
4 |
+
from hf_utils import get_model_list
|
5 |
+
|
6 |
+
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
|
7 |
+
|
8 |
+
examples = [
|
9 |
+
["an astronaut waving the arm on the moon"],
|
10 |
+
["a sloth surfing on a wakeboard"],
|
11 |
+
["an astronaut walking on a street"],
|
12 |
+
["a cute cat walking on grass"],
|
13 |
+
["a horse is galloping on a street"],
|
14 |
+
["an astronaut is skiing down the hill"],
|
15 |
+
["a gorilla walking alone down the street"],
|
16 |
+
["a gorilla dancing on times square"],
|
17 |
+
["A panda dancing dancing like crazy on Times Square"],
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def create_demo(model: Model):
|
22 |
+
|
23 |
+
with gr.Blocks() as demo:
|
24 |
+
with gr.Row():
|
25 |
+
gr.Markdown('## Text2Video-Zero: Video Generation')
|
26 |
+
with gr.Row():
|
27 |
+
gr.HTML(
|
28 |
+
"""
|
29 |
+
<div style="text-align: left; auto;">
|
30 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
31 |
+
Description: Simply input <b>any textual prompt</b> to generate videos right away and unleash your creativity and imagination! You can also select from the examples below. For performance purposes, our current preview release allows to generate up to 16 frames, which can be configured in the Advanced Options.
|
32 |
+
</h3>
|
33 |
+
</div>
|
34 |
+
""")
|
35 |
+
|
36 |
+
with gr.Row():
|
37 |
+
with gr.Column():
|
38 |
+
model_name = gr.Dropdown(
|
39 |
+
label="Model",
|
40 |
+
choices=get_model_list(),
|
41 |
+
value="dreamlike-art/dreamlike-photoreal-2.0",
|
42 |
+
)
|
43 |
+
prompt = gr.Textbox(label='Prompt')
|
44 |
+
run_button = gr.Button(label='Run')
|
45 |
+
with gr.Accordion('Advanced options', open=False):
|
46 |
+
watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
|
47 |
+
"None"], label="Watermark", value='Picsart AI Research')
|
48 |
+
|
49 |
+
if on_huggingspace:
|
50 |
+
video_length = gr.Slider(
|
51 |
+
label="Video length", minimum=8, maximum=16, step=1)
|
52 |
+
else:
|
53 |
+
video_length = gr.Number(
|
54 |
+
label="Video length", value=8, precision=0)
|
55 |
+
chunk_size = gr.Slider(
|
56 |
+
label="Chunk size", minimum=2, maximum=16, value=12 if on_huggingspace else 8, step=1, visible=not on_huggingspace)
|
57 |
+
|
58 |
+
motion_field_strength_x = gr.Slider(
|
59 |
+
label='Global Translation $\delta_{x}$', minimum=-20, maximum=20, value=12, step=1)
|
60 |
+
motion_field_strength_y = gr.Slider(
|
61 |
+
label='Global Translation $\delta_{y}$', minimum=-20, maximum=20, value=12, step=1)
|
62 |
+
|
63 |
+
t0 = gr.Slider(label="Timestep t0", minimum=0,
|
64 |
+
maximum=49, value=44, step=1)
|
65 |
+
t1 = gr.Slider(label="Timestep t1", minimum=0,
|
66 |
+
maximum=49, value=47, step=1)
|
67 |
+
|
68 |
+
n_prompt = gr.Textbox(
|
69 |
+
label="Optional Negative Prompt", value='')
|
70 |
+
with gr.Column():
|
71 |
+
result = gr.Video(label="Generated Video")
|
72 |
+
|
73 |
+
inputs = [
|
74 |
+
prompt,
|
75 |
+
model_name,
|
76 |
+
motion_field_strength_x,
|
77 |
+
motion_field_strength_y,
|
78 |
+
t0,
|
79 |
+
t1,
|
80 |
+
n_prompt,
|
81 |
+
chunk_size,
|
82 |
+
video_length,
|
83 |
+
watermark,
|
84 |
+
]
|
85 |
+
|
86 |
+
gr.Examples(examples=examples,
|
87 |
+
inputs=inputs,
|
88 |
+
outputs=result,
|
89 |
+
fn=model.process_text2video,
|
90 |
+
run_on_click=False,
|
91 |
+
cache_examples=on_huggingspace,
|
92 |
+
)
|
93 |
+
|
94 |
+
run_button.click(fn=model.process_text2video,
|
95 |
+
inputs=inputs,
|
96 |
+
outputs=result,)
|
97 |
+
return demo
|
config.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
save_memory = False
|
gradio_utils.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# App Canny utils
|
4 |
+
|
5 |
+
|
6 |
+
def edge_path_to_video_path(edge_path):
|
7 |
+
video_path = edge_path
|
8 |
+
|
9 |
+
vid_name = edge_path.split("/")[-1]
|
10 |
+
if vid_name == "butterfly.mp4":
|
11 |
+
video_path = "__assets__/canny_videos_mp4_2fps/butterfly.mp4"
|
12 |
+
elif vid_name == "deer.mp4":
|
13 |
+
video_path = "__assets__/canny_videos_mp4_2fps/deer.mp4"
|
14 |
+
elif vid_name == "fox.mp4":
|
15 |
+
video_path = "__assets__/canny_videos_mp4_2fps/fox.mp4"
|
16 |
+
elif vid_name == "girl_dancing.mp4":
|
17 |
+
video_path = "__assets__/canny_videos_mp4_2fps/girl_dancing.mp4"
|
18 |
+
elif vid_name == "girl_turning.mp4":
|
19 |
+
video_path = "__assets__/canny_videos_mp4_2fps/girl_turning.mp4"
|
20 |
+
elif vid_name == "halloween.mp4":
|
21 |
+
video_path = "__assets__/canny_videos_mp4_2fps/halloween.mp4"
|
22 |
+
elif vid_name == "santa.mp4":
|
23 |
+
video_path = "__assets__/canny_videos_mp4_2fps/santa.mp4"
|
24 |
+
|
25 |
+
assert os.path.isfile(video_path)
|
26 |
+
return video_path
|
27 |
+
|
28 |
+
|
29 |
+
# App Pose utils
|
30 |
+
def motion_to_video_path(motion):
|
31 |
+
videos = [
|
32 |
+
"__assets__/poses_skeleton_gifs/dance1_corr.mp4",
|
33 |
+
"__assets__/poses_skeleton_gifs/dance2_corr.mp4",
|
34 |
+
"__assets__/poses_skeleton_gifs/dance3_corr.mp4",
|
35 |
+
"__assets__/poses_skeleton_gifs/dance4_corr.mp4",
|
36 |
+
"__assets__/poses_skeleton_gifs/dance5_corr.mp4"
|
37 |
+
]
|
38 |
+
if len(motion.split(" ")) > 1 and motion.split(" ")[1].isnumeric():
|
39 |
+
id = int(motion.split(" ")[1]) - 1
|
40 |
+
return videos[id]
|
41 |
+
else:
|
42 |
+
return motion
|
43 |
+
|
44 |
+
|
45 |
+
# App Canny Dreambooth utils
|
46 |
+
def get_video_from_canny_selection(canny_selection):
|
47 |
+
if canny_selection == "woman1":
|
48 |
+
input_video_path = "__assets__/db_files_2fps/woman1.mp4"
|
49 |
+
|
50 |
+
elif canny_selection == "woman2":
|
51 |
+
input_video_path = "__assets__/db_files_2fps/woman2.mp4"
|
52 |
+
|
53 |
+
elif canny_selection == "man1":
|
54 |
+
input_video_path = "__assets__/db_files_2fps/man1.mp4"
|
55 |
+
|
56 |
+
elif canny_selection == "woman3":
|
57 |
+
input_video_path = "__assets__/db_files_2fps/woman3.mp4"
|
58 |
+
else:
|
59 |
+
input_video_path = canny_selection
|
60 |
+
|
61 |
+
assert os.path.isfile(input_video_path)
|
62 |
+
return input_video_path
|
63 |
+
|
64 |
+
|
65 |
+
def get_model_from_db_selection(db_selection):
|
66 |
+
if db_selection == "Anime DB":
|
67 |
+
input_video_path = 'PAIR/text2video-zero-controlnet-canny-anime'
|
68 |
+
elif db_selection == "Avatar DB":
|
69 |
+
input_video_path = 'PAIR/text2video-zero-controlnet-canny-avatar'
|
70 |
+
elif db_selection == "GTA-5 DB":
|
71 |
+
input_video_path = 'PAIR/text2video-zero-controlnet-canny-gta5'
|
72 |
+
elif db_selection == "Arcane DB":
|
73 |
+
input_video_path = 'PAIR/text2video-zero-controlnet-canny-arcane'
|
74 |
+
else:
|
75 |
+
input_video_path = db_selection
|
76 |
+
|
77 |
+
return input_video_path
|
78 |
+
|
79 |
+
|
80 |
+
def get_db_name_from_id(id):
|
81 |
+
db_names = ["Anime DB", "Arcane DB", "GTA-5 DB", "Avatar DB"]
|
82 |
+
return db_names[id]
|
83 |
+
|
84 |
+
|
85 |
+
def get_canny_name_from_id(id):
|
86 |
+
canny_names = ["woman1", "woman2", "man1", "woman3"]
|
87 |
+
return canny_names[id]
|
88 |
+
|
89 |
+
|
90 |
+
def logo_name_to_path(name):
|
91 |
+
logo_paths = {
|
92 |
+
'Picsart AI Research': '__assets__/pair_watermark.png',
|
93 |
+
'Text2Video-Zero': '__assets__/t2v-z_watermark.png',
|
94 |
+
'None': None
|
95 |
+
}
|
96 |
+
if name in logo_paths:
|
97 |
+
return logo_paths[name]
|
98 |
+
return name
|
hf_utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bs4 import BeautifulSoup
|
2 |
+
import requests
|
3 |
+
|
4 |
+
|
5 |
+
def model_url_list():
|
6 |
+
url_list = []
|
7 |
+
for i in range(0, 5):
|
8 |
+
url_list.append(
|
9 |
+
f"https://huggingface.co/models?p={i}&sort=downloads&search=dreambooth")
|
10 |
+
return url_list
|
11 |
+
|
12 |
+
|
13 |
+
def data_scraping(url_list):
|
14 |
+
model_list = []
|
15 |
+
for url in url_list:
|
16 |
+
response = requests.get(url)
|
17 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
18 |
+
div_class = 'grid grid-cols-1 gap-5 2xl:grid-cols-2'
|
19 |
+
div = soup.find('div', {'class': div_class})
|
20 |
+
for a in div.find_all('a', href=True):
|
21 |
+
model_list.append(a['href'])
|
22 |
+
return model_list
|
23 |
+
|
24 |
+
|
25 |
+
def get_model_list():
|
26 |
+
model_list = data_scraping(model_url_list())
|
27 |
+
for i in range(len(model_list)):
|
28 |
+
model_list[i] = model_list[i][1:]
|
29 |
+
|
30 |
+
best_model_list = [
|
31 |
+
"dreamlike-art/dreamlike-photoreal-2.0",
|
32 |
+
"dreamlike-art/dreamlike-diffusion-1.0",
|
33 |
+
"runwayml/stable-diffusion-v1-5",
|
34 |
+
"CompVis/stable-diffusion-v1-4",
|
35 |
+
"prompthero/openjourney",
|
36 |
+
]
|
37 |
+
|
38 |
+
model_list = best_model_list + model_list
|
39 |
+
return model_list
|
style (1).css
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
utils (1).py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import PIL.Image
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from torchvision.transforms import Resize, InterpolationMode
|
8 |
+
import imageio
|
9 |
+
from einops import rearrange
|
10 |
+
import cv2
|
11 |
+
from PIL import Image
|
12 |
+
from annotator.util import resize_image, HWC3
|
13 |
+
from annotator.canny import CannyDetector
|
14 |
+
from annotator.openpose import OpenposeDetector
|
15 |
+
import decord
|
16 |
+
# decord.bridge.set_bridge('torch')
|
17 |
+
|
18 |
+
apply_canny = CannyDetector()
|
19 |
+
apply_openpose = OpenposeDetector()
|
20 |
+
|
21 |
+
|
22 |
+
def add_watermark(image, watermark_path, wm_rel_size=1/16, boundary=5):
|
23 |
+
'''
|
24 |
+
Creates a watermark on the saved inference image.
|
25 |
+
We request that you do not remove this to properly assign credit to
|
26 |
+
Shi-Lab's work.
|
27 |
+
'''
|
28 |
+
watermark = Image.open(watermark_path)
|
29 |
+
w_0, h_0 = watermark.size
|
30 |
+
H, W, _ = image.shape
|
31 |
+
wmsize = int(max(H, W) * wm_rel_size)
|
32 |
+
aspect = h_0 / w_0
|
33 |
+
if aspect > 1.0:
|
34 |
+
watermark = watermark.resize((wmsize, int(aspect * wmsize)), Image.LANCZOS)
|
35 |
+
else:
|
36 |
+
watermark = watermark.resize((int(wmsize / aspect), wmsize), Image.LANCZOS)
|
37 |
+
w, h = watermark.size
|
38 |
+
loc_h = H - h - boundary
|
39 |
+
loc_w = W - w - boundary
|
40 |
+
image = Image.fromarray(image)
|
41 |
+
mask = watermark if watermark.mode in ('RGBA', 'LA') else None
|
42 |
+
image.paste(watermark, (loc_w, loc_h), mask)
|
43 |
+
return image
|
44 |
+
|
45 |
+
|
46 |
+
def pre_process_canny(input_video, low_threshold=100, high_threshold=200):
|
47 |
+
detected_maps = []
|
48 |
+
for frame in input_video:
|
49 |
+
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
|
50 |
+
detected_map = apply_canny(img, low_threshold, high_threshold)
|
51 |
+
detected_map = HWC3(detected_map)
|
52 |
+
detected_maps.append(detected_map[None])
|
53 |
+
detected_maps = np.concatenate(detected_maps)
|
54 |
+
control = torch.from_numpy(detected_maps.copy()).float() / 255.0
|
55 |
+
return rearrange(control, 'f h w c -> f c h w')
|
56 |
+
|
57 |
+
|
58 |
+
def pre_process_pose(input_video, apply_pose_detect: bool = True):
|
59 |
+
detected_maps = []
|
60 |
+
for frame in input_video:
|
61 |
+
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
|
62 |
+
img = HWC3(img)
|
63 |
+
if apply_pose_detect:
|
64 |
+
detected_map, _ = apply_openpose(img)
|
65 |
+
else:
|
66 |
+
detected_map = img
|
67 |
+
detected_map = HWC3(detected_map)
|
68 |
+
H, W, C = img.shape
|
69 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
|
70 |
+
detected_maps.append(detected_map[None])
|
71 |
+
detected_maps = np.concatenate(detected_maps)
|
72 |
+
control = torch.from_numpy(detected_maps.copy()).float() / 255.0
|
73 |
+
return rearrange(control, 'f h w c -> f c h w')
|
74 |
+
|
75 |
+
|
76 |
+
def create_video(frames, fps, rescale=False, path=None, watermark=None):
|
77 |
+
if path is None:
|
78 |
+
dir = "temporal"
|
79 |
+
os.makedirs(dir, exist_ok=True)
|
80 |
+
path = os.path.join(dir, 'movie.mp4')
|
81 |
+
|
82 |
+
outputs = []
|
83 |
+
for i, x in enumerate(frames):
|
84 |
+
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
|
85 |
+
if rescale:
|
86 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
87 |
+
x = (x * 255).numpy().astype(np.uint8)
|
88 |
+
|
89 |
+
if watermark is not None:
|
90 |
+
x = add_watermark(x, watermark)
|
91 |
+
outputs.append(x)
|
92 |
+
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
|
93 |
+
|
94 |
+
imageio.mimsave(path, outputs, fps=fps)
|
95 |
+
return path
|
96 |
+
|
97 |
+
def create_gif(frames, fps, rescale=False, path=None, watermark=None):
|
98 |
+
if path is None:
|
99 |
+
dir = "temporal"
|
100 |
+
os.makedirs(dir, exist_ok=True)
|
101 |
+
path = os.path.join(dir, 'canny_db.gif')
|
102 |
+
|
103 |
+
outputs = []
|
104 |
+
for i, x in enumerate(frames):
|
105 |
+
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
|
106 |
+
if rescale:
|
107 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
108 |
+
x = (x * 255).numpy().astype(np.uint8)
|
109 |
+
if watermark is not None:
|
110 |
+
x = add_watermark(x, watermark)
|
111 |
+
outputs.append(x)
|
112 |
+
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
|
113 |
+
|
114 |
+
imageio.mimsave(path, outputs, fps=fps)
|
115 |
+
return path
|
116 |
+
|
117 |
+
def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1):
|
118 |
+
vr = decord.VideoReader(video_path)
|
119 |
+
initial_fps = vr.get_avg_fps()
|
120 |
+
if output_fps == -1:
|
121 |
+
output_fps = int(initial_fps)
|
122 |
+
if end_t == -1:
|
123 |
+
end_t = len(vr) / initial_fps
|
124 |
+
else:
|
125 |
+
end_t = min(len(vr) / initial_fps, end_t)
|
126 |
+
assert 0 <= start_t < end_t
|
127 |
+
assert output_fps > 0
|
128 |
+
start_f_ind = int(start_t * initial_fps)
|
129 |
+
end_f_ind = int(end_t * initial_fps)
|
130 |
+
num_f = int((end_t - start_t) * output_fps)
|
131 |
+
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
|
132 |
+
video = vr.get_batch(sample_idx)
|
133 |
+
if torch.is_tensor(video):
|
134 |
+
video = video.detach().cpu().numpy()
|
135 |
+
else:
|
136 |
+
video = video.asnumpy()
|
137 |
+
_, h, w, _ = video.shape
|
138 |
+
video = rearrange(video, "f h w c -> f c h w")
|
139 |
+
video = torch.Tensor(video).to(device).to(dtype)
|
140 |
+
if h > w:
|
141 |
+
w = int(w * resolution / h)
|
142 |
+
w = w - w % 8
|
143 |
+
h = resolution - resolution % 8
|
144 |
+
else:
|
145 |
+
h = int(h * resolution / w)
|
146 |
+
h = h - h % 8
|
147 |
+
w = resolution - resolution % 8
|
148 |
+
video = Resize((h, w), interpolation=InterpolationMode.BILINEAR, antialias=True)(video)
|
149 |
+
if normalize:
|
150 |
+
video = video / 127.5 - 1.0
|
151 |
+
return video, output_fps
|
152 |
+
|
153 |
+
|
154 |
+
def post_process_gif(list_of_results, image_resolution):
|
155 |
+
output_file = "/tmp/ddxk.gif"
|
156 |
+
imageio.mimsave(output_file, list_of_results, fps=4)
|
157 |
+
return output_file
|
158 |
+
|
159 |
+
|
160 |
+
class CrossFrameAttnProcessor:
|
161 |
+
def __init__(self, unet_chunk_size=2):
|
162 |
+
self.unet_chunk_size = unet_chunk_size
|
163 |
+
|
164 |
+
def __call__(
|
165 |
+
self,
|
166 |
+
attn,
|
167 |
+
hidden_states,
|
168 |
+
encoder_hidden_states=None,
|
169 |
+
attention_mask=None):
|
170 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
171 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
172 |
+
query = attn.to_q(hidden_states)
|
173 |
+
|
174 |
+
is_cross_attention = encoder_hidden_states is not None
|
175 |
+
if encoder_hidden_states is None:
|
176 |
+
encoder_hidden_states = hidden_states
|
177 |
+
elif attn.cross_attention_norm:
|
178 |
+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
179 |
+
key = attn.to_k(encoder_hidden_states)
|
180 |
+
value = attn.to_v(encoder_hidden_states)
|
181 |
+
# Sparse Attention
|
182 |
+
if not is_cross_attention:
|
183 |
+
video_length = key.size()[0] // self.unet_chunk_size
|
184 |
+
# former_frame_index = torch.arange(video_length) - 1
|
185 |
+
# former_frame_index[0] = 0
|
186 |
+
former_frame_index = [0] * video_length
|
187 |
+
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
|
188 |
+
key = key[:, former_frame_index]
|
189 |
+
key = rearrange(key, "b f d c -> (b f) d c")
|
190 |
+
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
|
191 |
+
value = value[:, former_frame_index]
|
192 |
+
value = rearrange(value, "b f d c -> (b f) d c")
|
193 |
+
|
194 |
+
query = attn.head_to_batch_dim(query)
|
195 |
+
key = attn.head_to_batch_dim(key)
|
196 |
+
value = attn.head_to_batch_dim(value)
|
197 |
+
|
198 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
199 |
+
hidden_states = torch.bmm(attention_probs, value)
|
200 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
201 |
+
|
202 |
+
# linear proj
|
203 |
+
hidden_states = attn.to_out[0](hidden_states)
|
204 |
+
# dropout
|
205 |
+
hidden_states = attn.to_out[1](hidden_states)
|
206 |
+
|
207 |
+
return hidden_states
|