|
|
|
|
|
|
|
import torch |
|
from skimage.color import rgba2rgb |
|
from skimage.transform import resize |
|
import numpy as np |
|
|
|
from climategan.trainer import Trainer |
|
|
|
|
|
def uint8(array): |
|
""" |
|
convert an array to np.uint8 (does not rescale or anything else than changing dtype) |
|
Args: |
|
array (np.array): array to modify |
|
Returns: |
|
np.array(np.uint8): converted array |
|
""" |
|
return array.astype(np.uint8) |
|
|
|
|
|
def resize_and_crop(img, to=640): |
|
""" |
|
Resizes an image so that it keeps the aspect ratio and the smallest dimensions |
|
is `to`, then crops this resized image in its center so that the output is `to x to` |
|
without aspect ratio distortion |
|
Args: |
|
img (np.array): np.uint8 255 image |
|
Returns: |
|
np.array: [0, 1] np.float32 image |
|
""" |
|
|
|
h, w = img.shape[:2] |
|
if h < w: |
|
size = (to, int(to * w / h)) |
|
else: |
|
size = (int(to * h / w), to) |
|
|
|
r_img = resize(img, size, preserve_range=True, anti_aliasing=True) |
|
r_img = uint8(r_img) |
|
|
|
|
|
H, W = r_img.shape[:2] |
|
|
|
top = (H - to) // 2 |
|
left = (W - to) // 2 |
|
|
|
rc_img = r_img[top : top + to, left : left + to, :] |
|
|
|
return rc_img / 255.0 |
|
|
|
|
|
def to_m1_p1(img): |
|
""" |
|
rescales a [0, 1] image to [-1, +1] |
|
Args: |
|
img (np.array): float32 numpy array of an image in [0, 1] |
|
i (int): Index of the image being rescaled |
|
Raises: |
|
ValueError: If the image is not in [0, 1] |
|
Returns: |
|
np.array(np.float32): array in [-1, +1] |
|
""" |
|
if img.min() >= 0 and img.max() <= 1: |
|
return (img.astype(np.float32) - 0.5) * 2 |
|
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})") |
|
|
|
|
|
|
|
class ClimateGAN: |
|
def __init__(self, model_path) -> None: |
|
torch.set_grad_enabled(False) |
|
self.target_size = 640 |
|
self.trainer = Trainer.resume_from_path( |
|
model_path, |
|
setup=True, |
|
inference=True, |
|
new_exp=None, |
|
) |
|
|
|
|
|
def inference(self, orig_image): |
|
image = self._preprocess_image(orig_image) |
|
|
|
|
|
outputs = self.trainer.infer_all( |
|
image, |
|
numpy=True, |
|
bin_value=0.5, |
|
) |
|
|
|
return ( |
|
outputs["flood"].squeeze(), |
|
outputs["wildfire"].squeeze(), |
|
outputs["smog"].squeeze(), |
|
) |
|
|
|
def _preprocess_image(self, img): |
|
|
|
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255) |
|
|
|
|
|
data = resize_and_crop(data, self.target_size) |
|
|
|
|
|
data = to_m1_p1(data) |
|
return data |
|
|