Paolo-Fraccaro commited on
Commit
9bb6f80
·
1 Parent(s): a7f8c2e

add rgb output option

Browse files
Files changed (1) hide show
  1. Prithvi_run_inference.py +76 -19
Prithvi_run_inference.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
  import yaml
10
  from einops import rearrange
11
 
12
- from Prithvi import MaskedAutoencoderViT
13
 
14
 
15
  NO_DATA = -9999
@@ -21,12 +21,14 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
21
  """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
22
  original range using *data_mean* and *data_std* and then lowest and highest percentiles are
23
  removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
 
24
  Args:
25
  orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
26
  new_img: torch.Tensor representing image with shape = (bands, H, W).
27
  channels: list of indices representing RGB channels.
28
  data_mean: list of mean values for each band.
29
  data_std: list of std values for each band.
 
30
  Returns:
31
  torch.Tensor with shape (num_channels, height, width) for original image
32
  torch.Tensor with shape (num_channels, height, width) for the other image
@@ -37,7 +39,7 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
37
  for c in channels:
38
  orig_ch = orig_img[c, ...]
39
  valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
40
- valid_mask[orig_ch == 0.0001] = False
41
 
42
  # Back to original data range
43
  orig_ch = (orig_ch * data_std[c]) + data_mean[c]
@@ -64,9 +66,11 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
64
 
65
 
66
  def read_geotiff(file_path: str):
67
- """ Read all bands from *file_path* and returns image + meta info.
 
68
  Args:
69
  file_path: path to image file.
 
70
  Returns:
71
  np.ndarray with shape (bands, height, width)
72
  meta info dict
@@ -81,6 +85,7 @@ def read_geotiff(file_path: str):
81
 
82
  def save_geotiff(image, output_path: str, meta: dict):
83
  """ Save multi-band image in Geotiff file.
 
84
  Args:
85
  image: np.ndarray with shape (bands, height, width)
86
  output_path: path where to save the image
@@ -104,10 +109,12 @@ def _convert_np_uint8(float_image: torch.Tensor):
104
 
105
  def load_example(file_paths: List[str], mean: List[float], std: List[float]):
106
  """ Build an input example by loading images in *file_paths*.
 
107
  Args:
108
  file_paths: list of file paths .
109
  mean: list containing mean values for each band in the images in *file_paths*.
110
  std: list containing std values for each band in the images in *file_paths*.
 
111
  Returns:
112
  np.array containing created example
113
  list of meta info for each image in *file_paths*
@@ -126,8 +133,8 @@ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
126
  imgs.append(img)
127
  metas.append(meta)
128
 
129
- imgs = np.stack(imgs, axis=0) # num_frames, img_size, img_size, C
130
- imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, img_size, img_size
131
  imgs = np.expand_dims(imgs, axis=0) # add batch dim
132
 
133
  return imgs, metas
@@ -135,11 +142,13 @@ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
135
 
136
  def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
137
  """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
 
138
  Args:
139
  model: MAE model to run.
140
  input_data: torch.Tensor with shape (B, C, T, H, W).
141
  mask_ratio: mask ratio to use.
142
  device: device where model should run.
 
143
  Returns:
144
  3 torch.Tensor with shape (B, C, T, H, W).
145
  """
@@ -165,6 +174,7 @@ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: floa
165
 
166
  def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
167
  """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
 
168
  Args:
169
  input_img: input torch.Tensor with shape (C, T, H, W).
170
  rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
@@ -199,7 +209,41 @@ def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir,
199
  meta=meta_data[t])
200
 
201
 
202
- def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir: str, mask_ratio: float):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  os.makedirs(output_dir, exist_ok=True)
205
 
@@ -262,7 +306,7 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
262
  norm_pix_loss=False)
263
 
264
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
265
- print(f"\n--> model has {total_params / 1e6} Million params.\n")
266
 
267
  model.to(device)
268
 
@@ -275,6 +319,12 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
275
  model.eval()
276
  channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
277
 
 
 
 
 
 
 
278
  # Build sliding window
279
  batch = torch.tensor(input_data, device='cpu')
280
  windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
@@ -302,20 +352,23 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
302
  mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
303
  h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
304
 
305
- # Mix original image with patches
306
- h, w = rec_imgs.shape[-2:]
307
- rec_imgs_full = batch.clone()
308
- rec_imgs_full[..., :h, :w] = rec_imgs
309
 
310
- mask_imgs_full = torch.ones_like(batch)
311
- mask_imgs_full[..., :h, :w] = mask_imgs
 
 
312
 
313
- # Build RGB images
314
- for d in meta_data:
315
- d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
 
 
316
 
317
- save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
318
- channels, mean, std, output_dir, meta_data)
319
 
320
  print("Done!")
321
 
@@ -334,6 +387,10 @@ if __name__ == "__main__":
334
  parser.add_argument('--mask_ratio', default=None, type=float,
335
  help='Masking ratio (percentage of removed patches). '
336
  'If None (default) use same value used for pretraining.')
 
 
 
337
  args = parser.parse_args()
338
 
339
- main(**vars(args))
 
 
9
  import yaml
10
  from einops import rearrange
11
 
12
+ from mae.models_mae import MaskedAutoencoderViT
13
 
14
 
15
  NO_DATA = -9999
 
21
  """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
22
  original range using *data_mean* and *data_std* and then lowest and highest percentiles are
23
  removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
24
+
25
  Args:
26
  orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
27
  new_img: torch.Tensor representing image with shape = (bands, H, W).
28
  channels: list of indices representing RGB channels.
29
  data_mean: list of mean values for each band.
30
  data_std: list of std values for each band.
31
+
32
  Returns:
33
  torch.Tensor with shape (num_channels, height, width) for original image
34
  torch.Tensor with shape (num_channels, height, width) for the other image
 
39
  for c in channels:
40
  orig_ch = orig_img[c, ...]
41
  valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
42
+ valid_mask[orig_ch == NO_DATA_FLOAT] = False
43
 
44
  # Back to original data range
45
  orig_ch = (orig_ch * data_std[c]) + data_mean[c]
 
66
 
67
 
68
  def read_geotiff(file_path: str):
69
+ """ Read all bands from *file_path* and return image + meta info.
70
+
71
  Args:
72
  file_path: path to image file.
73
+
74
  Returns:
75
  np.ndarray with shape (bands, height, width)
76
  meta info dict
 
85
 
86
  def save_geotiff(image, output_path: str, meta: dict):
87
  """ Save multi-band image in Geotiff file.
88
+
89
  Args:
90
  image: np.ndarray with shape (bands, height, width)
91
  output_path: path where to save the image
 
109
 
110
  def load_example(file_paths: List[str], mean: List[float], std: List[float]):
111
  """ Build an input example by loading images in *file_paths*.
112
+
113
  Args:
114
  file_paths: list of file paths .
115
  mean: list containing mean values for each band in the images in *file_paths*.
116
  std: list containing std values for each band in the images in *file_paths*.
117
+
118
  Returns:
119
  np.array containing created example
120
  list of meta info for each image in *file_paths*
 
133
  imgs.append(img)
134
  metas.append(meta)
135
 
136
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
137
+ imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
138
  imgs = np.expand_dims(imgs, axis=0) # add batch dim
139
 
140
  return imgs, metas
 
142
 
143
  def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
144
  """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
145
+
146
  Args:
147
  model: MAE model to run.
148
  input_data: torch.Tensor with shape (B, C, T, H, W).
149
  mask_ratio: mask ratio to use.
150
  device: device where model should run.
151
+
152
  Returns:
153
  3 torch.Tensor with shape (B, C, T, H, W).
154
  """
 
174
 
175
  def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
176
  """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
177
+
178
  Args:
179
  input_img: input torch.Tensor with shape (C, T, H, W).
180
  rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
 
209
  meta=meta_data[t])
210
 
211
 
212
+ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
213
+ """ Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
214
+
215
+ Args:
216
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
217
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
218
+ mean: list of mean values for each band.
219
+ std: list of std values for each band.
220
+ output_dir: directory where to save outputs.
221
+ meta_data: list of dicts with geotiff meta info.
222
+ """
223
+
224
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
225
+ std = torch.tensor(np.asarray(std)[:, None, None])
226
+
227
+ for t in range(rec_img.shape[1]):
228
+
229
+ # Back to original data range
230
+ rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
231
+
232
+ mask_img_t = mask_img[:, t, :, :].to(torch.int16)
233
+
234
+ # Saving images
235
+
236
+ save_geotiff(image=rec_img_t,
237
+ output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
238
+ meta=meta_data[t])
239
+
240
+ save_geotiff(image=mask_img_t,
241
+ output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
242
+ meta=meta_data[t])
243
+
244
+
245
+ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir: str,
246
+ mask_ratio: float, rgb_outputs: bool):
247
 
248
  os.makedirs(output_dir, exist_ok=True)
249
 
 
306
  norm_pix_loss=False)
307
 
308
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
309
+ print(f"\n--> Model has {total_params:,} parameters.\n")
310
 
311
  model.to(device)
312
 
 
319
  model.eval()
320
  channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
321
 
322
+ # Reflect pad if not divisible by img_size
323
+ original_h, original_w = input_data.shape[-2:]
324
+ pad_h = img_size - (original_h % img_size)
325
+ pad_w = img_size - (original_w % img_size)
326
+ input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
327
+
328
  # Build sliding window
329
  batch = torch.tensor(input_data, device='cpu')
330
  windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
 
352
  mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
353
  h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
354
 
355
+ # Cut padded images back to original size
356
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
357
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
358
+ batch_full = batch[..., :original_h, :original_w]
359
 
360
+ # Build output images
361
+ if rgb_outputs:
362
+ for d in meta_data:
363
+ d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
364
 
365
+ save_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
366
+ channels, mean, std, output_dir, meta_data)
367
+ else:
368
+ for d in meta_data:
369
+ d.update(compress='lzw', nodata=0)
370
 
371
+ save_imgs(rec_imgs_full[0, ...], mask_imgs_full[0, ...], mean, std, output_dir, meta_data)
 
372
 
373
  print("Done!")
374
 
 
387
  parser.add_argument('--mask_ratio', default=None, type=float,
388
  help='Masking ratio (percentage of removed patches). '
389
  'If None (default) use same value used for pretraining.')
390
+ parser.add_argument('--rgb_outputs', action='store_true',
391
+ help='If present, output files will only contain RGB channels. '
392
+ 'Otherwise, all bands will be saved.')
393
  args = parser.parse_args()
394
 
395
+ main(**vars(args))
396
+