vict0rsch commited on
Commit
95ed5e7
ยท
1 Parent(s): 13aed19

add gradio app

Browse files
Files changed (5) hide show
  1. README.md +3 -0
  2. app.py +70 -0
  3. climategan/generator.py +5 -1
  4. climategan/masker.py +3 -3
  5. inferences.py +108 -0
README.md CHANGED
@@ -10,6 +10,9 @@ title: ClimateGAN
10
  emoji: ๐ŸŒŽ
11
  colorFrom: blue
12
  colorTo: green
 
 
 
13
  # datasets:
14
  # -
15
  ---
 
10
  emoji: ๐ŸŒŽ
11
  colorFrom: blue
12
  colorTo: green
13
+ sdk: gradio
14
+ sdk_version: 4.6
15
+ app_file: app.py
16
  # datasets:
17
  # -
18
  ---
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/app.py # noqa: E501
2
+ # thank you @NimaBoscarino
3
+
4
+ import os
5
+ import gradio as gr
6
+ import googlemaps
7
+ from skimage import io
8
+ from urllib import parse
9
+ from inferences import ClimateGAN
10
+
11
+
12
+ def predict(api_key):
13
+ def _predict(*args):
14
+ print("args: ", args)
15
+ image = place = None
16
+ if len(args) == 1:
17
+ image = args[0]
18
+ else:
19
+ assert len(args) == 2, "Unknown number of inputs {}".format(len(args))
20
+ image, place = args
21
+
22
+ if api_key and place:
23
+ geocode_result = gmaps.geocode(place)
24
+
25
+ address = geocode_result[0]["formatted_address"]
26
+ static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}"
27
+ img_np = io.imread(static_map_url)
28
+ else:
29
+ img_np = image
30
+ flood, wildfire, smog = model.inference(img_np)
31
+ return img_np, flood, wildfire, smog
32
+
33
+ return _predict
34
+
35
+
36
+ if __name__ == "__main__":
37
+
38
+ api_key = os.environ.get("GMAPS_API_KEY")
39
+ gmaps = None
40
+ if api_key is not None:
41
+ gmaps = googlemaps.Client(key=api_key)
42
+
43
+ model = ClimateGAN(model_path="config/model/masker")
44
+
45
+ inputs = inputs = [gr.inputs.Image(label="Input Image")]
46
+ if api_key:
47
+ inputs += [gr.inputs.Textbox(label="Address or place name")]
48
+
49
+ gr.Interface(
50
+ predict(api_key),
51
+ inputs=[
52
+ gr.inputs.Textbox(label="Address or place name"),
53
+ gr.inputs.Image(label="Input Image"),
54
+ ],
55
+ outputs=[
56
+ gr.outputs.Image(type="numpy", label="Original image"),
57
+ gr.outputs.Image(type="numpy", label="Flooding"),
58
+ gr.outputs.Image(type="numpy", label="Wildfire"),
59
+ gr.outputs.Image(type="numpy", label="Smog"),
60
+ ],
61
+ title="ClimateGAN: Visualize Climate Change",
62
+ description='Climate change does not impact everyone equally. This Space shows the effects of the climate emergency, "one address at a time". Visit the original experience at <a href="https://thisclimatedoesnotexist.com/">ThisClimateDoesNotExist.com</a>.<br>Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.', # noqa: E501
63
+ article="<p style='text-align: center'>This project is an unofficial clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>", # noqa: E501
64
+ # examples=[
65
+ # "Vancouver Art Gallery",
66
+ # "Chicago Bean",
67
+ # "Duomo Siracusa",
68
+ # ],
69
+ css=".footer{display:none !important}",
70
+ ).launch()
climategan/generator.py CHANGED
@@ -101,6 +101,10 @@ class OmniGenerator(nn.Module):
101
  if self.verbose > 0:
102
  print(" - Add Empty Painter")
103
 
 
 
 
 
104
  def __str__(self):
105
  return strings.generator(self)
106
 
@@ -381,7 +385,7 @@ class OmniGenerator(nn.Module):
381
  val_painter_opts = Dict(yaml.safe_load(f))
382
 
383
  # load checkpoint
384
- state_dict = torch.load(ckpt_path)
385
 
386
  # create dummy painter from loaded opts
387
  painter = create_painter(val_painter_opts)
 
101
  if self.verbose > 0:
102
  print(" - Add Empty Painter")
103
 
104
+ @property
105
+ def device(self):
106
+ return next(self.parameters()).device
107
+
108
  def __str__(self):
109
  return strings.generator(self)
110
 
 
385
  val_painter_opts = Dict(yaml.safe_load(f))
386
 
387
  # load checkpoint
388
+ state_dict = torch.load(ckpt_path, map_location=self.device)
389
 
390
  # create dummy painter from loaded opts
391
  painter = create_painter(val_painter_opts)
climategan/masker.py CHANGED
@@ -186,18 +186,18 @@ class MaskSpadeDecoder(nn.Module):
186
  for i in range(self.num_layers):
187
  self.spade_blocks.append(
188
  SPADEResnetBlock(
189
- int(self.z_nc / (2 ** i)),
190
  int(self.z_nc / (2 ** (i + 1))),
191
  cond_nc,
192
  spade_use_spectral_norm,
193
  spade_param_free_norm,
194
  spade_kernel_size,
195
  spade_activation,
196
- ).cuda()
197
  )
198
  self.spade_blocks = nn.Sequential(*self.spade_blocks)
199
 
200
- self.final_nc = int(self.z_nc / (2 ** self.num_layers))
201
  self.mask_conv = Conv2dBlock(
202
  self.final_nc,
203
  1,
 
186
  for i in range(self.num_layers):
187
  self.spade_blocks.append(
188
  SPADEResnetBlock(
189
+ int(self.z_nc / (2**i)),
190
  int(self.z_nc / (2 ** (i + 1))),
191
  cond_nc,
192
  spade_use_spectral_norm,
193
  spade_param_free_norm,
194
  spade_kernel_size,
195
  spade_activation,
196
+ )
197
  )
198
  self.spade_blocks = nn.Sequential(*self.spade_blocks)
199
 
200
+ self.final_nc = int(self.z_nc / (2**self.num_layers))
201
  self.mask_conv = Conv2dBlock(
202
  self.final_nc,
203
  1,
inferences.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
2
+ # thank you @NimaBoscarino
3
+
4
+ import torch
5
+ from skimage.color import rgba2rgb
6
+ from skimage.transform import resize
7
+ import numpy as np
8
+
9
+ from climategan.trainer import Trainer
10
+
11
+
12
+ def uint8(array):
13
+ """
14
+ convert an array to np.uint8 (does not rescale or anything else than changing dtype)
15
+ Args:
16
+ array (np.array): array to modify
17
+ Returns:
18
+ np.array(np.uint8): converted array
19
+ """
20
+ return array.astype(np.uint8)
21
+
22
+
23
+ def resize_and_crop(img, to=640):
24
+ """
25
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
26
+ is `to`, then crops this resized image in its center so that the output is `to x to`
27
+ without aspect ratio distortion
28
+ Args:
29
+ img (np.array): np.uint8 255 image
30
+ Returns:
31
+ np.array: [0, 1] np.float32 image
32
+ """
33
+ # resize keeping aspect ratio: smallest dim is 640
34
+ h, w = img.shape[:2]
35
+ if h < w:
36
+ size = (to, int(to * w / h))
37
+ else:
38
+ size = (int(to * h / w), to)
39
+
40
+ r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
41
+ r_img = uint8(r_img)
42
+
43
+ # crop in the center
44
+ H, W = r_img.shape[:2]
45
+
46
+ top = (H - to) // 2
47
+ left = (W - to) // 2
48
+
49
+ rc_img = r_img[top : top + to, left : left + to, :]
50
+
51
+ return rc_img / 255.0
52
+
53
+
54
+ def to_m1_p1(img):
55
+ """
56
+ rescales a [0, 1] image to [-1, +1]
57
+ Args:
58
+ img (np.array): float32 numpy array of an image in [0, 1]
59
+ i (int): Index of the image being rescaled
60
+ Raises:
61
+ ValueError: If the image is not in [0, 1]
62
+ Returns:
63
+ np.array(np.float32): array in [-1, +1]
64
+ """
65
+ if img.min() >= 0 and img.max() <= 1:
66
+ return (img.astype(np.float32) - 0.5) * 2
67
+ raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
68
+
69
+
70
+ # No need to do any timing in this, since it's just for the HF Space
71
+ class ClimateGAN:
72
+ def __init__(self, model_path) -> None:
73
+ torch.set_grad_enabled(False)
74
+ self.target_size = 640
75
+ self.trainer = Trainer.resume_from_path(
76
+ model_path,
77
+ setup=True,
78
+ inference=True,
79
+ new_exp=None,
80
+ )
81
+
82
+ # Does all three inferences at the moment.
83
+ def inference(self, orig_image):
84
+ image = self._preprocess_image(orig_image)
85
+
86
+ # Retrieve numpy events as a dict {event: array[BxHxWxC]}
87
+ outputs = self.trainer.infer_all(
88
+ image,
89
+ numpy=True,
90
+ bin_value=0.5,
91
+ )
92
+
93
+ return (
94
+ outputs["flood"].squeeze(),
95
+ outputs["wildfire"].squeeze(),
96
+ outputs["smog"].squeeze(),
97
+ )
98
+
99
+ def _preprocess_image(self, img):
100
+ # rgba to rgb
101
+ data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
102
+
103
+ # to args.target_size
104
+ data = resize_and_crop(data, self.target_size)
105
+
106
+ # resize() produces [0, 1] images, rescale to [-1, 1]
107
+ data = to_m1_p1(data)
108
+ return data