|
import torch |
|
from libs.base_utils import do_resize_content |
|
from imagedream.ldm.util import ( |
|
instantiate_from_config, |
|
get_obj_from_str, |
|
) |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
class TwoStagePipeline(object): |
|
def __init__( |
|
self, |
|
stage1_model_config, |
|
stage2_model_config, |
|
stage1_sampler_config, |
|
stage2_sampler_config, |
|
device="cuda", |
|
dtype=torch.float16, |
|
resize_rate=1, |
|
) -> None: |
|
""" |
|
only for two stage generate process. |
|
- the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config |
|
- the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config |
|
""" |
|
self.resize_rate = resize_rate |
|
|
|
self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model) |
|
self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False) |
|
self.stage1_model = self.stage1_model.to(device).to(dtype) |
|
|
|
self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model) |
|
sd = torch.load(stage2_model_config.resume, map_location="cpu") |
|
self.stage2_model.load_state_dict(sd, strict=False) |
|
self.stage2_model = self.stage2_model.to(device).to(dtype) |
|
|
|
self.stage1_model.device = device |
|
self.stage2_model.device = device |
|
self.device = device |
|
self.dtype = dtype |
|
self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)( |
|
self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params |
|
) |
|
self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)( |
|
self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params |
|
) |
|
|
|
def stage1_sample( |
|
self, |
|
pixel_img, |
|
prompt="3D assets", |
|
neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.", |
|
step=50, |
|
scale=5, |
|
ddim_eta=0.0, |
|
): |
|
if type(pixel_img) == str: |
|
pixel_img = Image.open(pixel_img) |
|
|
|
if isinstance(pixel_img, Image.Image): |
|
if pixel_img.mode == "RGBA": |
|
background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) |
|
pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") |
|
else: |
|
pixel_img = pixel_img.convert("RGB") |
|
else: |
|
raise |
|
uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device) |
|
stage1_images = self.stage1_sampler.i2i( |
|
self.stage1_sampler.model, |
|
self.stage1_sampler.size, |
|
prompt, |
|
uc=uc, |
|
sampler=self.stage1_sampler.sampler, |
|
ip=pixel_img, |
|
step=step, |
|
scale=scale, |
|
batch_size=self.stage1_sampler.batch_size, |
|
ddim_eta=ddim_eta, |
|
dtype=self.stage1_sampler.dtype, |
|
device=self.stage1_sampler.device, |
|
camera=self.stage1_sampler.camera, |
|
num_frames=self.stage1_sampler.num_frames, |
|
pixel_control=(self.stage1_sampler.mode == "pixel"), |
|
transform=self.stage1_sampler.image_transform, |
|
offset_noise=self.stage1_sampler.offset_noise, |
|
) |
|
|
|
stage1_images = [Image.fromarray(img) for img in stage1_images] |
|
stage1_images.pop(self.stage1_sampler.ref_position) |
|
return stage1_images |
|
|
|
def stage2_sample(self, pixel_img, stage1_images): |
|
if type(pixel_img) == str: |
|
pixel_img = Image.open(pixel_img) |
|
|
|
if isinstance(pixel_img, Image.Image): |
|
if pixel_img.mode == "RGBA": |
|
background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) |
|
pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") |
|
else: |
|
pixel_img = pixel_img.convert("RGB") |
|
else: |
|
raise |
|
stage2_images = self.stage2_sampler.i2iStage2( |
|
self.stage2_sampler.model, |
|
self.stage2_sampler.size, |
|
"3D assets", |
|
self.stage2_sampler.uc, |
|
self.stage2_sampler.sampler, |
|
pixel_images=stage1_images, |
|
ip=pixel_img, |
|
step=50, |
|
scale=5, |
|
batch_size=self.stage2_sampler.batch_size, |
|
ddim_eta=0.0, |
|
dtype=self.stage2_sampler.dtype, |
|
device=self.stage2_sampler.device, |
|
camera=self.stage2_sampler.camera, |
|
num_frames=self.stage2_sampler.num_frames, |
|
pixel_control=(self.stage2_sampler.mode == "pixel"), |
|
transform=self.stage2_sampler.image_transform, |
|
offset_noise=self.stage2_sampler.offset_noise, |
|
) |
|
stage2_images = [Image.fromarray(img) for img in stage2_images] |
|
return stage2_images |
|
|
|
def set_seed(self, seed): |
|
self.stage1_sampler.seed = seed |
|
self.stage2_sampler.seed = seed |
|
|
|
def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50): |
|
pixel_img = do_resize_content(pixel_img, self.resize_rate) |
|
stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step) |
|
stage2_images = self.stage2_sample(pixel_img, stage1_images) |
|
|
|
return { |
|
"ref_img": pixel_img, |
|
"stage1_images": stage1_images, |
|
"stage2_images": stage2_images, |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config |
|
stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config |
|
stage2_sampler_config = stage2_config.sampler |
|
stage1_sampler_config = stage1_config.sampler |
|
|
|
stage1_model_config = stage1_config.models |
|
stage2_model_config = stage2_config.models |
|
|
|
pipeline = TwoStagePipeline( |
|
stage1_model_config, |
|
stage2_model_config, |
|
stage1_sampler_config, |
|
stage2_sampler_config, |
|
) |
|
|
|
img = Image.open("assets/astronaut.png") |
|
rt_dict = pipeline(img) |
|
stage1_images = rt_dict["stage1_images"] |
|
stage2_images = rt_dict["stage2_images"] |
|
np_imgs = np.concatenate(stage1_images, 1) |
|
np_xyzs = np.concatenate(stage2_images, 1) |
|
Image.fromarray(np_imgs).save("pixel_images.png") |
|
Image.fromarray(np_xyzs).save("xyz_images.png") |
|
|