File size: 23,898 Bytes
3d5f935 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 |
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
# thank you @NimaBoscarino
import os
import re
from pathlib import Path
from uuid import uuid4
from minydra import resolved_args
import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from skimage.color import rgba2rgb
from skimage.transform import resize
from climategan.trainer import Trainer
CUDA = torch.cuda.is_available()
def concat_events(output_dict, events, i=None, axis=1):
"""
Concatenates the `i`th data in `output_dict` according to the keys listed
in `events` on dimension `axis`.
Args:
output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping
events to their corresponding data :
{k: [HxWxC]} (for i != None) or {k: BxHxWxC}.
events (list[str]): output_dict's keys to concatenate.
axis (int, optional): Concatenation axis. Defaults to 1.
"""
cs = [e for e in events if e in output_dict]
if i is not None:
return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis))
return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis))
def clear(folder):
"""
Deletes all the images without the inference separator "---" in their name.
Args:
folder (Union[str, Path]): The folder to clear.
"""
for i in list(Path(folder).iterdir()):
if i.is_file() and "---" in i.stem:
i.unlink()
def uint8(array, rescale=False):
"""
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
Args:
array (np.array): array to modify
Returns:
np.array(np.uint8): converted array
"""
if rescale:
if array.min() < 0:
if array.min() >= -1 and array.max() <= 1:
array = (array + 1) / 2
else:
raise ValueError(
f"Data range mismatch for image: ({array.min()}, {array.max()})"
)
if array.max() <= 1:
array = array * 255
return array.astype(np.uint8)
def resize_and_crop(img, to=640):
"""
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
is `to`, then crops this resized image in its center so that the output is `to x to`
without aspect ratio distortion
Args:
img (np.array): np.uint8 255 image
Returns:
np.array: [0, 1] np.float32 image
"""
# resize keeping aspect ratio: smallest dim is 640
h, w = img.shape[:2]
if h < w:
size = (to, int(to * w / h))
else:
size = (int(to * h / w), to)
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
r_img = uint8(r_img)
# crop in the center
H, W = r_img.shape[:2]
top = (H - to) // 2
left = (W - to) // 2
rc_img = r_img[top : top + to, left : left + to, :]
return rc_img / 255.0
def to_m1_p1(img):
"""
rescales a [0, 1] image to [-1, +1]
Args:
img (np.array): float32 numpy array of an image in [0, 1]
i (int): Index of the image being rescaled
Raises:
ValueError: If the image is not in [0, 1]
Returns:
np.array(np.float32): array in [-1, +1]
"""
if img.min() >= 0 and img.max() <= 1:
return (img.astype(np.float32) - 0.5) * 2
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
# No need to do any timing in this, since it's just for the HF Space
class ClimateGAN:
def __init__(self, model_path, dev_mode=False) -> None:
"""
A wrapper for the ClimateGAN model that you can use to generate
events from images or folders containing images.
Args:
model_path (Union[str, Path]): Where to load the Masker from
"""
torch.set_grad_enabled(False)
self.target_size = 640
self._stable_diffusion_is_setup = False
self.dev_mode = dev_mode
if self.dev_mode:
return
self.trainer = Trainer.resume_from_path(
model_path,
setup=True,
inference=True,
new_exp=None,
)
if CUDA:
self.trainer.G.half()
def _setup_stable_diffusion(self):
"""
Sets up the stable diffusion pipeline for in-painting.
Make sure you have accepted the license on the model's card
https://huggingface.co/CompVis/stable-diffusion-v1-4
"""
if self.dev_mode:
return
try:
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16" if CUDA else "main",
torch_dtype=torch.float16 if CUDA else torch.float32,
safety_checker=None,
use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
).to(self.trainer.device)
self._stable_diffusion_is_setup = True
except Exception as e:
print(
"\nCould not load stable diffusion model. "
+ "Please make sure you have accepted the license on the model's"
+ " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n"
)
raise e
def _preprocess_image(self, img):
"""
Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array
in [-1, 1].
Args:
img (np.array): Image to resize crop and rescale
Returns:
np.array: Resized, cropped and rescaled image
"""
# rgba to rgb
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
# to args.target_size
data = resize_and_crop(data, self.target_size)
# resize() produces [0, 1] images, rescale to [-1, 1]
data = to_m1_p1(data)
return data
# Does all three inferences at the moment.
def infer_single(
self,
orig_image,
painter="both",
prompt="An HD picture of a street with dirty water after a heavy flood",
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_flood",
"stable_copy_flood",
],
as_pil_image=False,
):
"""
Infers the image with the ClimateGAN model.
Importantly (and unlike self.infer_preprocessed_batch), the image is
pre-processed by self._preprocess_image before going through the networks.
Output dict contains the following keys:
- "input": The input image
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
- "masked_input": The input image with the mask applied
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
on the masked input (only if "painter" is "climategan" or "both").
- "stable_flood": The flooded image in-painted by the stable diffusion model
from the mask and the input image (only if "painter" is "stable_diffusion"
or "both").
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
model with its original context pasted back in:
y = m * flooded + (1-m) * input
(only if "painter" is "stable_diffusion" or "both").
Args:
orig_image (Union[str, np.array]): image to infer on. Can be a path to
an image which will be read.
painter (str, optional): Which painter to use: "climategan",
"stable_diffusion" or "both". Defaults to "both".
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "An HD picture of a street with dirty water after a heavy flood".
concats (list, optional): List of keys in `output` to concatenate together
in a new `{original_stem}_concat` image written. Defaults to:
["input", "masked_input", "climategan_flood", "stable_flood",
"stable_copy_flood"].
Returns:
dict: a dictionary containing the output images {k: HxWxC}. C is omitted
for masks (HxW).
"""
if self.dev_mode:
return {
"input": orig_image,
"mask": np.random.randint(0, 255, (640, 640)),
"masked_input": np.random.randint(0, 255, (640, 640, 3)),
"climategan_flood": np.random.randint(0, 255, (640, 640, 3)),
"stable_flood": np.random.randint(0, 255, (640, 640, 3)),
"stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)),
"concat": np.random.randint(0, 255, (640, 640 * 5, 3)),
"smog": np.random.randint(0, 255, (640, 640, 3)),
"wildfire": np.random.randint(0, 255, (640, 640, 3)),
"depth": np.random.randint(0, 255, (640, 640, 1)),
"segmentation": np.random.randint(0, 255, (640, 640, 3)),
}
return
image_array = (
np.array(Image.open(orig_image))
if isinstance(orig_image, str)
else orig_image
)
pil_image = None
if as_pil_image:
pil_image = Image.fromarray(image_array)
print("Preprocessing image")
image = self._preprocess_image(image_array)
output_dict = self.infer_preprocessed_batch(
images=image[None, ...],
painter=painter,
prompt=prompt,
concats=concats,
pil_image=pil_image,
)
print("Inference done")
return {k: v[0] for k, v in output_dict.items()}
def infer_preprocessed_batch(
self,
images,
painter="both",
prompt="An HD picture of a street with dirty water after a heavy flood",
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_flood",
"stable_copy_flood",
],
pil_image=None,
):
"""
Infers ClimateGAN predictions on a batch of preprocessed images.
It assumes that each image in the batch has been preprocessed with
self._preprocess_image().
Output dict contains the following keys:
- "input": The input image
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
- "masked_input": The input image with the mask applied
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
on the masked input (only if "painter" is "climategan" or "both").
- "stable_flood": The flooded image in-painted by the stable diffusion model
from the mask and the input image (only if "painter" is "stable_diffusion"
or "both").
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
model with its original context pasted back in:
y = m * flooded + (1-m) * input
(only if "painter" is "stable_diffusion" or "both").
Args:
images (np.array): A batch of input images BxHxWx3
painter (str, optional): Which painter to use: "climategan",
"stable_diffusion" or "both". Defaults to "both".
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "An HD picture of a street with dirty water after a heavy flood".
concats (list, optional): List of keys in `output` to concatenate together
in a new `{original_stem}_concat` image written. Defaults to:
["input", "masked_input", "climategan_flood", "stable_flood",
"stable_copy_flood"].
pil_image (PIL.Image, optional): The original PIL image. If provided,
will be used for a single inference (batch_size=1)
Returns:
dict: a dictionary containing the output images
"""
assert painter in [
"both",
"stable_diffusion",
"climategan",
], f"Unknown painter: {painter}"
ignore_event = set()
if painter == "stable_diffusion":
ignore_event.add("flood")
if pil_image is not None:
print("Warning: `pil_image` has been provided, it will override `images`")
images = self._preprocess_image(np.array(pil_image))[None, ...]
pil_image = Image.fromarray(((images[0] + 1) / 2 * 255).astype(np.uint8))
# Retrieve numpy events as a dict {event: array[BxHxWxC]}
print("Inferring ClimateGAN events")
outputs = self.trainer.infer_all(
images,
numpy=True,
bin_value=0.5,
half=CUDA,
ignore_event=ignore_event,
return_intermediates=True,
)
outputs["input"] = uint8(images, True)
# from Bx1xHxW to BxHxWx1
outputs["masked_input"] = outputs["input"] * (
outputs["mask"].squeeze(1)[..., None] == 0
)
if painter in {"both", "climategan"}:
outputs["climategan_flood"] = outputs.pop("flood")
else:
del outputs["flood"]
if painter != "climategan":
if not self._stable_diffusion_is_setup:
print("Setting up stable diffusion in-painting pipeline")
self._setup_stable_diffusion()
mask = outputs["mask"].squeeze(1)
input_images = (
torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
if pil_image is None
else pil_image
)
input_mask = (
torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
if pil_image is None
else Image.fromarray(mask[0])
)
print("Inferring stable diffusion in-painting for 50 steps")
floods = self.sdip_pipeline(
prompt=[prompt] * images.shape[0],
image=input_images,
mask_image=input_mask,
height=640,
width=640,
num_inference_steps=50,
)
print("Stable diffusion in-painting done")
bin_mask = mask[..., None] > 0
flood = np.stack([np.array(i) for i in floods.images])
copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask)
outputs["stable_flood"] = flood
outputs["stable_copy_flood"] = copy_flood
if concats:
print("Concatenating flood images")
outputs["concat"] = concat_events(outputs, concats, axis=2)
return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()}
def infer_folder(
self,
folder_path,
painter="both",
prompt="An HD picture of a street with dirty water after a heavy flood",
batch_size=4,
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_flood",
"stable_copy_flood",
],
write=True,
overwrite=False,
):
"""
Infers the images in a folder with the ClimateGAN model, batching images for
inference according to the batch_size.
Images must end in .jpg, .jpeg or .png (not case-sensitive).
Images must not contain the separator ("---") in their name.
Images will be written to disk in the same folder as the input images, with
a name that depends on its data, potentially the prompt and a random
identifier in case multiple inferences are run in the folder.
Output dict contains the following keys:
- "input": The input image
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
- "masked_input": The input image with the mask applied
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
on the masked input (only if "painter" is "climategan" or "both").
- "stable_flood": The flooded image in-painted by the stable diffusion model
from the mask and the input image (only if "painter" is "stable_diffusion"
or "both").
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
model with its original context pasted back in:
y = m * flooded + (1-m) * input
(only if "painter" is "stable_diffusion" or "both").
Args:
folder_path (Union[str, Path]): Where to read images from.
painter (str, optional): Which painter to use: "climategan",
"stable_diffusion" or "both". Defaults to "both".
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "An HD picture of a street with dirty water after a heavy flood".
batch_size (int, optional): Size of inference batches. Defaults to 4.
concats (list, optional): List of keys in `output` to concatenate together
in a new `{original_stem}_concat` image written. Defaults to:
["input", "masked_input", "climategan_flood", "stable_flood",
"stable_copy_flood"].
write (bool, optional): Whether or not to write the outputs to the input
folder.Defaults to True.
overwrite (Union[bool, str], optional): Whether to overwrite the images or
not. If a string is provided, it will be included in the name.
Defaults to False.
Returns:
dict: a dictionary containing the output images
"""
folder_path = Path(folder_path).expanduser().resolve()
assert folder_path.exists(), f"Folder {str(folder_path)} does not exist"
assert folder_path.is_dir(), f"{str(folder_path)} is not a directory"
im_paths = [
p
for p in folder_path.iterdir()
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
]
assert im_paths, f"No images found in {str(folder_path)}"
ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths]
batches = [
np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size)
]
inferences = [
self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches
]
outputs = {
k: [i for e in inferences for i in e[k]] for k in inferences[0].keys()
}
if write:
self.write(outputs, im_paths, painter, overwrite, prompt)
return outputs
def write(
self,
outputs,
im_paths,
painter="both",
overwrite=False,
prompt="",
):
"""
Writes the outputs of the inference to disk, in the input folder.
Images will be named like:
f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}"
`painter_type` is either "climategan" or f"stable_diffusion_{prompt}"
Args:
outputs (_type_): The inference procedure's output dict.
im_paths (list[Path]): The list of input images paths.
painter (str, optional): Which painter was used. Defaults to "both".
overwrite (bool, optional): Whether to overwrite the images or not.
If a string is provided, it will be included in the name.
If False, a random identifier will be added to the name.
Defaults to False.
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "".
"""
prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower()
overwrite_prefix = ""
if not overwrite:
overwrite_prefix = str(uuid4())[:8]
print("Writing events with prefix", overwrite_prefix)
else:
if isinstance(overwrite, str):
overwrite_prefix = overwrite
print("Writing events with prefix", overwrite_prefix)
# for each image, for each event/data type
for i, im_path in enumerate(im_paths):
for event, ims in outputs.items():
painter_prefix = ""
if painter == "climategan" and event == "flood":
painter_prefix = "climategan"
elif (
painter in {"stable_diffusion", "both"} and event == "stable_flood"
):
painter_prefix = f"_stable_{prompt}"
elif painter == "both" and event == "climategan_flood":
painter_prefix = ""
im = ims[i]
im = Image.fromarray(uint8(im))
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
im.save(im_path.parent / (imstem + im_path.suffix))
if __name__ == "__main__":
print("Run `$ python climategan_wrapper.py help` for usage instructions\n")
# parse arguments
args = resolved_args(
defaults={
"input_folder": None,
"output_folder": None,
"painter": "both",
"help": False,
}
)
# print help
if args.help:
print(
"Usage: python inference.py input_folder=/path/to/folder\n"
+ "By default inferences will be stored in the input folder.\n"
+ "Add `output_folder=/path/to/folder` for a different output folder.\n"
+ "By default, both ClimateGAN and Stable Diffusion will be used."
+ "Change this by adding `painter=climategan` or"
+ " `painter=stable_diffusion`.\n"
+ "Make sure you have agreed to the terms of use for the models."
+ "In particular, visit SD's model card to agree to the terms of use:"
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting"
)
# print args
args.pretty_print()
# load models
cg = ClimateGAN("models/climategan")
# check painter type
assert args.painter in {"climategan", "stable_diffusion", "both",}, (
f"Unknown painter {args.painter}. "
+ "Allowed values are 'climategan', 'stable_diffusion' and 'both'."
)
# load SD pipeline if need be
if args.painter != "climate_gan":
cg._setup_stable_diffusion()
# resolve input folder path
in_path = Path(args.input_folder).expanduser().resolve()
assert in_path.exists(), f"Folder {str(in_path)} does not exist"
# output is input if not specified
if args.output_folder is None:
out_path = in_path
# find images in input folder
im_paths = [
p
for p in in_path.iterdir()
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
]
assert im_paths, f"No images found in {str(im_paths)}"
print(f"\nFound {len(im_paths)} images in {str(in_path)}\n")
# infer and write
for i, im_path in enumerate(im_paths):
print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name)
outs = cg.infer_single(
np.array(Image.open(im_path)),
args.painter,
as_pil_image=True,
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_copy_flood",
],
)
for k, v in outs.items():
name = f"{im_path.stem}---{k}{im_path.suffix}"
im = Image.fromarray(uint8(v))
im.save(out_path / name)
print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n")
|