Spaces:
Runtime error
Runtime error
new `tensor_to_uint8_numpy_image` tensor util
Browse files- climategan/trainer.py +10 -16
- climategan/tutils.py +46 -4
climategan/trainer.py
CHANGED
@@ -39,10 +39,10 @@ from climategan.tutils import (
|
|
39 |
get_WGAN_gradient,
|
40 |
lrgb2srgb,
|
41 |
normalize,
|
42 |
-
normalize_tensor,
|
43 |
print_num_parameters,
|
44 |
shuffle_batch_tuple,
|
45 |
srgb2lrgb,
|
|
|
46 |
vgg_preprocess,
|
47 |
zero_grad,
|
48 |
)
|
@@ -231,12 +231,15 @@ class Trainer:
|
|
231 |
return_intermediates=False,
|
232 |
):
|
233 |
"""
|
234 |
-
Create a
|
235 |
single or batch image data.
|
236 |
|
237 |
-
stores is a
|
238 |
|
239 |
bin_value is used to binarize (or not) flood masks
|
|
|
|
|
|
|
240 |
"""
|
241 |
assert self.is_setup
|
242 |
assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
|
@@ -316,21 +319,14 @@ class Trainer:
|
|
316 |
with Timer(store=stores.get("numpy", [])):
|
317 |
if "flood" not in ignore_event:
|
318 |
# normalize to 0-1
|
319 |
-
flood =
|
320 |
-
# convert to numpy
|
321 |
-
flood = flood.permute(0, 2, 3, 1).numpy()
|
322 |
# convert to 0-255 uint8
|
323 |
-
flood = (flood * 255).astype(np.uint8)
|
324 |
output_data["flood"] = flood
|
325 |
if "wildfire" not in ignore_event:
|
326 |
-
wildfire =
|
327 |
-
wildfire = wildfire.permute(0, 2, 3, 1).numpy()
|
328 |
-
wildfire = (wildfire * 255).astype(np.uint8)
|
329 |
output_data["wildfire"] = wildfire
|
330 |
if "smog" not in ignore_event:
|
331 |
-
smog =
|
332 |
-
smog = smog.permute(0, 2, 3, 1).numpy()
|
333 |
-
smog = (smog * 255).astype(np.uint8)
|
334 |
output_data["smog"] = smog
|
335 |
|
336 |
if return_intermediates:
|
@@ -338,9 +334,7 @@ class Trainer:
|
|
338 |
output_data["mask"] = (
|
339 |
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
340 |
)
|
341 |
-
output_data["depth"] = (
|
342 |
-
normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
|
343 |
-
)
|
344 |
output_data["segmentation"] = (
|
345 |
decode_segmap_merged_labels(segmentation, "r", False)
|
346 |
.cpu()
|
|
|
39 |
get_WGAN_gradient,
|
40 |
lrgb2srgb,
|
41 |
normalize,
|
|
|
42 |
print_num_parameters,
|
43 |
shuffle_batch_tuple,
|
44 |
srgb2lrgb,
|
45 |
+
tensor_to_uint8_numpy_image,
|
46 |
vgg_preprocess,
|
47 |
zero_grad,
|
48 |
)
|
|
|
231 |
return_intermediates=False,
|
232 |
):
|
233 |
"""
|
234 |
+
Create a dictionary of events from a numpy or tensor,
|
235 |
single or batch image data.
|
236 |
|
237 |
+
stores is a dictionary of times for the Timer class.
|
238 |
|
239 |
bin_value is used to binarize (or not) flood masks
|
240 |
+
|
241 |
+
all values in the output dictionary have 4 dimensions:
|
242 |
+
BxHxWxC if numpy else BxCxHxW
|
243 |
"""
|
244 |
assert self.is_setup
|
245 |
assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
|
|
|
319 |
with Timer(store=stores.get("numpy", [])):
|
320 |
if "flood" not in ignore_event:
|
321 |
# normalize to 0-1
|
322 |
+
flood = tensor_to_uint8_numpy_image(flood)
|
|
|
|
|
323 |
# convert to 0-255 uint8
|
|
|
324 |
output_data["flood"] = flood
|
325 |
if "wildfire" not in ignore_event:
|
326 |
+
wildfire = tensor_to_uint8_numpy_image(wildfire)
|
|
|
|
|
327 |
output_data["wildfire"] = wildfire
|
328 |
if "smog" not in ignore_event:
|
329 |
+
smog = tensor_to_uint8_numpy_image(smog)
|
|
|
|
|
330 |
output_data["smog"] = smog
|
331 |
|
332 |
if return_intermediates:
|
|
|
334 |
output_data["mask"] = (
|
335 |
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
336 |
)
|
337 |
+
output_data["depth"] = tensor_to_uint8_numpy_image(depth)
|
|
|
|
|
338 |
output_data["segmentation"] = (
|
339 |
decode_segmap_merged_labels(segmentation, "r", False)
|
340 |
.cpu()
|
climategan/tutils.py
CHANGED
@@ -564,14 +564,29 @@ def lrgb2srgb(ims):
|
|
564 |
return outs[0]
|
565 |
|
566 |
|
567 |
-
def normalize(t, mini=0, maxi=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
if len(t.shape) == 3:
|
569 |
return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
|
570 |
|
571 |
batch_size = t.shape[0]
|
572 |
-
|
|
|
573 |
t = t - min_t
|
574 |
-
max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size,
|
575 |
t = t / max_t
|
576 |
return mini + (maxi - mini) * t
|
577 |
|
@@ -644,7 +659,7 @@ def write_architecture(trainer):
|
|
644 |
f.write(output)
|
645 |
|
646 |
|
647 |
-
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t
|
648 |
delta = (res[0] / shape[0], res[1] / shape[1])
|
649 |
d = (shape[0] // res[0], shape[1] // res[1])
|
650 |
|
@@ -719,3 +734,30 @@ def tensor_ims_to_np_uint8s(ims):
|
|
719 |
nps.append(n.astype(np.uint8))
|
720 |
|
721 |
return nps[0] if len(nps) == 1 else nps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
return outs[0]
|
565 |
|
566 |
|
567 |
+
def normalize(t, mini=0.0, maxi=1.0):
|
568 |
+
"""
|
569 |
+
Normalizes a tensor to [0, 1].
|
570 |
+
If the tensor has more than 3 dimensions, the first one
|
571 |
+
is assumed to be the batch dimension and the tensor is
|
572 |
+
normalized per batch element, not across the batches.
|
573 |
+
|
574 |
+
Args:
|
575 |
+
t (torch.Tensor): Tensor to normalize
|
576 |
+
mini (float, optional): Min allowed value. Defaults to 0.
|
577 |
+
maxi (float, optional): Max allowed value. Defaults to 1.
|
578 |
+
|
579 |
+
Returns:
|
580 |
+
torch.Tensor: The normalized tensor
|
581 |
+
"""
|
582 |
if len(t.shape) == 3:
|
583 |
return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
|
584 |
|
585 |
batch_size = t.shape[0]
|
586 |
+
extra_dims = [1] * (t.ndim - 1)
|
587 |
+
min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, *extra_dims)
|
588 |
t = t - min_t
|
589 |
+
max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, *extra_dims)
|
590 |
t = t / max_t
|
591 |
return mini + (maxi - mini) * t
|
592 |
|
|
|
659 |
f.write(output)
|
660 |
|
661 |
|
662 |
+
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
663 |
delta = (res[0] / shape[0], res[1] / shape[1])
|
664 |
d = (shape[0] // res[0], shape[1] // res[1])
|
665 |
|
|
|
734 |
nps.append(n.astype(np.uint8))
|
735 |
|
736 |
return nps[0] if len(nps) == 1 else nps
|
737 |
+
|
738 |
+
|
739 |
+
def tensor_to_uint8_numpy_image(tensor):
|
740 |
+
"""
|
741 |
+
Turns a BxCxHxW tensor into a numpy image:
|
742 |
+
* normalize
|
743 |
+
* to [0, 255]
|
744 |
+
* detach
|
745 |
+
* channels last
|
746 |
+
* to uin8
|
747 |
+
* to cpu
|
748 |
+
* to numpy
|
749 |
+
|
750 |
+
Args:
|
751 |
+
tensor (torch.Tensor): Tensor to transform
|
752 |
+
|
753 |
+
Returns:
|
754 |
+
np.array: BxHxWxC np.uint8 array in [0, 255]
|
755 |
+
"""
|
756 |
+
return (
|
757 |
+
normalize(tensor, 0, 255) # [0, 255]
|
758 |
+
.detach() # detach from graph if needed
|
759 |
+
.permute(0, 2, 3, 1) # BxHxWxC
|
760 |
+
.to(torch.uint8) # uint8
|
761 |
+
.cpu() # cpu
|
762 |
+
.numpy() # numpy array
|
763 |
+
)
|