File size: 3,046 Bytes
95ed5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
# thank you @NimaBoscarino

import torch
from skimage.color import rgba2rgb
from skimage.transform import resize
import numpy as np

from climategan.trainer import Trainer


def uint8(array):
    """
    convert an array to np.uint8 (does not rescale or anything else than changing dtype)
    Args:
        array (np.array): array to modify
    Returns:
        np.array(np.uint8): converted array
    """
    return array.astype(np.uint8)


def resize_and_crop(img, to=640):
    """
    Resizes an image so that it keeps the aspect ratio and the smallest dimensions
    is `to`, then crops this resized image in its center so that the output is `to x to`
    without aspect ratio distortion
    Args:
        img (np.array): np.uint8 255 image
    Returns:
        np.array: [0, 1] np.float32 image
    """
    # resize keeping aspect ratio: smallest dim is 640
    h, w = img.shape[:2]
    if h < w:
        size = (to, int(to * w / h))
    else:
        size = (int(to * h / w), to)

    r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
    r_img = uint8(r_img)

    # crop in the center
    H, W = r_img.shape[:2]

    top = (H - to) // 2
    left = (W - to) // 2

    rc_img = r_img[top : top + to, left : left + to, :]

    return rc_img / 255.0


def to_m1_p1(img):
    """
    rescales a [0, 1] image to [-1, +1]
    Args:
        img (np.array): float32 numpy array of an image in [0, 1]
        i (int): Index of the image being rescaled
    Raises:
        ValueError: If the image is not in [0, 1]
    Returns:
        np.array(np.float32): array in [-1, +1]
    """
    if img.min() >= 0 and img.max() <= 1:
        return (img.astype(np.float32) - 0.5) * 2
    raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")


# No need to do any timing in this, since it's just for the HF Space
class ClimateGAN:
    def __init__(self, model_path) -> None:
        torch.set_grad_enabled(False)
        self.target_size = 640
        self.trainer = Trainer.resume_from_path(
            model_path,
            setup=True,
            inference=True,
            new_exp=None,
        )

    # Does all three inferences at the moment.
    def inference(self, orig_image):
        image = self._preprocess_image(orig_image)

        # Retrieve numpy events as a dict {event: array[BxHxWxC]}
        outputs = self.trainer.infer_all(
            image,
            numpy=True,
            bin_value=0.5,
        )

        return (
            outputs["flood"].squeeze(),
            outputs["wildfire"].squeeze(),
            outputs["smog"].squeeze(),
        )

    def _preprocess_image(self, img):
        # rgba to rgb
        data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)

        # to args.target_size
        data = resize_and_crop(data, self.target_size)

        # resize() produces [0, 1] images, rescale to [-1, 1]
        data = to_m1_p1(data)
        return data