|
|
|
|
|
|
|
import os |
|
import re |
|
from pathlib import Path |
|
from uuid import uuid4 |
|
from minydra import resolved_args |
|
import numpy as np |
|
import torch |
|
from diffusers import StableDiffusionInpaintPipeline |
|
from PIL import Image |
|
from skimage.color import rgba2rgb |
|
from skimage.transform import resize |
|
|
|
from climategan.trainer import Trainer |
|
|
|
|
|
CUDA = torch.cuda.is_available() |
|
|
|
|
|
def concat_events(output_dict, events, i=None, axis=1): |
|
""" |
|
Concatenates the `i`th data in `output_dict` according to the keys listed |
|
in `events` on dimension `axis`. |
|
|
|
Args: |
|
output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping |
|
events to their corresponding data : |
|
{k: [HxWxC]} (for i != None) or {k: BxHxWxC}. |
|
events (list[str]): output_dict's keys to concatenate. |
|
axis (int, optional): Concatenation axis. Defaults to 1. |
|
""" |
|
cs = [e for e in events if e in output_dict] |
|
if i is not None: |
|
return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis)) |
|
return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis)) |
|
|
|
|
|
def clear(folder): |
|
""" |
|
Deletes all the images without the inference separator "---" in their name. |
|
|
|
Args: |
|
folder (Union[str, Path]): The folder to clear. |
|
""" |
|
for i in list(Path(folder).iterdir()): |
|
if i.is_file() and "---" in i.stem: |
|
i.unlink() |
|
|
|
|
|
def uint8(array, rescale=False): |
|
""" |
|
convert an array to np.uint8 (does not rescale or anything else than changing dtype) |
|
Args: |
|
array (np.array): array to modify |
|
Returns: |
|
np.array(np.uint8): converted array |
|
""" |
|
if rescale: |
|
if array.min() < 0: |
|
if array.min() >= -1 and array.max() <= 1: |
|
array = (array + 1) / 2 |
|
else: |
|
raise ValueError( |
|
f"Data range mismatch for image: ({array.min()}, {array.max()})" |
|
) |
|
if array.max() <= 1: |
|
array = array * 255 |
|
return array.astype(np.uint8) |
|
|
|
|
|
def resize_and_crop(img, to=640): |
|
""" |
|
Resizes an image so that it keeps the aspect ratio and the smallest dimensions |
|
is `to`, then crops this resized image in its center so that the output is `to x to` |
|
without aspect ratio distortion |
|
Args: |
|
img (np.array): np.uint8 255 image |
|
Returns: |
|
np.array: [0, 1] np.float32 image |
|
""" |
|
|
|
h, w = img.shape[:2] |
|
if h < w: |
|
size = (to, int(to * w / h)) |
|
else: |
|
size = (int(to * h / w), to) |
|
|
|
r_img = resize(img, size, preserve_range=True, anti_aliasing=True) |
|
r_img = uint8(r_img) |
|
|
|
|
|
H, W = r_img.shape[:2] |
|
|
|
top = (H - to) // 2 |
|
left = (W - to) // 2 |
|
|
|
rc_img = r_img[top : top + to, left : left + to, :] |
|
|
|
return rc_img / 255.0 |
|
|
|
|
|
def to_m1_p1(img): |
|
""" |
|
rescales a [0, 1] image to [-1, +1] |
|
Args: |
|
img (np.array): float32 numpy array of an image in [0, 1] |
|
i (int): Index of the image being rescaled |
|
Raises: |
|
ValueError: If the image is not in [0, 1] |
|
Returns: |
|
np.array(np.float32): array in [-1, +1] |
|
""" |
|
if img.min() >= 0 and img.max() <= 1: |
|
return (img.astype(np.float32) - 0.5) * 2 |
|
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})") |
|
|
|
|
|
|
|
class ClimateGAN: |
|
def __init__(self, model_path, dev_mode=False) -> None: |
|
""" |
|
A wrapper for the ClimateGAN model that you can use to generate |
|
events from images or folders containing images. |
|
|
|
Args: |
|
model_path (Union[str, Path]): Where to load the Masker from |
|
""" |
|
torch.set_grad_enabled(False) |
|
self.target_size = 640 |
|
self._stable_diffusion_is_setup = False |
|
self.dev_mode = dev_mode |
|
if self.dev_mode: |
|
return |
|
self.trainer = Trainer.resume_from_path( |
|
model_path, |
|
setup=True, |
|
inference=True, |
|
new_exp=None, |
|
) |
|
if CUDA: |
|
self.trainer.G.half() |
|
|
|
def _setup_stable_diffusion(self): |
|
""" |
|
Sets up the stable diffusion pipeline for in-painting. |
|
Make sure you have accepted the license on the model's card |
|
https://huggingface.co/CompVis/stable-diffusion-v1-4 |
|
""" |
|
if self.dev_mode: |
|
return |
|
|
|
try: |
|
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", |
|
revision="fp16" if CUDA else "main", |
|
torch_dtype=torch.float16 if CUDA else torch.float32, |
|
safety_checker=None, |
|
use_auth_token=os.environ.get("HF_AUTH_TOKEN"), |
|
).to(self.trainer.device) |
|
self._stable_diffusion_is_setup = True |
|
except Exception as e: |
|
print( |
|
"\nCould not load stable diffusion model. " |
|
+ "Please make sure you have accepted the license on the model's" |
|
+ " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n" |
|
) |
|
raise e |
|
|
|
def _preprocess_image(self, img): |
|
""" |
|
Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array |
|
in [-1, 1]. |
|
|
|
Args: |
|
img (np.array): Image to resize crop and rescale |
|
|
|
Returns: |
|
np.array: Resized, cropped and rescaled image |
|
""" |
|
|
|
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255) |
|
|
|
|
|
data = resize_and_crop(data, self.target_size) |
|
|
|
|
|
data = to_m1_p1(data) |
|
return data |
|
|
|
|
|
def infer_single( |
|
self, |
|
orig_image, |
|
painter="both", |
|
prompt="An HD picture of a street with dirty water after a heavy flood", |
|
concats=[ |
|
"input", |
|
"masked_input", |
|
"climategan_flood", |
|
"stable_flood", |
|
"stable_copy_flood", |
|
], |
|
as_pil_image=False, |
|
): |
|
""" |
|
Infers the image with the ClimateGAN model. |
|
Importantly (and unlike self.infer_preprocessed_batch), the image is |
|
pre-processed by self._preprocess_image before going through the networks. |
|
|
|
Output dict contains the following keys: |
|
- "input": The input image |
|
- "mask": The mask used to generate the flood (from ClimateGAN's Masker) |
|
- "masked_input": The input image with the mask applied |
|
- "climategan_flood": The flooded image generated by ClimateGAN's Painter |
|
on the masked input (only if "painter" is "climategan" or "both"). |
|
- "stable_flood": The flooded image in-painted by the stable diffusion model |
|
from the mask and the input image (only if "painter" is "stable_diffusion" |
|
or "both"). |
|
- "stable_copy_flood": The flooded image in-painted by the stable diffusion |
|
model with its original context pasted back in: |
|
y = m * flooded + (1-m) * input |
|
(only if "painter" is "stable_diffusion" or "both"). |
|
|
|
Args: |
|
orig_image (Union[str, np.array]): image to infer on. Can be a path to |
|
an image which will be read. |
|
painter (str, optional): Which painter to use: "climategan", |
|
"stable_diffusion" or "both". Defaults to "both". |
|
prompt (str, optional): The prompt used to guide the diffusion. Defaults |
|
to "An HD picture of a street with dirty water after a heavy flood". |
|
concats (list, optional): List of keys in `output` to concatenate together |
|
in a new `{original_stem}_concat` image written. Defaults to: |
|
["input", "masked_input", "climategan_flood", "stable_flood", |
|
"stable_copy_flood"]. |
|
|
|
Returns: |
|
dict: a dictionary containing the output images {k: HxWxC}. C is omitted |
|
for masks (HxW). |
|
""" |
|
if self.dev_mode: |
|
return { |
|
"input": orig_image, |
|
"mask": np.random.randint(0, 255, (640, 640)), |
|
"masked_input": np.random.randint(0, 255, (640, 640, 3)), |
|
"climategan_flood": np.random.randint(0, 255, (640, 640, 3)), |
|
"stable_flood": np.random.randint(0, 255, (640, 640, 3)), |
|
"stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)), |
|
"concat": np.random.randint(0, 255, (640, 640 * 5, 3)), |
|
"smog": np.random.randint(0, 255, (640, 640, 3)), |
|
"wildfire": np.random.randint(0, 255, (640, 640, 3)), |
|
"depth": np.random.randint(0, 255, (640, 640, 1)), |
|
"segmentation": np.random.randint(0, 255, (640, 640, 3)), |
|
} |
|
return |
|
|
|
image_array = ( |
|
np.array(Image.open(orig_image)) |
|
if isinstance(orig_image, str) |
|
else orig_image |
|
) |
|
|
|
pil_image = None |
|
if as_pil_image: |
|
pil_image = Image.fromarray(image_array) |
|
print("Preprocessing image") |
|
image = self._preprocess_image(image_array) |
|
output_dict = self.infer_preprocessed_batch( |
|
images=image[None, ...], |
|
painter=painter, |
|
prompt=prompt, |
|
concats=concats, |
|
pil_image=pil_image, |
|
) |
|
print("Inference done") |
|
return {k: v[0] for k, v in output_dict.items()} |
|
|
|
def infer_preprocessed_batch( |
|
self, |
|
images, |
|
painter="both", |
|
prompt="An HD picture of a street with dirty water after a heavy flood", |
|
concats=[ |
|
"input", |
|
"masked_input", |
|
"climategan_flood", |
|
"stable_flood", |
|
"stable_copy_flood", |
|
], |
|
pil_image=None, |
|
): |
|
""" |
|
Infers ClimateGAN predictions on a batch of preprocessed images. |
|
It assumes that each image in the batch has been preprocessed with |
|
self._preprocess_image(). |
|
|
|
Output dict contains the following keys: |
|
- "input": The input image |
|
- "mask": The mask used to generate the flood (from ClimateGAN's Masker) |
|
- "masked_input": The input image with the mask applied |
|
- "climategan_flood": The flooded image generated by ClimateGAN's Painter |
|
on the masked input (only if "painter" is "climategan" or "both"). |
|
- "stable_flood": The flooded image in-painted by the stable diffusion model |
|
from the mask and the input image (only if "painter" is "stable_diffusion" |
|
or "both"). |
|
- "stable_copy_flood": The flooded image in-painted by the stable diffusion |
|
model with its original context pasted back in: |
|
y = m * flooded + (1-m) * input |
|
(only if "painter" is "stable_diffusion" or "both"). |
|
|
|
Args: |
|
images (np.array): A batch of input images BxHxWx3 |
|
painter (str, optional): Which painter to use: "climategan", |
|
"stable_diffusion" or "both". Defaults to "both". |
|
prompt (str, optional): The prompt used to guide the diffusion. Defaults |
|
to "An HD picture of a street with dirty water after a heavy flood". |
|
concats (list, optional): List of keys in `output` to concatenate together |
|
in a new `{original_stem}_concat` image written. Defaults to: |
|
["input", "masked_input", "climategan_flood", "stable_flood", |
|
"stable_copy_flood"]. |
|
pil_image (PIL.Image, optional): The original PIL image. If provided, |
|
will be used for a single inference (batch_size=1) |
|
|
|
Returns: |
|
dict: a dictionary containing the output images |
|
""" |
|
assert painter in [ |
|
"both", |
|
"stable_diffusion", |
|
"climategan", |
|
], f"Unknown painter: {painter}" |
|
|
|
ignore_event = set() |
|
if painter == "stable_diffusion": |
|
ignore_event.add("flood") |
|
|
|
if pil_image is not None: |
|
print("Warning: `pil_image` has been provided, it will override `images`") |
|
images = self._preprocess_image(np.array(pil_image))[None, ...] |
|
pil_image = Image.fromarray(((images[0] + 1) / 2 * 255).astype(np.uint8)) |
|
|
|
|
|
print("Inferring ClimateGAN events") |
|
outputs = self.trainer.infer_all( |
|
images, |
|
numpy=True, |
|
bin_value=0.5, |
|
half=CUDA, |
|
ignore_event=ignore_event, |
|
return_intermediates=True, |
|
) |
|
|
|
outputs["input"] = uint8(images, True) |
|
|
|
outputs["masked_input"] = outputs["input"] * ( |
|
outputs["mask"].squeeze(1)[..., None] == 0 |
|
) |
|
|
|
if painter in {"both", "climategan"}: |
|
outputs["climategan_flood"] = outputs.pop("flood") |
|
else: |
|
del outputs["flood"] |
|
|
|
if painter != "climategan": |
|
if not self._stable_diffusion_is_setup: |
|
print("Setting up stable diffusion in-painting pipeline") |
|
self._setup_stable_diffusion() |
|
|
|
mask = outputs["mask"].squeeze(1) |
|
input_images = ( |
|
torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device) |
|
if pil_image is None |
|
else pil_image |
|
) |
|
input_mask = ( |
|
torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device) |
|
if pil_image is None |
|
else Image.fromarray(mask[0]) |
|
) |
|
print("Inferring stable diffusion in-painting for 50 steps") |
|
floods = self.sdip_pipeline( |
|
prompt=[prompt] * images.shape[0], |
|
image=input_images, |
|
mask_image=input_mask, |
|
height=640, |
|
width=640, |
|
num_inference_steps=50, |
|
) |
|
print("Stable diffusion in-painting done") |
|
|
|
bin_mask = mask[..., None] > 0 |
|
flood = np.stack([np.array(i) for i in floods.images]) |
|
copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask) |
|
outputs["stable_flood"] = flood |
|
outputs["stable_copy_flood"] = copy_flood |
|
|
|
if concats: |
|
print("Concatenating flood images") |
|
outputs["concat"] = concat_events(outputs, concats, axis=2) |
|
|
|
return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()} |
|
|
|
def infer_folder( |
|
self, |
|
folder_path, |
|
painter="both", |
|
prompt="An HD picture of a street with dirty water after a heavy flood", |
|
batch_size=4, |
|
concats=[ |
|
"input", |
|
"masked_input", |
|
"climategan_flood", |
|
"stable_flood", |
|
"stable_copy_flood", |
|
], |
|
write=True, |
|
overwrite=False, |
|
): |
|
""" |
|
Infers the images in a folder with the ClimateGAN model, batching images for |
|
inference according to the batch_size. |
|
|
|
Images must end in .jpg, .jpeg or .png (not case-sensitive). |
|
Images must not contain the separator ("---") in their name. |
|
|
|
Images will be written to disk in the same folder as the input images, with |
|
a name that depends on its data, potentially the prompt and a random |
|
identifier in case multiple inferences are run in the folder. |
|
|
|
Output dict contains the following keys: |
|
- "input": The input image |
|
- "mask": The mask used to generate the flood (from ClimateGAN's Masker) |
|
- "masked_input": The input image with the mask applied |
|
- "climategan_flood": The flooded image generated by ClimateGAN's Painter |
|
on the masked input (only if "painter" is "climategan" or "both"). |
|
- "stable_flood": The flooded image in-painted by the stable diffusion model |
|
from the mask and the input image (only if "painter" is "stable_diffusion" |
|
or "both"). |
|
- "stable_copy_flood": The flooded image in-painted by the stable diffusion |
|
model with its original context pasted back in: |
|
y = m * flooded + (1-m) * input |
|
(only if "painter" is "stable_diffusion" or "both"). |
|
|
|
Args: |
|
folder_path (Union[str, Path]): Where to read images from. |
|
painter (str, optional): Which painter to use: "climategan", |
|
"stable_diffusion" or "both". Defaults to "both". |
|
prompt (str, optional): The prompt used to guide the diffusion. Defaults |
|
to "An HD picture of a street with dirty water after a heavy flood". |
|
batch_size (int, optional): Size of inference batches. Defaults to 4. |
|
concats (list, optional): List of keys in `output` to concatenate together |
|
in a new `{original_stem}_concat` image written. Defaults to: |
|
["input", "masked_input", "climategan_flood", "stable_flood", |
|
"stable_copy_flood"]. |
|
write (bool, optional): Whether or not to write the outputs to the input |
|
folder.Defaults to True. |
|
overwrite (Union[bool, str], optional): Whether to overwrite the images or |
|
not. If a string is provided, it will be included in the name. |
|
Defaults to False. |
|
|
|
Returns: |
|
dict: a dictionary containing the output images |
|
""" |
|
folder_path = Path(folder_path).expanduser().resolve() |
|
assert folder_path.exists(), f"Folder {str(folder_path)} does not exist" |
|
assert folder_path.is_dir(), f"{str(folder_path)} is not a directory" |
|
im_paths = [ |
|
p |
|
for p in folder_path.iterdir() |
|
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name |
|
] |
|
assert im_paths, f"No images found in {str(folder_path)}" |
|
ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths] |
|
batches = [ |
|
np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size) |
|
] |
|
inferences = [ |
|
self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches |
|
] |
|
|
|
outputs = { |
|
k: [i for e in inferences for i in e[k]] for k in inferences[0].keys() |
|
} |
|
|
|
if write: |
|
self.write(outputs, im_paths, painter, overwrite, prompt) |
|
|
|
return outputs |
|
|
|
def write( |
|
self, |
|
outputs, |
|
im_paths, |
|
painter="both", |
|
overwrite=False, |
|
prompt="", |
|
): |
|
""" |
|
Writes the outputs of the inference to disk, in the input folder. |
|
|
|
Images will be named like: |
|
f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}" |
|
`painter_type` is either "climategan" or f"stable_diffusion_{prompt}" |
|
|
|
Args: |
|
outputs (_type_): The inference procedure's output dict. |
|
im_paths (list[Path]): The list of input images paths. |
|
painter (str, optional): Which painter was used. Defaults to "both". |
|
overwrite (bool, optional): Whether to overwrite the images or not. |
|
If a string is provided, it will be included in the name. |
|
If False, a random identifier will be added to the name. |
|
Defaults to False. |
|
prompt (str, optional): The prompt used to guide the diffusion. Defaults |
|
to "". |
|
""" |
|
prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower() |
|
overwrite_prefix = "" |
|
if not overwrite: |
|
overwrite_prefix = str(uuid4())[:8] |
|
print("Writing events with prefix", overwrite_prefix) |
|
else: |
|
if isinstance(overwrite, str): |
|
overwrite_prefix = overwrite |
|
print("Writing events with prefix", overwrite_prefix) |
|
|
|
|
|
for i, im_path in enumerate(im_paths): |
|
for event, ims in outputs.items(): |
|
painter_prefix = "" |
|
if painter == "climategan" and event == "flood": |
|
painter_prefix = "climategan" |
|
elif ( |
|
painter in {"stable_diffusion", "both"} and event == "stable_flood" |
|
): |
|
painter_prefix = f"_stable_{prompt}" |
|
elif painter == "both" and event == "climategan_flood": |
|
painter_prefix = "" |
|
|
|
im = ims[i] |
|
im = Image.fromarray(uint8(im)) |
|
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}" |
|
im.save(im_path.parent / (imstem + im_path.suffix)) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Run `$ python climategan_wrapper.py help` for usage instructions\n") |
|
|
|
|
|
args = resolved_args( |
|
defaults={ |
|
"input_folder": None, |
|
"output_folder": None, |
|
"painter": "both", |
|
"help": False, |
|
} |
|
) |
|
|
|
|
|
if args.help: |
|
print( |
|
"Usage: python inference.py input_folder=/path/to/folder\n" |
|
+ "By default inferences will be stored in the input folder.\n" |
|
+ "Add `output_folder=/path/to/folder` for a different output folder.\n" |
|
+ "By default, both ClimateGAN and Stable Diffusion will be used." |
|
+ "Change this by adding `painter=climategan` or" |
|
+ " `painter=stable_diffusion`.\n" |
|
+ "Make sure you have agreed to the terms of use for the models." |
|
+ "In particular, visit SD's model card to agree to the terms of use:" |
|
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting" |
|
) |
|
|
|
args.pretty_print() |
|
|
|
|
|
cg = ClimateGAN("models/climategan") |
|
|
|
|
|
assert args.painter in {"climategan", "stable_diffusion", "both",}, ( |
|
f"Unknown painter {args.painter}. " |
|
+ "Allowed values are 'climategan', 'stable_diffusion' and 'both'." |
|
) |
|
|
|
|
|
if args.painter != "climate_gan": |
|
cg._setup_stable_diffusion() |
|
|
|
|
|
in_path = Path(args.input_folder).expanduser().resolve() |
|
assert in_path.exists(), f"Folder {str(in_path)} does not exist" |
|
|
|
|
|
if args.output_folder is None: |
|
out_path = in_path |
|
|
|
|
|
im_paths = [ |
|
p |
|
for p in in_path.iterdir() |
|
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name |
|
] |
|
assert im_paths, f"No images found in {str(im_paths)}" |
|
|
|
print(f"\nFound {len(im_paths)} images in {str(in_path)}\n") |
|
|
|
|
|
for i, im_path in enumerate(im_paths): |
|
print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name) |
|
outs = cg.infer_single( |
|
np.array(Image.open(im_path)), |
|
args.painter, |
|
as_pil_image=True, |
|
concats=[ |
|
"input", |
|
"masked_input", |
|
"climategan_flood", |
|
"stable_copy_flood", |
|
], |
|
) |
|
for k, v in outs.items(): |
|
name = f"{im_path.stem}---{k}{im_path.suffix}" |
|
im = Image.fromarray(uint8(v)) |
|
im.save(out_path / name) |
|
print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n") |
|
|