Omarrran commited on
Commit
4ede577
·
verified ·
1 Parent(s): e05dd7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -211
app.py CHANGED
@@ -1,248 +1,204 @@
1
-
2
- import streamlit as st
3
  import tensorflow as tf
4
  from tensorflow.keras.models import load_model
5
  from tensorflow.keras.preprocessing import image
6
  import numpy as np
7
- import plotly.graph_objects as go
8
  import cv2
9
- from tensorflow.keras.models import Sequential
10
- from tensorflow.keras.layers import Dense, Dropout, Flatten
11
- from tensorflow.keras.optimizers import Adamax
12
- from tensorflow.keras.metrics import Precision, Recall
13
- import google.generativeai as genai
14
- # from google.colab import userdata
15
  import PIL.Image
16
  import os
17
  from dotenv import load_dotenv
18
- load_dotenv()
19
-
20
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
21
-
22
- output_dir = 'saliency_maps'
23
- os.makedirs(output_dir, exist_ok=True)
24
-
25
- def generate_explanation(img_path, model_prediction, confidence):
26
 
27
- prompt = f"""You are an expert neurologist. You are tasked with explaining a saliency map of a brain tumor MRI scan.
28
- The saliency map was generated by a deep learning model that was trained to classify brain tumors as either
29
- glioma, meningioma, pituitary, or no tumor.
30
 
31
- The saliency map highlights the regions of the image that the machine learning model is focusing on to make the predictions.
 
32
 
33
- The deep learning model predicted the image to be of class '{model_prediction}' with a confidence of {confidence * 100}%.
 
34
 
35
- In your response:
36
- - Explain what regions of the brain the model is focusing on, based on the saliency map. Refer to the regions highlighted in light cyan, those are the regions where the model is focusing on.
37
- - Explain possible reasons why the model made the prediction it did.
38
- - Don't mention anything like 'The saliency map highlights the regions the model is focusing on, which are in light cyan' in your explanation.
39
- - Keep your explanation to 5 sentences max.
40
 
41
- Your response will go directly on the report to the doctor and patient, so don't add any extra phrases like 'Sure!' or ask any questions at the end
42
- Let's think step by step about this.
43
- """
 
 
 
44
 
45
- img = PIL.Image.open(img_path)
 
 
 
46
 
47
- model = genai.GenerativeModel(model_name="gemini-1.5-flash")
48
- response = model.generate_content([prompt, img])
49
 
50
- return response.text
51
 
52
- def generate_saliency_map(model, img_array, class_index, img_size):
53
- with tf.GradientTape() as tape:
54
- img_tensor = tf.convert_to_tensor(img_array)
55
- tape.watch(img_tensor)
56
- predictions = model(img_tensor)
57
- target_class = predictions[:, class_index]
58
 
59
- gradients = tape.gradient(target_class, img_tensor)
60
- gradients = tf.math.abs(gradients)
61
- gradients = tf.reduce_max(gradients, axis=-1)
62
- gradients = gradients.numpy().squeeze()
 
63
 
64
- # Resize gradients to match original image size
65
- gradients = cv2.resize(gradients, img_size)
 
66
 
67
- # Create a circular mask for the brain area
68
- center = (gradients.shape[0] // 2, gradients.shape[1] // 2)
69
- radius = min(center[0], center[1]) - 10
70
- y, x = np.ogrid[:gradients.shape[0], :gradients.shape[1]]
71
- mask = (x - center[0])**2 + (y - center[1])**2 <= radius**2
72
 
73
- # Apply mask to gradients
74
- gradients = gradients * mask
75
 
76
- # Normalize only the brain area
77
- brain_gradients = gradients[mask]
78
- if brain_gradients.max() > brain_gradients.min():
79
- brain_gradients = (brain_gradients - brain_gradients.min()) / (brain_gradients.max() - brain_gradients.min())
80
- gradients[mask] = brain_gradients
 
81
 
82
- # Apply a higher threshold
83
- threshold = np.percentile(gradients[mask], 80)
84
- gradients[gradients < threshold] = 0
 
85
 
86
- # Apply more aggressive smoothing
87
- gradients = cv2.GaussianBlur(gradients, (11, 11), 0)
88
 
89
- # Create a heatmap overlay with enhanced contrast
90
- heatmap = cv2.applyColorMap(np.uint8(255 * gradients), cv2.COLORMAP_JET)
91
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
 
 
92
 
93
- # Resize heatmap to match original image size
94
- heatmap = cv2.resize(heatmap, img_size)
95
 
96
- # Superimpose the heatmap on original image with increased opacity
97
- original_img = image.img_to_array(img)
98
- superimposed_img = heatmap * 0.7 + original_img * 0.3
99
- superimposed_img = superimposed_img.astype(np.uint8)
 
100
 
101
- img_path = os.path.join(output_dir, uploaded_file.name)
102
- with open(img_path, "wb") as f:
103
- f.write(uploaded_file.getbuffer())
104
 
105
- saliency_map_path = f'saliency_maps/{uploaded_file.name}'
 
106
 
107
- # Save saliency map
108
- cv2.imwrite(saliency_map_path, cv2.cvtColor(superimposed_img, cv2.COLOR_RGB2BGR))
 
109
 
110
- return superimposed_img
111
 
 
 
 
 
112
 
 
113
 
114
  def load_xception_model(model_path):
115
- img_shape=(299, 299, 3)
116
- base_model = tf.keras.applications.Xception(include_top=False, weights="imagenet", input_shape=img_shape, pooling='max')
117
-
118
- model = Sequential([
119
- base_model,
120
- Flatten(),
121
- Dropout(rate=0.3),
122
- Dense(128, activation='relu'),
123
- Dropout(rate=0.25),
124
- Dense(4, activation='softmax')
125
- ])
126
-
127
- model.build((None,) + img_shape)
128
-
129
- # Compile the model
130
- model.compile(Adamax(learning_rate=0.001),
131
- loss='categorical_crossentropy',
132
- metrics=['accuracy', Precision(), Recall()])
133
- model.load_weights(model_path)
134
- return model
135
-
136
- st.title("Brain Tumor Classification")
137
-
138
- st.write("Upload an image of a brain MRI scan to classify.")
139
-
140
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
141
-
142
- if uploaded_file is not None:
143
- selected_model = st.radio(
144
- "Select Model",
145
- ("Transfer Learning - Xception", "Custom CNN")
146
- )
147
-
148
- if selected_model == "Transfer Learning - Xception":
149
- model = load_xception_model('xception_model.weights.h5')
150
- img_size=(299, 299)
151
- else:
152
- model = load_model('cnn_model.h5')
153
- img_size = (224, 224)
154
-
155
-
156
- labels = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary']
157
- img = image.load_img(uploaded_file, target_size=img_size)
158
- img_array = image.img_to_array(img)
159
- img_array = np.expand_dims(img_array, axis=0)
160
- img_array /= 255.0
161
-
162
- prediction = model.predict(img_array)
163
-
164
- # Get the class with the highest probability
165
- class_index = np.argmax(prediction[0])
166
- result = labels[class_index]
167
-
168
- saliency_map = generate_saliency_map(model, img_array, class_index, img_size)
169
-
170
-
171
- col1, col2 = st.columns(2)
172
- with col1:
173
- st.image(uploaded_file, caption='Uploaded Image', use_column_width=True)
174
- with col2:
175
- st.image(saliency_map, caption='Saliency Map', use_column_width=True)
176
-
177
-
178
- st.write("## Classification Results")
179
-
180
- result_container = st.container()
181
- result_container = st.container()
182
- result_container.markdown(
183
- f"""
184
- <div style="background-color: #000000; color: #ffffff; padding: 30px; border-radius: 15px;">
185
- <div style="display: flex; justify-content: space-between; align-items: center;">
186
- <div style="flex: 1; text-align: center;">
187
- <h3 style="color: #ffffff; margin-bottom: 10px; font-size: 20px;">Prediction</h3>
188
- <p style="font-size: 36px; font-weight: 800; color: #FF0000; margin: 0;">
189
- {result}
190
- </p>
191
- </div>
192
- <div style="width: 2px; height: 80px; background-color: #ffffff; margin: 0 20px;"></div>
193
- <div style="flex: 1; text-align: center;">
194
- <h3 style="color: #ffffff; margin-bottom: 10px; font-size: 20px;">Confidence</h3>
195
- <p style="font-size: 36px; font-weight: 800; color: #2196F3; margin: 0;">
196
- {prediction[0][class_index]:.4%}
197
- </p>
198
- </div>
199
- </div>
200
- </div>
201
- """,
202
- unsafe_allow_html=True
203
- )
204
-
205
- # Prepare data for Plotly chart
206
- probabilities = prediction[0]
207
- sorted_indices = np.argsort(probabilities)[::-1]
208
- sorted_labels = [labels[i] for i in sorted_indices]
209
- sorted_probabilities = probabilities[sorted_indices]
210
-
211
-
212
- # Create Plotly bar chart
213
- fig = go.Figure(go.Bar(
214
- x=sorted_probabilities,
215
- y=sorted_labels,
216
- orientation='h',
217
- marker_color=['red' if label == result else 'blue' for label in sorted_labels]
218
- ))
219
-
220
- # Customize chart layout
221
- fig.update_layout(
222
- title='Probability for each class',
223
- xaxis_title='Probability',
224
- yaxis_title='Class',
225
- height=400,
226
- width=600,
227
- yaxis=dict(autorange='reversed')
228
- )
229
-
230
- # Add value labels to the bars
231
- for i, prob in enumerate(sorted_probabilities):
232
- fig.add_annotation(
233
- x=prob,
234
- y=i,
235
- text=f'{prob:.4f}',
236
- showarrow=False,
237
- xanchor='left',
238
- xshift=5
239
  )
 
240
 
241
- # Display Plotly chart
242
- st.plotly_chart(fig)
243
-
244
- saliency_map_path = f'saliency_maps/{uploaded_file.name}'
245
- explanation = generate_explanation(saliency_map_path, result, prediction[0][class_index])
246
-
247
- st.write("## Explanation")
248
- st.write(explanation)
 
1
+ import gradio as gr
 
2
  import tensorflow as tf
3
  from tensorflow.keras.models import load_model
4
  from tensorflow.keras.preprocessing import image
5
  import numpy as np
 
6
  import cv2
 
 
 
 
 
 
7
  import PIL.Image
8
  import os
9
  from dotenv import load_dotenv
10
+ import traceback
11
+ import base64
12
+ from io import BytesIO
 
 
 
 
 
13
 
14
+ # Load environment variables
15
+ load_dotenv()
 
16
 
17
+ # Remove Google Generative AI imports and configuration
18
+ # We will use a local model for generating explanations
19
 
20
+ # Import the transformers library for text generation
21
+ from transformers import pipeline
22
 
23
+ # Initialize the text generation pipeline (using a smaller model for demonstration)
24
+ # Replace 'google/flan-t5-base' with 'meta-llama/Llama-2-7b-chat-hf' if you have the resources
25
+ generator = pipeline('text2text-generation', model='google/flan-t5-base')
 
 
26
 
27
+ def generate_explanation(saliency_map_image, model_prediction, confidence):
28
+ # Convert the saliency map image array to a base64-encoded string
29
+ buffered = BytesIO()
30
+ img = PIL.Image.fromarray(saliency_map_image)
31
+ img.save(buffered, format="PNG")
32
+ img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
33
 
34
+ # Prepare the prompt (Note: models like FLAN-T5 cannot process images directly)
35
+ prompt = f"""You are an expert neurologist. You are tasked with explaining a saliency map of a brain tumor MRI scan.
36
+ The saliency map was generated by a deep learning model that was trained to classify brain tumors as either
37
+ glioma, meningioma, pituitary, or no tumor.
38
 
39
+ The saliency map highlights the regions of the image that the machine learning model is focusing on to make the predictions.
 
40
 
41
+ The deep learning model predicted the image to be of class '{model_prediction}' with a confidence of {confidence * 100:.2f}%.
42
 
 
 
 
 
 
 
43
 
44
+ In your response:
45
+ - Explain what regions of the brain the model is focusing on, based on the saliency map.
46
+ - Explain possible reasons why the model made the prediction it did.
47
+ - Do not mention phrases like 'The saliency map highlights the regions the model is focusing on, which are in light cyan.'
48
+ - Keep your explanation to 5 sentences max.
49
 
50
+ Your response will go directly in the report to the doctor and patient, so do not add extra phrases like 'Sure!' or ask any questions at the end.
51
+ Let's think step by step about this.
52
+ """
53
 
54
+ # Generate the explanation using the text generation pipeline
55
+ response = generator(prompt, max_length=500)
56
+ explanation = response[0]['generated_text']
 
 
57
 
58
+ return explanation
 
59
 
60
+ def generate_saliency_map(model, img_array, class_index, img_size):
61
+ with tf.GradientTape() as tape:
62
+ img_tensor = tf.convert_to_tensor(img_array)
63
+ tape.watch(img_tensor)
64
+ predictions = model(img_tensor)
65
+ target_class = predictions[:, class_index]
66
 
67
+ gradients = tape.gradient(target_class, img_tensor)
68
+ gradients = tf.math.abs(gradients)
69
+ gradients = tf.reduce_max(gradients, axis=-1)
70
+ gradients = gradients.numpy().squeeze()
71
 
72
+ gradients = cv2.resize(gradients, img_size)
 
73
 
74
+ # Create a circular mask to focus on the brain region
75
+ center = (gradients.shape[0] // 2, gradients.shape[1] // 2)
76
+ radius = min(center[0], center[1]) - 10
77
+ y, x = np.ogrid[:gradients.shape[0], :gradients.shape[1]]
78
+ mask = (x - center[0])**2 + (y - center[1])**2 <= radius**2
79
 
80
+ gradients = gradients * mask
 
81
 
82
+ # Normalize the gradients within the brain region
83
+ brain_gradients = gradients[mask]
84
+ if brain_gradients.max() > brain_gradients.min():
85
+ brain_gradients = (brain_gradients - brain_gradients.min()) / (brain_gradients.max() - brain_gradients.min())
86
+ gradients[mask] = brain_gradients
87
 
88
+ # Apply thresholding to highlight important regions
89
+ threshold = np.percentile(gradients[mask], 80)
90
+ gradients[gradients < threshold] = 0
91
 
92
+ # Apply Gaussian blur to smooth the saliency map
93
+ gradients = cv2.GaussianBlur(gradients, (11, 11), 0)
94
 
95
+ # Create a heatmap from the gradients
96
+ heatmap = cv2.applyColorMap(np.uint8(255 * gradients), cv2.COLORMAP_JET)
97
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
98
 
99
+ heatmap = cv2.resize(heatmap, img_size)
100
 
101
+ # Superimpose the heatmap on the original image
102
+ original_img = image.img_to_array(PIL.Image.fromarray((img_array[0] * 255).astype(np.uint8)))
103
+ superimposed_img = heatmap * 0.7 + original_img * 0.3
104
+ superimposed_img = superimposed_img.astype(np.uint8)
105
 
106
+ return superimposed_img
107
 
108
  def load_xception_model(model_path):
109
+ img_shape = (299, 299, 3)
110
+ base_model = tf.keras.applications.Xception(include_top=False, weights="imagenet", input_shape=img_shape, pooling='max')
111
+
112
+ model = tf.keras.Sequential([
113
+ base_model,
114
+ tf.keras.layers.Flatten(),
115
+ tf.keras.layers.Dropout(rate=0.3),
116
+ tf.keras.layers.Dense(128, activation='relu'),
117
+ tf.keras.layers.Dropout(rate=0.25),
118
+ tf.keras.layers.Dense(4, activation='softmax')
119
+ ])
120
+
121
+ model.compile(optimizer=tf.keras.optimizers.Adamax(learning_rate=0.001),
122
+ loss='categorical_crossentropy',
123
+ metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
124
+ model.load_weights(model_path)
125
+ return model
126
+
127
+ def classify_brain_tumor(image_file, model_choice):
128
+ try:
129
+ # Load the selected model
130
+ if model_choice == "Transfer Learning - Xception":
131
+ model = load_xception_model('xception_model.weights.h5')
132
+ img_size = (299, 299)
133
+ else:
134
+ model = load_model('cnn_model.h5')
135
+ img_size = (224, 224)
136
+
137
+ labels = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary']
138
+
139
+ # Preprocess the input image
140
+ img = image.load_img(image_file, target_size=img_size)
141
+ img_array = image.img_to_array(img)
142
+ img_array = np.expand_dims(img_array, axis=0)
143
+ img_array /= 255.0
144
+
145
+ # Make the prediction
146
+ prediction = model.predict(img_array)
147
+ class_index = np.argmax(prediction[0])
148
+ result = labels[class_index]
149
+ confidence = prediction[0][class_index]
150
+
151
+ # Generate the saliency map
152
+ saliency_map = generate_saliency_map(model, img_array, class_index, img_size)
153
+ # No longer saving the saliency map to disk
154
+
155
+ # Generate the explanation
156
+ explanation = generate_explanation(saliency_map, result, confidence)
157
+
158
+ # Prepare probabilities for all classes
159
+ probabilities = prediction[0]
160
+ prob_dict = dict(zip(labels, probabilities))
161
+
162
+ # Return the outputs in the expected order
163
+ return [
164
+ result,
165
+ confidence,
166
+ saliency_map,
167
+ explanation,
168
+ "", # Empty string for Logs
169
+ prob_dict # For displaying probabilities
170
+ ]
171
+ except Exception as e:
172
+ # Return error information
173
+ return [
174
+ "Error",
175
+ 0.0,
176
+ None,
177
+ "",
178
+ f"Error: {str(e)}\nTraceback:\n{traceback.format_exc()}",
179
+ {} # Empty probabilities
180
+ ]
181
+
182
+ def main():
183
+ # Define the interface
184
+ interface = gr.Interface(
185
+ fn=classify_brain_tumor,
186
+ inputs=[
187
+ gr.Image(type="filepath"),
188
+ gr.Radio(choices=["Transfer Learning - Xception", "Custom CNN"], label="Select Model")
189
+ ],
190
+ outputs=[
191
+ gr.Textbox(label="Prediction"),
192
+ gr.Number(label="Confidence", precision=2),
193
+ gr.Image(type="numpy", label="Saliency Map"),
194
+ gr.Textbox(label="Explanation"),
195
+ gr.Textbox(label="Logs"),
196
+ gr.Label(num_top_classes=4, label="Class Probabilities")
197
+ ],
198
+ title="Brain Tumor Classification",
199
+ description="Upload an MRI scan image to classify the tumor and view saliency maps with model explanations.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  )
201
+ interface.launch()
202
 
203
+ if __name__ == "__main__":
204
+ main()