PoseDiffusion_MVP / models /pose_diffusion_model.py
hugoycj
Initial commit
3d3e4e9
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Standard library imports
import base64
import io
import logging
import math
import pickle
import warnings
from collections import defaultdict
from dataclasses import field, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
# Third-party library imports
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.transforms import (
se3_exp_map,
se3_log_map,
Transform3d,
so3_relative_angle,
)
from util.camera_transform import pose_encoding_to_camera
import models
from hydra.utils import instantiate
from pytorch3d.renderer.cameras import PerspectiveCameras
logger = logging.getLogger(__name__)
class PoseDiffusionModel(nn.Module):
def __init__(
self,
pose_encoding_type: str,
IMAGE_FEATURE_EXTRACTOR: Dict,
DIFFUSER: Dict,
DENOISER: Dict,
):
"""Initializes a PoseDiffusion model.
Args:
pose_encoding_type (str):
Defines the encoding type for extrinsics and intrinsics
Currently, only `"absT_quaR_logFL"` is supported -
a concatenation of the translation vector,
rotation quaternion, and logarithm of focal length.
image_feature_extractor_cfg (Dict):
Configuration for the image feature extractor.
diffuser_cfg (Dict):
Configuration for the diffuser.
denoiser_cfg (Dict):
Configuration for the denoiser.
"""
super().__init__()
self.pose_encoding_type = pose_encoding_type
self.image_feature_extractor = instantiate(
IMAGE_FEATURE_EXTRACTOR, _recursive_=False
)
self.diffuser = instantiate(DIFFUSER, _recursive_=False)
denoiser = instantiate(DENOISER, _recursive_=False)
self.diffuser.model = denoiser
self.target_dim = denoiser.target_dim
def forward(
self,
image: torch.Tensor,
gt_cameras: Optional[CamerasBase] = None,
sequence_name: Optional[List[str]] = None,
cond_fn=None,
cond_start_step=0,
):
"""
Forward pass of the PoseDiffusionModel.
Args:
image (torch.Tensor):
Input image tensor, Bx3xHxW.
gt_cameras (Optional[CamerasBase], optional):
Camera object. Defaults to None.
sequence_name (Optional[List[str]], optional):
List of sequence names. Defaults to None.
cond_fn ([type], optional):
Conditional function. Wrapper for GGS or other functions.
cond_start_step (int, optional):
The sampling step to start using conditional function.
Returns:
PerspectiveCameras: PyTorch3D camera object.
"""
z = self.image_feature_extractor(image)
z = z.unsqueeze(0)
B, N, _ = z.shape
target_shape = [B, N, self.target_dim]
# sampling
pose_encoding, pose_encoding_diffusion_samples = self.diffuser.sample(
shape=target_shape,
z=z,
cond_fn=cond_fn,
cond_start_step=cond_start_step,
)
# convert the encoded representation to PyTorch3D cameras
pred_cameras = pose_encoding_to_camera(
pose_encoding, pose_encoding_type=self.pose_encoding_type
)
return pred_cameras