Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import mlxu | |
import os | |
import re | |
import torch | |
from io import BytesIO | |
from natsort import natsorted | |
from PIL import Image | |
from inference import LocalInferenceModel | |
FLAGS, _ = mlxu.define_flags_with_default( | |
host='0.0.0.0', | |
port=5000, | |
dtype='float16', | |
checkpoint='Emma02/LVM_ckpts', | |
torch_devices='', | |
context_frames=16, | |
) | |
def natural_sort_key(s): | |
return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)] | |
def load_example_image_groups(directory): | |
example_groups = {} | |
for subdir in os.listdir(directory): | |
subdir_path = os.path.join(directory, subdir) | |
if os.path.isdir(subdir_path): | |
example_groups[subdir] = [] | |
images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
images = natsorted(images, key=natural_sort_key) | |
for filename in images: | |
img = Image.open(os.path.join(subdir_path, filename)) | |
example_groups[subdir].append(img) | |
return example_groups | |
def main(_): | |
assert FLAGS.checkpoint != '' | |
model = LocalInferenceModel( | |
checkpoint=FLAGS.checkpoint, | |
torch_device=torch.device("cuda"), | |
dtype=FLAGS.dtype, | |
context_frames=FLAGS.context_frames, | |
use_lock=False, | |
) | |
checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1) | |
checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1) | |
checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32) | |
def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9): | |
assert len(input_images) > 0 | |
input_images = [ | |
np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0 | |
for img in input_images | |
] | |
input_images = np.stack(input_images, axis=0) | |
output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0] | |
generated_images = [] | |
for candidate in output_images: | |
concatenated_image = [] | |
for i, img in enumerate(candidate): | |
concatenated_image.append(img) | |
if i < len(candidate) - 1: | |
concatenated_image.append(checkerboard) | |
generated_images.append( | |
Image.fromarray( | |
(np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8) | |
) | |
) | |
return generated_images | |
with gr.Blocks(css=""" | |
.small-button { | |
padding: 5px 10px; | |
min-width: 80px; | |
} | |
.large-gallery img { | |
width: 100%; | |
height: auto; | |
max-height: 150px; | |
} | |
""") as demo: | |
with gr.Column(): | |
image_list = gr.State([]) | |
gr.Markdown('# LVM Demo') | |
gr.Markdown(f'Serving model: {FLAGS.checkpoint}') | |
gr.Markdown('## Inputs') | |
with gr.Row(): | |
upload_drag = gr.File( | |
type='binary', | |
file_types=['image'], | |
file_count='multiple', | |
) | |
with gr.Column(): | |
gen_length_slider = gr.Slider( | |
label='Generation length', | |
minimum=1, | |
maximum=32, | |
value=1, | |
step=1, | |
interactive=True, | |
) | |
n_candidates_slider = gr.Slider( | |
label='Number of candidates', | |
minimum=1, | |
maximum=10, | |
value=4, | |
step=1, | |
interactive=True, | |
) | |
temp_slider = gr.Slider( | |
label='Temperature', | |
minimum=0, | |
maximum=2.0, | |
value=1.0, | |
interactive=True, | |
) | |
top_p_slider = gr.Slider( | |
label='Top p', | |
minimum=0, | |
maximum=1.0, | |
value=0.9, | |
interactive=True, | |
) | |
clear_btn = gr.Button( | |
value='Clear', | |
elem_classes=['small-button'], | |
) | |
generate_btn = gr.Button( | |
value='Generate', | |
interactive=False, | |
elem_classes=['small-button'], | |
) | |
input_gallery = gr.Gallery( | |
columns=7, | |
rows=1, | |
object_fit='scale-down', | |
label="Input image sequence" | |
) | |
gr.Markdown('## Outputs (multi candidates)') | |
output_gallery = gr.Gallery( | |
columns=4, | |
object_fit='scale-down', | |
label="Output image" | |
) | |
def upload_image_fn(files, images): | |
for file in files: | |
images.append(Image.open(BytesIO(file))) | |
return { | |
upload_drag: None, | |
image_list: images, | |
input_gallery: images, | |
generate_btn: gr.update(interactive=True), | |
} | |
def clear_fn(): | |
return { | |
image_list: [], | |
input_gallery: [], | |
generate_btn: gr.update(interactive=False), | |
output_gallery: [], | |
} | |
def disable_generate_btn(): | |
return { | |
generate_btn: gr.update(interactive=False), | |
} | |
def generate_fn(images, n_candidates, gen_length, temperature, top_p): | |
new_images = generate_images( | |
images, | |
gen_length, | |
n_candidates=n_candidates, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
return { | |
output_gallery: new_images, | |
generate_btn: gr.update(interactive=True), | |
} | |
upload_drag.upload( | |
upload_image_fn, | |
inputs=[upload_drag, image_list], | |
outputs=[upload_drag, image_list, input_gallery, generate_btn], | |
) | |
clear_btn.click( | |
clear_fn, | |
inputs=None, | |
outputs=[image_list, input_gallery, generate_btn, output_gallery], | |
) | |
generate_btn.click( | |
disable_generate_btn, | |
inputs=None, | |
outputs=[generate_btn], | |
).then( | |
generate_fn, | |
inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider], | |
outputs=[output_gallery, generate_btn], | |
) | |
example_groups = load_example_image_groups('prompts') | |
def add_image_group_fn(group_name, images): | |
new_images = images + example_groups[group_name] | |
return { | |
image_list: new_images, | |
input_gallery: new_images, | |
generate_btn: gr.update(interactive=True), | |
} | |
for group_name, group_images in example_groups.items(): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button']) | |
with gr.Column(scale=7): | |
group_gallery = gr.Gallery( | |
value=[Image.fromarray(np.array(img)) for img in group_images], | |
columns=5, | |
rows=1, | |
object_fit='scale-down', | |
label=group_name, | |
elem_classes=['large-gallery'], | |
) | |
add_button.click( | |
add_image_group_fn, | |
inputs=[gr.State(group_name), image_list], | |
outputs=[image_list, input_gallery, generate_btn], | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
mlxu.run(main) | |