Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,248 +1,204 @@
|
|
1 |
-
|
2 |
-
import streamlit as st
|
3 |
import tensorflow as tf
|
4 |
from tensorflow.keras.models import load_model
|
5 |
from tensorflow.keras.preprocessing import image
|
6 |
import numpy as np
|
7 |
-
import plotly.graph_objects as go
|
8 |
import cv2
|
9 |
-
from tensorflow.keras.models import Sequential
|
10 |
-
from tensorflow.keras.layers import Dense, Dropout, Flatten
|
11 |
-
from tensorflow.keras.optimizers import Adamax
|
12 |
-
from tensorflow.keras.metrics import Precision, Recall
|
13 |
-
import google.generativeai as genai
|
14 |
-
# from google.colab import userdata
|
15 |
import PIL.Image
|
16 |
import os
|
17 |
from dotenv import load_dotenv
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
output_dir = 'saliency_maps'
|
23 |
-
os.makedirs(output_dir, exist_ok=True)
|
24 |
-
|
25 |
-
def generate_explanation(img_path, model_prediction, confidence):
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
glioma, meningioma, pituitary, or no tumor.
|
30 |
|
31 |
-
|
|
|
32 |
|
33 |
-
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
- Don't mention anything like 'The saliency map highlights the regions the model is focusing on, which are in light cyan' in your explanation.
|
39 |
-
- Keep your explanation to 5 sentences max.
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
44 |
|
45 |
-
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
response = model.generate_content([prompt, img])
|
49 |
|
50 |
-
|
51 |
|
52 |
-
def generate_saliency_map(model, img_array, class_index, img_size):
|
53 |
-
with tf.GradientTape() as tape:
|
54 |
-
img_tensor = tf.convert_to_tensor(img_array)
|
55 |
-
tape.watch(img_tensor)
|
56 |
-
predictions = model(img_tensor)
|
57 |
-
target_class = predictions[:, class_index]
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
y, x = np.ogrid[:gradients.shape[0], :gradients.shape[1]]
|
71 |
-
mask = (x - center[0])**2 + (y - center[1])**2 <= radius**2
|
72 |
|
73 |
-
|
74 |
-
gradients = gradients * mask
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
85 |
|
86 |
-
|
87 |
-
gradients = cv2.GaussianBlur(gradients, (11, 11), 0)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
heatmap = cv2.resize(heatmap, img_size)
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
|
105 |
-
|
|
|
106 |
|
107 |
-
|
108 |
-
|
|
|
109 |
|
110 |
-
|
111 |
|
|
|
|
|
|
|
|
|
112 |
|
|
|
113 |
|
114 |
def load_xception_model(model_path):
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
probabilities = prediction[0]
|
207 |
-
sorted_indices = np.argsort(probabilities)[::-1]
|
208 |
-
sorted_labels = [labels[i] for i in sorted_indices]
|
209 |
-
sorted_probabilities = probabilities[sorted_indices]
|
210 |
-
|
211 |
-
|
212 |
-
# Create Plotly bar chart
|
213 |
-
fig = go.Figure(go.Bar(
|
214 |
-
x=sorted_probabilities,
|
215 |
-
y=sorted_labels,
|
216 |
-
orientation='h',
|
217 |
-
marker_color=['red' if label == result else 'blue' for label in sorted_labels]
|
218 |
-
))
|
219 |
-
|
220 |
-
# Customize chart layout
|
221 |
-
fig.update_layout(
|
222 |
-
title='Probability for each class',
|
223 |
-
xaxis_title='Probability',
|
224 |
-
yaxis_title='Class',
|
225 |
-
height=400,
|
226 |
-
width=600,
|
227 |
-
yaxis=dict(autorange='reversed')
|
228 |
-
)
|
229 |
-
|
230 |
-
# Add value labels to the bars
|
231 |
-
for i, prob in enumerate(sorted_probabilities):
|
232 |
-
fig.add_annotation(
|
233 |
-
x=prob,
|
234 |
-
y=i,
|
235 |
-
text=f'{prob:.4f}',
|
236 |
-
showarrow=False,
|
237 |
-
xanchor='left',
|
238 |
-
xshift=5
|
239 |
)
|
|
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
saliency_map_path = f'saliency_maps/{uploaded_file.name}'
|
245 |
-
explanation = generate_explanation(saliency_map_path, result, prediction[0][class_index])
|
246 |
-
|
247 |
-
st.write("## Explanation")
|
248 |
-
st.write(explanation)
|
|
|
1 |
+
import gradio as gr
|
|
|
2 |
import tensorflow as tf
|
3 |
from tensorflow.keras.models import load_model
|
4 |
from tensorflow.keras.preprocessing import image
|
5 |
import numpy as np
|
|
|
6 |
import cv2
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import PIL.Image
|
8 |
import os
|
9 |
from dotenv import load_dotenv
|
10 |
+
import traceback
|
11 |
+
import base64
|
12 |
+
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
# Load environment variables
|
15 |
+
load_dotenv()
|
|
|
16 |
|
17 |
+
# Remove Google Generative AI imports and configuration
|
18 |
+
# We will use a local model for generating explanations
|
19 |
|
20 |
+
# Import the transformers library for text generation
|
21 |
+
from transformers import pipeline
|
22 |
|
23 |
+
# Initialize the text generation pipeline (using a smaller model for demonstration)
|
24 |
+
# Replace 'google/flan-t5-base' with 'meta-llama/Llama-2-7b-chat-hf' if you have the resources
|
25 |
+
generator = pipeline('text2text-generation', model='google/flan-t5-base')
|
|
|
|
|
26 |
|
27 |
+
def generate_explanation(saliency_map_image, model_prediction, confidence):
|
28 |
+
# Convert the saliency map image array to a base64-encoded string
|
29 |
+
buffered = BytesIO()
|
30 |
+
img = PIL.Image.fromarray(saliency_map_image)
|
31 |
+
img.save(buffered, format="PNG")
|
32 |
+
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
33 |
|
34 |
+
# Prepare the prompt (Note: models like FLAN-T5 cannot process images directly)
|
35 |
+
prompt = f"""You are an expert neurologist. You are tasked with explaining a saliency map of a brain tumor MRI scan.
|
36 |
+
The saliency map was generated by a deep learning model that was trained to classify brain tumors as either
|
37 |
+
glioma, meningioma, pituitary, or no tumor.
|
38 |
|
39 |
+
The saliency map highlights the regions of the image that the machine learning model is focusing on to make the predictions.
|
|
|
40 |
|
41 |
+
The deep learning model predicted the image to be of class '{model_prediction}' with a confidence of {confidence * 100:.2f}%.
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
In your response:
|
45 |
+
- Explain what regions of the brain the model is focusing on, based on the saliency map.
|
46 |
+
- Explain possible reasons why the model made the prediction it did.
|
47 |
+
- Do not mention phrases like 'The saliency map highlights the regions the model is focusing on, which are in light cyan.'
|
48 |
+
- Keep your explanation to 5 sentences max.
|
49 |
|
50 |
+
Your response will go directly in the report to the doctor and patient, so do not add extra phrases like 'Sure!' or ask any questions at the end.
|
51 |
+
Let's think step by step about this.
|
52 |
+
"""
|
53 |
|
54 |
+
# Generate the explanation using the text generation pipeline
|
55 |
+
response = generator(prompt, max_length=500)
|
56 |
+
explanation = response[0]['generated_text']
|
|
|
|
|
57 |
|
58 |
+
return explanation
|
|
|
59 |
|
60 |
+
def generate_saliency_map(model, img_array, class_index, img_size):
|
61 |
+
with tf.GradientTape() as tape:
|
62 |
+
img_tensor = tf.convert_to_tensor(img_array)
|
63 |
+
tape.watch(img_tensor)
|
64 |
+
predictions = model(img_tensor)
|
65 |
+
target_class = predictions[:, class_index]
|
66 |
|
67 |
+
gradients = tape.gradient(target_class, img_tensor)
|
68 |
+
gradients = tf.math.abs(gradients)
|
69 |
+
gradients = tf.reduce_max(gradients, axis=-1)
|
70 |
+
gradients = gradients.numpy().squeeze()
|
71 |
|
72 |
+
gradients = cv2.resize(gradients, img_size)
|
|
|
73 |
|
74 |
+
# Create a circular mask to focus on the brain region
|
75 |
+
center = (gradients.shape[0] // 2, gradients.shape[1] // 2)
|
76 |
+
radius = min(center[0], center[1]) - 10
|
77 |
+
y, x = np.ogrid[:gradients.shape[0], :gradients.shape[1]]
|
78 |
+
mask = (x - center[0])**2 + (y - center[1])**2 <= radius**2
|
79 |
|
80 |
+
gradients = gradients * mask
|
|
|
81 |
|
82 |
+
# Normalize the gradients within the brain region
|
83 |
+
brain_gradients = gradients[mask]
|
84 |
+
if brain_gradients.max() > brain_gradients.min():
|
85 |
+
brain_gradients = (brain_gradients - brain_gradients.min()) / (brain_gradients.max() - brain_gradients.min())
|
86 |
+
gradients[mask] = brain_gradients
|
87 |
|
88 |
+
# Apply thresholding to highlight important regions
|
89 |
+
threshold = np.percentile(gradients[mask], 80)
|
90 |
+
gradients[gradients < threshold] = 0
|
91 |
|
92 |
+
# Apply Gaussian blur to smooth the saliency map
|
93 |
+
gradients = cv2.GaussianBlur(gradients, (11, 11), 0)
|
94 |
|
95 |
+
# Create a heatmap from the gradients
|
96 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * gradients), cv2.COLORMAP_JET)
|
97 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
98 |
|
99 |
+
heatmap = cv2.resize(heatmap, img_size)
|
100 |
|
101 |
+
# Superimpose the heatmap on the original image
|
102 |
+
original_img = image.img_to_array(PIL.Image.fromarray((img_array[0] * 255).astype(np.uint8)))
|
103 |
+
superimposed_img = heatmap * 0.7 + original_img * 0.3
|
104 |
+
superimposed_img = superimposed_img.astype(np.uint8)
|
105 |
|
106 |
+
return superimposed_img
|
107 |
|
108 |
def load_xception_model(model_path):
|
109 |
+
img_shape = (299, 299, 3)
|
110 |
+
base_model = tf.keras.applications.Xception(include_top=False, weights="imagenet", input_shape=img_shape, pooling='max')
|
111 |
+
|
112 |
+
model = tf.keras.Sequential([
|
113 |
+
base_model,
|
114 |
+
tf.keras.layers.Flatten(),
|
115 |
+
tf.keras.layers.Dropout(rate=0.3),
|
116 |
+
tf.keras.layers.Dense(128, activation='relu'),
|
117 |
+
tf.keras.layers.Dropout(rate=0.25),
|
118 |
+
tf.keras.layers.Dense(4, activation='softmax')
|
119 |
+
])
|
120 |
+
|
121 |
+
model.compile(optimizer=tf.keras.optimizers.Adamax(learning_rate=0.001),
|
122 |
+
loss='categorical_crossentropy',
|
123 |
+
metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
|
124 |
+
model.load_weights(model_path)
|
125 |
+
return model
|
126 |
+
|
127 |
+
def classify_brain_tumor(image_file, model_choice):
|
128 |
+
try:
|
129 |
+
# Load the selected model
|
130 |
+
if model_choice == "Transfer Learning - Xception":
|
131 |
+
model = load_xception_model('xception_model.weights.h5')
|
132 |
+
img_size = (299, 299)
|
133 |
+
else:
|
134 |
+
model = load_model('cnn_model.h5')
|
135 |
+
img_size = (224, 224)
|
136 |
+
|
137 |
+
labels = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary']
|
138 |
+
|
139 |
+
# Preprocess the input image
|
140 |
+
img = image.load_img(image_file, target_size=img_size)
|
141 |
+
img_array = image.img_to_array(img)
|
142 |
+
img_array = np.expand_dims(img_array, axis=0)
|
143 |
+
img_array /= 255.0
|
144 |
+
|
145 |
+
# Make the prediction
|
146 |
+
prediction = model.predict(img_array)
|
147 |
+
class_index = np.argmax(prediction[0])
|
148 |
+
result = labels[class_index]
|
149 |
+
confidence = prediction[0][class_index]
|
150 |
+
|
151 |
+
# Generate the saliency map
|
152 |
+
saliency_map = generate_saliency_map(model, img_array, class_index, img_size)
|
153 |
+
# No longer saving the saliency map to disk
|
154 |
+
|
155 |
+
# Generate the explanation
|
156 |
+
explanation = generate_explanation(saliency_map, result, confidence)
|
157 |
+
|
158 |
+
# Prepare probabilities for all classes
|
159 |
+
probabilities = prediction[0]
|
160 |
+
prob_dict = dict(zip(labels, probabilities))
|
161 |
+
|
162 |
+
# Return the outputs in the expected order
|
163 |
+
return [
|
164 |
+
result,
|
165 |
+
confidence,
|
166 |
+
saliency_map,
|
167 |
+
explanation,
|
168 |
+
"", # Empty string for Logs
|
169 |
+
prob_dict # For displaying probabilities
|
170 |
+
]
|
171 |
+
except Exception as e:
|
172 |
+
# Return error information
|
173 |
+
return [
|
174 |
+
"Error",
|
175 |
+
0.0,
|
176 |
+
None,
|
177 |
+
"",
|
178 |
+
f"Error: {str(e)}\nTraceback:\n{traceback.format_exc()}",
|
179 |
+
{} # Empty probabilities
|
180 |
+
]
|
181 |
+
|
182 |
+
def main():
|
183 |
+
# Define the interface
|
184 |
+
interface = gr.Interface(
|
185 |
+
fn=classify_brain_tumor,
|
186 |
+
inputs=[
|
187 |
+
gr.Image(type="filepath"),
|
188 |
+
gr.Radio(choices=["Transfer Learning - Xception", "Custom CNN"], label="Select Model")
|
189 |
+
],
|
190 |
+
outputs=[
|
191 |
+
gr.Textbox(label="Prediction"),
|
192 |
+
gr.Number(label="Confidence", precision=2),
|
193 |
+
gr.Image(type="numpy", label="Saliency Map"),
|
194 |
+
gr.Textbox(label="Explanation"),
|
195 |
+
gr.Textbox(label="Logs"),
|
196 |
+
gr.Label(num_top_classes=4, label="Class Probabilities")
|
197 |
+
],
|
198 |
+
title="Brain Tumor Classification",
|
199 |
+
description="Upload an MRI scan image to classify the tumor and view saliency maps with model explanations.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
)
|
201 |
+
interface.launch()
|
202 |
|
203 |
+
if __name__ == "__main__":
|
204 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|