chris1nexus commited on
Commit
ad1b20c
·
1 Parent(s): d60982d

Update app

Browse files
Files changed (1) hide show
  1. app.py +153 -137
app.py CHANGED
@@ -5,141 +5,157 @@ from streamlit_option_menu import option_menu
5
  import torch
6
 
7
 
8
- if torch.cuda.is_available():
9
- os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
10
- os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
11
- os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
12
- else:
13
- os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
14
- os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
15
- os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
16
-
17
- from predict import Predictor
18
-
19
-
20
-
21
- # environment variables for the inference api
22
- os.environ['DATA_DIR'] = 'queries'
23
- os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches')
24
- os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides')
25
- os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots')
26
- os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True)
27
-
28
-
29
- # manually put the metadata in the metadata folder
30
- os.environ['CLASS_METADATA'] ='metadata/label_map.pkl'
31
-
32
- # manually put the desired weights in the weights folder
33
- os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights'
34
- os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth')
35
- os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth')
36
-
37
-
38
- st.set_page_config(page_title="",layout='wide')
39
- predictor = Predictor()
40
-
41
-
42
-
43
-
44
-
45
- ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool."
46
- CONTACT_TEXT = """
47
- _Built by Christian Cancedda and LabLab lads with love_ ❤️
48
- [![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus)
49
- [![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda)
50
- """
51
- VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window"
52
- DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease"
53
-
54
-
55
-
56
- with st.sidebar:
57
- choice = option_menu("LastMinute - Diagnosis",
58
- ["About", "Visualize WSI slide", "Cancer Detection", "Contact"],
59
- icons=['house', 'upload', 'activity', 'person lines fill'],
60
- menu_icon="app-indicator", default_index=0,
61
- styles={
62
- # "container": {"padding": "5!important", "background-color": "#fafafa", },
63
- "container": {"border-radius": ".0rem"},
64
- # "icon": {"color": "orange", "font-size": "25px"},
65
- # "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px",
66
- # "--hover-color": "#eee"},
67
- # "nav-link-selected": {"background-color": "#02ab21"},
68
- }
69
- )
70
- st.sidebar.markdown(
71
- """
72
- <style>
73
- .aligncenter {
74
- text-align: center;
75
- }
76
- </style>
77
- <p class="aligncenter">
78
- <a href="https://twitter.com/chris_cancedda" target="_blank">
79
- <img src="https://img.shields.io/twitter/follow/chris_cancedda?style=social"/>
80
- </a>
81
- </p>
82
- """,
83
- unsafe_allow_html=True,
84
- )
85
-
86
-
87
-
88
- if choice == "About":
89
- st.title(choice)
90
-
91
-
92
-
93
- if choice == "Visualize WSI slide":
94
- st.title(choice)
95
- st.markdown(VISUALIZE_TEXT)
96
-
97
- uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
98
- if uploaded_file is not None:
99
- ori = openslide.OpenSlide(uploaded_file.name)
100
- width, height = ori.dimensions
101
-
102
- REDUCTION_FACTOR = 20
103
- w, h = int(width/512), int(height/512)
104
- w_r, h_r = int(width/20), int(height/20)
105
- resized_img = ori.get_thumbnail((w_r,h_r))
106
- resized_img = resized_img.resize((w_r,h_r))
107
- ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
108
- #print('ratios ', ratio_w, ratio_h)
109
- w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
110
- st.image(resized_img, use_column_width='never')
111
-
112
- if choice == "Cancer Detection":
113
- state = dict()
114
-
115
- st.title(choice)
116
- st.markdown(DETECT_TEXT)
117
- uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
118
- if uploaded_file is not None:
119
- # To read file as bytes:
120
- #print(uploaded_file)
121
- with open(os.path.join(uploaded_file.name),"wb") as f:
122
- f.write(uploaded_file.getbuffer())
123
- with st.spinner(text="Computation is running"):
124
- predicted_class, viz_dict = predictor.predict(uploaded_file.name)
125
- st.info('Computation completed.')
126
- st.header(f'Predicted to be: {predicted_class}')
127
- st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected')
128
- state['cur'] = predicted_class
129
- mapper = {'ORI': predicted_class, predicted_class:'ORI'}
130
- readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' }
131
- #def fn():
132
- # st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR')
133
- # state['cur'] = mapper[state['cur']]
134
- # return
135
-
136
- #st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn )
137
- #st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR')
138
- st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR'
139
- # use_column_width='never',
140
- )
141
-
142
 
143
- if choice == "Contact":
144
- st.title(choice)
145
- st.markdown(CONTACT_TEXT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def main():
10
+
11
+ from predict import Predictor
12
+
13
+
14
+
15
+ # environment variables for the inference api
16
+ os.environ['DATA_DIR'] = 'queries'
17
+ os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches')
18
+ os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides')
19
+ os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots')
20
+ os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True)
21
+
22
+
23
+ # manually put the metadata in the metadata folder
24
+ os.environ['CLASS_METADATA'] ='metadata/label_map.pkl'
25
+
26
+ # manually put the desired weights in the weights folder
27
+ os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights'
28
+ os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth')
29
+ os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth')
30
+
31
+
32
+ st.set_page_config(page_title="",layout='wide')
33
+ predictor = Predictor()
34
+
35
+
36
+
37
+
38
+
39
+ ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool."
40
+ CONTACT_TEXT = """
41
+ _Built by Christian Cancedda and LabLab lads with love_ ❤️
42
+ [![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus)
43
+ [![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda)
44
+ Star project repository:
45
+ [![GitHub stars](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus/inference-graph-transformer)
46
+ """
47
+ VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window"
48
+ DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease"
49
+
50
+
51
+
52
+ with st.sidebar:
53
+ choice = option_menu("LastMinute - Diagnosis",
54
+ ["About", "Visualize WSI slide", "Cancer Detection", "Contact"],
55
+ icons=['house', 'upload', 'activity', 'person lines fill'],
56
+ menu_icon="app-indicator", default_index=0,
57
+ styles={
58
+ # "container": {"padding": "5!important", "background-color": "#fafafa", },
59
+ "container": {"border-radius": ".0rem"},
60
+ # "icon": {"color": "orange", "font-size": "25px"},
61
+ # "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px",
62
+ # "--hover-color": "#eee"},
63
+ # "nav-link-selected": {"background-color": "#02ab21"},
64
+ }
65
+ )
66
+ st.sidebar.markdown(
67
+ """
68
+ <style>
69
+ .aligncenter {
70
+ text-align: center;
71
+ }
72
+ </style>
73
+ <p style='text-align: center'>
74
+ <a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank">Project Repository</a>
75
+ </p>
76
+ <p class="aligncenter">
77
+ <a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank">
78
+ <img src="https://img.shields.io/github/stars/Chris1nexus/inference-graph-transformer?style=social"/>
79
+ </a>
80
+ </p>
81
+
82
+ <p class="aligncenter">
83
+ <a href="https://twitter.com/chris_cancedda" target="_blank">
84
+ <img src="https://img.shields.io/twitter/follow/chris_cancedda?style=social"/>
85
+ </a>
86
+ </p>
87
+ """,
88
+ unsafe_allow_html=True,
89
+ )
90
+
91
+
92
+
93
+ if choice == "About":
94
+ st.title(choice)
95
+
96
+
97
+
98
+ if choice == "Visualize WSI slide":
99
+ st.title(choice)
100
+ st.markdown(VISUALIZE_TEXT)
101
+
102
+ uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
103
+ if uploaded_file is not None:
104
+ ori = openslide.OpenSlide(uploaded_file.name)
105
+ width, height = ori.dimensions
106
+
107
+ REDUCTION_FACTOR = 20
108
+ w, h = int(width/512), int(height/512)
109
+ w_r, h_r = int(width/20), int(height/20)
110
+ resized_img = ori.get_thumbnail((w_r,h_r))
111
+ resized_img = resized_img.resize((w_r,h_r))
112
+ ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
113
+ #print('ratios ', ratio_w, ratio_h)
114
+ w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
115
+ st.image(resized_img, use_column_width='never')
116
+
117
+ if choice == "Cancer Detection":
118
+ state = dict()
119
+
120
+ st.title(choice)
121
+ st.markdown(DETECT_TEXT)
122
+ uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
123
+ if uploaded_file is not None:
124
+ # To read file as bytes:
125
+ #print(uploaded_file)
126
+ with open(os.path.join(uploaded_file.name),"wb") as f:
127
+ f.write(uploaded_file.getbuffer())
128
+ with st.spinner(text="Computation is running"):
129
+ predicted_class, viz_dict = predictor.predict(uploaded_file.name)
130
+ st.info('Computation completed.')
131
+ st.header(f'Predicted to be: {predicted_class}')
132
+ st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected')
133
+ state['cur'] = predicted_class
134
+ mapper = {'ORI': predicted_class, predicted_class:'ORI'}
135
+ readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' }
136
+ #def fn():
137
+ # st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR')
138
+ # state['cur'] = mapper[state['cur']]
139
+ # return
140
+
141
+ #st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn )
142
+ #st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR')
143
+ st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR'
144
+ # use_column_width='never',
145
+ )
146
+
147
+
148
+ if choice == "Contact":
149
+ st.title(choice)
150
+ st.markdown(CONTACT_TEXT)
151
+
152
+ if __name__ == '__main__':
153
+ if torch.cuda.is_available():
154
+ os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
155
+ os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
156
+ os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
157
+ else:
158
+ os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
159
+ os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
160
+ os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
161
+ main()