|
from io import BytesIO |
|
|
|
import torch |
|
|
|
import PIL |
|
import requests |
|
from diffusers import RePaintPipeline, RePaintScheduler |
|
|
|
|
|
def download_image(url): |
|
response = requests.get(url) |
|
return PIL.Image.open(BytesIO(response.content)).convert("RGB") |
|
|
|
|
|
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png" |
|
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" |
|
|
|
|
|
original_image = download_image(img_url).resize((256, 256)) |
|
mask_image = download_image(mask_url).resize((256, 256)) |
|
|
|
|
|
DEVICE = "cuda:1" |
|
CACHE_DIR = "/comp_robot/rentianhe/weights/diffusers/" |
|
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256", cache_dir=CACHE_DIR) |
|
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler, cache_dir=CACHE_DIR) |
|
pipe = pipe.to(DEVICE) |
|
|
|
generator = torch.Generator(device=DEVICE).manual_seed(0) |
|
output = pipe( |
|
image=original_image, |
|
mask_image=mask_image, |
|
num_inference_steps=250, |
|
eta=0.0, |
|
jump_length=10, |
|
jump_n_sample=10, |
|
generator=generator, |
|
) |
|
inpainted_image = output.images[0] |
|
inpainted_image.save("./repaint_demo.jpg") |