climateGAN / climategan_wrapper.py
vict0rsch's picture
update from climategan space
3d5f935
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
# thank you @NimaBoscarino
import os
import re
from pathlib import Path
from uuid import uuid4
from minydra import resolved_args
import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from skimage.color import rgba2rgb
from skimage.transform import resize
from climategan.trainer import Trainer
CUDA = torch.cuda.is_available()
def concat_events(output_dict, events, i=None, axis=1):
"""
Concatenates the `i`th data in `output_dict` according to the keys listed
in `events` on dimension `axis`.
Args:
output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping
events to their corresponding data :
{k: [HxWxC]} (for i != None) or {k: BxHxWxC}.
events (list[str]): output_dict's keys to concatenate.
axis (int, optional): Concatenation axis. Defaults to 1.
"""
cs = [e for e in events if e in output_dict]
if i is not None:
return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis))
return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis))
def clear(folder):
"""
Deletes all the images without the inference separator "---" in their name.
Args:
folder (Union[str, Path]): The folder to clear.
"""
for i in list(Path(folder).iterdir()):
if i.is_file() and "---" in i.stem:
i.unlink()
def uint8(array, rescale=False):
"""
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
Args:
array (np.array): array to modify
Returns:
np.array(np.uint8): converted array
"""
if rescale:
if array.min() < 0:
if array.min() >= -1 and array.max() <= 1:
array = (array + 1) / 2
else:
raise ValueError(
f"Data range mismatch for image: ({array.min()}, {array.max()})"
)
if array.max() <= 1:
array = array * 255
return array.astype(np.uint8)
def resize_and_crop(img, to=640):
"""
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
is `to`, then crops this resized image in its center so that the output is `to x to`
without aspect ratio distortion
Args:
img (np.array): np.uint8 255 image
Returns:
np.array: [0, 1] np.float32 image
"""
# resize keeping aspect ratio: smallest dim is 640
h, w = img.shape[:2]
if h < w:
size = (to, int(to * w / h))
else:
size = (int(to * h / w), to)
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
r_img = uint8(r_img)
# crop in the center
H, W = r_img.shape[:2]
top = (H - to) // 2
left = (W - to) // 2
rc_img = r_img[top : top + to, left : left + to, :]
return rc_img / 255.0
def to_m1_p1(img):
"""
rescales a [0, 1] image to [-1, +1]
Args:
img (np.array): float32 numpy array of an image in [0, 1]
i (int): Index of the image being rescaled
Raises:
ValueError: If the image is not in [0, 1]
Returns:
np.array(np.float32): array in [-1, +1]
"""
if img.min() >= 0 and img.max() <= 1:
return (img.astype(np.float32) - 0.5) * 2
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
# No need to do any timing in this, since it's just for the HF Space
class ClimateGAN:
def __init__(self, model_path, dev_mode=False) -> None:
"""
A wrapper for the ClimateGAN model that you can use to generate
events from images or folders containing images.
Args:
model_path (Union[str, Path]): Where to load the Masker from
"""
torch.set_grad_enabled(False)
self.target_size = 640
self._stable_diffusion_is_setup = False
self.dev_mode = dev_mode
if self.dev_mode:
return
self.trainer = Trainer.resume_from_path(
model_path,
setup=True,
inference=True,
new_exp=None,
)
if CUDA:
self.trainer.G.half()
def _setup_stable_diffusion(self):
"""
Sets up the stable diffusion pipeline for in-painting.
Make sure you have accepted the license on the model's card
https://huggingface.co/CompVis/stable-diffusion-v1-4
"""
if self.dev_mode:
return
try:
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16" if CUDA else "main",
torch_dtype=torch.float16 if CUDA else torch.float32,
safety_checker=None,
use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
).to(self.trainer.device)
self._stable_diffusion_is_setup = True
except Exception as e:
print(
"\nCould not load stable diffusion model. "
+ "Please make sure you have accepted the license on the model's"
+ " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n"
)
raise e
def _preprocess_image(self, img):
"""
Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array
in [-1, 1].
Args:
img (np.array): Image to resize crop and rescale
Returns:
np.array: Resized, cropped and rescaled image
"""
# rgba to rgb
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
# to args.target_size
data = resize_and_crop(data, self.target_size)
# resize() produces [0, 1] images, rescale to [-1, 1]
data = to_m1_p1(data)
return data
# Does all three inferences at the moment.
def infer_single(
self,
orig_image,
painter="both",
prompt="An HD picture of a street with dirty water after a heavy flood",
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_flood",
"stable_copy_flood",
],
as_pil_image=False,
):
"""
Infers the image with the ClimateGAN model.
Importantly (and unlike self.infer_preprocessed_batch), the image is
pre-processed by self._preprocess_image before going through the networks.
Output dict contains the following keys:
- "input": The input image
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
- "masked_input": The input image with the mask applied
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
on the masked input (only if "painter" is "climategan" or "both").
- "stable_flood": The flooded image in-painted by the stable diffusion model
from the mask and the input image (only if "painter" is "stable_diffusion"
or "both").
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
model with its original context pasted back in:
y = m * flooded + (1-m) * input
(only if "painter" is "stable_diffusion" or "both").
Args:
orig_image (Union[str, np.array]): image to infer on. Can be a path to
an image which will be read.
painter (str, optional): Which painter to use: "climategan",
"stable_diffusion" or "both". Defaults to "both".
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "An HD picture of a street with dirty water after a heavy flood".
concats (list, optional): List of keys in `output` to concatenate together
in a new `{original_stem}_concat` image written. Defaults to:
["input", "masked_input", "climategan_flood", "stable_flood",
"stable_copy_flood"].
Returns:
dict: a dictionary containing the output images {k: HxWxC}. C is omitted
for masks (HxW).
"""
if self.dev_mode:
return {
"input": orig_image,
"mask": np.random.randint(0, 255, (640, 640)),
"masked_input": np.random.randint(0, 255, (640, 640, 3)),
"climategan_flood": np.random.randint(0, 255, (640, 640, 3)),
"stable_flood": np.random.randint(0, 255, (640, 640, 3)),
"stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)),
"concat": np.random.randint(0, 255, (640, 640 * 5, 3)),
"smog": np.random.randint(0, 255, (640, 640, 3)),
"wildfire": np.random.randint(0, 255, (640, 640, 3)),
"depth": np.random.randint(0, 255, (640, 640, 1)),
"segmentation": np.random.randint(0, 255, (640, 640, 3)),
}
return
image_array = (
np.array(Image.open(orig_image))
if isinstance(orig_image, str)
else orig_image
)
pil_image = None
if as_pil_image:
pil_image = Image.fromarray(image_array)
print("Preprocessing image")
image = self._preprocess_image(image_array)
output_dict = self.infer_preprocessed_batch(
images=image[None, ...],
painter=painter,
prompt=prompt,
concats=concats,
pil_image=pil_image,
)
print("Inference done")
return {k: v[0] for k, v in output_dict.items()}
def infer_preprocessed_batch(
self,
images,
painter="both",
prompt="An HD picture of a street with dirty water after a heavy flood",
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_flood",
"stable_copy_flood",
],
pil_image=None,
):
"""
Infers ClimateGAN predictions on a batch of preprocessed images.
It assumes that each image in the batch has been preprocessed with
self._preprocess_image().
Output dict contains the following keys:
- "input": The input image
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
- "masked_input": The input image with the mask applied
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
on the masked input (only if "painter" is "climategan" or "both").
- "stable_flood": The flooded image in-painted by the stable diffusion model
from the mask and the input image (only if "painter" is "stable_diffusion"
or "both").
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
model with its original context pasted back in:
y = m * flooded + (1-m) * input
(only if "painter" is "stable_diffusion" or "both").
Args:
images (np.array): A batch of input images BxHxWx3
painter (str, optional): Which painter to use: "climategan",
"stable_diffusion" or "both". Defaults to "both".
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "An HD picture of a street with dirty water after a heavy flood".
concats (list, optional): List of keys in `output` to concatenate together
in a new `{original_stem}_concat` image written. Defaults to:
["input", "masked_input", "climategan_flood", "stable_flood",
"stable_copy_flood"].
pil_image (PIL.Image, optional): The original PIL image. If provided,
will be used for a single inference (batch_size=1)
Returns:
dict: a dictionary containing the output images
"""
assert painter in [
"both",
"stable_diffusion",
"climategan",
], f"Unknown painter: {painter}"
ignore_event = set()
if painter == "stable_diffusion":
ignore_event.add("flood")
if pil_image is not None:
print("Warning: `pil_image` has been provided, it will override `images`")
images = self._preprocess_image(np.array(pil_image))[None, ...]
pil_image = Image.fromarray(((images[0] + 1) / 2 * 255).astype(np.uint8))
# Retrieve numpy events as a dict {event: array[BxHxWxC]}
print("Inferring ClimateGAN events")
outputs = self.trainer.infer_all(
images,
numpy=True,
bin_value=0.5,
half=CUDA,
ignore_event=ignore_event,
return_intermediates=True,
)
outputs["input"] = uint8(images, True)
# from Bx1xHxW to BxHxWx1
outputs["masked_input"] = outputs["input"] * (
outputs["mask"].squeeze(1)[..., None] == 0
)
if painter in {"both", "climategan"}:
outputs["climategan_flood"] = outputs.pop("flood")
else:
del outputs["flood"]
if painter != "climategan":
if not self._stable_diffusion_is_setup:
print("Setting up stable diffusion in-painting pipeline")
self._setup_stable_diffusion()
mask = outputs["mask"].squeeze(1)
input_images = (
torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
if pil_image is None
else pil_image
)
input_mask = (
torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
if pil_image is None
else Image.fromarray(mask[0])
)
print("Inferring stable diffusion in-painting for 50 steps")
floods = self.sdip_pipeline(
prompt=[prompt] * images.shape[0],
image=input_images,
mask_image=input_mask,
height=640,
width=640,
num_inference_steps=50,
)
print("Stable diffusion in-painting done")
bin_mask = mask[..., None] > 0
flood = np.stack([np.array(i) for i in floods.images])
copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask)
outputs["stable_flood"] = flood
outputs["stable_copy_flood"] = copy_flood
if concats:
print("Concatenating flood images")
outputs["concat"] = concat_events(outputs, concats, axis=2)
return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()}
def infer_folder(
self,
folder_path,
painter="both",
prompt="An HD picture of a street with dirty water after a heavy flood",
batch_size=4,
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_flood",
"stable_copy_flood",
],
write=True,
overwrite=False,
):
"""
Infers the images in a folder with the ClimateGAN model, batching images for
inference according to the batch_size.
Images must end in .jpg, .jpeg or .png (not case-sensitive).
Images must not contain the separator ("---") in their name.
Images will be written to disk in the same folder as the input images, with
a name that depends on its data, potentially the prompt and a random
identifier in case multiple inferences are run in the folder.
Output dict contains the following keys:
- "input": The input image
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
- "masked_input": The input image with the mask applied
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
on the masked input (only if "painter" is "climategan" or "both").
- "stable_flood": The flooded image in-painted by the stable diffusion model
from the mask and the input image (only if "painter" is "stable_diffusion"
or "both").
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
model with its original context pasted back in:
y = m * flooded + (1-m) * input
(only if "painter" is "stable_diffusion" or "both").
Args:
folder_path (Union[str, Path]): Where to read images from.
painter (str, optional): Which painter to use: "climategan",
"stable_diffusion" or "both". Defaults to "both".
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "An HD picture of a street with dirty water after a heavy flood".
batch_size (int, optional): Size of inference batches. Defaults to 4.
concats (list, optional): List of keys in `output` to concatenate together
in a new `{original_stem}_concat` image written. Defaults to:
["input", "masked_input", "climategan_flood", "stable_flood",
"stable_copy_flood"].
write (bool, optional): Whether or not to write the outputs to the input
folder.Defaults to True.
overwrite (Union[bool, str], optional): Whether to overwrite the images or
not. If a string is provided, it will be included in the name.
Defaults to False.
Returns:
dict: a dictionary containing the output images
"""
folder_path = Path(folder_path).expanduser().resolve()
assert folder_path.exists(), f"Folder {str(folder_path)} does not exist"
assert folder_path.is_dir(), f"{str(folder_path)} is not a directory"
im_paths = [
p
for p in folder_path.iterdir()
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
]
assert im_paths, f"No images found in {str(folder_path)}"
ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths]
batches = [
np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size)
]
inferences = [
self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches
]
outputs = {
k: [i for e in inferences for i in e[k]] for k in inferences[0].keys()
}
if write:
self.write(outputs, im_paths, painter, overwrite, prompt)
return outputs
def write(
self,
outputs,
im_paths,
painter="both",
overwrite=False,
prompt="",
):
"""
Writes the outputs of the inference to disk, in the input folder.
Images will be named like:
f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}"
`painter_type` is either "climategan" or f"stable_diffusion_{prompt}"
Args:
outputs (_type_): The inference procedure's output dict.
im_paths (list[Path]): The list of input images paths.
painter (str, optional): Which painter was used. Defaults to "both".
overwrite (bool, optional): Whether to overwrite the images or not.
If a string is provided, it will be included in the name.
If False, a random identifier will be added to the name.
Defaults to False.
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "".
"""
prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower()
overwrite_prefix = ""
if not overwrite:
overwrite_prefix = str(uuid4())[:8]
print("Writing events with prefix", overwrite_prefix)
else:
if isinstance(overwrite, str):
overwrite_prefix = overwrite
print("Writing events with prefix", overwrite_prefix)
# for each image, for each event/data type
for i, im_path in enumerate(im_paths):
for event, ims in outputs.items():
painter_prefix = ""
if painter == "climategan" and event == "flood":
painter_prefix = "climategan"
elif (
painter in {"stable_diffusion", "both"} and event == "stable_flood"
):
painter_prefix = f"_stable_{prompt}"
elif painter == "both" and event == "climategan_flood":
painter_prefix = ""
im = ims[i]
im = Image.fromarray(uint8(im))
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
im.save(im_path.parent / (imstem + im_path.suffix))
if __name__ == "__main__":
print("Run `$ python climategan_wrapper.py help` for usage instructions\n")
# parse arguments
args = resolved_args(
defaults={
"input_folder": None,
"output_folder": None,
"painter": "both",
"help": False,
}
)
# print help
if args.help:
print(
"Usage: python inference.py input_folder=/path/to/folder\n"
+ "By default inferences will be stored in the input folder.\n"
+ "Add `output_folder=/path/to/folder` for a different output folder.\n"
+ "By default, both ClimateGAN and Stable Diffusion will be used."
+ "Change this by adding `painter=climategan` or"
+ " `painter=stable_diffusion`.\n"
+ "Make sure you have agreed to the terms of use for the models."
+ "In particular, visit SD's model card to agree to the terms of use:"
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting"
)
# print args
args.pretty_print()
# load models
cg = ClimateGAN("models/climategan")
# check painter type
assert args.painter in {"climategan", "stable_diffusion", "both",}, (
f"Unknown painter {args.painter}. "
+ "Allowed values are 'climategan', 'stable_diffusion' and 'both'."
)
# load SD pipeline if need be
if args.painter != "climate_gan":
cg._setup_stable_diffusion()
# resolve input folder path
in_path = Path(args.input_folder).expanduser().resolve()
assert in_path.exists(), f"Folder {str(in_path)} does not exist"
# output is input if not specified
if args.output_folder is None:
out_path = in_path
# find images in input folder
im_paths = [
p
for p in in_path.iterdir()
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
]
assert im_paths, f"No images found in {str(im_paths)}"
print(f"\nFound {len(im_paths)} images in {str(in_path)}\n")
# infer and write
for i, im_path in enumerate(im_paths):
print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name)
outs = cg.infer_single(
np.array(Image.open(im_path)),
args.painter,
as_pil_image=True,
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_copy_flood",
],
)
for k, v in outs.items():
name = f"{im_path.stem}---{k}{im_path.suffix}"
im = Image.fromarray(uint8(v))
im.save(out_path / name)
print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n")