chris1nexus
commited on
Commit
·
ad1b20c
1
Parent(s):
d60982d
Update app
Browse files
app.py
CHANGED
@@ -5,141 +5,157 @@ from streamlit_option_menu import option_menu
|
|
5 |
import torch
|
6 |
|
7 |
|
8 |
-
if torch.cuda.is_available():
|
9 |
-
os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
|
10 |
-
os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
|
11 |
-
os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
|
12 |
-
else:
|
13 |
-
os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
|
14 |
-
os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
|
15 |
-
os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
|
16 |
-
|
17 |
-
from predict import Predictor
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
# environment variables for the inference api
|
22 |
-
os.environ['DATA_DIR'] = 'queries'
|
23 |
-
os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches')
|
24 |
-
os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides')
|
25 |
-
os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots')
|
26 |
-
os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True)
|
27 |
-
|
28 |
-
|
29 |
-
# manually put the metadata in the metadata folder
|
30 |
-
os.environ['CLASS_METADATA'] ='metadata/label_map.pkl'
|
31 |
-
|
32 |
-
# manually put the desired weights in the weights folder
|
33 |
-
os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights'
|
34 |
-
os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth')
|
35 |
-
os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth')
|
36 |
-
|
37 |
-
|
38 |
-
st.set_page_config(page_title="",layout='wide')
|
39 |
-
predictor = Predictor()
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool."
|
46 |
-
CONTACT_TEXT = """
|
47 |
-
_Built by Christian Cancedda and LabLab lads with love_ ❤️
|
48 |
-
[![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus)
|
49 |
-
[![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda)
|
50 |
-
"""
|
51 |
-
VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window"
|
52 |
-
DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease"
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
with st.sidebar:
|
57 |
-
choice = option_menu("LastMinute - Diagnosis",
|
58 |
-
["About", "Visualize WSI slide", "Cancer Detection", "Contact"],
|
59 |
-
icons=['house', 'upload', 'activity', 'person lines fill'],
|
60 |
-
menu_icon="app-indicator", default_index=0,
|
61 |
-
styles={
|
62 |
-
# "container": {"padding": "5!important", "background-color": "#fafafa", },
|
63 |
-
"container": {"border-radius": ".0rem"},
|
64 |
-
# "icon": {"color": "orange", "font-size": "25px"},
|
65 |
-
# "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px",
|
66 |
-
# "--hover-color": "#eee"},
|
67 |
-
# "nav-link-selected": {"background-color": "#02ab21"},
|
68 |
-
}
|
69 |
-
)
|
70 |
-
st.sidebar.markdown(
|
71 |
-
"""
|
72 |
-
<style>
|
73 |
-
.aligncenter {
|
74 |
-
text-align: center;
|
75 |
-
}
|
76 |
-
</style>
|
77 |
-
<p class="aligncenter">
|
78 |
-
<a href="https://twitter.com/chris_cancedda" target="_blank">
|
79 |
-
<img src="https://img.shields.io/twitter/follow/chris_cancedda?style=social"/>
|
80 |
-
</a>
|
81 |
-
</p>
|
82 |
-
""",
|
83 |
-
unsafe_allow_html=True,
|
84 |
-
)
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
if choice == "About":
|
89 |
-
st.title(choice)
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
if choice == "Visualize WSI slide":
|
94 |
-
st.title(choice)
|
95 |
-
st.markdown(VISUALIZE_TEXT)
|
96 |
-
|
97 |
-
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
|
98 |
-
if uploaded_file is not None:
|
99 |
-
ori = openslide.OpenSlide(uploaded_file.name)
|
100 |
-
width, height = ori.dimensions
|
101 |
-
|
102 |
-
REDUCTION_FACTOR = 20
|
103 |
-
w, h = int(width/512), int(height/512)
|
104 |
-
w_r, h_r = int(width/20), int(height/20)
|
105 |
-
resized_img = ori.get_thumbnail((w_r,h_r))
|
106 |
-
resized_img = resized_img.resize((w_r,h_r))
|
107 |
-
ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
|
108 |
-
#print('ratios ', ratio_w, ratio_h)
|
109 |
-
w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
|
110 |
-
st.image(resized_img, use_column_width='never')
|
111 |
-
|
112 |
-
if choice == "Cancer Detection":
|
113 |
-
state = dict()
|
114 |
-
|
115 |
-
st.title(choice)
|
116 |
-
st.markdown(DETECT_TEXT)
|
117 |
-
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
|
118 |
-
if uploaded_file is not None:
|
119 |
-
# To read file as bytes:
|
120 |
-
#print(uploaded_file)
|
121 |
-
with open(os.path.join(uploaded_file.name),"wb") as f:
|
122 |
-
f.write(uploaded_file.getbuffer())
|
123 |
-
with st.spinner(text="Computation is running"):
|
124 |
-
predicted_class, viz_dict = predictor.predict(uploaded_file.name)
|
125 |
-
st.info('Computation completed.')
|
126 |
-
st.header(f'Predicted to be: {predicted_class}')
|
127 |
-
st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected')
|
128 |
-
state['cur'] = predicted_class
|
129 |
-
mapper = {'ORI': predicted_class, predicted_class:'ORI'}
|
130 |
-
readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' }
|
131 |
-
#def fn():
|
132 |
-
# st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR')
|
133 |
-
# state['cur'] = mapper[state['cur']]
|
134 |
-
# return
|
135 |
-
|
136 |
-
#st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn )
|
137 |
-
#st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR')
|
138 |
-
st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR'
|
139 |
-
# use_column_width='never',
|
140 |
-
)
|
141 |
-
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch
|
6 |
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
def main():
|
10 |
+
|
11 |
+
from predict import Predictor
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# environment variables for the inference api
|
16 |
+
os.environ['DATA_DIR'] = 'queries'
|
17 |
+
os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches')
|
18 |
+
os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides')
|
19 |
+
os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots')
|
20 |
+
os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True)
|
21 |
+
|
22 |
+
|
23 |
+
# manually put the metadata in the metadata folder
|
24 |
+
os.environ['CLASS_METADATA'] ='metadata/label_map.pkl'
|
25 |
+
|
26 |
+
# manually put the desired weights in the weights folder
|
27 |
+
os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights'
|
28 |
+
os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth')
|
29 |
+
os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth')
|
30 |
+
|
31 |
+
|
32 |
+
st.set_page_config(page_title="",layout='wide')
|
33 |
+
predictor = Predictor()
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool."
|
40 |
+
CONTACT_TEXT = """
|
41 |
+
_Built by Christian Cancedda and LabLab lads with love_ ❤️
|
42 |
+
[![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus)
|
43 |
+
[![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda)
|
44 |
+
Star project repository:
|
45 |
+
[![GitHub stars](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus/inference-graph-transformer)
|
46 |
+
"""
|
47 |
+
VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window"
|
48 |
+
DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease"
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
with st.sidebar:
|
53 |
+
choice = option_menu("LastMinute - Diagnosis",
|
54 |
+
["About", "Visualize WSI slide", "Cancer Detection", "Contact"],
|
55 |
+
icons=['house', 'upload', 'activity', 'person lines fill'],
|
56 |
+
menu_icon="app-indicator", default_index=0,
|
57 |
+
styles={
|
58 |
+
# "container": {"padding": "5!important", "background-color": "#fafafa", },
|
59 |
+
"container": {"border-radius": ".0rem"},
|
60 |
+
# "icon": {"color": "orange", "font-size": "25px"},
|
61 |
+
# "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px",
|
62 |
+
# "--hover-color": "#eee"},
|
63 |
+
# "nav-link-selected": {"background-color": "#02ab21"},
|
64 |
+
}
|
65 |
+
)
|
66 |
+
st.sidebar.markdown(
|
67 |
+
"""
|
68 |
+
<style>
|
69 |
+
.aligncenter {
|
70 |
+
text-align: center;
|
71 |
+
}
|
72 |
+
</style>
|
73 |
+
<p style='text-align: center'>
|
74 |
+
<a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank">Project Repository</a>
|
75 |
+
</p>
|
76 |
+
<p class="aligncenter">
|
77 |
+
<a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank">
|
78 |
+
<img src="https://img.shields.io/github/stars/Chris1nexus/inference-graph-transformer?style=social"/>
|
79 |
+
</a>
|
80 |
+
</p>
|
81 |
+
|
82 |
+
<p class="aligncenter">
|
83 |
+
<a href="https://twitter.com/chris_cancedda" target="_blank">
|
84 |
+
<img src="https://img.shields.io/twitter/follow/chris_cancedda?style=social"/>
|
85 |
+
</a>
|
86 |
+
</p>
|
87 |
+
""",
|
88 |
+
unsafe_allow_html=True,
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
if choice == "About":
|
94 |
+
st.title(choice)
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
if choice == "Visualize WSI slide":
|
99 |
+
st.title(choice)
|
100 |
+
st.markdown(VISUALIZE_TEXT)
|
101 |
+
|
102 |
+
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
|
103 |
+
if uploaded_file is not None:
|
104 |
+
ori = openslide.OpenSlide(uploaded_file.name)
|
105 |
+
width, height = ori.dimensions
|
106 |
+
|
107 |
+
REDUCTION_FACTOR = 20
|
108 |
+
w, h = int(width/512), int(height/512)
|
109 |
+
w_r, h_r = int(width/20), int(height/20)
|
110 |
+
resized_img = ori.get_thumbnail((w_r,h_r))
|
111 |
+
resized_img = resized_img.resize((w_r,h_r))
|
112 |
+
ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
|
113 |
+
#print('ratios ', ratio_w, ratio_h)
|
114 |
+
w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
|
115 |
+
st.image(resized_img, use_column_width='never')
|
116 |
+
|
117 |
+
if choice == "Cancer Detection":
|
118 |
+
state = dict()
|
119 |
+
|
120 |
+
st.title(choice)
|
121 |
+
st.markdown(DETECT_TEXT)
|
122 |
+
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
|
123 |
+
if uploaded_file is not None:
|
124 |
+
# To read file as bytes:
|
125 |
+
#print(uploaded_file)
|
126 |
+
with open(os.path.join(uploaded_file.name),"wb") as f:
|
127 |
+
f.write(uploaded_file.getbuffer())
|
128 |
+
with st.spinner(text="Computation is running"):
|
129 |
+
predicted_class, viz_dict = predictor.predict(uploaded_file.name)
|
130 |
+
st.info('Computation completed.')
|
131 |
+
st.header(f'Predicted to be: {predicted_class}')
|
132 |
+
st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected')
|
133 |
+
state['cur'] = predicted_class
|
134 |
+
mapper = {'ORI': predicted_class, predicted_class:'ORI'}
|
135 |
+
readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' }
|
136 |
+
#def fn():
|
137 |
+
# st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR')
|
138 |
+
# state['cur'] = mapper[state['cur']]
|
139 |
+
# return
|
140 |
+
|
141 |
+
#st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn )
|
142 |
+
#st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR')
|
143 |
+
st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR'
|
144 |
+
# use_column_width='never',
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
if choice == "Contact":
|
149 |
+
st.title(choice)
|
150 |
+
st.markdown(CONTACT_TEXT)
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
if torch.cuda.is_available():
|
154 |
+
os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
|
155 |
+
os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
|
156 |
+
os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
|
157 |
+
else:
|
158 |
+
os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
|
159 |
+
os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
|
160 |
+
os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
|
161 |
+
main()
|