Paolo-Fraccaro
commited on
Commit
·
9bb6f80
1
Parent(s):
a7f8c2e
add rgb output option
Browse files- Prithvi_run_inference.py +76 -19
Prithvi_run_inference.py
CHANGED
@@ -9,7 +9,7 @@ import torch
|
|
9 |
import yaml
|
10 |
from einops import rearrange
|
11 |
|
12 |
-
from
|
13 |
|
14 |
|
15 |
NO_DATA = -9999
|
@@ -21,12 +21,14 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
|
21 |
""" Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
22 |
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
23 |
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
|
|
24 |
Args:
|
25 |
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
26 |
new_img: torch.Tensor representing image with shape = (bands, H, W).
|
27 |
channels: list of indices representing RGB channels.
|
28 |
data_mean: list of mean values for each band.
|
29 |
data_std: list of std values for each band.
|
|
|
30 |
Returns:
|
31 |
torch.Tensor with shape (num_channels, height, width) for original image
|
32 |
torch.Tensor with shape (num_channels, height, width) for the other image
|
@@ -37,7 +39,7 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
|
37 |
for c in channels:
|
38 |
orig_ch = orig_img[c, ...]
|
39 |
valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
|
40 |
-
valid_mask[orig_ch ==
|
41 |
|
42 |
# Back to original data range
|
43 |
orig_ch = (orig_ch * data_std[c]) + data_mean[c]
|
@@ -64,9 +66,11 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
|
64 |
|
65 |
|
66 |
def read_geotiff(file_path: str):
|
67 |
-
""" Read all bands from *file_path* and
|
|
|
68 |
Args:
|
69 |
file_path: path to image file.
|
|
|
70 |
Returns:
|
71 |
np.ndarray with shape (bands, height, width)
|
72 |
meta info dict
|
@@ -81,6 +85,7 @@ def read_geotiff(file_path: str):
|
|
81 |
|
82 |
def save_geotiff(image, output_path: str, meta: dict):
|
83 |
""" Save multi-band image in Geotiff file.
|
|
|
84 |
Args:
|
85 |
image: np.ndarray with shape (bands, height, width)
|
86 |
output_path: path where to save the image
|
@@ -104,10 +109,12 @@ def _convert_np_uint8(float_image: torch.Tensor):
|
|
104 |
|
105 |
def load_example(file_paths: List[str], mean: List[float], std: List[float]):
|
106 |
""" Build an input example by loading images in *file_paths*.
|
|
|
107 |
Args:
|
108 |
file_paths: list of file paths .
|
109 |
mean: list containing mean values for each band in the images in *file_paths*.
|
110 |
std: list containing std values for each band in the images in *file_paths*.
|
|
|
111 |
Returns:
|
112 |
np.array containing created example
|
113 |
list of meta info for each image in *file_paths*
|
@@ -126,8 +133,8 @@ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
|
|
126 |
imgs.append(img)
|
127 |
metas.append(meta)
|
128 |
|
129 |
-
imgs = np.stack(imgs, axis=0) # num_frames,
|
130 |
-
imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames,
|
131 |
imgs = np.expand_dims(imgs, axis=0) # add batch dim
|
132 |
|
133 |
return imgs, metas
|
@@ -135,11 +142,13 @@ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
|
|
135 |
|
136 |
def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
|
137 |
""" Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
|
|
|
138 |
Args:
|
139 |
model: MAE model to run.
|
140 |
input_data: torch.Tensor with shape (B, C, T, H, W).
|
141 |
mask_ratio: mask ratio to use.
|
142 |
device: device where model should run.
|
|
|
143 |
Returns:
|
144 |
3 torch.Tensor with shape (B, C, T, H, W).
|
145 |
"""
|
@@ -165,6 +174,7 @@ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: floa
|
|
165 |
|
166 |
def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
|
167 |
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
|
|
168 |
Args:
|
169 |
input_img: input torch.Tensor with shape (C, T, H, W).
|
170 |
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
@@ -199,7 +209,41 @@ def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir,
|
|
199 |
meta=meta_data[t])
|
200 |
|
201 |
|
202 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
os.makedirs(output_dir, exist_ok=True)
|
205 |
|
@@ -262,7 +306,7 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
262 |
norm_pix_loss=False)
|
263 |
|
264 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
265 |
-
print(f"\n-->
|
266 |
|
267 |
model.to(device)
|
268 |
|
@@ -275,6 +319,12 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
275 |
model.eval()
|
276 |
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
|
277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
# Build sliding window
|
279 |
batch = torch.tensor(input_data, device='cpu')
|
280 |
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
@@ -302,20 +352,23 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
302 |
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
|
303 |
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
|
304 |
|
305 |
-
#
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
|
310 |
-
|
311 |
-
|
|
|
|
|
312 |
|
313 |
-
|
314 |
-
|
315 |
-
|
|
|
|
|
316 |
|
317 |
-
|
318 |
-
channels, mean, std, output_dir, meta_data)
|
319 |
|
320 |
print("Done!")
|
321 |
|
@@ -334,6 +387,10 @@ if __name__ == "__main__":
|
|
334 |
parser.add_argument('--mask_ratio', default=None, type=float,
|
335 |
help='Masking ratio (percentage of removed patches). '
|
336 |
'If None (default) use same value used for pretraining.')
|
|
|
|
|
|
|
337 |
args = parser.parse_args()
|
338 |
|
339 |
-
main(**vars(args))
|
|
|
|
9 |
import yaml
|
10 |
from einops import rearrange
|
11 |
|
12 |
+
from mae.models_mae import MaskedAutoencoderViT
|
13 |
|
14 |
|
15 |
NO_DATA = -9999
|
|
|
21 |
""" Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
22 |
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
23 |
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
24 |
+
|
25 |
Args:
|
26 |
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
27 |
new_img: torch.Tensor representing image with shape = (bands, H, W).
|
28 |
channels: list of indices representing RGB channels.
|
29 |
data_mean: list of mean values for each band.
|
30 |
data_std: list of std values for each band.
|
31 |
+
|
32 |
Returns:
|
33 |
torch.Tensor with shape (num_channels, height, width) for original image
|
34 |
torch.Tensor with shape (num_channels, height, width) for the other image
|
|
|
39 |
for c in channels:
|
40 |
orig_ch = orig_img[c, ...]
|
41 |
valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
|
42 |
+
valid_mask[orig_ch == NO_DATA_FLOAT] = False
|
43 |
|
44 |
# Back to original data range
|
45 |
orig_ch = (orig_ch * data_std[c]) + data_mean[c]
|
|
|
66 |
|
67 |
|
68 |
def read_geotiff(file_path: str):
|
69 |
+
""" Read all bands from *file_path* and return image + meta info.
|
70 |
+
|
71 |
Args:
|
72 |
file_path: path to image file.
|
73 |
+
|
74 |
Returns:
|
75 |
np.ndarray with shape (bands, height, width)
|
76 |
meta info dict
|
|
|
85 |
|
86 |
def save_geotiff(image, output_path: str, meta: dict):
|
87 |
""" Save multi-band image in Geotiff file.
|
88 |
+
|
89 |
Args:
|
90 |
image: np.ndarray with shape (bands, height, width)
|
91 |
output_path: path where to save the image
|
|
|
109 |
|
110 |
def load_example(file_paths: List[str], mean: List[float], std: List[float]):
|
111 |
""" Build an input example by loading images in *file_paths*.
|
112 |
+
|
113 |
Args:
|
114 |
file_paths: list of file paths .
|
115 |
mean: list containing mean values for each band in the images in *file_paths*.
|
116 |
std: list containing std values for each band in the images in *file_paths*.
|
117 |
+
|
118 |
Returns:
|
119 |
np.array containing created example
|
120 |
list of meta info for each image in *file_paths*
|
|
|
133 |
imgs.append(img)
|
134 |
metas.append(meta)
|
135 |
|
136 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
137 |
+
imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
|
138 |
imgs = np.expand_dims(imgs, axis=0) # add batch dim
|
139 |
|
140 |
return imgs, metas
|
|
|
142 |
|
143 |
def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
|
144 |
""" Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
|
145 |
+
|
146 |
Args:
|
147 |
model: MAE model to run.
|
148 |
input_data: torch.Tensor with shape (B, C, T, H, W).
|
149 |
mask_ratio: mask ratio to use.
|
150 |
device: device where model should run.
|
151 |
+
|
152 |
Returns:
|
153 |
3 torch.Tensor with shape (B, C, T, H, W).
|
154 |
"""
|
|
|
174 |
|
175 |
def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
|
176 |
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
177 |
+
|
178 |
Args:
|
179 |
input_img: input torch.Tensor with shape (C, T, H, W).
|
180 |
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
|
|
209 |
meta=meta_data[t])
|
210 |
|
211 |
|
212 |
+
def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
213 |
+
""" Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
217 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
218 |
+
mean: list of mean values for each band.
|
219 |
+
std: list of std values for each band.
|
220 |
+
output_dir: directory where to save outputs.
|
221 |
+
meta_data: list of dicts with geotiff meta info.
|
222 |
+
"""
|
223 |
+
|
224 |
+
mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
|
225 |
+
std = torch.tensor(np.asarray(std)[:, None, None])
|
226 |
+
|
227 |
+
for t in range(rec_img.shape[1]):
|
228 |
+
|
229 |
+
# Back to original data range
|
230 |
+
rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
|
231 |
+
|
232 |
+
mask_img_t = mask_img[:, t, :, :].to(torch.int16)
|
233 |
+
|
234 |
+
# Saving images
|
235 |
+
|
236 |
+
save_geotiff(image=rec_img_t,
|
237 |
+
output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
|
238 |
+
meta=meta_data[t])
|
239 |
+
|
240 |
+
save_geotiff(image=mask_img_t,
|
241 |
+
output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
|
242 |
+
meta=meta_data[t])
|
243 |
+
|
244 |
+
|
245 |
+
def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir: str,
|
246 |
+
mask_ratio: float, rgb_outputs: bool):
|
247 |
|
248 |
os.makedirs(output_dir, exist_ok=True)
|
249 |
|
|
|
306 |
norm_pix_loss=False)
|
307 |
|
308 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
309 |
+
print(f"\n--> Model has {total_params:,} parameters.\n")
|
310 |
|
311 |
model.to(device)
|
312 |
|
|
|
319 |
model.eval()
|
320 |
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
|
321 |
|
322 |
+
# Reflect pad if not divisible by img_size
|
323 |
+
original_h, original_w = input_data.shape[-2:]
|
324 |
+
pad_h = img_size - (original_h % img_size)
|
325 |
+
pad_w = img_size - (original_w % img_size)
|
326 |
+
input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
|
327 |
+
|
328 |
# Build sliding window
|
329 |
batch = torch.tensor(input_data, device='cpu')
|
330 |
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
|
|
352 |
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
|
353 |
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
|
354 |
|
355 |
+
# Cut padded images back to original size
|
356 |
+
rec_imgs_full = rec_imgs[..., :original_h, :original_w]
|
357 |
+
mask_imgs_full = mask_imgs[..., :original_h, :original_w]
|
358 |
+
batch_full = batch[..., :original_h, :original_w]
|
359 |
|
360 |
+
# Build output images
|
361 |
+
if rgb_outputs:
|
362 |
+
for d in meta_data:
|
363 |
+
d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
|
364 |
|
365 |
+
save_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
|
366 |
+
channels, mean, std, output_dir, meta_data)
|
367 |
+
else:
|
368 |
+
for d in meta_data:
|
369 |
+
d.update(compress='lzw', nodata=0)
|
370 |
|
371 |
+
save_imgs(rec_imgs_full[0, ...], mask_imgs_full[0, ...], mean, std, output_dir, meta_data)
|
|
|
372 |
|
373 |
print("Done!")
|
374 |
|
|
|
387 |
parser.add_argument('--mask_ratio', default=None, type=float,
|
388 |
help='Masking ratio (percentage of removed patches). '
|
389 |
'If None (default) use same value used for pretraining.')
|
390 |
+
parser.add_argument('--rgb_outputs', action='store_true',
|
391 |
+
help='If present, output files will only contain RGB channels. '
|
392 |
+
'Otherwise, all bands will be saved.')
|
393 |
args = parser.parse_args()
|
394 |
|
395 |
+
main(**vars(args))
|
396 |
+
|