import os import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange import warnings from warnings import warn from PIL import Image import romatch from romatch.utils import get_tuple_transform_ops from romatch.utils.local_correlation import local_correlation from romatch.utils.utils import cls_to_flow_refine from romatch.utils.kde import kde from typing import Union class ConvRefiner(nn.Module): def __init__( self, in_dim=6, hidden_dim=16, out_dim=2, dw=False, kernel_size=5, hidden_blocks=3, displacement_emb=None, displacement_emb_dim=None, local_corr_radius=None, corr_in_other=None, no_im_B_fm=False, amp=False, concat_logits=False, use_bias_block_1=True, use_cosine_corr=False, disable_local_corr_grad=False, is_classifier=False, sample_mode="bilinear", norm_type=nn.BatchNorm2d, bn_momentum=0.1, amp_dtype=torch.float16, ): super().__init__() self.bn_momentum = bn_momentum self.block1 = self.create_block( in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias=use_bias_block_1, ) self.hidden_blocks = nn.Sequential( *[ self.create_block( hidden_dim, hidden_dim, dw=dw, kernel_size=kernel_size, norm_type=norm_type, ) for hb in range(hidden_blocks) ] ) self.hidden_blocks = self.hidden_blocks self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) if displacement_emb: self.has_displacement_emb = True self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0) else: self.has_displacement_emb = False self.local_corr_radius = local_corr_radius self.corr_in_other = corr_in_other self.no_im_B_fm = no_im_B_fm self.amp = amp self.concat_logits = concat_logits self.use_cosine_corr = use_cosine_corr self.disable_local_corr_grad = disable_local_corr_grad self.is_classifier = is_classifier self.sample_mode = sample_mode self.amp_dtype = amp_dtype def create_block( self, in_dim, out_dim, dw=False, kernel_size=5, bias=True, norm_type=nn.BatchNorm2d, ): num_groups = 1 if not dw else in_dim if dw: assert ( out_dim % in_dim == 0 ), "outdim must be divisible by indim for depthwise" conv1 = nn.Conv2d( in_dim, out_dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=num_groups, bias=bias, ) norm = norm_type(out_dim, momentum=self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type( num_channels=out_dim) relu = nn.ReLU(inplace=True) conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) return nn.Sequential(conv1, norm, relu, conv2) def forward(self, x, y, flow, scale_factor=1, logits=None): b, c, hs, ws = x.shape with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): with torch.no_grad(): x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode=self.sample_mode) if self.has_displacement_emb: im_A_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device), torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device), ) ) im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) in_displacement = flow - im_A_coords emb_in_displacement = self.disp_emb(40 / 32 * scale_factor * in_displacement) if self.local_corr_radius: if self.corr_in_other: # Corr in other means take a kxk grid around the predicted coordinate in other image local_corr = local_correlation(x, y, local_radius=self.local_corr_radius, flow=flow, sample_mode=self.sample_mode) else: raise NotImplementedError("Local corr in own frame should not be used.") if self.no_im_B_fm: x_hat = torch.zeros_like(x) d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) else: d = torch.cat((x, x_hat, emb_in_displacement), dim=1) else: if self.no_im_B_fm: x_hat = torch.zeros_like(x) d = torch.cat((x, x_hat), dim=1) if self.concat_logits: d = torch.cat((d, logits), dim=1) d = self.block1(d) d = self.hidden_blocks(d) d = self.out_conv(d.float()) displacement, certainty = d[:, :-1], d[:, -1:] return displacement, certainty class CosKernel(nn.Module): # similar to softmax kernel def __init__(self, T, learn_temperature=False): super().__init__() self.learn_temperature = learn_temperature if self.learn_temperature: self.T = nn.Parameter(torch.tensor(T)) else: self.T = T def __call__(self, x, y, eps=1e-6): c = torch.einsum("bnd,bmd->bnm", x, y) / ( x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps ) if self.learn_temperature: T = self.T.abs() + 0.01 else: T = torch.tensor(self.T, device=c.device) K = ((c - 1.0) / T).exp() return K class GP(nn.Module): def __init__( self, kernel, T=1, learn_temperature=False, only_attention=False, gp_dim=64, basis="fourier", covar_size=5, only_nearest_neighbour=False, sigma_noise=0.1, no_cov=False, predict_features=False, ): super().__init__() self.K = kernel(T=T, learn_temperature=learn_temperature) self.sigma_noise = sigma_noise self.covar_size = covar_size self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) self.only_attention = only_attention self.only_nearest_neighbour = only_nearest_neighbour self.basis = basis self.no_cov = no_cov self.dim = gp_dim self.predict_features = predict_features def get_local_cov(self, cov): K = self.covar_size b, h, w, h, w = cov.shape hw = h * w cov = F.pad(cov, 4 * (K // 2,)) # pad v_q delta = torch.stack( torch.meshgrid( torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1) ), dim=-1, ) positions = torch.stack( torch.meshgrid( torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2) ), dim=-1, ) neighbours = positions[:, :, None, None, :] + delta[None, :, :] points = torch.arange(hw)[:, None].expand(hw, K ** 2) local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ :, points.flatten(), neighbours[..., 0].flatten(), neighbours[..., 1].flatten(), ].reshape(b, h, w, K ** 2) return local_cov def reshape(self, x): return rearrange(x, "b d h w -> b (h w) d") def project_to_basis(self, x): if self.basis == "fourier": return torch.cos(8 * math.pi * self.pos_conv(x)) elif self.basis == "linear": return self.pos_conv(x) else: raise ValueError( "No other bases other than fourier and linear currently im_Bed in public release" ) def get_pos_enc(self, y): b, c, h, w = y.shape coarse_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), ) ) coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ None ].expand(b, h, w, 2) coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") coarse_embedded_coords = self.project_to_basis(coarse_coords) return coarse_embedded_coords def forward(self, x, y, **kwargs): b, c, h1, w1 = x.shape b, c, h2, w2 = y.shape f = self.get_pos_enc(y) b, d, h2, w2 = f.shape x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f) K_xx = self.K(x, x) K_yy = self.K(y, y) K_xy = self.K(x, y) K_yx = K_xy.permute(0, 2, 1) sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] with warnings.catch_warnings(): K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) mu_x = K_xy.matmul(K_yy_inv.matmul(f)) mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) if not self.no_cov: cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1) local_cov_x = self.get_local_cov(cov_x) local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") gp_feats = torch.cat((mu_x, local_cov_x), dim=1) else: gp_feats = mu_x return gp_feats class Decoder(nn.Module): def __init__( self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings=None, num_refinement_steps_per_scale=1, warp_noise_std=0.0, displacement_dropout_p=0.0, gm_warp_dropout_p=0.0, flow_upsample_mode="bilinear", amp_dtype=torch.float16, ): super().__init__() self.embedding_decoder = embedding_decoder self.num_refinement_steps_per_scale = num_refinement_steps_per_scale self.gps = gps self.proj = proj self.conv_refiner = conv_refiner self.detach = detach if pos_embeddings is None: self.pos_embeddings = {} else: self.pos_embeddings = pos_embeddings if scales == "all": self.scales = ["32", "16", "8", "4", "2", "1"] else: self.scales = scales self.warp_noise_std = warp_noise_std self.refine_init = 4 self.displacement_dropout_p = displacement_dropout_p self.gm_warp_dropout_p = gm_warp_dropout_p self.flow_upsample_mode = flow_upsample_mode self.amp_dtype = amp_dtype def get_placeholder_flow(self, b, h, w, device): coarse_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), ) ) coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ None ].expand(b, h, w, 2) coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") return coarse_coords def get_positional_embedding(self, b, h, w, device): coarse_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), ) ) coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ None ].expand(b, h, w, 2) coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") coarse_embedded_coords = self.pos_embedding(coarse_coords) return coarse_embedded_coords def forward(self, f1, f2, gt_warp=None, gt_prob=None, upsample=False, flow=None, certainty=None, scale_factor=1): coarse_scales = self.embedding_decoder.scales() all_scales = self.scales if not upsample else ["8", "4", "2", "1"] sizes = {scale: f1[scale].shape[-2:] for scale in f1} h, w = sizes[1] b = f1[1].shape[0] device = f1[1].device coarsest_scale = int(all_scales[0]) old_stuff = torch.zeros( b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device ) corresps = {} if not upsample: flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) certainty = 0.0 else: flow = F.interpolate( flow, size=sizes[coarsest_scale], align_corners=False, mode="bilinear", ) certainty = F.interpolate( certainty, size=sizes[coarsest_scale], align_corners=False, mode="bilinear", ) displacement = 0.0 for new_scale in all_scales: ins = int(new_scale) corresps[ins] = {} f1_s, f2_s = f1[ins], f2[ins] if new_scale in self.proj: with torch.autocast("cuda", dtype=self.amp_dtype): f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) if ins in coarse_scales: old_stuff = F.interpolate( old_stuff, size=sizes[ins], mode="bilinear", align_corners=False ) gp_posterior = self.gps[new_scale](f1_s, f2_s) gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder( gp_posterior, f1_s, old_stuff, new_scale ) if self.embedding_decoder.is_classifier: flow = cls_to_flow_refine( gm_warp_or_cls, ).permute(0, 3, 1, 2) corresps[ins].update( {"gm_cls": gm_warp_or_cls, "gm_certainty": certainty, }) if self.training else None else: corresps[ins].update( {"gm_flow": gm_warp_or_cls, "gm_certainty": certainty, }) if self.training else None flow = gm_warp_or_cls.detach() if new_scale in self.conv_refiner: corresps[ins].update({"flow_pre_delta": flow}) if self.training else None delta_flow, delta_certainty = self.conv_refiner[new_scale]( f1_s, f2_s, flow, scale_factor=scale_factor, logits=certainty, ) corresps[ins].update({"delta_flow": delta_flow, }) if self.training else None displacement = ins * torch.stack((delta_flow[:, 0].float() / (self.refine_init * w), delta_flow[:, 1].float() / (self.refine_init * h),), dim=1, ) flow = flow + displacement certainty = ( certainty + delta_certainty ) # predict both certainty and displacement corresps[ins].update({ "certainty": certainty, "flow": flow, }) if new_scale != "1": flow = F.interpolate( flow, size=sizes[ins // 2], mode=self.flow_upsample_mode, ) certainty = F.interpolate( certainty, size=sizes[ins // 2], mode=self.flow_upsample_mode, ) if self.detach: flow = flow.detach() certainty = certainty.detach() # torch.cuda.empty_cache() return corresps class RegressionMatcher(nn.Module): def __init__( self, encoder, decoder, h=448, w=448, sample_mode="threshold_balanced", upsample_preds=False, symmetric=False, name=None, attenuate_cert=None, recrop_upsample=False, ): super().__init__() self.attenuate_cert = attenuate_cert self.encoder = encoder self.decoder = decoder self.name = name self.w_resized = w self.h_resized = h self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) self.sample_mode = sample_mode self.upsample_preds = upsample_preds self.upsample_res = (14 * 16 * 6, 14 * 16 * 6) self.symmetric = symmetric self.sample_thresh = 0.05 self.recrop_upsample = recrop_upsample def get_output_resolution(self): if not self.upsample_preds: return self.h_resized, self.w_resized else: return self.upsample_res def extract_backbone_features(self, batch, batched=True, upsample=False): x_q = batch["im_A"] x_s = batch["im_B"] if batched: X = torch.cat((x_q, x_s), dim=0) feature_pyramid = self.encoder(X, upsample=upsample) else: feature_pyramid = self.encoder(x_q, upsample=upsample), self.encoder(x_s, upsample=upsample) return feature_pyramid def sample( self, matches, certainty, num=10000, ): if "threshold" in self.sample_mode: upper_thresh = self.sample_thresh certainty = certainty.clone() certainty[certainty > upper_thresh] = 1 matches, certainty = ( matches.reshape(-1, 4), certainty.reshape(-1), ) expansion_factor = 4 if "balanced" in self.sample_mode else 1 good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * num, len(certainty)), replacement=False) good_matches, good_certainty = matches[good_samples], certainty[good_samples] if "balanced" not in self.sample_mode: return good_matches, good_certainty density = kde(good_matches, std=0.1, half=False) p = 1 / (density + 1) p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones balanced_samples = torch.multinomial(p, num_samples=min(num, len(good_certainty)), replacement=False) return good_matches[balanced_samples], good_certainty[balanced_samples] def forward(self, batch, batched=True, upsample=False, scale_factor=1): feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample=upsample) if batched: f_q_pyramid = { scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() } f_s_pyramid = { scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() } else: f_q_pyramid, f_s_pyramid = feature_pyramid corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample=upsample, **(batch["corresps"] if "corresps" in batch else {}), scale_factor=scale_factor) return corresps def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1): feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample=upsample) f_q_pyramid = feature_pyramid f_s_pyramid = { scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0) for scale, f_scale in feature_pyramid.items() } corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample=upsample, **(batch["corresps"] if "corresps" in batch else {}), scale_factor=scale_factor) return corresps def conf_from_fb_consistency(self, flow_forward, flow_backward, th=2): # assumes that flow forward is of shape (..., H, W, 2) has_batch = False if len(flow_forward.shape) == 3: flow_forward, flow_backward = flow_forward[None], flow_backward[None] else: has_batch = True H, W = flow_forward.shape[-3:-1] th_n = 2 * th / max(H, W) coords = torch.stack(torch.meshgrid( torch.linspace(-1 + 1 / W, 1 - 1 / W, W), torch.linspace(-1 + 1 / H, 1 - 1 / H, H), indexing="xy"), dim=-1).to(flow_forward.device) coords_fb = F.grid_sample( flow_backward.permute(0, 3, 1, 2), flow_forward, align_corners=False, mode="bilinear").permute(0, 2, 3, 1) diff = (coords - coords_fb).norm(dim=-1) in_th = (diff < th_n).float() if not has_batch: in_th = in_th[0] return in_th def to_pixel_coordinates(self, coords, H_A, W_A, H_B=None, W_B=None): if coords.shape[-1] == 2: return self._to_pixel_coordinates(coords, H_A, W_A) if isinstance(coords, (list, tuple)): kpts_A, kpts_B = coords[0], coords[1] else: kpts_A, kpts_B = coords[..., :2], coords[..., 2:] return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B) def _to_pixel_coordinates(self, coords, H, W): kpts = torch.stack((W / 2 * (coords[..., 0] + 1), H / 2 * (coords[..., 1] + 1)), axis=-1) return kpts def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B): if isinstance(coords, (list, tuple)): kpts_A, kpts_B = coords[0], coords[1] else: kpts_A, kpts_B = coords[..., :2], coords[..., 2:] kpts_A = torch.stack((2 / W_A * kpts_A[..., 0] - 1, 2 / H_A * kpts_A[..., 1] - 1), axis=-1) kpts_B = torch.stack((2 / W_B * kpts_B[..., 0] - 1, 2 / H_B * kpts_B[..., 1] - 1), axis=-1) return kpts_A, kpts_B def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple=True, return_inds=False): x_A_to_B = F.grid_sample(warp[..., -2:].permute(2, 0, 1)[None], x_A[None, None], align_corners=False, mode="bilinear")[0, :, 0].mT cert_A_to_B = F.grid_sample(certainty[None, None, ...], x_A[None, None], align_corners=False, mode="bilinear")[ 0, 0, 0] D = torch.cdist(x_A_to_B, x_B) inds_A, inds_B = torch.nonzero( (D == D.min(dim=-1, keepdim=True).values) * (D == D.min(dim=-2, keepdim=True).values) * ( cert_A_to_B[:, None] > self.sample_thresh), as_tuple=True) if return_tuple: if return_inds: return inds_A, inds_B else: return x_A[inds_A], x_B[inds_B] else: if return_inds: return torch.cat((inds_A, inds_B), dim=-1) else: return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1) def get_roi(self, certainty, W, H, thr=0.025): raise NotImplementedError("WIP, disable for now") hs, ws = certainty.shape certainty = certainty / certainty.sum(dim=(-1, -2)) cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2) cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1) print(cum_certainty_w) print(torch.min(torch.nonzero(cum_certainty_w > thr))) print(torch.min(torch.nonzero(cum_certainty_w < thr))) left = int(W / ws * torch.min(torch.nonzero(cum_certainty_w > thr))) right = int(W / ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr))) top = int(H / hs * torch.min(torch.nonzero(cum_certainty_h > thr))) bottom = int(H / hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr))) print(left, right, top, bottom) return left, top, right, bottom def recrop(self, certainty, image_path): roi = self.get_roi(certainty, *Image.open(image_path).size) return Image.open(image_path).convert("RGB").crop(roi) @torch.inference_mode() def match( self, im_A_path: Union[str, os.PathLike, Image.Image], im_B_path: Union[str, os.PathLike, Image.Image], *args, batched=False, device=None, ): if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if isinstance(im_A_path, (str, os.PathLike)): im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") else: im_A, im_B = im_A_path, im_B_path symmetric = self.symmetric self.train(False) with torch.no_grad(): if not batched: b = 1 w, h = im_A.size w2, h2 = im_B.size # Get images in good format ws = self.w_resized hs = self.h_resized test_transform = get_tuple_transform_ops( resize=(hs, ws), normalize=True, clahe=False ) im_A, im_B = test_transform((im_A, im_B)) batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)} else: b, c, h, w = im_A.shape b, c, h2, w2 = im_B.shape assert w == w2 and h == h2, "For batched images we assume same size" batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)} if h != self.h_resized or self.w_resized != w: warn("Model resolution and batch resolution differ, may produce unexpected results") hs, ws = h, w finest_scale = 1 # Run matcher if symmetric: corresps = self.forward_symmetric(batch) else: corresps = self.forward(batch, batched=True) if self.upsample_preds: hs, ws = self.upsample_res if self.attenuate_cert: low_res_certainty = F.interpolate( corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear" ) cert_clamp = 0 factor = 0.5 low_res_certainty = factor * low_res_certainty * (low_res_certainty < cert_clamp) if self.upsample_preds: finest_corresps = corresps[finest_scale] torch.cuda.empty_cache() test_transform = get_tuple_transform_ops( resize=(hs, ws), normalize=True ) if self.recrop_upsample: raise NotImplementedError("recrop_upsample not implemented") certainty = corresps[finest_scale]["certainty"] print(certainty.shape) im_A = self.recrop(certainty[0, 0], im_A_path) im_B = self.recrop(certainty[1, 0], im_B_path) # TODO: need to adjust corresps when doing this im_A, im_B = test_transform((im_A, im_B)) im_A, im_B = im_A[None].to(device), im_B[None].to(device) scale_factor = math.sqrt( self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized)) batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps} if symmetric: corresps = self.forward_symmetric(batch, upsample=True, batched=True, scale_factor=scale_factor) else: corresps = self.forward(batch, batched=True, upsample=True, scale_factor=scale_factor) im_A_to_im_B = corresps[finest_scale]["flow"] certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0) if finest_scale != 1: im_A_to_im_B = F.interpolate( im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" ) certainty = F.interpolate( certainty, size=(hs, ws), align_corners=False, mode="bilinear" ) im_A_to_im_B = im_A_to_im_B.permute( 0, 2, 3, 1 ) # Create im_A meshgrid im_A_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), ) ) im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) certainty = certainty.sigmoid() # logits -> probs im_A_coords = im_A_coords.permute(0, 2, 3, 1) if (im_A_to_im_B.abs() > 1).any() and True: wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0 certainty[wrong[:, None]] = 0 im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1) if symmetric: A_to_B, B_to_A = im_A_to_im_B.chunk(2) q_warp = torch.cat((im_A_coords, A_to_B), dim=-1) im_B_coords = im_A_coords s_warp = torch.cat((B_to_A, im_B_coords), dim=-1) warp = torch.cat((q_warp, s_warp), dim=2) certainty = torch.cat(certainty.chunk(2), dim=3) else: warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) if batched: return ( warp, certainty[:, 0] ) else: return ( warp[0], certainty[0, 0], ) def visualize_warp(self, warp, certainty, im_A=None, im_B=None, im_A_path=None, im_B_path=None, device="cuda", symmetric=True, save_path=None, unnormalize=False): # assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)" H, W2, _ = warp.shape W = W2 // 2 if symmetric else W2 if im_A is None: from PIL import Image im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") if not isinstance(im_A, torch.Tensor): im_A = im_A.resize((W, H)) im_B = im_B.resize((W, H)) x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1) if symmetric: x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1) else: if symmetric: x_A = im_A x_B = im_B im_A_transfer_rgb = F.grid_sample( x_B[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False )[0] if symmetric: im_B_transfer_rgb = F.grid_sample( x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False )[0] warp_im = torch.cat((im_A_transfer_rgb, im_B_transfer_rgb), dim=2) white_im = torch.ones((H, 2 * W), device=device) else: warp_im = im_A_transfer_rgb white_im = torch.ones((H, W), device=device) vis_im = certainty * warp_im + (1 - certainty) * white_im if save_path is not None: from romatch.utils import tensor_to_pil tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path) return vis_im