|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
from .image_encoder import ImageEncoderViT |
|
from .mask_decoder import MaskDecoder |
|
from .prompt_encoder import PromptEncoder |
|
|
|
from ..utils.transforms import ResizeLongestSide |
|
|
|
class Sam(nn.Module): |
|
mask_threshold: float = 0.0 |
|
image_format: str = "RGB" |
|
|
|
def __init__( |
|
self, |
|
image_encoder: ImageEncoderViT, |
|
prompt_encoder: PromptEncoder, |
|
mask_decoder: MaskDecoder, |
|
pixel_mean: List[float] = [123.675, 116.28, 103.53], |
|
pixel_std: List[float] = [58.395, 57.12, 57.375], |
|
) -> None: |
|
""" |
|
SAM predicts object masks from an image and input prompts. |
|
|
|
Arguments: |
|
image_encoder (ImageEncoderViT): The backbone used to encode the |
|
image into image embeddings that allow for efficient mask prediction. |
|
prompt_encoder (PromptEncoder): Encodes various types of input prompts. |
|
mask_decoder (MaskDecoder): Predicts masks from the image embeddings |
|
and encoded prompts. |
|
pixel_mean (list(float)): Mean values for normalizing pixels in the input image. |
|
pixel_std (list(float)): Std values for normalizing pixels in the input image. |
|
""" |
|
super().__init__() |
|
self.image_encoder = image_encoder |
|
self.prompt_encoder = prompt_encoder |
|
self.mask_decoder = mask_decoder |
|
self.transform = ResizeLongestSide(image_encoder.img_size) |
|
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) |
|
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) |
|
|
|
@property |
|
def device(self) -> Any: |
|
return self.pixel_mean.device |
|
|
|
@torch.no_grad() |
|
def forward( |
|
self, |
|
batched_input: List[Dict[str, Any]], |
|
multimask_output: bool, |
|
) -> List[Dict[str, torch.Tensor]]: |
|
""" |
|
Predicts masks end-to-end from provided images and prompts. |
|
If prompts are not known in advance, using SamPredictor is |
|
recommended over calling the model directly. |
|
|
|
Arguments: |
|
batched_input (list(dict)): A list over input images, each a |
|
dictionary with the following keys. A prompt key can be |
|
excluded if it is not present. |
|
'image': The image as a torch tensor in 3xHxW format, |
|
already transformed for input to the model. |
|
'original_size': (tuple(int, int)) The original size of |
|
the image before transformation, as (H, W). |
|
'point_coords': (torch.Tensor) Batched point prompts for |
|
this image, with shape BxNx2. Already transformed to the |
|
input frame of the model. |
|
'point_labels': (torch.Tensor) Batched labels for point prompts, |
|
with shape BxN. |
|
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. |
|
Already transformed to the input frame of the model. |
|
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, |
|
in the form Bx1xHxW. |
|
multimask_output (bool): Whether the model should predict multiple |
|
disambiguating masks, or return a single mask. |
|
|
|
Returns: |
|
(list(dict)): A list over input images, where each element is |
|
as dictionary with the following keys. |
|
'masks': (torch.Tensor) Batched binary mask predictions, |
|
with shape BxCxHxW, where B is the number of input prompts, |
|
C is determined by multimask_output, and (H, W) is the |
|
original size of the image. |
|
'iou_predictions': (torch.Tensor) The model's predictions |
|
of mask quality, in shape BxC. |
|
'low_res_logits': (torch.Tensor) Low resolution logits with |
|
shape BxCxHxW, where H=W=256. Can be passed as mask input |
|
to subsequent iterations of prediction. |
|
""" |
|
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) |
|
image_embeddings = self.image_encoder(input_images) |
|
outputs = [] |
|
for image_record, curr_embedding in zip(batched_input, image_embeddings): |
|
if "point_coords" in image_record: |
|
points = (image_record["point_coords"], image_record["point_labels"]) |
|
else: |
|
points = None |
|
sparse_embeddings, dense_embeddings = self.prompt_encoder( |
|
points=points, |
|
boxes=image_record.get("boxes", None), |
|
masks=image_record.get("mask_inputs", None), |
|
) |
|
low_res_masks, iou_predictions = self.mask_decoder( |
|
image_embeddings=curr_embedding.unsqueeze(0), |
|
image_pe=self.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
) |
|
masks = self.postprocess_masks( |
|
low_res_masks, |
|
input_size=image_record["image"].shape[-2:], |
|
original_size=image_record["original_size"], |
|
) |
|
masks = masks > self.mask_threshold |
|
outputs.append( |
|
{ |
|
"masks": masks, |
|
"iou_predictions": iou_predictions, |
|
"low_res_logits": low_res_masks, |
|
} |
|
) |
|
return outputs |
|
|
|
def postprocess_masks( |
|
self, |
|
masks: torch.Tensor, |
|
input_size: Tuple[int, ...], |
|
original_size: Tuple[int, ...], |
|
) -> torch.Tensor: |
|
""" |
|
Remove padding and upscale masks to the original image size. |
|
|
|
Arguments: |
|
masks (torch.Tensor): Batched masks from the mask_decoder, |
|
in BxCxHxW format. |
|
input_size (tuple(int, int)): The size of the image input to the |
|
model, in (H, W) format. Used to remove padding. |
|
original_size (tuple(int, int)): The original size of the image |
|
before resizing for input to the model, in (H, W) format. |
|
|
|
Returns: |
|
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) |
|
is given by original_size. |
|
""" |
|
masks = F.interpolate( |
|
masks, |
|
(self.image_encoder.img_size, self.image_encoder.img_size), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
masks = masks[..., : input_size[0], : input_size[1]] |
|
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) |
|
return masks |
|
|
|
def preprocess(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Normalize pixel values and pad to a square input.""" |
|
|
|
x = (x - self.pixel_mean) / self.pixel_std |
|
|
|
|
|
h, w = x.shape[-2:] |
|
padh = self.image_encoder.img_size - h |
|
padw = self.image_encoder.img_size - w |
|
x = F.pad(x, (0, padw, 0, padh)) |
|
return x |
|
|
|
@torch.no_grad() |
|
def forward_custom( |
|
self, |
|
batched_input: List[Dict[str, Any]], |
|
multimask_output: bool, |
|
) -> List[Dict[str, torch.Tensor]]: |
|
""" |
|
Predicts masks end-to-end from provided images and prompts. |
|
If prompts are not known in advance, using SamPredictor is |
|
recommended over calling the model directly. |
|
|
|
Arguments: |
|
batched_input (list(dict)): A list over input images, each a |
|
dictionary with the following keys. A prompt key can be |
|
excluded if it is not present. |
|
'image': The image as a torch tensor in 3xHxW format, |
|
already transformed for input to the model. |
|
'original_size': (tuple(int, int)) The original size of |
|
the image before transformation, as (H, W). |
|
'point_coords': (torch.Tensor) Batched point prompts for |
|
this image, with shape BxNx2. Already transformed to the |
|
input frame of the model. |
|
'point_labels': (torch.Tensor) Batched labels for point prompts, |
|
with shape BxN. |
|
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. |
|
Already transformed to the input frame of the model. |
|
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, |
|
in the form Bx1xHxW. |
|
multimask_output (bool): Whether the model should predict multiple |
|
disambiguating masks, or return a single mask. |
|
|
|
Returns: |
|
(list(dict)): A list over input images, where each element is |
|
as dictionary with the following keys. |
|
'masks': (torch.Tensor) Batched binary mask predictions, |
|
with shape BxCxHxW, where B is the number of input prompts, |
|
C is determined by multimask_output, and (H, W) is the |
|
original size of the image. |
|
'iou_predictions': (torch.Tensor) The model's predictions |
|
of mask quality, in shape BxC. |
|
'low_res_logits': (torch.Tensor) Low resolution logits with |
|
shape BxCxHxW, where H=W=256. Can be passed as mask input |
|
to subsequent iterations of prediction. |
|
""" |
|
batched_input = self.preprocess_custom(batched_input) |
|
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) |
|
image_embeddings = self.image_encoder(input_images) |
|
|
|
outputs = [] |
|
for image_record, curr_embedding in zip(batched_input, image_embeddings): |
|
if "point_coords" in image_record: |
|
points = (image_record["point_coords"], image_record["point_labels"]) |
|
else: |
|
points = None |
|
sparse_embeddings, dense_embeddings = self.prompt_encoder( |
|
points=points, |
|
boxes=image_record.get("boxes", None), |
|
masks=image_record.get("mask_inputs", None), |
|
) |
|
low_res_masks, iou_predictions = self.mask_decoder( |
|
image_embeddings=curr_embedding.unsqueeze(0), |
|
image_pe=self.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
) |
|
masks = self.postprocess_masks( |
|
low_res_masks, |
|
input_size=image_record["image"].shape[-2:], |
|
original_size=image_record["original_size"], |
|
) |
|
masks = masks > self.mask_threshold |
|
outputs.append( |
|
{ |
|
"masks": masks, |
|
"iou_predictions": iou_predictions, |
|
"low_res_logits": low_res_masks, |
|
} |
|
) |
|
return outputs |
|
|
|
def preprocess_custom(self, input_list): |
|
"""Normalize pixel values and pad to a square input.""" |
|
for input_ in input_list: |
|
img = input_["image"] |
|
input_image = self.transform.apply_image(img) |
|
input_image_torch = torch.as_tensor(input_image, device=self.device) |
|
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous() |
|
input_["image"] = input_image_torch |
|
box = input_["boxes"] |
|
box = self.transform.apply_boxes(box, input_["original_size"]) |
|
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) |
|
input_["boxes"] = box_torch |
|
return input_list |
|
|
|
@torch.no_grad() |
|
def forward_m2m( |
|
self, |
|
images, |
|
bboxes, |
|
multimask_output: bool, |
|
) -> List[Dict[str, torch.Tensor]]: |
|
""" |
|
Returns: |
|
(list(dict)): A list over input images, where each element is |
|
as dictionary with the following keys. |
|
'masks': (torch.Tensor) Batched binary mask predictions, |
|
with shape BxCxHxW, where B is the number of input prompts, |
|
C is determined by multimask_output, and (H, W) is the |
|
original size of the image. |
|
'iou_predictions': (torch.Tensor) The model's predictions |
|
of mask quality, in shape BxC. |
|
'low_res_logits': (torch.Tensor) Low resolution logits with |
|
shape BxCxHxW, where H=W=256. Can be passed as mask input |
|
to subsequent iterations of prediction. |
|
""" |
|
|
|
|
|
image_embeddings = self.image_encoder(images) |
|
|
|
masks = [] |
|
for image_record, curr_embedding in zip(bboxes, image_embeddings): |
|
sparse_embeddings, dense_embeddings = self.prompt_encoder( |
|
points=None, |
|
boxes=image_record, |
|
masks=None, |
|
) |
|
low_res_masks, iou_predictions = self.mask_decoder( |
|
image_embeddings=curr_embedding.unsqueeze(0), |
|
image_pe=self.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
) |
|
mask = low_res_masks[:, iou_predictions.argmax()] |
|
mask[mask > self.mask_threshold] = 1.0 |
|
mask[mask <= self.mask_threshold] = 0.0 |
|
masks.append(mask) |
|
return image_embeddings, torch.stack(masks, dim=0) |
|
|
|
@torch.no_grad() |
|
def forward_m2m_inference( |
|
self, |
|
input_dict, |
|
multimask_output: bool, |
|
) -> List[Dict[str, torch.Tensor]]: |
|
""" |
|
Returns: |
|
(list(dict)): A list over input images, where each element is |
|
as dictionary with the following keys. |
|
'masks': (torch.Tensor) Batched binary mask predictions, |
|
with shape BxCxHxW, where B is the number of input prompts, |
|
C is determined by multimask_output, and (H, W) is the |
|
original size of the image. |
|
'iou_predictions': (torch.Tensor) The model's predictions |
|
of mask quality, in shape BxC. |
|
'low_res_logits': (torch.Tensor) Low resolution logits with |
|
shape BxCxHxW, where H=W=256. Can be passed as mask input |
|
to subsequent iterations of prediction. |
|
""" |
|
image = input_dict["image"] |
|
image_embeddings = self.image_encoder(image) |
|
masks = [] |
|
post_masks = [] |
|
|
|
if 'bbox' in input_dict: |
|
bboxes = input_dict["bbox"] |
|
for image_record, curr_embedding in zip(bboxes, image_embeddings): |
|
sparse_embeddings, dense_embeddings = self.prompt_encoder( |
|
points=None, |
|
boxes=image_record, |
|
masks=None, |
|
) |
|
low_res_masks, iou_predictions = self.mask_decoder( |
|
image_embeddings=curr_embedding.unsqueeze(0), |
|
image_pe=self.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
) |
|
guide_mask = low_res_masks[:, iou_predictions.argmax()] |
|
final_masks = self.postprocess_masks( |
|
low_res_masks, |
|
input_size=input_dict["pad_shape"], |
|
original_size=input_dict["ori_shape"], |
|
) |
|
post_mask = final_masks[:, iou_predictions.argmax()] |
|
|
|
guide_mask[guide_mask > self.mask_threshold] = 1.0 |
|
guide_mask[guide_mask <= self.mask_threshold] = 0.0 |
|
post_mask[post_mask > self.mask_threshold] = 1.0 |
|
post_mask[post_mask <= self.mask_threshold] = 0.0 |
|
|
|
masks.append(guide_mask) |
|
post_masks.append(post_mask) |
|
|
|
elif 'point' in input_dict: |
|
pointes = input_dict["point"] |
|
labels = input_dict["label"] |
|
for point, label, curr_embedding in zip(pointes, labels, image_embeddings): |
|
sparse_embeddings, dense_embeddings = self.prompt_encoder( |
|
points=(point[None,:], label[None,:]), |
|
boxes=None, |
|
masks=None, |
|
) |
|
low_res_masks, iou_predictions = self.mask_decoder( |
|
image_embeddings=curr_embedding.unsqueeze(0), |
|
image_pe=self.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
) |
|
guide_mask = low_res_masks[:, iou_predictions.argmax()] |
|
final_masks = self.postprocess_masks( |
|
low_res_masks, |
|
input_size=input_dict["pad_shape"], |
|
original_size=input_dict["ori_shape"], |
|
) |
|
post_mask = final_masks[:, iou_predictions.argmax()] |
|
|
|
guide_mask[guide_mask > self.mask_threshold] = 1.0 |
|
guide_mask[guide_mask <= self.mask_threshold] = 0.0 |
|
post_mask[post_mask > self.mask_threshold] = 1.0 |
|
post_mask[post_mask <= self.mask_threshold] = 0.0 |
|
|
|
masks.append(guide_mask) |
|
post_masks.append(post_mask) |
|
|
|
|
|
return image_embeddings, torch.stack(masks, dim=0), torch.stack(post_masks, dim=0) |