import numpy as np
import gzip
from pathlib import Path
import tempfile
import cv2
import tensorflow as tf
import skimage.morphology
import skimage.filters.rank
import skimage.util
from tensorflow.keras.models import load_model
from aix.utils import hardened_dice_coef
from aix.losses import dice_loss
class AreaModel:
def __init__(self, model_path="model/majority_roi_production.keras"):
self.model_path = model_path
self.model = load_model(model_path)
self.IMG_SIZE = (192, 240)
self.IMG_SHAPE = (self.IMG_SIZE[0], self.IMG_SIZE[1], self.INPUT_CHANNELS)
self.MASK_SHAPE = (self.IMG_SIZE[0], self.IMG_SIZE[1], 1)
def compute_area(self, img):
roi_img = roi(img)
roi_shape = roi_img.shape
#print(roi_img.dtype, roi_shape)
t_img = tensor(roi_img, self.IMG_SHAPE)
y = self.model.predict(x=t_img)
mask = y[0]
resized_mask = tf.image.resize(mask, roi_shape)
area = np.sum(resized_mask)
return area, roi_img, resized_mask
def image_to_file_path(image):
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
def raw_image(file_path, remove_alpha=True):
img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
if len(img.shape) == 3 and img.shape[2] == 4:
#print("The image is in RGBA format. We remove the A")
img = img[:, :, :3]
return img
def tensor(img, shape):
#img = raw_image(file_path, cv2.IMREAD_GRAYSCALE)
if len(img.shape) == 2:
img.shape = (img.shape[0], img.shape[1], 1)
t = tf.convert_to_tensor(img)
t = tf.image.resize(t, shape[:2])
t = tf.reshape(t, (1, *shape))
t = tf.cast(t, tf.float32)
return t
def roi(cv2_img):
roi, (left, top), (right, bottom) = extract_roi(cv2_img / 255., filled=True, border=.01)
#print("ROI found", (left, top), (right, bottom))
return cv2_img[top:bottom, left:right]
def overlay_mask_on_image(image, mask, alpha=0.1, mask_color=(0, 255, 0)):
Overlays a mask on an image.
image (np.array): The original image.
mask (np.array): The mask to overlay.
alpha (float): The opacity of the mask.
mask_color (tuple): The color to use for the mask.
np.array: The image with the mask overlay.
rgb_image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
colored_mask = np.zeros_like(rgb_image)
colored_mask[:, :, 0] = mask_color[0] * mask[:, :, 0]
colored_mask[:, :, 1] = mask_color[1] * mask[:, :, 0]
colored_mask[:, :, 2] = mask_color[2] * mask[:, :, 0]
#colored_mask *= mask
# Overlay the mask on the image
overlay = cv2.addWeighted(rgb_image, 1, colored_mask, alpha, 0)
return overlay
def local_entropy(im, kernel_size=5, normalize=True):
entr_img = skimage.filters.rank.entropy(skimage.util.img_as_ubyte(im), kernel)
if normalize:
max_img = np.max(entr_img)
entr_img = (entr_img*255/max_img).astype(np.uint8)
return entr_img
def calc_dim(contour):
c_0 = [ point[0][0] for point in contour]
c_1 = [ point[0][1] for point in contour]
return (min(c_0), max(c_0), min(c_1), max(c_1))
def calc_size(dim):
return (dim[1] - dim[0]) * (dim[3] - dim[2])
def calc_dist(dim1, dim2):
return None
def extract_roi(img, threshold=135, kernel_size=5, min_fratio=.3, max_sratio=5, filled=True, border=.01):
entr_img = local_entropy(img, kernel_size=kernel_size)
_, mask = cv2.threshold(entr_img, threshold, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
contours_d = [calc_dim(c) for c in contours]
contours_sizes = [calc_size(c) for c in contours_d]
contour_indices = np.argsort(contours_sizes)[::-1]
# remove artifacts
fratio = min_fratio
sratio = max_sratio
idx = -1
while fratio<=min_fratio or sratio>=max_sratio:
idx += 1
biggest = contour_indices[idx]
filled_mask = np.zeros(img.shape, dtype=np.uint8)
filled_mask = cv2.fillPoly(filled_mask, [contours[biggest]], 255)
fratio = filled_mask.sum()/255/contours_sizes[biggest]
cdim = contours_d[biggest]
sratio = (cdim[3]-cdim[2])/(cdim[1]-cdim[0])
if sratio<1: sratio = 1 / sratio
#print(fratio, sratio, cdim, filled_mask.sum()//255)
# generating the mask
filled_mask = np.zeros(img.shape, dtype=np.uint8)
extra = ( int(img.shape[0] * border) , int(img.shape[1] * border) )
origin = (max(0, cdim[0]-extra[1]), max(0, cdim[2]-extra[0]))
to = (min(img.shape[1]-1 , cdim[1]+extra[1]), min(img.shape[0]-1 , cdim[3]+extra[0]))
if filled:
filled_mask = cv2.rectangle(filled_mask, origin, to, 255, -1)
filled_mask = cv2.rectangle(filled_mask, origin, to, 255, 2)
return filled_mask, origin, to