soumickmj commited on
Commit
526c3e1
·
1 Parent(s): e03497c

patching added for trying

Browse files
Files changed (2) hide show
  1. app.py +45 -8
  2. requirements.txt +2 -4
app.py CHANGED
@@ -1,18 +1,15 @@
1
  import streamlit as st
2
- import json
3
  import math
4
  import numpy as np
5
  import nibabel as nib
6
  import torch
7
  import torch.nn.functional as F
8
- import scipy.io
9
- from io import BytesIO
10
  from transformers import AutoModel
11
  import os
12
  import tempfile
13
  from pathlib import Path
14
- import pandas as pd
15
  from skimage.filters import threshold_otsu
 
16
 
17
  def infer_full_vol(tensor, model):
18
  tensor = torch.movedim(tensor, -1, -3)
@@ -46,6 +43,37 @@ def infer_full_vol(tensor, model):
46
  output = torch.movedim(output, -3, -1).type(tensor.type())
47
  return output.squeeze().detach().cpu().numpy()
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Set page configuration
50
  st.set_page_config(
51
  page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)",
@@ -62,7 +90,7 @@ with st.sidebar:
62
 
63
  **Instructions**:
64
  - Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time.
65
- - Select a seed value from the dropdown menu.
66
  - Click the "Process" button to generate the latent factors.
67
  """)
68
  st.markdown("---")
@@ -77,10 +105,14 @@ uploaded_file = st.file_uploader(
77
  type=["nii", "nii.gz"]
78
  )
79
 
80
- # Seed selection
81
  model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"]
82
  selected_model = st.selectbox("Select a pretrained model:", model_options)
83
 
 
 
 
 
84
  # Process button
85
  process_button = st.button("Process")
86
 
@@ -111,7 +143,7 @@ if uploaded_file is not None and process_button:
111
  # Add batch and channel dimensions
112
  tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W]
113
 
114
- # Construct the model name based on the selected seed
115
  model_name = f"soumickmj/{selected_model}"
116
 
117
  # Load the pre-trained model from Hugging Face
@@ -145,7 +177,12 @@ if uploaded_file is not None and process_button:
145
 
146
  # Process the tensor through the model
147
  with st.spinner('Processing the tensor through the model...'):
148
- output = infer_full_vol(tensor, model)
 
 
 
 
 
149
 
150
  st.success("Processing complete.")
151
  st.write(f"Output tensor shape: `{output.shape}`")
 
1
  import streamlit as st
 
2
  import math
3
  import numpy as np
4
  import nibabel as nib
5
  import torch
6
  import torch.nn.functional as F
 
 
7
  from transformers import AutoModel
8
  import os
9
  import tempfile
10
  from pathlib import Path
 
11
  from skimage.filters import threshold_otsu
12
+ import torchio as tio
13
 
14
  def infer_full_vol(tensor, model):
15
  tensor = torch.movedim(tensor, -1, -3)
 
43
  output = torch.movedim(output, -3, -1).type(tensor.type())
44
  return output.squeeze().detach().cpu().numpy()
45
 
46
+ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_width=32, stride_depth=16, batch_size=10, num_worker=2):
47
+ test_subject = tio.Subject(img = tio.ScalarImage(tensor=tensor))
48
+ overlap = np.subtract(patch_size, (stride_length, stride_width, stride_depth))
49
+
50
+ with torch.no_grad():
51
+ grid_sampler = tio.inference.GridSampler(
52
+ test_subject,
53
+ patch_size,
54
+ overlap,
55
+ )
56
+ aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode="average")
57
+ patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size, shuffle=False, num_workers=num_worker)
58
+ for _, patches_batch in enumerate(patch_loader):
59
+ local_batch = patches_batch['img'][tio.DATA].float()
60
+ local_batch = local_batch / local_batch.max()
61
+ locations = patches_batch[tio.LOCATION]
62
+
63
+ local_batch = torch.movedim(local_batch, -1, -3)
64
+
65
+ output = model(local_batch)
66
+ if type(output) is tuple or type(output) is list:
67
+ output = output[0]
68
+ output = torch.sigmoid(output).detach().cpu()
69
+
70
+ output = torch.movedim(output, -3, -1).type(local_batch.type())
71
+ aggregator.add_batch(output, locations)
72
+
73
+ predicted = aggregator.get_output_tensor().squeeze().numpy()
74
+
75
+ return predicted
76
+
77
  # Set page configuration
78
  st.set_page_config(
79
  page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)",
 
90
 
91
  **Instructions**:
92
  - Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time.
93
+ - Select a model from the dropdown menu.
94
  - Click the "Process" button to generate the latent factors.
95
  """)
96
  st.markdown("---")
 
105
  type=["nii", "nii.gz"]
106
  )
107
 
108
+ # Model selection
109
  model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"]
110
  selected_model = st.selectbox("Select a pretrained model:", model_options)
111
 
112
+ # Mode selection
113
+ mode_options = ["Full volume inference", "Patch-based inference [Default for DS6]"]
114
+ selected_mode = st.selectbox("Select the running mode:", mode_options)
115
+
116
  # Process button
117
  process_button = st.button("Process")
118
 
 
143
  # Add batch and channel dimensions
144
  tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W]
145
 
146
+ # Construct the model name based on the selected model
147
  model_name = f"soumickmj/{selected_model}"
148
 
149
  # Load the pre-trained model from Hugging Face
 
177
 
178
  # Process the tensor through the model
179
  with st.spinner('Processing the tensor through the model...'):
180
+ if selected_mode == "full volume inference":
181
+ st.info("Running full volume inference...")
182
+ output = infer_full_vol(tensor, model)
183
+ else:
184
+ st.info("Running patch-based inference [Default for DS6]...")
185
+ output = infer_patch_based(tensor, model)
186
 
187
  st.success("Processing complete.")
188
  st.write(f"Output tensor shape: `{output.shape}`")
requirements.txt CHANGED
@@ -1,7 +1,5 @@
 
1
  nibabel
2
  torch
3
- pytorch_lightning
4
- scipy
5
  transformers
6
- torchvision
7
- scikit-image
 
1
+ scikit-image
2
  nibabel
3
  torch
 
 
4
  transformers
5
+ torchio