from functools import partial from PIL import Image import numpy as np import gradio as gr import torch import os import fire from omegaconf import OmegaConf from ldm.models.diffusion.sync_dreamer import SyncDDIMSampler, SyncMultiviewDiffusion from ldm.util import add_margin, instantiate_from_config from sam_utils import sam_init, sam_out_nosave import torch _TITLE = '''HarmonyView: Harmonizing Consistency and Diversity in One-Image-to-3D''' _DESCRIPTION = '''
Given a single-view image, HarmonyView is able to generate diverse and multiview-consistent images, resulting in creating plausible 3D contents with NeuS or NeRF
Procedure:
**Step 1**. Upload an image. ==> The foreground is masked out by SAM.
**Step 2**. Select the input to HarmonyView (Input image or SAM output). ==> Then, we crop it as inputs.
**Step 3**. Select "Elevation angle "and click "Run generation". ==> Generate multiview images. The **Elevation angle** is the elevation of the Input image. (This costs about 45s.)
You may adjust the **Crop size** and **Elevation angle** to get a better result!
To reconstruct a NeRF or a 3D mesh from the generated images, please refer to our [github repository](https://github.com/byeongjun-park/HarmonyView).
We have heavily borrowed codes from [Syncdreamer](https://huggingface.co/spaces/liuyuan-pal/SyncDreamer), which is an our strong baseline. ''' deployed = True if deployed: print(f"Is CUDA available: {torch.cuda.is_available()}") print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") class BackgroundRemoval: def __init__(self, device='cuda'): from carvekit.api.high import HiInterface self.interface = HiInterface( object_type="object", # Can be "object" or "hairs-like". batch_size_seg=5, batch_size_matting=1, device=device, seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, fp16=True, ) @torch.no_grad() def __call__(self, image): # image: [H, W, 3] array in [0, 255]. image = self.interface([image])[0] return image def resize_inputs(original_image, sam_image, crop_size, background_removal): image_input = original_image if background_removal == "Input image" else sam_image if image_input is None: return None alpha_np = np.asarray(image_input)[:, :, 3] coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] min_x, min_y = np.min(coords, 0) max_x, max_y = np.max(coords, 0) ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) h, w = ref_img_.height, ref_img_.width scale = crop_size / max(h, w) h_, w_ = int(scale * h), int(scale * w) ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC) results = add_margin(ref_img_, size=256) return results def generate(model, cfg_scale_1, cfg_scale_2, seed, image_input, elevation_input): sample_num = 1 sample_steps = 50 batch_view_num = 16 if deployed: assert isinstance(model, SyncMultiviewDiffusion) seed=int(seed) torch.random.manual_seed(seed) np.random.seed(seed) # prepare data image_input = np.asarray(image_input) image_input = image_input.astype(np.float32) / 255.0 alpha_values = image_input[:,:, 3:] image_input[:, :, :3] = alpha_values * image_input[:,:, :3] + 1 - alpha_values # white background image_input = image_input[:, :, :3] * 2.0 - 1.0 image_input = torch.from_numpy(image_input.astype(np.float32)) elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32)) data = {"input_image": image_input, "input_elevation": elevation_input} for k, v in data.items(): if deployed: data[k] = v.unsqueeze(0).cuda() else: data[k] = v.unsqueeze(0) data[k] = torch.repeat_interleave(data[k], sample_num, dim=0) if deployed: sampler = SyncDDIMSampler(model, sample_steps) x_sample = model.sample(sampler, data, (cfg_scale_1, cfg_scale_2), batch_view_num) else: x_sample = torch.zeros(sample_num, 16, 3, 256, 256) B, N, _, H, W = x_sample.shape x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5 x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255 x_sample = x_sample.astype(np.uint8) results = [] for bi in range(B): results.append(np.concatenate([x_sample[bi,ni] for ni in range(N)], 1)) results = np.concatenate(results, 0) return Image.fromarray(results) else: return Image.fromarray(np.zeros([sample_num*256,16*256,3],np.uint8)) def sam_predict(predictor, removal, raw_im): if raw_im is None: return None if deployed: raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS) image_nobg = removal(raw_im.convert('RGB')) arr = np.asarray(image_nobg)[:, :, -1] x_nonzero = np.nonzero(arr.sum(axis=0)) y_nonzero = np.nonzero(arr.sum(axis=1)) x_min = int(x_nonzero[0].min()) y_min = int(y_nonzero[0].min()) x_max = int(x_nonzero[0].max()) y_max = int(y_nonzero[0].max()) # image_nobg.save('./nobg.png') image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS) image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max)) # imsave('./mask.png', np.asarray(image_sam)[:,:,3]*255) image_sam = np.asarray(image_sam, np.float32) / 255 out_mask = image_sam[:, :, 3:] out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8) image_sam = Image.fromarray(out_img, mode='RGBA') # image_sam.save('./output.png') torch.cuda.empty_cache() return image_sam else: return raw_im def run_demo(): # device = f"cuda:0" if torch.cuda.is_available() else "cpu" # models = None # init_model(device, os.path.join(code_dir, ckpt)) cfg = 'configs/syncdreamer.yaml' ckpt = 'ckpt/syncdreamer-pretrain.ckpt' config = OmegaConf.load(cfg) # model = None if deployed: model = instantiate_from_config(config.model) print(f'loading model from {ckpt} ...') ckpt = torch.load(ckpt,map_location='cpu') model.load_state_dict(ckpt['state_dict'], strict=True) model = model.cuda().eval() del ckpt mask_predictor = sam_init() removal = BackgroundRemoval() else: model = None mask_predictor = None removal = None # NOTE: Examples must match inputs examples_full = [ ['hf_demo/examples/dragon.png',30,200,"Input image"], ['hf_demo/examples/drum_kids.png',15,240,"Input image"], ['hf_demo/examples/table.png',30,200,"Input image"], ['hf_demo/examples/panda_back.png', 15, 240, "SAM output"], ['hf_demo/examples/boxer_toy.png', 30, 220, "SAM output"], ['hf_demo/examples/rose.png',30,200,"Input image"], ['hf_demo/examples/monkey.png', 30, 200, "SAM output"], ['hf_demo/examples/forest.png',30,200,"SAM output"], ['hf_demo/examples/flower.png',0,200,"SAM output"], ['hf_demo/examples/teapot.png',20,200,"SAM output"], ] image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True) elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True) crop_size = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True) background_removal = gr.Radio(["Input image", "SAM output"], value=["SAM output"], label="Input to HarmonyView", info="Which image do you want for the input to HarmonyView?") # Compose demo layout & data flow. with gr.Blocks(title=_TITLE, css="hf_demo/style.css") as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown('# ' + _TITLE) gr.Markdown(_DESCRIPTION) with gr.Row(variant='panel'): with gr.Column(scale=1.2): gr.Examples( examples=examples_full, # NOTE: elements must match inputs list! inputs=[image_block, elevation, crop_size, background_removal], outputs=[image_block, elevation, crop_size, background_removal], cache_examples=False, label='Examples (click one of the images below to start)', examples_per_page=5, ) with gr.Column(scale=0.8): image_block.render() crop_size.render() fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False) with gr.Column(scale=0.8): sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False) # crop_btn = gr.Button('Crop it', variant='primary', interactive=True) elevation.render() fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False) with gr.Column(scale=0.8): input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to HarmonyView", height=256, interactive=False) background_removal.render() with gr.Accordion('Advanced options', open=False): cfg_scale_1 = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance 1', info='How consistent to be with the Input image', interactive=True) cfg_scale_2 = gr.Slider(0.5, 1.5, 1.0, step=0.1, label='Classifier free guidance 2', info='How diverse a novel view to create', interactive=True) seed = gr.Number(6033, label='Random seed', interactive=True) run_btn = gr.Button('Run generation', variant='primary', interactive=True) output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of HarmonyView", height=256, interactive=False) image_block.change(fn=partial(sam_predict, mask_predictor, removal), inputs=[image_block], outputs=[sam_block], queue=True) background_removal.change(fn=resize_inputs, inputs=[image_block, sam_block, crop_size, background_removal], outputs=[input_block], queue=True) crop_size.change(fn=resize_inputs, inputs=[image_block, sam_block, crop_size, background_removal], outputs=[input_block], queue=True) run_btn.click(partial(generate, model), inputs=[cfg_scale_1, cfg_scale_2, seed, input_block, elevation], outputs=[output_block], queue=True) demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD']) if __name__=="__main__": fire.Fire(run_demo)