vict0rsch commited on
Commit
3d5f935
•
1 Parent(s): f1e742e

update from climategan space

Browse files
Files changed (8) hide show
  1. README.md +29 -4
  2. app.py +274 -41
  3. climategan/trainer.py +41 -25
  4. climategan/tutils.py +46 -4
  5. climategan/utils.py +3 -3
  6. climategan_wrapper.py +624 -0
  7. inferences.py +0 -108
  8. requirements.txt +169 -0
README.md CHANGED
@@ -11,11 +11,10 @@ emoji: 🌎
11
  colorFrom: blue
12
  colorTo: green
13
  sdk: gradio
14
- sdk_version: 4.6
15
  app_file: app.py
16
  inference: true
17
- # datasets:
18
- # -
19
  ---
20
 
21
  # ClimateGAN: Raising Awareness about Climate Change by Generating Images of Floods
@@ -38,7 +37,33 @@ If you use this code, data or pre-trained weights, please cite our ICLR 2022 pap
38
  }
39
  ```
40
 
41
- ## Using pre-trained weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  In the paper, we present ClimateGAN as a solution to produce images of floods. It can actually do **more**:
44
 
 
11
  colorFrom: blue
12
  colorTo: green
13
  sdk: gradio
14
+ sdk_version: 3.6
15
  app_file: app.py
16
  inference: true
17
+ pinned: true
 
18
  ---
19
 
20
  # ClimateGAN: Raising Awareness about Climate Change by Generating Images of Floods
 
37
  }
38
  ```
39
 
40
+ ## Using pre-trained weights from this Huggingface Space and Stable Diffusion In-painting
41
+
42
+ [**Huggingface ClimateGAN Space**](https://huggingface.co/spaces/vict0rsch/climateGAN)
43
+
44
+ 1. Download code and model
45
+ ```
46
+ git clone https://huggingface.co/spaces/vict0rsch/climateGAN hf-spaces-climategan
47
+ cd hf-spaces-climategan
48
+ git lfs install
49
+ git lfs pull
50
+ ```
51
+ 2. Install requirements
52
+ ```
53
+ pip install requirements.txt
54
+ ```
55
+ 3. **Enable Stable Diffusion Inpainting** by visiting the model's card: https://huggingface.co/runwayml/stable-diffusion-inpainting **and** running `$ huggingface-cli login`
56
+ 4. Run `$ python climategan_wrapper.py help` for usage instructions on how to infer on a folder's images.
57
+ 5. Run `$ python app.py` to see the Gradio app.
58
+ 1. To use Google Street View you'll need an API key and set the `GMAPS_API_KEY` environment variable.
59
+ 2. To use Stable Diffusion if you can't run `$ huggingface-cli login` (on a Huggingface Space for instance) set the `HF_AUTH_TOKEN` env variable to a [Huggingface authorization token](https://huggingface.co/settings/tokens)
60
+ 3. To change the UI without model overhead, set the `CG_DEV_MODE` environment variable to `true`.
61
+
62
+ For a more fine-grained control on ClimateGAN's inferences, refer to `apply_events.py` (does not support Stable Diffusion painter)
63
+
64
+ **Note:** you don't have control on the prompt by design because I disabled the safety checker. Fork this space/repo and do it yourself if you really need to change the prompt. At least [open a discussion](https://huggingface.co/spaces/vict0rsch/climateGAN/discussions).
65
+
66
+ ## Using pre-trained weights from source
67
 
68
  In the paper, we present ClimateGAN as a solution to produce images of floods. It can actually do **more**:
69
 
app.py CHANGED
@@ -2,69 +2,302 @@
2
  # thank you @NimaBoscarino
3
 
4
  import os
5
- import gradio as gr
 
 
 
6
  import googlemaps
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from skimage import io
8
- from urllib import parse
9
- from inferences import ClimateGAN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def predict(api_key):
 
13
  def _predict(*args):
14
- print("args: ", args)
15
- image = place = None
16
- if len(args) == 1:
17
- image = args[0]
18
  else:
19
- assert len(args) == 2, "Unknown number of inputs {}".format(len(args))
20
- image, place = args
21
 
22
- if api_key and place:
23
  geocode_result = gmaps.geocode(place)
24
 
25
  address = geocode_result[0]["formatted_address"]
26
  static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}"
27
  img_np = io.imread(static_map_url)
 
28
  else:
 
29
  img_np = image
30
- flood, wildfire, smog = model.inference(img_np)
31
- return img_np, flood, wildfire, smog
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  return _predict
34
 
35
 
36
  if __name__ == "__main__":
37
 
 
 
 
38
  api_key = os.environ.get("GMAPS_API_KEY")
39
  gmaps = None
40
  if api_key is not None:
41
  gmaps = googlemaps.Client(key=api_key)
42
 
43
- model = ClimateGAN(model_path="config/model/masker")
44
-
45
- inputs = inputs = [gr.inputs.Image(label="Input Image")]
46
- if api_key:
47
- inputs += [gr.inputs.Textbox(label="Address or place name")]
48
-
49
- gr.Interface(
50
- predict(api_key),
51
- inputs=[
52
- gr.inputs.Textbox(label="Address or place name"),
53
- gr.inputs.Image(label="Input Image"),
54
- ],
55
- outputs=[
56
- gr.outputs.Image(type="numpy", label="Original image"),
57
- gr.outputs.Image(type="numpy", label="Flooding"),
58
- gr.outputs.Image(type="numpy", label="Wildfire"),
59
- gr.outputs.Image(type="numpy", label="Smog"),
60
- ],
61
- title="ClimateGAN: Visualize Climate Change",
62
- description='Climate change does not impact everyone equally. This Space shows the effects of the climate emergency, "one address at a time". Visit the original experience at <a href="https://thisclimatedoesnotexist.com/">ThisClimateDoesNotExist.com</a>.<br>Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.', # noqa: E501
63
- article="<p style='text-align: center'>This project is an unofficial clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>", # noqa: E501
64
- # examples=[
65
- # "Vancouver Art Gallery",
66
- # "Chicago Bean",
67
- # "Duomo Siracusa",
68
- # ],
69
- css=".footer{display:none !important}",
70
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  # thank you @NimaBoscarino
3
 
4
  import os
5
+ from textwrap import dedent
6
+ from urllib import parse
7
+ from requests import get
8
+
9
  import googlemaps
10
+ import gradio as gr
11
+ import numpy as np
12
+ from gradio.components import (
13
+ HTML,
14
+ Button,
15
+ Column,
16
+ Dropdown,
17
+ Image,
18
+ Markdown,
19
+ Radio,
20
+ Row,
21
+ Textbox,
22
+ )
23
  from skimage import io
24
+ from datetime import datetime
25
+
26
+ from climategan_wrapper import ClimateGAN
27
+
28
+ TEXTS = [
29
+ dedent(
30
+ """
31
+ <p>
32
+ Climate change does not impact everyone equally.
33
+ This Space shows the effects of the climate emergency,
34
+ "one address at a time".
35
+ Visit the original experience at
36
+ <a href="https://thisclimatedoesnotexist.com/">
37
+ ThisClimateDoesNotExist.com
38
+ </a>
39
+ </p>
40
+ <br>
41
+ <p>
42
+ Enter an address or upload a Street View image, and ClimateGAN
43
+ will generate images showing how the location could be impacted
44
+ by flooding, wildfires, or smog if it happened there.
45
+ </p>
46
+ <br>
47
+ <p>
48
+ This is <strong>NOT</strong> an exercise in climate prediction,
49
+ rather an exercise of empathy, to put yourself in others' shoes,
50
+ as if Climate Change came crushing on your doorstep.
51
+ </p>
52
+ <br>
53
+ <p>
54
+ After you have selected an image and started the inference you
55
+ will see all the outputs of ClimateGAN, including intermediate
56
+ outputs such as the flood mask, the segmentation map and the
57
+ depth maps used to produce the 3 events.
58
+ </p>
59
+ <br>
60
+ <p>
61
+ This Space makes use of recent Stable Diffusion in-painting
62
+ pipelines to replace ClimateGAN's original Painter. If you
63
+ select 'Both' painters, you will see a comparison
64
+ </p>
65
+ <br>
66
+ <p style='text-align: center'>
67
+ Visit
68
+ <a href='https://thisclimatedoesnotexist.com/'>
69
+ ThisClimateDoesNotExist.com</a>
70
+ &nbsp;for more information
71
+ &nbsp;&nbsp;|&nbsp;&nbsp;
72
+ Original
73
+ <a href='https://github.com/cc-ai/climategan'>
74
+ ClimateGAN GitHub Repo
75
+ </a>
76
+ &nbsp;&nbsp;|&nbsp;&nbsp;
77
+ Read the original
78
+ <a
79
+ href='https://openreview.net/forum?id=EZNOb_uNpJk'
80
+ target='_blank'>
81
+ ICLR 2021 ClimateGAN paper
82
+ </a>
83
+ </p>
84
+ """
85
+ ),
86
+ dedent(
87
+ """
88
+ ## How to use this Space
89
+
90
+ 1. Enter an address or upload a Street View image
91
+ 2. Select the type of Painter you'd like to use for the flood renderings
92
+ 3. Click on the "See for yourself!" button
93
+ 4. Wait for the inference to complete, typically around 30 seconds
94
+ (plus queue time)
95
+ 5. Enjoy the results!
96
+
97
+ 1. The prompt for Stable Diffusion is `An HD picture of a street with
98
+ dirty water after a heavy flood`
99
+ 2. Pay attention to potential "inventions" by Stable Diffusion's in-painting
100
+ 3. The "restricted to masked area" SD output is the result of:
101
+ `y = mask * flooded + (1-mask) * input`
102
+
103
+ """
104
+ ),
105
+ ]
106
+ CSS = dedent(
107
+ """
108
+ a {
109
+ color: #0088ff;
110
+ text-decoration: underline;
111
+ }
112
+ strong {
113
+ color: #c34318;
114
+ font-weight: bolder;
115
+ }
116
+ #how-to-use-md li {
117
+ margin: 0.1em;
118
+ }
119
+ #how-to-use-md li p {
120
+ margin: 0.1em;
121
+ }
122
+ """
123
+ )
124
+
125
 
126
+ def toggle(radio):
127
+ if "address" in radio.lower():
128
+ return [
129
+ gr.update(visible=True),
130
+ gr.update(visible=False),
131
+ gr.update(visible=True),
132
+ ]
133
+ else:
134
+ return [
135
+ gr.update(visible=False),
136
+ gr.update(visible=True),
137
+ gr.update(visible=True),
138
+ ]
139
 
140
+
141
+ def predict(cg: ClimateGAN, api_key):
142
  def _predict(*args):
143
+ print(f"Starting inference ({str(datetime.now())})")
144
+ image = place = painter = radio = None
145
+ if api_key:
146
+ radio, image, place, painter = args
147
  else:
148
+ image, painter = args
 
149
 
150
+ if api_key and place and "address" in radio.lower():
151
  geocode_result = gmaps.geocode(place)
152
 
153
  address = geocode_result[0]["formatted_address"]
154
  static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}"
155
  img_np = io.imread(static_map_url)
156
+ print("Using GSV image")
157
  else:
158
+ print("Using user image")
159
  img_np = image
160
+
161
+ painters = {
162
+ "ClimateGAN Painter": "climategan",
163
+ "Stable Diffusion Painter": "stable_diffusion",
164
+ "Both": "both",
165
+ }
166
+ print("Using painter", painters[painter])
167
+ output_dict = cg.infer_single(
168
+ img_np,
169
+ painters[painter],
170
+ concats=[
171
+ "input",
172
+ "masked_input",
173
+ "climategan_flood",
174
+ "stable_copy_flood",
175
+ ],
176
+ as_pil_image=True,
177
+ )
178
+
179
+ input_image = output_dict["input"]
180
+ masked_input = output_dict["masked_input"]
181
+ wildfire = output_dict["wildfire"]
182
+ smog = output_dict["smog"]
183
+ depth = np.repeat(output_dict["depth"], 3, axis=-1)
184
+ segmentation = output_dict["segmentation"]
185
+
186
+ climategan_flood = output_dict.get(
187
+ "climategan_flood",
188
+ np.ones(input_image.shape, dtype=np.uint8) * 255,
189
+ )
190
+ stable_flood = output_dict.get(
191
+ "stable_flood",
192
+ np.ones(input_image.shape, dtype=np.uint8) * 255,
193
+ )
194
+ stable_copy_flood = output_dict.get(
195
+ "stable_copy_flood",
196
+ np.ones(input_image.shape, dtype=np.uint8) * 255,
197
+ )
198
+ concat = output_dict.get(
199
+ "concat",
200
+ np.ones(input_image.shape, dtype=np.uint8) * 255,
201
+ )
202
+
203
+ return (
204
+ input_image,
205
+ masked_input,
206
+ segmentation,
207
+ depth,
208
+ climategan_flood,
209
+ stable_flood,
210
+ stable_copy_flood,
211
+ concat,
212
+ wildfire,
213
+ smog,
214
+ )
215
 
216
  return _predict
217
 
218
 
219
  if __name__ == "__main__":
220
 
221
+ ip = get("https://api.ipify.org").content.decode("utf8")
222
+ print("My public IP address is: {}".format(ip))
223
+
224
  api_key = os.environ.get("GMAPS_API_KEY")
225
  gmaps = None
226
  if api_key is not None:
227
  gmaps = googlemaps.Client(key=api_key)
228
 
229
+ cg = ClimateGAN(
230
+ model_path="config/model/masker",
231
+ dev_mode=os.environ.get("CG_DEV_MODE", "").lower() == "true",
232
+ )
233
+ cg._setup_stable_diffusion()
234
+
235
+ radio = address = None
236
+ pred_ins = []
237
+ pred_outs = []
238
+
239
+ with gr.Blocks(css=CSS) as app:
240
+ with Row():
241
+ with Column():
242
+ Markdown("# ClimateGAN: Visualize Climate Change")
243
+ HTML(TEXTS[0])
244
+ with Column():
245
+ Markdown(TEXTS[1], elem_id="how-to-use-md")
246
+ with Row():
247
+ HTML("<hr><br><h2 style='font-size: 1.5rem;'>Choose Inputs</h2>")
248
+ with Row():
249
+ with Column():
250
+ if api_key:
251
+ radio = Radio(["From Address", "From Image"], label="Input Type")
252
+ pred_ins += [radio]
253
+ im_inp = Image(label="Input Image", visible=not api_key)
254
+ pred_ins += [im_inp]
255
+ if api_key:
256
+ address = Textbox(label="Address or place name", visible=False)
257
+ pred_ins += [address]
258
+ with Column():
259
+ pred_ins += [
260
+ Dropdown(
261
+ choices=[
262
+ "ClimateGAN Painter",
263
+ "Stable Diffusion Painter",
264
+ "Both",
265
+ ],
266
+ label="Choose Flood Painter",
267
+ value="Both",
268
+ )
269
+ ]
270
+ btn = Button(
271
+ "See for yourself!",
272
+ label="Run",
273
+ variant="primary",
274
+ visible=not api_key,
275
+ )
276
+ with Row():
277
+ Markdown("## Outputs")
278
+ with Row():
279
+ pred_outs += [Image(type="numpy", label="Original image")]
280
+ pred_outs += [Image(type="numpy", label="Masked input image")]
281
+ pred_outs += [Image(type="numpy", label="Segmentation map")]
282
+ pred_outs += [Image(type="numpy", label="Depth map")]
283
+ with Row():
284
+ pred_outs += [Image(type="numpy", label="ClimateGAN-Flooded image")]
285
+ pred_outs += [Image(type="numpy", label="Stable Diffusion-Flooded image")]
286
+ pred_outs += [
287
+ Image(
288
+ type="numpy",
289
+ label="Stable Diffusion-Flooded image (restricted to masked area)",
290
+ )
291
+ ]
292
+ with Row():
293
+ pred_outs += [Image(type="numpy", label="Comparison of flood images")]
294
+ with Row():
295
+ pred_outs += [Image(type="numpy", label="Wildfire")]
296
+ pred_outs += [Image(type="numpy", label="Smog")]
297
+ Image(type="numpy", label="Empty on purpose", interactive=False)
298
+ btn.click(predict(cg, api_key), inputs=pred_ins, outputs=pred_outs)
299
+
300
+ if api_key:
301
+ radio.change(toggle, inputs=[radio], outputs=[address, im_inp, btn])
302
+
303
+ app.launch(show_api=False)
climategan/trainer.py CHANGED
@@ -22,7 +22,8 @@ from torch import autograd, sigmoid, softmax
22
  from torch.cuda.amp import GradScaler, autocast
23
  from tqdm import tqdm
24
 
25
- from climategan.data import get_all_loaders
 
26
  from climategan.discriminator import OmniDiscriminator, create_discriminator
27
  from climategan.eval_metrics import accuracy, mIOU
28
  from climategan.fid import compute_val_fid
@@ -41,6 +42,7 @@ from climategan.tutils import (
41
  print_num_parameters,
42
  shuffle_batch_tuple,
43
  srgb2lrgb,
 
44
  vgg_preprocess,
45
  zero_grad,
46
  )
@@ -223,18 +225,21 @@ class Trainer:
223
  bin_value=-1,
224
  half=False,
225
  xla=False,
226
- cloudy=False,
227
  auto_resize_640=False,
228
  ignore_event=set(),
229
- return_masks=False,
230
  ):
231
  """
232
- Create a dictionnary of events from a numpy or tensor,
233
  single or batch image data.
234
 
235
- stores is a dictionnary of times for the Timer class.
236
 
237
  bin_value is used to binarize (or not) flood masks
 
 
 
238
  """
239
  assert self.is_setup
240
  assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
@@ -308,28 +313,39 @@ class Trainer:
308
  if xla:
309
  xm.mark_step()
310
 
 
 
311
  if numpy:
312
  with Timer(store=stores.get("numpy", [])):
313
- # normalize to 0-1
314
- flood = normalize(flood).cpu()
315
- smog = normalize(smog).cpu()
316
- wildfire = normalize(wildfire).cpu()
317
-
318
- # convert to numpy
319
- flood = flood.permute(0, 2, 3, 1).numpy()
320
- smog = smog.permute(0, 2, 3, 1).numpy()
321
- wildfire = wildfire.permute(0, 2, 3, 1).numpy()
322
-
323
- # convert to 0-255 uint8
324
- flood = (flood * 255).astype(np.uint8)
325
- smog = (smog * 255).astype(np.uint8)
326
- wildfire = (wildfire * 255).astype(np.uint8)
327
-
328
- output_data = {"flood": flood, "wildfire": wildfire, "smog": smog}
329
- if return_masks:
330
- output_data["mask"] = (
331
- ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
332
- )
 
 
 
 
 
 
 
 
 
333
 
334
  return output_data
335
 
 
22
  from torch.cuda.amp import GradScaler, autocast
23
  from tqdm import tqdm
24
 
25
+ from climategan.data import get_all_loaders, decode_segmap_merged_labels
26
+
27
  from climategan.discriminator import OmniDiscriminator, create_discriminator
28
  from climategan.eval_metrics import accuracy, mIOU
29
  from climategan.fid import compute_val_fid
 
42
  print_num_parameters,
43
  shuffle_batch_tuple,
44
  srgb2lrgb,
45
+ tensor_to_uint8_numpy_image,
46
  vgg_preprocess,
47
  zero_grad,
48
  )
 
225
  bin_value=-1,
226
  half=False,
227
  xla=False,
228
+ cloudy=True,
229
  auto_resize_640=False,
230
  ignore_event=set(),
231
+ return_intermediates=False,
232
  ):
233
  """
234
+ Create a dictionary of events from a numpy or tensor,
235
  single or batch image data.
236
 
237
+ stores is a dictionary of times for the Timer class.
238
 
239
  bin_value is used to binarize (or not) flood masks
240
+
241
+ all values in the output dictionary have 4 dimensions:
242
+ BxHxWxC if numpy else BxCxHxW
243
  """
244
  assert self.is_setup
245
  assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
 
313
  if xla:
314
  xm.mark_step()
315
 
316
+ output_data = {}
317
+
318
  if numpy:
319
  with Timer(store=stores.get("numpy", [])):
320
+ if "flood" not in ignore_event:
321
+ # normalize to 0-1
322
+ flood = tensor_to_uint8_numpy_image(flood)
323
+ # convert to 0-255 uint8
324
+ output_data["flood"] = flood
325
+ if "wildfire" not in ignore_event:
326
+ wildfire = tensor_to_uint8_numpy_image(wildfire)
327
+ output_data["wildfire"] = wildfire
328
+ if "smog" not in ignore_event:
329
+ smog = tensor_to_uint8_numpy_image(smog)
330
+ output_data["smog"] = smog
331
+
332
+ if return_intermediates:
333
+ if numpy:
334
+ output_data["mask"] = (
335
+ ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
336
+ )
337
+ output_data["depth"] = tensor_to_uint8_numpy_image(depth)
338
+ output_data["segmentation"] = (
339
+ decode_segmap_merged_labels(segmentation, "r", False)
340
+ .cpu()
341
+ .permute(0, 2, 3, 1)
342
+ .numpy()
343
+ .astype(np.uint8)
344
+ )
345
+ else:
346
+ output_data["mask"] = mask
347
+ output_data["depth"] = depth
348
+ output_data["segmentation"] = segmentation
349
 
350
  return output_data
351
 
climategan/tutils.py CHANGED
@@ -564,14 +564,29 @@ def lrgb2srgb(ims):
564
  return outs[0]
565
 
566
 
567
- def normalize(t, mini=0, maxi=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  if len(t.shape) == 3:
569
  return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
570
 
571
  batch_size = t.shape[0]
572
- min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, 1, 1, 1)
 
573
  t = t - min_t
574
- max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, 1, 1, 1)
575
  t = t / max_t
576
  return mini + (maxi - mini) * t
577
 
@@ -644,7 +659,7 @@ def write_architecture(trainer):
644
  f.write(output)
645
 
646
 
647
- def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
648
  delta = (res[0] / shape[0], res[1] / shape[1])
649
  d = (shape[0] // res[0], shape[1] // res[1])
650
 
@@ -719,3 +734,30 @@ def tensor_ims_to_np_uint8s(ims):
719
  nps.append(n.astype(np.uint8))
720
 
721
  return nps[0] if len(nps) == 1 else nps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  return outs[0]
565
 
566
 
567
+ def normalize(t, mini=0.0, maxi=1.0):
568
+ """
569
+ Normalizes a tensor to [0, 1].
570
+ If the tensor has more than 3 dimensions, the first one
571
+ is assumed to be the batch dimension and the tensor is
572
+ normalized per batch element, not across the batches.
573
+
574
+ Args:
575
+ t (torch.Tensor): Tensor to normalize
576
+ mini (float, optional): Min allowed value. Defaults to 0.
577
+ maxi (float, optional): Max allowed value. Defaults to 1.
578
+
579
+ Returns:
580
+ torch.Tensor: The normalized tensor
581
+ """
582
  if len(t.shape) == 3:
583
  return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
584
 
585
  batch_size = t.shape[0]
586
+ extra_dims = [1] * (t.ndim - 1)
587
+ min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, *extra_dims)
588
  t = t - min_t
589
+ max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, *extra_dims)
590
  t = t / max_t
591
  return mini + (maxi - mini) * t
592
 
 
659
  f.write(output)
660
 
661
 
662
+ def rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
663
  delta = (res[0] / shape[0], res[1] / shape[1])
664
  d = (shape[0] // res[0], shape[1] // res[1])
665
 
 
734
  nps.append(n.astype(np.uint8))
735
 
736
  return nps[0] if len(nps) == 1 else nps
737
+
738
+
739
+ def tensor_to_uint8_numpy_image(tensor):
740
+ """
741
+ Turns a BxCxHxW tensor into a numpy image:
742
+ * normalize
743
+ * to [0, 255]
744
+ * detach
745
+ * channels last
746
+ * to uin8
747
+ * to cpu
748
+ * to numpy
749
+
750
+ Args:
751
+ tensor (torch.Tensor): Tensor to transform
752
+
753
+ Returns:
754
+ np.array: BxHxWxC np.uint8 array in [0, 255]
755
+ """
756
+ return (
757
+ normalize(tensor, 0, 255) # [0, 255]
758
+ .detach() # detach from graph if needed
759
+ .permute(0, 2, 3, 1) # BxHxWxC
760
+ .to(torch.uint8) # uint8
761
+ .cpu() # cpu
762
+ .numpy() # numpy array
763
+ )
climategan/utils.py CHANGED
@@ -917,14 +917,14 @@ def all_texts_to_array(texts, width=640, height=40):
917
 
918
 
919
  class Timer:
920
- def __init__(self, name="", store=None, precision=3, ignore=False, cuda=True):
921
  self.name = name
922
  self.store = store
923
  self.precision = precision
924
  self.ignore = ignore
925
- self.cuda = cuda
926
 
927
- if cuda:
928
  self._start_event = torch.cuda.Event(enable_timing=True)
929
  self._end_event = torch.cuda.Event(enable_timing=True)
930
 
 
917
 
918
 
919
  class Timer:
920
+ def __init__(self, name="", store=None, precision=3, ignore=False, cuda=None):
921
  self.name = name
922
  self.store = store
923
  self.precision = precision
924
  self.ignore = ignore
925
+ self.cuda = cuda if cuda is not None else torch.cuda.is_available()
926
 
927
+ if self.cuda:
928
  self._start_event = torch.cuda.Event(enable_timing=True)
929
  self._end_event = torch.cuda.Event(enable_timing=True)
930
 
climategan_wrapper.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
2
+ # thank you @NimaBoscarino
3
+
4
+ import os
5
+ import re
6
+ from pathlib import Path
7
+ from uuid import uuid4
8
+ from minydra import resolved_args
9
+ import numpy as np
10
+ import torch
11
+ from diffusers import StableDiffusionInpaintPipeline
12
+ from PIL import Image
13
+ from skimage.color import rgba2rgb
14
+ from skimage.transform import resize
15
+
16
+ from climategan.trainer import Trainer
17
+
18
+
19
+ CUDA = torch.cuda.is_available()
20
+
21
+
22
+ def concat_events(output_dict, events, i=None, axis=1):
23
+ """
24
+ Concatenates the `i`th data in `output_dict` according to the keys listed
25
+ in `events` on dimension `axis`.
26
+
27
+ Args:
28
+ output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping
29
+ events to their corresponding data :
30
+ {k: [HxWxC]} (for i != None) or {k: BxHxWxC}.
31
+ events (list[str]): output_dict's keys to concatenate.
32
+ axis (int, optional): Concatenation axis. Defaults to 1.
33
+ """
34
+ cs = [e for e in events if e in output_dict]
35
+ if i is not None:
36
+ return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis))
37
+ return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis))
38
+
39
+
40
+ def clear(folder):
41
+ """
42
+ Deletes all the images without the inference separator "---" in their name.
43
+
44
+ Args:
45
+ folder (Union[str, Path]): The folder to clear.
46
+ """
47
+ for i in list(Path(folder).iterdir()):
48
+ if i.is_file() and "---" in i.stem:
49
+ i.unlink()
50
+
51
+
52
+ def uint8(array, rescale=False):
53
+ """
54
+ convert an array to np.uint8 (does not rescale or anything else than changing dtype)
55
+ Args:
56
+ array (np.array): array to modify
57
+ Returns:
58
+ np.array(np.uint8): converted array
59
+ """
60
+ if rescale:
61
+ if array.min() < 0:
62
+ if array.min() >= -1 and array.max() <= 1:
63
+ array = (array + 1) / 2
64
+ else:
65
+ raise ValueError(
66
+ f"Data range mismatch for image: ({array.min()}, {array.max()})"
67
+ )
68
+ if array.max() <= 1:
69
+ array = array * 255
70
+ return array.astype(np.uint8)
71
+
72
+
73
+ def resize_and_crop(img, to=640):
74
+ """
75
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
76
+ is `to`, then crops this resized image in its center so that the output is `to x to`
77
+ without aspect ratio distortion
78
+ Args:
79
+ img (np.array): np.uint8 255 image
80
+ Returns:
81
+ np.array: [0, 1] np.float32 image
82
+ """
83
+ # resize keeping aspect ratio: smallest dim is 640
84
+ h, w = img.shape[:2]
85
+ if h < w:
86
+ size = (to, int(to * w / h))
87
+ else:
88
+ size = (int(to * h / w), to)
89
+
90
+ r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
91
+ r_img = uint8(r_img)
92
+
93
+ # crop in the center
94
+ H, W = r_img.shape[:2]
95
+
96
+ top = (H - to) // 2
97
+ left = (W - to) // 2
98
+
99
+ rc_img = r_img[top : top + to, left : left + to, :]
100
+
101
+ return rc_img / 255.0
102
+
103
+
104
+ def to_m1_p1(img):
105
+ """
106
+ rescales a [0, 1] image to [-1, +1]
107
+ Args:
108
+ img (np.array): float32 numpy array of an image in [0, 1]
109
+ i (int): Index of the image being rescaled
110
+ Raises:
111
+ ValueError: If the image is not in [0, 1]
112
+ Returns:
113
+ np.array(np.float32): array in [-1, +1]
114
+ """
115
+ if img.min() >= 0 and img.max() <= 1:
116
+ return (img.astype(np.float32) - 0.5) * 2
117
+ raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
118
+
119
+
120
+ # No need to do any timing in this, since it's just for the HF Space
121
+ class ClimateGAN:
122
+ def __init__(self, model_path, dev_mode=False) -> None:
123
+ """
124
+ A wrapper for the ClimateGAN model that you can use to generate
125
+ events from images or folders containing images.
126
+
127
+ Args:
128
+ model_path (Union[str, Path]): Where to load the Masker from
129
+ """
130
+ torch.set_grad_enabled(False)
131
+ self.target_size = 640
132
+ self._stable_diffusion_is_setup = False
133
+ self.dev_mode = dev_mode
134
+ if self.dev_mode:
135
+ return
136
+ self.trainer = Trainer.resume_from_path(
137
+ model_path,
138
+ setup=True,
139
+ inference=True,
140
+ new_exp=None,
141
+ )
142
+ if CUDA:
143
+ self.trainer.G.half()
144
+
145
+ def _setup_stable_diffusion(self):
146
+ """
147
+ Sets up the stable diffusion pipeline for in-painting.
148
+ Make sure you have accepted the license on the model's card
149
+ https://huggingface.co/CompVis/stable-diffusion-v1-4
150
+ """
151
+ if self.dev_mode:
152
+ return
153
+
154
+ try:
155
+ self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
156
+ "runwayml/stable-diffusion-inpainting",
157
+ revision="fp16" if CUDA else "main",
158
+ torch_dtype=torch.float16 if CUDA else torch.float32,
159
+ safety_checker=None,
160
+ use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
161
+ ).to(self.trainer.device)
162
+ self._stable_diffusion_is_setup = True
163
+ except Exception as e:
164
+ print(
165
+ "\nCould not load stable diffusion model. "
166
+ + "Please make sure you have accepted the license on the model's"
167
+ + " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n"
168
+ )
169
+ raise e
170
+
171
+ def _preprocess_image(self, img):
172
+ """
173
+ Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array
174
+ in [-1, 1].
175
+
176
+ Args:
177
+ img (np.array): Image to resize crop and rescale
178
+
179
+ Returns:
180
+ np.array: Resized, cropped and rescaled image
181
+ """
182
+ # rgba to rgb
183
+ data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
184
+
185
+ # to args.target_size
186
+ data = resize_and_crop(data, self.target_size)
187
+
188
+ # resize() produces [0, 1] images, rescale to [-1, 1]
189
+ data = to_m1_p1(data)
190
+ return data
191
+
192
+ # Does all three inferences at the moment.
193
+ def infer_single(
194
+ self,
195
+ orig_image,
196
+ painter="both",
197
+ prompt="An HD picture of a street with dirty water after a heavy flood",
198
+ concats=[
199
+ "input",
200
+ "masked_input",
201
+ "climategan_flood",
202
+ "stable_flood",
203
+ "stable_copy_flood",
204
+ ],
205
+ as_pil_image=False,
206
+ ):
207
+ """
208
+ Infers the image with the ClimateGAN model.
209
+ Importantly (and unlike self.infer_preprocessed_batch), the image is
210
+ pre-processed by self._preprocess_image before going through the networks.
211
+
212
+ Output dict contains the following keys:
213
+ - "input": The input image
214
+ - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
215
+ - "masked_input": The input image with the mask applied
216
+ - "climategan_flood": The flooded image generated by ClimateGAN's Painter
217
+ on the masked input (only if "painter" is "climategan" or "both").
218
+ - "stable_flood": The flooded image in-painted by the stable diffusion model
219
+ from the mask and the input image (only if "painter" is "stable_diffusion"
220
+ or "both").
221
+ - "stable_copy_flood": The flooded image in-painted by the stable diffusion
222
+ model with its original context pasted back in:
223
+ y = m * flooded + (1-m) * input
224
+ (only if "painter" is "stable_diffusion" or "both").
225
+
226
+ Args:
227
+ orig_image (Union[str, np.array]): image to infer on. Can be a path to
228
+ an image which will be read.
229
+ painter (str, optional): Which painter to use: "climategan",
230
+ "stable_diffusion" or "both". Defaults to "both".
231
+ prompt (str, optional): The prompt used to guide the diffusion. Defaults
232
+ to "An HD picture of a street with dirty water after a heavy flood".
233
+ concats (list, optional): List of keys in `output` to concatenate together
234
+ in a new `{original_stem}_concat` image written. Defaults to:
235
+ ["input", "masked_input", "climategan_flood", "stable_flood",
236
+ "stable_copy_flood"].
237
+
238
+ Returns:
239
+ dict: a dictionary containing the output images {k: HxWxC}. C is omitted
240
+ for masks (HxW).
241
+ """
242
+ if self.dev_mode:
243
+ return {
244
+ "input": orig_image,
245
+ "mask": np.random.randint(0, 255, (640, 640)),
246
+ "masked_input": np.random.randint(0, 255, (640, 640, 3)),
247
+ "climategan_flood": np.random.randint(0, 255, (640, 640, 3)),
248
+ "stable_flood": np.random.randint(0, 255, (640, 640, 3)),
249
+ "stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)),
250
+ "concat": np.random.randint(0, 255, (640, 640 * 5, 3)),
251
+ "smog": np.random.randint(0, 255, (640, 640, 3)),
252
+ "wildfire": np.random.randint(0, 255, (640, 640, 3)),
253
+ "depth": np.random.randint(0, 255, (640, 640, 1)),
254
+ "segmentation": np.random.randint(0, 255, (640, 640, 3)),
255
+ }
256
+ return
257
+
258
+ image_array = (
259
+ np.array(Image.open(orig_image))
260
+ if isinstance(orig_image, str)
261
+ else orig_image
262
+ )
263
+
264
+ pil_image = None
265
+ if as_pil_image:
266
+ pil_image = Image.fromarray(image_array)
267
+ print("Preprocessing image")
268
+ image = self._preprocess_image(image_array)
269
+ output_dict = self.infer_preprocessed_batch(
270
+ images=image[None, ...],
271
+ painter=painter,
272
+ prompt=prompt,
273
+ concats=concats,
274
+ pil_image=pil_image,
275
+ )
276
+ print("Inference done")
277
+ return {k: v[0] for k, v in output_dict.items()}
278
+
279
+ def infer_preprocessed_batch(
280
+ self,
281
+ images,
282
+ painter="both",
283
+ prompt="An HD picture of a street with dirty water after a heavy flood",
284
+ concats=[
285
+ "input",
286
+ "masked_input",
287
+ "climategan_flood",
288
+ "stable_flood",
289
+ "stable_copy_flood",
290
+ ],
291
+ pil_image=None,
292
+ ):
293
+ """
294
+ Infers ClimateGAN predictions on a batch of preprocessed images.
295
+ It assumes that each image in the batch has been preprocessed with
296
+ self._preprocess_image().
297
+
298
+ Output dict contains the following keys:
299
+ - "input": The input image
300
+ - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
301
+ - "masked_input": The input image with the mask applied
302
+ - "climategan_flood": The flooded image generated by ClimateGAN's Painter
303
+ on the masked input (only if "painter" is "climategan" or "both").
304
+ - "stable_flood": The flooded image in-painted by the stable diffusion model
305
+ from the mask and the input image (only if "painter" is "stable_diffusion"
306
+ or "both").
307
+ - "stable_copy_flood": The flooded image in-painted by the stable diffusion
308
+ model with its original context pasted back in:
309
+ y = m * flooded + (1-m) * input
310
+ (only if "painter" is "stable_diffusion" or "both").
311
+
312
+ Args:
313
+ images (np.array): A batch of input images BxHxWx3
314
+ painter (str, optional): Which painter to use: "climategan",
315
+ "stable_diffusion" or "both". Defaults to "both".
316
+ prompt (str, optional): The prompt used to guide the diffusion. Defaults
317
+ to "An HD picture of a street with dirty water after a heavy flood".
318
+ concats (list, optional): List of keys in `output` to concatenate together
319
+ in a new `{original_stem}_concat` image written. Defaults to:
320
+ ["input", "masked_input", "climategan_flood", "stable_flood",
321
+ "stable_copy_flood"].
322
+ pil_image (PIL.Image, optional): The original PIL image. If provided,
323
+ will be used for a single inference (batch_size=1)
324
+
325
+ Returns:
326
+ dict: a dictionary containing the output images
327
+ """
328
+ assert painter in [
329
+ "both",
330
+ "stable_diffusion",
331
+ "climategan",
332
+ ], f"Unknown painter: {painter}"
333
+
334
+ ignore_event = set()
335
+ if painter == "stable_diffusion":
336
+ ignore_event.add("flood")
337
+
338
+ if pil_image is not None:
339
+ print("Warning: `pil_image` has been provided, it will override `images`")
340
+ images = self._preprocess_image(np.array(pil_image))[None, ...]
341
+ pil_image = Image.fromarray(((images[0] + 1) / 2 * 255).astype(np.uint8))
342
+
343
+ # Retrieve numpy events as a dict {event: array[BxHxWxC]}
344
+ print("Inferring ClimateGAN events")
345
+ outputs = self.trainer.infer_all(
346
+ images,
347
+ numpy=True,
348
+ bin_value=0.5,
349
+ half=CUDA,
350
+ ignore_event=ignore_event,
351
+ return_intermediates=True,
352
+ )
353
+
354
+ outputs["input"] = uint8(images, True)
355
+ # from Bx1xHxW to BxHxWx1
356
+ outputs["masked_input"] = outputs["input"] * (
357
+ outputs["mask"].squeeze(1)[..., None] == 0
358
+ )
359
+
360
+ if painter in {"both", "climategan"}:
361
+ outputs["climategan_flood"] = outputs.pop("flood")
362
+ else:
363
+ del outputs["flood"]
364
+
365
+ if painter != "climategan":
366
+ if not self._stable_diffusion_is_setup:
367
+ print("Setting up stable diffusion in-painting pipeline")
368
+ self._setup_stable_diffusion()
369
+
370
+ mask = outputs["mask"].squeeze(1)
371
+ input_images = (
372
+ torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
373
+ if pil_image is None
374
+ else pil_image
375
+ )
376
+ input_mask = (
377
+ torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
378
+ if pil_image is None
379
+ else Image.fromarray(mask[0])
380
+ )
381
+ print("Inferring stable diffusion in-painting for 50 steps")
382
+ floods = self.sdip_pipeline(
383
+ prompt=[prompt] * images.shape[0],
384
+ image=input_images,
385
+ mask_image=input_mask,
386
+ height=640,
387
+ width=640,
388
+ num_inference_steps=50,
389
+ )
390
+ print("Stable diffusion in-painting done")
391
+
392
+ bin_mask = mask[..., None] > 0
393
+ flood = np.stack([np.array(i) for i in floods.images])
394
+ copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask)
395
+ outputs["stable_flood"] = flood
396
+ outputs["stable_copy_flood"] = copy_flood
397
+
398
+ if concats:
399
+ print("Concatenating flood images")
400
+ outputs["concat"] = concat_events(outputs, concats, axis=2)
401
+
402
+ return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()}
403
+
404
+ def infer_folder(
405
+ self,
406
+ folder_path,
407
+ painter="both",
408
+ prompt="An HD picture of a street with dirty water after a heavy flood",
409
+ batch_size=4,
410
+ concats=[
411
+ "input",
412
+ "masked_input",
413
+ "climategan_flood",
414
+ "stable_flood",
415
+ "stable_copy_flood",
416
+ ],
417
+ write=True,
418
+ overwrite=False,
419
+ ):
420
+ """
421
+ Infers the images in a folder with the ClimateGAN model, batching images for
422
+ inference according to the batch_size.
423
+
424
+ Images must end in .jpg, .jpeg or .png (not case-sensitive).
425
+ Images must not contain the separator ("---") in their name.
426
+
427
+ Images will be written to disk in the same folder as the input images, with
428
+ a name that depends on its data, potentially the prompt and a random
429
+ identifier in case multiple inferences are run in the folder.
430
+
431
+ Output dict contains the following keys:
432
+ - "input": The input image
433
+ - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
434
+ - "masked_input": The input image with the mask applied
435
+ - "climategan_flood": The flooded image generated by ClimateGAN's Painter
436
+ on the masked input (only if "painter" is "climategan" or "both").
437
+ - "stable_flood": The flooded image in-painted by the stable diffusion model
438
+ from the mask and the input image (only if "painter" is "stable_diffusion"
439
+ or "both").
440
+ - "stable_copy_flood": The flooded image in-painted by the stable diffusion
441
+ model with its original context pasted back in:
442
+ y = m * flooded + (1-m) * input
443
+ (only if "painter" is "stable_diffusion" or "both").
444
+
445
+ Args:
446
+ folder_path (Union[str, Path]): Where to read images from.
447
+ painter (str, optional): Which painter to use: "climategan",
448
+ "stable_diffusion" or "both". Defaults to "both".
449
+ prompt (str, optional): The prompt used to guide the diffusion. Defaults
450
+ to "An HD picture of a street with dirty water after a heavy flood".
451
+ batch_size (int, optional): Size of inference batches. Defaults to 4.
452
+ concats (list, optional): List of keys in `output` to concatenate together
453
+ in a new `{original_stem}_concat` image written. Defaults to:
454
+ ["input", "masked_input", "climategan_flood", "stable_flood",
455
+ "stable_copy_flood"].
456
+ write (bool, optional): Whether or not to write the outputs to the input
457
+ folder.Defaults to True.
458
+ overwrite (Union[bool, str], optional): Whether to overwrite the images or
459
+ not. If a string is provided, it will be included in the name.
460
+ Defaults to False.
461
+
462
+ Returns:
463
+ dict: a dictionary containing the output images
464
+ """
465
+ folder_path = Path(folder_path).expanduser().resolve()
466
+ assert folder_path.exists(), f"Folder {str(folder_path)} does not exist"
467
+ assert folder_path.is_dir(), f"{str(folder_path)} is not a directory"
468
+ im_paths = [
469
+ p
470
+ for p in folder_path.iterdir()
471
+ if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
472
+ ]
473
+ assert im_paths, f"No images found in {str(folder_path)}"
474
+ ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths]
475
+ batches = [
476
+ np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size)
477
+ ]
478
+ inferences = [
479
+ self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches
480
+ ]
481
+
482
+ outputs = {
483
+ k: [i for e in inferences for i in e[k]] for k in inferences[0].keys()
484
+ }
485
+
486
+ if write:
487
+ self.write(outputs, im_paths, painter, overwrite, prompt)
488
+
489
+ return outputs
490
+
491
+ def write(
492
+ self,
493
+ outputs,
494
+ im_paths,
495
+ painter="both",
496
+ overwrite=False,
497
+ prompt="",
498
+ ):
499
+ """
500
+ Writes the outputs of the inference to disk, in the input folder.
501
+
502
+ Images will be named like:
503
+ f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}"
504
+ `painter_type` is either "climategan" or f"stable_diffusion_{prompt}"
505
+
506
+ Args:
507
+ outputs (_type_): The inference procedure's output dict.
508
+ im_paths (list[Path]): The list of input images paths.
509
+ painter (str, optional): Which painter was used. Defaults to "both".
510
+ overwrite (bool, optional): Whether to overwrite the images or not.
511
+ If a string is provided, it will be included in the name.
512
+ If False, a random identifier will be added to the name.
513
+ Defaults to False.
514
+ prompt (str, optional): The prompt used to guide the diffusion. Defaults
515
+ to "".
516
+ """
517
+ prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower()
518
+ overwrite_prefix = ""
519
+ if not overwrite:
520
+ overwrite_prefix = str(uuid4())[:8]
521
+ print("Writing events with prefix", overwrite_prefix)
522
+ else:
523
+ if isinstance(overwrite, str):
524
+ overwrite_prefix = overwrite
525
+ print("Writing events with prefix", overwrite_prefix)
526
+
527
+ # for each image, for each event/data type
528
+ for i, im_path in enumerate(im_paths):
529
+ for event, ims in outputs.items():
530
+ painter_prefix = ""
531
+ if painter == "climategan" and event == "flood":
532
+ painter_prefix = "climategan"
533
+ elif (
534
+ painter in {"stable_diffusion", "both"} and event == "stable_flood"
535
+ ):
536
+ painter_prefix = f"_stable_{prompt}"
537
+ elif painter == "both" and event == "climategan_flood":
538
+ painter_prefix = ""
539
+
540
+ im = ims[i]
541
+ im = Image.fromarray(uint8(im))
542
+ imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
543
+ im.save(im_path.parent / (imstem + im_path.suffix))
544
+
545
+
546
+ if __name__ == "__main__":
547
+ print("Run `$ python climategan_wrapper.py help` for usage instructions\n")
548
+
549
+ # parse arguments
550
+ args = resolved_args(
551
+ defaults={
552
+ "input_folder": None,
553
+ "output_folder": None,
554
+ "painter": "both",
555
+ "help": False,
556
+ }
557
+ )
558
+
559
+ # print help
560
+ if args.help:
561
+ print(
562
+ "Usage: python inference.py input_folder=/path/to/folder\n"
563
+ + "By default inferences will be stored in the input folder.\n"
564
+ + "Add `output_folder=/path/to/folder` for a different output folder.\n"
565
+ + "By default, both ClimateGAN and Stable Diffusion will be used."
566
+ + "Change this by adding `painter=climategan` or"
567
+ + " `painter=stable_diffusion`.\n"
568
+ + "Make sure you have agreed to the terms of use for the models."
569
+ + "In particular, visit SD's model card to agree to the terms of use:"
570
+ + " https://huggingface.co/runwayml/stable-diffusion-inpainting"
571
+ )
572
+ # print args
573
+ args.pretty_print()
574
+
575
+ # load models
576
+ cg = ClimateGAN("models/climategan")
577
+
578
+ # check painter type
579
+ assert args.painter in {"climategan", "stable_diffusion", "both",}, (
580
+ f"Unknown painter {args.painter}. "
581
+ + "Allowed values are 'climategan', 'stable_diffusion' and 'both'."
582
+ )
583
+
584
+ # load SD pipeline if need be
585
+ if args.painter != "climate_gan":
586
+ cg._setup_stable_diffusion()
587
+
588
+ # resolve input folder path
589
+ in_path = Path(args.input_folder).expanduser().resolve()
590
+ assert in_path.exists(), f"Folder {str(in_path)} does not exist"
591
+
592
+ # output is input if not specified
593
+ if args.output_folder is None:
594
+ out_path = in_path
595
+
596
+ # find images in input folder
597
+ im_paths = [
598
+ p
599
+ for p in in_path.iterdir()
600
+ if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
601
+ ]
602
+ assert im_paths, f"No images found in {str(im_paths)}"
603
+
604
+ print(f"\nFound {len(im_paths)} images in {str(in_path)}\n")
605
+
606
+ # infer and write
607
+ for i, im_path in enumerate(im_paths):
608
+ print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name)
609
+ outs = cg.infer_single(
610
+ np.array(Image.open(im_path)),
611
+ args.painter,
612
+ as_pil_image=True,
613
+ concats=[
614
+ "input",
615
+ "masked_input",
616
+ "climategan_flood",
617
+ "stable_copy_flood",
618
+ ],
619
+ )
620
+ for k, v in outs.items():
621
+ name = f"{im_path.stem}---{k}{im_path.suffix}"
622
+ im = Image.fromarray(uint8(v))
623
+ im.save(out_path / name)
624
+ print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n")
inferences.py DELETED
@@ -1,108 +0,0 @@
1
- # based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
2
- # thank you @NimaBoscarino
3
-
4
- import torch
5
- from skimage.color import rgba2rgb
6
- from skimage.transform import resize
7
- import numpy as np
8
-
9
- from climategan.trainer import Trainer
10
-
11
-
12
- def uint8(array):
13
- """
14
- convert an array to np.uint8 (does not rescale or anything else than changing dtype)
15
- Args:
16
- array (np.array): array to modify
17
- Returns:
18
- np.array(np.uint8): converted array
19
- """
20
- return array.astype(np.uint8)
21
-
22
-
23
- def resize_and_crop(img, to=640):
24
- """
25
- Resizes an image so that it keeps the aspect ratio and the smallest dimensions
26
- is `to`, then crops this resized image in its center so that the output is `to x to`
27
- without aspect ratio distortion
28
- Args:
29
- img (np.array): np.uint8 255 image
30
- Returns:
31
- np.array: [0, 1] np.float32 image
32
- """
33
- # resize keeping aspect ratio: smallest dim is 640
34
- h, w = img.shape[:2]
35
- if h < w:
36
- size = (to, int(to * w / h))
37
- else:
38
- size = (int(to * h / w), to)
39
-
40
- r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
41
- r_img = uint8(r_img)
42
-
43
- # crop in the center
44
- H, W = r_img.shape[:2]
45
-
46
- top = (H - to) // 2
47
- left = (W - to) // 2
48
-
49
- rc_img = r_img[top : top + to, left : left + to, :]
50
-
51
- return rc_img / 255.0
52
-
53
-
54
- def to_m1_p1(img):
55
- """
56
- rescales a [0, 1] image to [-1, +1]
57
- Args:
58
- img (np.array): float32 numpy array of an image in [0, 1]
59
- i (int): Index of the image being rescaled
60
- Raises:
61
- ValueError: If the image is not in [0, 1]
62
- Returns:
63
- np.array(np.float32): array in [-1, +1]
64
- """
65
- if img.min() >= 0 and img.max() <= 1:
66
- return (img.astype(np.float32) - 0.5) * 2
67
- raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
68
-
69
-
70
- # No need to do any timing in this, since it's just for the HF Space
71
- class ClimateGAN:
72
- def __init__(self, model_path) -> None:
73
- torch.set_grad_enabled(False)
74
- self.target_size = 640
75
- self.trainer = Trainer.resume_from_path(
76
- model_path,
77
- setup=True,
78
- inference=True,
79
- new_exp=None,
80
- )
81
-
82
- # Does all three inferences at the moment.
83
- def inference(self, orig_image):
84
- image = self._preprocess_image(orig_image)
85
-
86
- # Retrieve numpy events as a dict {event: array[BxHxWxC]}
87
- outputs = self.trainer.infer_all(
88
- image,
89
- numpy=True,
90
- bin_value=0.5,
91
- )
92
-
93
- return (
94
- outputs["flood"].squeeze(),
95
- outputs["wildfire"].squeeze(),
96
- outputs["smog"].squeeze(),
97
- )
98
-
99
- def _preprocess_image(self, img):
100
- # rgba to rgb
101
- data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
102
-
103
- # to args.target_size
104
- data = resize_and_crop(data, self.target_size)
105
-
106
- # resize() produces [0, 1] images, rescale to [-1, 1]
107
- data = to_m1_p1(data)
108
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict==2.4.0
2
+ aiohttp==3.8.3
3
+ aiosignal==1.2.0
4
+ anyio==3.6.2
5
+ appnope==0.1.3
6
+ APScheduler==3.7.0
7
+ asttokens==2.0.8
8
+ async-timeout==4.0.2
9
+ attrs==21.2.0
10
+ backcall==0.2.0
11
+ bcrypt==4.0.1
12
+ black==22.10.0
13
+ blis==0.7.9
14
+ Brotli==1.0.9
15
+ catalogue==2.0.8
16
+ certifi==2021.5.30
17
+ cffi==1.15.1
18
+ charset-normalizer==2.0.4
19
+ click==8.0.1
20
+ codecarbon==1.2.0
21
+ comet-ml==3.15.3
22
+ commonmark==0.9.1
23
+ confection==0.0.3
24
+ configobj==5.0.6
25
+ contourpy==1.0.5
26
+ cryptography==38.0.1
27
+ cycler==0.10.0
28
+ cymem==2.0.7
29
+ dash==2.0.0
30
+ dash-bootstrap-components==0.13.0
31
+ dash-core-components==2.0.0
32
+ dash-html-components==2.0.0
33
+ dash-table==5.0.0
34
+ dataclasses==0.6
35
+ decorator==5.0.9
36
+ diffusers==0.6.0
37
+ dulwich==0.20.25
38
+ everett==2.0.1
39
+ executing==1.1.1
40
+ fastapi==0.85.1
41
+ ffmpy==0.3.0
42
+ filelock==3.0.12
43
+ fire==0.4.0
44
+ flake8==5.0.4
45
+ Flask==2.0.1
46
+ Flask-Compress==1.10.1
47
+ fonttools==4.38.0
48
+ frozenlist==1.3.1
49
+ fsspec==2022.10.0
50
+ future==0.18.2
51
+ gdown==3.13.0
52
+ googlemaps==4.6.0
53
+ gradio==3.6
54
+ h11==0.12.0
55
+ httpcore==0.15.0
56
+ httpx==0.23.0
57
+ huggingface-hub==0.10.1
58
+ hydra-core==0.11.3
59
+ idna==3.2
60
+ imageio==2.9.0
61
+ importlib-metadata==5.0.0
62
+ ipython==7.27.0
63
+ itsdangerous==2.0.1
64
+ jedi==0.18.0
65
+ Jinja2==3.0.1
66
+ joblib==1.0.1
67
+ jsonschema==3.2.0
68
+ kiwisolver==1.3.2
69
+ kornia==0.5.10
70
+ langcodes==3.3.0
71
+ linkify-it-py==1.0.3
72
+ markdown-it-py==2.1.0
73
+ MarkupSafe==2.0.1
74
+ matplotlib==3.4.3
75
+ matplotlib-inline==0.1.2
76
+ mccabe==0.7.0
77
+ mdit-py-plugins==0.3.1
78
+ mdurl==0.1.2
79
+ minydra==0.1.6
80
+ multidict==6.0.2
81
+ murmurhash==1.0.9
82
+ mypy-extensions==0.4.3
83
+ networkx==2.6.2
84
+ numpy==1.21.2
85
+ nvidia-ml-py3==7.352.0
86
+ omegaconf==1.4.1
87
+ opencv-python==4.5.3.56
88
+ orjson==3.8.0
89
+ packaging==21.0
90
+ pandas==1.3.2
91
+ paramiko==2.11.0
92
+ parso==0.8.2
93
+ pathspec==0.10.1
94
+ pathy==0.6.2
95
+ pexpect==4.8.0
96
+ pickleshare==0.7.5
97
+ Pillow==8.3.2
98
+ platformdirs==2.5.2
99
+ plotly==5.3.1
100
+ preshed==3.0.8
101
+ prompt-toolkit==3.0.20
102
+ ptyprocess==0.7.0
103
+ pure-eval==0.2.2
104
+ py-cpuinfo==8.0.0
105
+ pycodestyle==2.9.1
106
+ pycparser==2.21
107
+ pycryptodome==3.15.0
108
+ pydantic==1.10.2
109
+ pydub==0.25.1
110
+ pyflakes==2.5.0
111
+ Pygments==2.10.0
112
+ PyNaCl==1.5.0
113
+ pynvml==11.0.0
114
+ pyparsing==2.4.7
115
+ pyrsistent==0.18.0
116
+ PySocks==1.7.1
117
+ python-dateutil==2.8.2
118
+ python-multipart==0.0.5
119
+ pytorch-ranger==0.1.1
120
+ pytz==2021.1
121
+ PyWavelets==1.1.1
122
+ PyYAML==5.4.1
123
+ regex==2022.9.13
124
+ requests==2.26.0
125
+ requests-toolbelt==0.9.1
126
+ rfc3986==1.5.0
127
+ rich==12.6.0
128
+ scikit-image==0.18.3
129
+ scikit-learn==0.24.2
130
+ scipy==1.7.1
131
+ seaborn==0.11.2
132
+ semantic-version==2.8.5
133
+ six==1.16.0
134
+ smart-open==5.2.1
135
+ sniffio==1.3.0
136
+ spacy==3.4.2
137
+ spacy-legacy==3.0.10
138
+ spacy-loggers==1.0.3
139
+ srsly==2.4.5
140
+ stack-data==0.5.1
141
+ starlette==0.20.4
142
+ tenacity==8.0.1
143
+ termcolor==1.1.0
144
+ thinc==8.1.5
145
+ threadpoolctl==2.2.0
146
+ tifffile==2021.8.30
147
+ tokenizers==0.13.1
148
+ tomli==2.0.1
149
+ torch==1.7.1
150
+ torch-optimizer==0.1.0
151
+ torchvision==0.8.2
152
+ tqdm==4.62.2
153
+ traitlets==5.1.0
154
+ transformers==4.23.1
155
+ typer==0.4.2
156
+ typing_extensions==4.4.0
157
+ tzlocal==2.1
158
+ uc-micro-py==1.0.1
159
+ urllib3==1.26.6
160
+ uvicorn==0.19.0
161
+ wasabi==0.10.1
162
+ wcwidth==0.2.5
163
+ websocket-client==1.2.1
164
+ websockets==10.3
165
+ Werkzeug==2.0.1
166
+ wrapt==1.12.1
167
+ wurlitzer==3.0.2
168
+ yarl==1.8.1
169
+ zipp==3.10.0