Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -57,7 +57,9 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
|
|
57 |
)
|
58 |
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode="average")
|
59 |
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size, shuffle=False, num_workers=num_worker)
|
60 |
-
|
|
|
|
|
61 |
local_batch = patches_batch['img'][tio.DATA].float()
|
62 |
local_batch = local_batch / local_batch.max()
|
63 |
locations = patches_batch[tio.LOCATION]
|
@@ -72,6 +74,8 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
|
|
72 |
output = torch.movedim(output, -3, -1).type(local_batch.type())
|
73 |
aggregator.add_batch(output, locations)
|
74 |
|
|
|
|
|
75 |
predicted = aggregator.get_output_tensor().squeeze().numpy()
|
76 |
|
77 |
return predicted
|
|
|
57 |
)
|
58 |
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode="average")
|
59 |
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size, shuffle=False, num_workers=num_worker)
|
60 |
+
total_batches = len(patch_loader)
|
61 |
+
progress_bar = st.progress(0)
|
62 |
+
for i, patches_batch in enumerate(patch_loader):
|
63 |
local_batch = patches_batch['img'][tio.DATA].float()
|
64 |
local_batch = local_batch / local_batch.max()
|
65 |
locations = patches_batch[tio.LOCATION]
|
|
|
74 |
output = torch.movedim(output, -3, -1).type(local_batch.type())
|
75 |
aggregator.add_batch(output, locations)
|
76 |
|
77 |
+
progress_bar.progress((i + 1) / total_batches)
|
78 |
+
|
79 |
predicted = aggregator.get_output_tensor().squeeze().numpy()
|
80 |
|
81 |
return predicted
|