Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
from third_parts.sam2.modeling.sam2_base import SAM2Base as _SAM2Base | |
from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE | |
class SAM2Base(_SAM2Base): | |
def track_step( | |
self, | |
frame_idx, | |
is_init_cond_frame, | |
current_vision_feats, | |
current_vision_pos_embeds, | |
feat_sizes, | |
point_inputs, | |
mask_inputs, | |
output_dict, | |
num_frames, | |
track_in_reverse=False, # tracking in reverse time order (for demo usage) | |
# Whether to run the memory encoder on the predicted masks. Sometimes we might want | |
# to skip the memory encoder with `run_mem_encoder=False`. For example, | |
# in demo we might call `track_step` multiple times for each user click, | |
# and only encode the memory when the user finalizes their clicks. And in ablation | |
# settings like SAM training on static images, we don't need the memory encoder. | |
run_mem_encoder=True, | |
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo). | |
prev_sam_mask_logits=None, | |
## Extension: LLM prompt | |
language_embd=None, | |
): | |
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} | |
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW | |
if len(current_vision_feats) > 1: | |
high_res_features = [ | |
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) | |
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) | |
] | |
else: | |
high_res_features = None | |
if mask_inputs is not None and self.use_mask_input_as_output_without_sam: | |
# When use_mask_input_as_output_without_sam=True, we directly output the mask input | |
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder. | |
pix_feat = current_vision_feats[-1].permute(1, 2, 0) | |
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) | |
sam_outputs = self._use_mask_as_output( | |
pix_feat, high_res_features, mask_inputs | |
) | |
else: | |
# fused the visual feature with previous memory features in the memory bank | |
pix_feat_with_mem = self._prepare_memory_conditioned_features( | |
frame_idx=frame_idx, | |
is_init_cond_frame=is_init_cond_frame, | |
current_vision_feats=current_vision_feats[-1:], | |
current_vision_pos_embeds=current_vision_pos_embeds[-1:], | |
feat_sizes=feat_sizes[-1:], | |
output_dict=output_dict, | |
num_frames=num_frames, | |
track_in_reverse=track_in_reverse, | |
) | |
# apply SAM-style segmentation head | |
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, | |
# e.g. in demo where such logits come from earlier interaction instead of correction sampling | |
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) | |
if prev_sam_mask_logits is not None: | |
assert point_inputs is not None and mask_inputs is None | |
mask_inputs = prev_sam_mask_logits | |
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) | |
sam_outputs = self._forward_sam_heads( | |
backbone_features=pix_feat_with_mem, | |
point_inputs=point_inputs, | |
mask_inputs=mask_inputs, | |
high_res_features=high_res_features, | |
multimask_output=multimask_output, | |
# Inject language Embed if possible | |
language_embd=language_embd, | |
) | |
( | |
_, | |
_, | |
_, | |
low_res_masks, | |
high_res_masks, | |
obj_ptr, | |
_, | |
) = sam_outputs | |
current_out["pred_masks"] = low_res_masks | |
current_out["pred_masks_high_res"] = high_res_masks | |
current_out["obj_ptr"] = obj_ptr | |
# Finally run the memory encoder on the predicted mask to encode | |
# it into a new memory feature (that can be used in future frames) | |
if run_mem_encoder and self.num_maskmem > 0: | |
high_res_masks_for_mem_enc = high_res_masks | |
maskmem_features, maskmem_pos_enc = self._encode_new_memory( | |
current_vision_feats=current_vision_feats, | |
feat_sizes=feat_sizes, | |
pred_masks_high_res=high_res_masks_for_mem_enc, | |
is_mask_from_pts=(point_inputs is not None), | |
) | |
current_out["maskmem_features"] = maskmem_features | |
current_out["maskmem_pos_enc"] = maskmem_pos_enc | |
else: | |
current_out["maskmem_features"] = None | |
current_out["maskmem_pos_enc"] = None | |
return current_out | |
def _forward_sam_heads( | |
self, | |
backbone_features, | |
point_inputs=None, | |
mask_inputs=None, | |
high_res_features=None, | |
multimask_output=False, | |
## Extension: LLM prompt | |
language_embd=None, | |
): | |
""" | |
Forward SAM prompt encoders and mask heads. | |
Inputs: | |
- backbone_features: image features of [B, C, H, W] shape | |
- point_inputs: a dictionary with "point_coords" and "point_labels", where | |
1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the | |
absolute pixel-unit coordinate in (x, y) format of the P input points | |
2) "point_labels" has shape [B, P] and int32 dtype, where 1 means | |
positive clicks, 0 means negative clicks, and -1 means padding | |
- mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the | |
same spatial size as the image. | |
- high_res_features: either 1) None or 2) or a list of length 2 containing | |
two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, | |
which will be used as high-resolution feature maps for SAM decoder. | |
- multimask_output: if it's True, we output 3 candidate masks and their 3 | |
corresponding IoU estimates, and if it's False, we output only 1 mask and | |
its corresponding IoU estimate. | |
Outputs: | |
- low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if | |
`multimask_output=True` and M = 1 if `multimask_output=False`), the SAM | |
output mask logits (before sigmoid) for the low-resolution masks, with 4x | |
the resolution (1/4 stride) of the input backbone_features. | |
- high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 | |
if `multimask_output=True` and M = 1 if `multimask_output=False`), | |
upsampled from the low-resolution masks, with shape size as the image | |
(stride is 1 pixel). | |
- ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 | |
if `multimask_output=False`), the estimated IoU of each output mask. | |
- low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. | |
If `multimask_output=True`, it's the mask with the highest IoU estimate. | |
If `multimask_output=False`, it's the same as `low_res_multimasks`. | |
- high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. | |
If `multimask_output=True`, it's the mask with the highest IoU estimate. | |
If `multimask_output=False`, it's the same as `high_res_multimasks`. | |
- obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted | |
based on the output token from the SAM mask decoder. | |
""" | |
B = backbone_features.size(0) | |
device = backbone_features.device | |
assert backbone_features.size(1) == self.sam_prompt_embed_dim | |
assert backbone_features.size(2) == self.sam_image_embedding_size | |
assert backbone_features.size(3) == self.sam_image_embedding_size | |
# a) Handle point prompts | |
if point_inputs is not None: | |
sam_point_coords = point_inputs["point_coords"] | |
sam_point_labels = point_inputs["point_labels"] | |
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B | |
else: | |
# If no points are provide, pad with an empty point (with label -1) | |
sam_point_coords = torch.zeros(B, 1, 2, device=device) | |
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) | |
# b) Handle mask prompts | |
if mask_inputs is not None: | |
# If mask_inputs is provided, downsize it into low-res mask input if needed | |
# and feed it as a dense mask prompt into the SAM mask encoder | |
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) | |
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: | |
sam_mask_prompt = F.interpolate( | |
mask_inputs.float(), | |
size=self.sam_prompt_encoder.mask_input_size, | |
align_corners=False, | |
mode="bilinear", | |
antialias=True, # use antialias for downsampling | |
) | |
else: | |
sam_mask_prompt = mask_inputs | |
else: | |
# Otherwise, simply feed None (and SAM's prompt encoder will add | |
# a learned `no_mask_embed` to indicate no mask input in this case). | |
sam_mask_prompt = None | |
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( | |
points=(sam_point_coords, sam_point_labels), | |
boxes=None, | |
masks=sam_mask_prompt, | |
) | |
## Extension: LLM prompt | |
if language_embd is not None: | |
# B N C | |
assert sparse_embeddings.size(0) == language_embd.size(0) | |
assert sparse_embeddings.size(2) == language_embd.size(2) | |
sparse_embeddings = torch.cat([sparse_embeddings, language_embd], dim=1) | |
( | |
low_res_multimasks, | |
ious, | |
sam_output_tokens, | |
object_score_logits, | |
) = self.sam_mask_decoder( | |
image_embeddings=backbone_features, | |
image_pe=self.sam_prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=multimask_output, | |
repeat_image=False, # the image is already batched | |
high_res_features=high_res_features, | |
) | |
if self.pred_obj_scores: | |
is_obj_appearing = object_score_logits > 0 | |
# Mask used for spatial memories is always a *hard* choice between obj and no obj, | |
# consistent with the actual mask prediction | |
# print('Do torch.where !!!') | |
# low_res_multimasks = torch.where( | |
# is_obj_appearing[:, None, None], | |
# low_res_multimasks, | |
# NO_OBJ_SCORE, | |
# ) | |
# convert masks from possibly bfloat16 (or float16) to float32 | |
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16) | |
low_res_multimasks = low_res_multimasks.float() | |
high_res_multimasks = F.interpolate( | |
low_res_multimasks, | |
size=(self.image_size, self.image_size), | |
mode="bilinear", | |
align_corners=False, | |
) | |
sam_output_token = sam_output_tokens[:, 0] | |
if multimask_output: | |
# take the best mask prediction (with the highest IoU estimation) | |
best_iou_inds = torch.argmax(ious, dim=-1) | |
batch_inds = torch.arange(B, device=device) | |
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) | |
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) | |
if sam_output_tokens.size(1) > 1: | |
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] | |
else: | |
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks | |
# Extract object pointer from the SAM output token (with occlusion handling) | |
obj_ptr = self.obj_ptr_proj(sam_output_token) | |
if self.pred_obj_scores: | |
# Allow *soft* no obj ptr, unlike for masks | |
if self.soft_no_obj_ptr: | |
# Only hard possible with gt | |
assert not self.teacher_force_obj_scores_for_mem | |
lambda_is_obj_appearing = object_score_logits.sigmoid() | |
else: | |
lambda_is_obj_appearing = is_obj_appearing.float() | |
if self.fixed_no_obj_ptr: | |
obj_ptr = lambda_is_obj_appearing * obj_ptr | |
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr | |
return ( | |
low_res_multimasks, | |
high_res_multimasks, | |
ious, | |
low_res_masks, | |
high_res_masks, | |
obj_ptr, | |
object_score_logits, | |
) | |