vict0rsch commited on
Commit
0d2cc0b
·
1 Parent(s): 490814b

new `tensor_to_uint8_numpy_image` tensor util

Browse files
Files changed (2) hide show
  1. climategan/trainer.py +10 -16
  2. climategan/tutils.py +46 -4
climategan/trainer.py CHANGED
@@ -39,10 +39,10 @@ from climategan.tutils import (
39
  get_WGAN_gradient,
40
  lrgb2srgb,
41
  normalize,
42
- normalize_tensor,
43
  print_num_parameters,
44
  shuffle_batch_tuple,
45
  srgb2lrgb,
 
46
  vgg_preprocess,
47
  zero_grad,
48
  )
@@ -231,12 +231,15 @@ class Trainer:
231
  return_intermediates=False,
232
  ):
233
  """
234
- Create a dictionnary of events from a numpy or tensor,
235
  single or batch image data.
236
 
237
- stores is a dictionnary of times for the Timer class.
238
 
239
  bin_value is used to binarize (or not) flood masks
 
 
 
240
  """
241
  assert self.is_setup
242
  assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
@@ -316,21 +319,14 @@ class Trainer:
316
  with Timer(store=stores.get("numpy", [])):
317
  if "flood" not in ignore_event:
318
  # normalize to 0-1
319
- flood = normalize(flood).cpu()
320
- # convert to numpy
321
- flood = flood.permute(0, 2, 3, 1).numpy()
322
  # convert to 0-255 uint8
323
- flood = (flood * 255).astype(np.uint8)
324
  output_data["flood"] = flood
325
  if "wildfire" not in ignore_event:
326
- wildfire = normalize(wildfire).cpu()
327
- wildfire = wildfire.permute(0, 2, 3, 1).numpy()
328
- wildfire = (wildfire * 255).astype(np.uint8)
329
  output_data["wildfire"] = wildfire
330
  if "smog" not in ignore_event:
331
- smog = normalize(smog).cpu()
332
- smog = smog.permute(0, 2, 3, 1).numpy()
333
- smog = (smog * 255).astype(np.uint8)
334
  output_data["smog"] = smog
335
 
336
  if return_intermediates:
@@ -338,9 +334,7 @@ class Trainer:
338
  output_data["mask"] = (
339
  ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
340
  )
341
- output_data["depth"] = (
342
- normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
343
- )
344
  output_data["segmentation"] = (
345
  decode_segmap_merged_labels(segmentation, "r", False)
346
  .cpu()
 
39
  get_WGAN_gradient,
40
  lrgb2srgb,
41
  normalize,
 
42
  print_num_parameters,
43
  shuffle_batch_tuple,
44
  srgb2lrgb,
45
+ tensor_to_uint8_numpy_image,
46
  vgg_preprocess,
47
  zero_grad,
48
  )
 
231
  return_intermediates=False,
232
  ):
233
  """
234
+ Create a dictionary of events from a numpy or tensor,
235
  single or batch image data.
236
 
237
+ stores is a dictionary of times for the Timer class.
238
 
239
  bin_value is used to binarize (or not) flood masks
240
+
241
+ all values in the output dictionary have 4 dimensions:
242
+ BxHxWxC if numpy else BxCxHxW
243
  """
244
  assert self.is_setup
245
  assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
 
319
  with Timer(store=stores.get("numpy", [])):
320
  if "flood" not in ignore_event:
321
  # normalize to 0-1
322
+ flood = tensor_to_uint8_numpy_image(flood)
 
 
323
  # convert to 0-255 uint8
 
324
  output_data["flood"] = flood
325
  if "wildfire" not in ignore_event:
326
+ wildfire = tensor_to_uint8_numpy_image(wildfire)
 
 
327
  output_data["wildfire"] = wildfire
328
  if "smog" not in ignore_event:
329
+ smog = tensor_to_uint8_numpy_image(smog)
 
 
330
  output_data["smog"] = smog
331
 
332
  if return_intermediates:
 
334
  output_data["mask"] = (
335
  ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
336
  )
337
+ output_data["depth"] = tensor_to_uint8_numpy_image(depth)
 
 
338
  output_data["segmentation"] = (
339
  decode_segmap_merged_labels(segmentation, "r", False)
340
  .cpu()
climategan/tutils.py CHANGED
@@ -564,14 +564,29 @@ def lrgb2srgb(ims):
564
  return outs[0]
565
 
566
 
567
- def normalize(t, mini=0, maxi=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  if len(t.shape) == 3:
569
  return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
570
 
571
  batch_size = t.shape[0]
572
- min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, 1, 1, 1)
 
573
  t = t - min_t
574
- max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, 1, 1, 1)
575
  t = t / max_t
576
  return mini + (maxi - mini) * t
577
 
@@ -644,7 +659,7 @@ def write_architecture(trainer):
644
  f.write(output)
645
 
646
 
647
- def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
648
  delta = (res[0] / shape[0], res[1] / shape[1])
649
  d = (shape[0] // res[0], shape[1] // res[1])
650
 
@@ -719,3 +734,30 @@ def tensor_ims_to_np_uint8s(ims):
719
  nps.append(n.astype(np.uint8))
720
 
721
  return nps[0] if len(nps) == 1 else nps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  return outs[0]
565
 
566
 
567
+ def normalize(t, mini=0.0, maxi=1.0):
568
+ """
569
+ Normalizes a tensor to [0, 1].
570
+ If the tensor has more than 3 dimensions, the first one
571
+ is assumed to be the batch dimension and the tensor is
572
+ normalized per batch element, not across the batches.
573
+
574
+ Args:
575
+ t (torch.Tensor): Tensor to normalize
576
+ mini (float, optional): Min allowed value. Defaults to 0.
577
+ maxi (float, optional): Max allowed value. Defaults to 1.
578
+
579
+ Returns:
580
+ torch.Tensor: The normalized tensor
581
+ """
582
  if len(t.shape) == 3:
583
  return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
584
 
585
  batch_size = t.shape[0]
586
+ extra_dims = [1] * (t.ndim - 1)
587
+ min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, *extra_dims)
588
  t = t - min_t
589
+ max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, *extra_dims)
590
  t = t / max_t
591
  return mini + (maxi - mini) * t
592
 
 
659
  f.write(output)
660
 
661
 
662
+ def rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
663
  delta = (res[0] / shape[0], res[1] / shape[1])
664
  d = (shape[0] // res[0], shape[1] // res[1])
665
 
 
734
  nps.append(n.astype(np.uint8))
735
 
736
  return nps[0] if len(nps) == 1 else nps
737
+
738
+
739
+ def tensor_to_uint8_numpy_image(tensor):
740
+ """
741
+ Turns a BxCxHxW tensor into a numpy image:
742
+ * normalize
743
+ * to [0, 255]
744
+ * detach
745
+ * channels last
746
+ * to uin8
747
+ * to cpu
748
+ * to numpy
749
+
750
+ Args:
751
+ tensor (torch.Tensor): Tensor to transform
752
+
753
+ Returns:
754
+ np.array: BxHxWxC np.uint8 array in [0, 255]
755
+ """
756
+ return (
757
+ normalize(tensor, 0, 255) # [0, 255]
758
+ .detach() # detach from graph if needed
759
+ .permute(0, 2, 3, 1) # BxHxWxC
760
+ .to(torch.uint8) # uint8
761
+ .cpu() # cpu
762
+ .numpy() # numpy array
763
+ )