File size: 23,898 Bytes
3d5f935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
# thank you @NimaBoscarino

import os
import re
from pathlib import Path
from uuid import uuid4
from minydra import resolved_args
import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from skimage.color import rgba2rgb
from skimage.transform import resize

from climategan.trainer import Trainer


CUDA = torch.cuda.is_available()


def concat_events(output_dict, events, i=None, axis=1):
    """
    Concatenates the `i`th data in `output_dict` according to the keys listed
    in `events` on dimension `axis`.

    Args:
        output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping
            events to their corresponding data :
            {k: [HxWxC]} (for i != None) or {k: BxHxWxC}.
        events (list[str]): output_dict's keys to concatenate.
        axis (int, optional): Concatenation axis. Defaults to 1.
    """
    cs = [e for e in events if e in output_dict]
    if i is not None:
        return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis))
    return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis))


def clear(folder):
    """
    Deletes all the images without the inference separator "---" in their name.

    Args:
        folder (Union[str, Path]): The folder to clear.
    """
    for i in list(Path(folder).iterdir()):
        if i.is_file() and "---" in i.stem:
            i.unlink()


def uint8(array, rescale=False):
    """
    convert an array to np.uint8 (does not rescale or anything else than changing dtype)
    Args:
        array (np.array): array to modify
    Returns:
        np.array(np.uint8): converted array
    """
    if rescale:
        if array.min() < 0:
            if array.min() >= -1 and array.max() <= 1:
                array = (array + 1) / 2
            else:
                raise ValueError(
                    f"Data range mismatch for image: ({array.min()}, {array.max()})"
                )
        if array.max() <= 1:
            array = array * 255
    return array.astype(np.uint8)


def resize_and_crop(img, to=640):
    """
    Resizes an image so that it keeps the aspect ratio and the smallest dimensions
    is `to`, then crops this resized image in its center so that the output is `to x to`
    without aspect ratio distortion
    Args:
        img (np.array): np.uint8 255 image
    Returns:
        np.array: [0, 1] np.float32 image
    """
    # resize keeping aspect ratio: smallest dim is 640
    h, w = img.shape[:2]
    if h < w:
        size = (to, int(to * w / h))
    else:
        size = (int(to * h / w), to)

    r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
    r_img = uint8(r_img)

    # crop in the center
    H, W = r_img.shape[:2]

    top = (H - to) // 2
    left = (W - to) // 2

    rc_img = r_img[top : top + to, left : left + to, :]

    return rc_img / 255.0


def to_m1_p1(img):
    """
    rescales a [0, 1] image to [-1, +1]
    Args:
        img (np.array): float32 numpy array of an image in [0, 1]
        i (int): Index of the image being rescaled
    Raises:
        ValueError: If the image is not in [0, 1]
    Returns:
        np.array(np.float32): array in [-1, +1]
    """
    if img.min() >= 0 and img.max() <= 1:
        return (img.astype(np.float32) - 0.5) * 2
    raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")


# No need to do any timing in this, since it's just for the HF Space
class ClimateGAN:
    def __init__(self, model_path, dev_mode=False) -> None:
        """
        A wrapper for the ClimateGAN model that you can use to generate
        events from images or folders containing images.

        Args:
            model_path (Union[str, Path]): Where to load the Masker from
        """
        torch.set_grad_enabled(False)
        self.target_size = 640
        self._stable_diffusion_is_setup = False
        self.dev_mode = dev_mode
        if self.dev_mode:
            return
        self.trainer = Trainer.resume_from_path(
            model_path,
            setup=True,
            inference=True,
            new_exp=None,
        )
        if CUDA:
            self.trainer.G.half()

    def _setup_stable_diffusion(self):
        """
        Sets up the stable diffusion pipeline for in-painting.
        Make sure you have accepted the license on the model's card
        https://huggingface.co/CompVis/stable-diffusion-v1-4
        """
        if self.dev_mode:
            return

        try:
            self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
                "runwayml/stable-diffusion-inpainting",
                revision="fp16" if CUDA else "main",
                torch_dtype=torch.float16 if CUDA else torch.float32,
                safety_checker=None,
                use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
            ).to(self.trainer.device)
            self._stable_diffusion_is_setup = True
        except Exception as e:
            print(
                "\nCould not load stable diffusion model. "
                + "Please make sure you have accepted the license on the model's"
                + " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n"
            )
            raise e

    def _preprocess_image(self, img):
        """
        Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array
        in [-1, 1].

        Args:
            img (np.array): Image to resize crop and rescale

        Returns:
            np.array: Resized, cropped and rescaled image
        """
        # rgba to rgb
        data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)

        # to args.target_size
        data = resize_and_crop(data, self.target_size)

        # resize() produces [0, 1] images, rescale to [-1, 1]
        data = to_m1_p1(data)
        return data

    # Does all three inferences at the moment.
    def infer_single(
        self,
        orig_image,
        painter="both",
        prompt="An HD picture of a street with dirty water after a heavy flood",
        concats=[
            "input",
            "masked_input",
            "climategan_flood",
            "stable_flood",
            "stable_copy_flood",
        ],
        as_pil_image=False,
    ):
        """
        Infers the image with the ClimateGAN model.
        Importantly (and unlike self.infer_preprocessed_batch), the image is
        pre-processed by self._preprocess_image before going through the networks.

        Output dict contains the following keys:
        - "input": The input image
        - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
        - "masked_input": The input image with the mask applied
        - "climategan_flood": The flooded image generated by ClimateGAN's Painter
            on the masked input (only if "painter" is "climategan" or "both").
        - "stable_flood": The flooded image in-painted by the stable diffusion model
            from the mask and the input image (only if "painter" is "stable_diffusion"
            or "both").
        - "stable_copy_flood": The flooded image in-painted by the stable diffusion
            model with its original context pasted back in:
            y = m * flooded + (1-m) * input
            (only if "painter" is "stable_diffusion" or "both").

        Args:
            orig_image (Union[str, np.array]): image to infer on. Can be a path to
                an image which will be read.
            painter (str, optional): Which painter to use: "climategan",
                "stable_diffusion" or "both". Defaults to "both".
            prompt (str, optional): The prompt used to guide the diffusion. Defaults
                to "An HD picture of a street with dirty water after a heavy flood".
            concats (list, optional): List of keys in `output` to concatenate together
                in a new `{original_stem}_concat` image written. Defaults to:
                ["input", "masked_input", "climategan_flood", "stable_flood",
                "stable_copy_flood"].

        Returns:
            dict: a dictionary containing the output images {k: HxWxC}. C is omitted
                for masks (HxW).
        """
        if self.dev_mode:
            return {
                "input": orig_image,
                "mask": np.random.randint(0, 255, (640, 640)),
                "masked_input": np.random.randint(0, 255, (640, 640, 3)),
                "climategan_flood": np.random.randint(0, 255, (640, 640, 3)),
                "stable_flood": np.random.randint(0, 255, (640, 640, 3)),
                "stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)),
                "concat": np.random.randint(0, 255, (640, 640 * 5, 3)),
                "smog": np.random.randint(0, 255, (640, 640, 3)),
                "wildfire": np.random.randint(0, 255, (640, 640, 3)),
                "depth": np.random.randint(0, 255, (640, 640, 1)),
                "segmentation": np.random.randint(0, 255, (640, 640, 3)),
            }
            return

        image_array = (
            np.array(Image.open(orig_image))
            if isinstance(orig_image, str)
            else orig_image
        )

        pil_image = None
        if as_pil_image:
            pil_image = Image.fromarray(image_array)
        print("Preprocessing image")
        image = self._preprocess_image(image_array)
        output_dict = self.infer_preprocessed_batch(
            images=image[None, ...],
            painter=painter,
            prompt=prompt,
            concats=concats,
            pil_image=pil_image,
        )
        print("Inference done")
        return {k: v[0] for k, v in output_dict.items()}

    def infer_preprocessed_batch(
        self,
        images,
        painter="both",
        prompt="An HD picture of a street with dirty water after a heavy flood",
        concats=[
            "input",
            "masked_input",
            "climategan_flood",
            "stable_flood",
            "stable_copy_flood",
        ],
        pil_image=None,
    ):
        """
        Infers ClimateGAN predictions on a batch of preprocessed images.
        It assumes that each image in the batch has been preprocessed with
        self._preprocess_image().

        Output dict contains the following keys:
        - "input": The input image
        - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
        - "masked_input": The input image with the mask applied
        - "climategan_flood": The flooded image generated by ClimateGAN's Painter
            on the masked input (only if "painter" is "climategan" or "both").
        - "stable_flood": The flooded image in-painted by the stable diffusion model
            from the mask and the input image (only if "painter" is "stable_diffusion"
            or "both").
        - "stable_copy_flood": The flooded image in-painted by the stable diffusion
            model with its original context pasted back in:
            y = m * flooded + (1-m) * input
            (only if "painter" is "stable_diffusion" or "both").

        Args:
            images (np.array): A batch of input images BxHxWx3
            painter (str, optional): Which painter to use: "climategan",
                "stable_diffusion" or "both". Defaults to "both".
            prompt (str, optional): The prompt used to guide the diffusion. Defaults
                to "An HD picture of a street with dirty water after a heavy flood".
            concats (list, optional): List of keys in `output` to concatenate together
                in a new `{original_stem}_concat` image written. Defaults to:
                ["input", "masked_input", "climategan_flood", "stable_flood",
                "stable_copy_flood"].
            pil_image (PIL.Image, optional): The original PIL image. If provided,
                will be used for a single inference (batch_size=1)

        Returns:
            dict: a dictionary containing the output images
        """
        assert painter in [
            "both",
            "stable_diffusion",
            "climategan",
        ], f"Unknown painter: {painter}"

        ignore_event = set()
        if painter == "stable_diffusion":
            ignore_event.add("flood")

        if pil_image is not None:
            print("Warning: `pil_image` has been provided, it will override `images`")
            images = self._preprocess_image(np.array(pil_image))[None, ...]
            pil_image = Image.fromarray(((images[0] + 1) / 2 * 255).astype(np.uint8))

        # Retrieve numpy events as a dict {event: array[BxHxWxC]}
        print("Inferring ClimateGAN events")
        outputs = self.trainer.infer_all(
            images,
            numpy=True,
            bin_value=0.5,
            half=CUDA,
            ignore_event=ignore_event,
            return_intermediates=True,
        )

        outputs["input"] = uint8(images, True)
        # from Bx1xHxW to BxHxWx1
        outputs["masked_input"] = outputs["input"] * (
            outputs["mask"].squeeze(1)[..., None] == 0
        )

        if painter in {"both", "climategan"}:
            outputs["climategan_flood"] = outputs.pop("flood")
        else:
            del outputs["flood"]

        if painter != "climategan":
            if not self._stable_diffusion_is_setup:
                print("Setting up stable diffusion in-painting pipeline")
                self._setup_stable_diffusion()

            mask = outputs["mask"].squeeze(1)
            input_images = (
                torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
                if pil_image is None
                else pil_image
            )
            input_mask = (
                torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
                if pil_image is None
                else Image.fromarray(mask[0])
            )
            print("Inferring stable diffusion in-painting for 50 steps")
            floods = self.sdip_pipeline(
                prompt=[prompt] * images.shape[0],
                image=input_images,
                mask_image=input_mask,
                height=640,
                width=640,
                num_inference_steps=50,
            )
            print("Stable diffusion in-painting done")

            bin_mask = mask[..., None] > 0
            flood = np.stack([np.array(i) for i in floods.images])
            copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask)
            outputs["stable_flood"] = flood
            outputs["stable_copy_flood"] = copy_flood

        if concats:
            print("Concatenating flood images")
            outputs["concat"] = concat_events(outputs, concats, axis=2)

        return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()}

    def infer_folder(
        self,
        folder_path,
        painter="both",
        prompt="An HD picture of a street with dirty water after a heavy flood",
        batch_size=4,
        concats=[
            "input",
            "masked_input",
            "climategan_flood",
            "stable_flood",
            "stable_copy_flood",
        ],
        write=True,
        overwrite=False,
    ):
        """
        Infers the images in a folder with the ClimateGAN model, batching images for
        inference according to the batch_size.

        Images must end in .jpg, .jpeg or .png (not case-sensitive).
        Images must not contain the separator ("---") in their name.

        Images will be written to disk in the same folder as the input images, with
        a name that depends on its data, potentially the prompt and a random
        identifier in case multiple inferences are run in the folder.

        Output dict contains the following keys:
        - "input": The input image
        - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
        - "masked_input": The input image with the mask applied
        - "climategan_flood": The flooded image generated by ClimateGAN's Painter
            on the masked input (only if "painter" is "climategan" or "both").
        - "stable_flood": The flooded image in-painted by the stable diffusion model
            from the mask and the input image (only if "painter" is "stable_diffusion"
            or "both").
        - "stable_copy_flood": The flooded image in-painted by the stable diffusion
            model with its original context pasted back in:
            y = m * flooded + (1-m) * input
            (only if "painter" is "stable_diffusion" or "both").

        Args:
            folder_path (Union[str, Path]): Where to read images from.
            painter (str, optional): Which painter to use: "climategan",
                "stable_diffusion" or "both". Defaults to "both".
            prompt (str, optional): The prompt used to guide the diffusion. Defaults
                to "An HD picture of a street with dirty water after a heavy flood".
            batch_size (int, optional): Size of inference batches. Defaults to 4.
            concats (list, optional): List of keys in `output` to concatenate together
                in a new `{original_stem}_concat` image written. Defaults to:
                ["input", "masked_input", "climategan_flood", "stable_flood",
                "stable_copy_flood"].
            write (bool, optional): Whether or not to write the outputs to the input
                folder.Defaults to True.
            overwrite (Union[bool, str], optional): Whether to overwrite the images or
                not. If a string is provided, it will be included in the name.
                Defaults to False.

        Returns:
            dict: a dictionary containing the output images
        """
        folder_path = Path(folder_path).expanduser().resolve()
        assert folder_path.exists(), f"Folder {str(folder_path)} does not exist"
        assert folder_path.is_dir(), f"{str(folder_path)} is not a directory"
        im_paths = [
            p
            for p in folder_path.iterdir()
            if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
        ]
        assert im_paths, f"No images found in {str(folder_path)}"
        ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths]
        batches = [
            np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size)
        ]
        inferences = [
            self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches
        ]

        outputs = {
            k: [i for e in inferences for i in e[k]] for k in inferences[0].keys()
        }

        if write:
            self.write(outputs, im_paths, painter, overwrite, prompt)

        return outputs

    def write(
        self,
        outputs,
        im_paths,
        painter="both",
        overwrite=False,
        prompt="",
    ):
        """
        Writes the outputs of the inference to disk, in the input folder.

        Images will be named like:
        f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}"
        `painter_type` is either "climategan" or f"stable_diffusion_{prompt}"

        Args:
            outputs (_type_): The inference procedure's output dict.
            im_paths (list[Path]): The list of input images paths.
            painter (str, optional): Which painter was used. Defaults to "both".
            overwrite (bool, optional): Whether to overwrite the images or not.
                If a string is provided, it will be included in the name.
                If False, a random identifier will be added to the name.
                Defaults to False.
            prompt (str, optional): The prompt used to guide the diffusion. Defaults
                to "".
        """
        prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower()
        overwrite_prefix = ""
        if not overwrite:
            overwrite_prefix = str(uuid4())[:8]
            print("Writing events with prefix", overwrite_prefix)
        else:
            if isinstance(overwrite, str):
                overwrite_prefix = overwrite
                print("Writing events with prefix", overwrite_prefix)

        # for each image, for each event/data type
        for i, im_path in enumerate(im_paths):
            for event, ims in outputs.items():
                painter_prefix = ""
                if painter == "climategan" and event == "flood":
                    painter_prefix = "climategan"
                elif (
                    painter in {"stable_diffusion", "both"} and event == "stable_flood"
                ):
                    painter_prefix = f"_stable_{prompt}"
                elif painter == "both" and event == "climategan_flood":
                    painter_prefix = ""

                im = ims[i]
                im = Image.fromarray(uint8(im))
                imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
                im.save(im_path.parent / (imstem + im_path.suffix))


if __name__ == "__main__":
    print("Run `$ python climategan_wrapper.py help` for usage instructions\n")

    # parse arguments
    args = resolved_args(
        defaults={
            "input_folder": None,
            "output_folder": None,
            "painter": "both",
            "help": False,
        }
    )

    # print help
    if args.help:
        print(
            "Usage: python inference.py input_folder=/path/to/folder\n"
            + "By default inferences will be stored in the input folder.\n"
            + "Add `output_folder=/path/to/folder` for a different output folder.\n"
            + "By default, both ClimateGAN and Stable Diffusion will be used."
            + "Change this by adding `painter=climategan` or"
            + " `painter=stable_diffusion`.\n"
            + "Make sure you have agreed to the terms of use for the models."
            + "In particular, visit SD's model card to agree to the terms of use:"
            + " https://huggingface.co/runwayml/stable-diffusion-inpainting"
        )
    # print args
    args.pretty_print()

    # load models
    cg = ClimateGAN("models/climategan")

    # check painter type
    assert args.painter in {"climategan", "stable_diffusion", "both",}, (
        f"Unknown painter {args.painter}. "
        + "Allowed values are 'climategan', 'stable_diffusion' and 'both'."
    )

    # load SD pipeline if need be
    if args.painter != "climate_gan":
        cg._setup_stable_diffusion()

    # resolve input folder path
    in_path = Path(args.input_folder).expanduser().resolve()
    assert in_path.exists(), f"Folder {str(in_path)} does not exist"

    # output is input if not specified
    if args.output_folder is None:
        out_path = in_path

    # find images in input folder
    im_paths = [
        p
        for p in in_path.iterdir()
        if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
    ]
    assert im_paths, f"No images found in {str(im_paths)}"

    print(f"\nFound {len(im_paths)} images in {str(in_path)}\n")

    # infer and write
    for i, im_path in enumerate(im_paths):
        print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name)
        outs = cg.infer_single(
            np.array(Image.open(im_path)),
            args.painter,
            as_pil_image=True,
            concats=[
                "input",
                "masked_input",
                "climategan_flood",
                "stable_copy_flood",
            ],
        )
        for k, v in outs.items():
            name = f"{im_path.stem}---{k}{im_path.suffix}"
            im = Image.fromarray(uint8(v))
            im.save(out_path / name)
        print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n")