soumickmj commited on
Commit
59075ed
·
1 Parent(s): 087ab59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -57,7 +57,9 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
57
  )
58
  aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode="average")
59
  patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size, shuffle=False, num_workers=num_worker)
60
- for _, patches_batch in enumerate(patch_loader):
 
 
61
  local_batch = patches_batch['img'][tio.DATA].float()
62
  local_batch = local_batch / local_batch.max()
63
  locations = patches_batch[tio.LOCATION]
@@ -72,6 +74,8 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
72
  output = torch.movedim(output, -3, -1).type(local_batch.type())
73
  aggregator.add_batch(output, locations)
74
 
 
 
75
  predicted = aggregator.get_output_tensor().squeeze().numpy()
76
 
77
  return predicted
 
57
  )
58
  aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode="average")
59
  patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size, shuffle=False, num_workers=num_worker)
60
+ total_batches = len(patch_loader)
61
+ progress_bar = st.progress(0)
62
+ for i, patches_batch in enumerate(patch_loader):
63
  local_batch = patches_batch['img'][tio.DATA].float()
64
  local_batch = local_batch / local_batch.max()
65
  locations = patches_batch[tio.LOCATION]
 
74
  output = torch.movedim(output, -3, -1).type(local_batch.type())
75
  aggregator.add_batch(output, locations)
76
 
77
+ progress_bar.progress((i + 1) / total_batches)
78
+
79
  predicted = aggregator.get_output_tensor().squeeze().numpy()
80
 
81
  return predicted