fossil_app / inference_sam.py
andy-wyx's picture
change map location to CPU
505fe72
import torch
import tensorflow as tf
device = torch.device("cpu")
print(f"Torch device: {device}")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if device.type == "cuda":
# torch.cuda.set_per_process_memory_fraction(0.3, device=device.index if device.index is not None else 0)
# else:
# device = "cpu"
# print(f"Torch device: {device}")
tf.config.set_visible_devices([], 'GPU')
# gpu_devices = tf.config.experimental.list_physical_devices('GPU')
# if gpu_devices:
# tf.config.experimental.set_memory_growth(gpu_devices[0], True)
# else:
# print(f"TensorFlow device: {gpu_devices}")
from segment_anything import SamPredictor, sam_model_registry
import matplotlib.pyplot as plt
import cv2
import numpy as np
from math import ceil
import os
from huggingface_hub import snapshot_download
if not os.path.exists('model'):
REPO_ID='Serrelab/SAM_Leaves'
token = os.environ.get('READ_TOKEN')
print(f"Read token:{token}")
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model')
original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
kwargs['map_location'] = device
return original_torch_load(*args, **kwargs)
torch.load = patched_torch_load
model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth')
sam = sam_model_registry["default"](model_path)
sam.to(device) #sam.cuda()
predictor = SamPredictor(sam)
torch.load = original_torch_load
from torch.nn import functional as F
def pad_gt(x):
h, w = x.shape[-2:]
padh = sam.image_encoder.img_size - h
padw = sam.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def preprocess(img):
img = np.array(img).astype(np.uint8)
#assert img.max() > 127.0
img_preprocess = predictor.transform.apply_image(img)
intermediate_shape = img_preprocess.shape
img_preprocess = torch.as_tensor(img_preprocess).to(device) #torch.as_tensor(img_preprocess).cuda()
img_preprocess = img_preprocess.permute(2, 0, 1).contiguous()[None, :, :, :]
img_preprocess = sam.preprocess(img_preprocess)
if len(intermediate_shape) == 3:
intermediate_shape = intermediate_shape[:2]
elif len(intermediate_shape) == 4:
intermediate_shape = intermediate_shape[1:3]
return img_preprocess, intermediate_shape
def normalize(img):
img = img - tf.math.reduce_min(img)
img = img / tf.math.reduce_max(img)
img = img * 2.0 - 1.0
return img
def resize(img):
# default resize function for all pi outputs
return tf.image.resize(img, (SIZE, SIZE), method="bicubic")
def smooth_mask(mask, ds=20):
shape = tf.shape(mask)
w, h = shape[0], shape[1]
return tf.image.resize(tf.image.resize(mask, (ds, ds), method="bicubic"), (w, h), method="bicubic")
def pi(img, mask):
img = tf.cast(img, tf.float32)
shape = tf.shape(img)
w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64)
mask = smooth_mask(mask.cpu().numpy().astype(float))
mask = tf.reduce_mean(mask, -1)
img = img * tf.cast(mask > 0.01, tf.float32)[:, :, None]
img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True)
img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
# building 2 anchors
anchors = tf.where(mask > 0.15)
anchor_xmin = tf.math.reduce_min(anchors[:, 0])
anchor_xmax = tf.math.reduce_max(anchors[:, 0])
anchor_ymin = tf.math.reduce_min(anchors[:, 1])
anchor_ymax = tf.math.reduce_max(anchors[:, 1])
if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50:
img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax])
delta_x = (anchor_xmax - anchor_xmin) // 4
delta_y = (anchor_ymax - anchor_ymin) // 4
img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x,
anchor_ymin+delta_y:anchor_ymax-delta_y]
img_anchor_2 = resize(img_anchor_2)
else:
img_anchor_1 = img_resize
img_anchor_2 = img_pad
# building the anchors max
anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0]
anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1]
img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w),
tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)]
img_max_zoom1 = resize(img_max_zoom1)
img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2,
anchor_max_y-SIZE//2:anchor_max_y+SIZE//2]
#img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w),
# tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)]
#tf.print(img_max_zoom2.shape)
#img_max_zoom2 = resize(img_max_zoom2)
return tf.cast([
img_resize,
#img_pad,
img_anchor_1,
img_anchor_2,
img_max_zoom1,
#img_max_zoom2,
], tf.float32)
def one_step_inference(x):
if len(x.shape) == 3:
original_size = x.shape[:2]
elif len(x.shape) == 4:
original_size = x.shape[1:3]
x, intermediate_shape = preprocess(x)
with torch.no_grad():
image_embedding = sam.image_encoder(x)
with torch.no_grad():
sparse_embeddings, dense_embeddings = sam.prompt_encoder(points = None, boxes = None,masks = None)
low_res_masks, iou_predictions = sam.mask_decoder(
image_embeddings=image_embedding,
image_pe=sam.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
if len(x.shape) == 3:
input_size = tuple(x.shape[:2])
elif len(x.shape) == 4:
input_size = tuple(x.shape[-2:])
#upscaled_masks = sam.postprocess_masks(low_res_masks, input_size, original_size).cuda()
mask = F.interpolate(low_res_masks, (1024, 1024))[:, :, :intermediate_shape[0], :intermediate_shape[1]]
mask = F.interpolate(mask, (original_size[0], original_size[1]))
return mask.to(device) #mask
def segmentation_sam(x,SIZE=384):
x = tf.image.resize_with_pad(x, SIZE, SIZE)
predicted_mask = one_step_inference(x)
fig, ax = plt.subplots()
img = x.cpu().numpy()
mask = predicted_mask.cpu().numpy()[0][0]>0.2
ax.imshow(img)
ax.imshow(mask, cmap='jet', alpha=0.4)
plt.savefig('test.png')
ax.axis('off')
fig.canvas.draw()
# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data