Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Standard library imports | |
import base64 | |
import io | |
import logging | |
import math | |
import pickle | |
import warnings | |
from collections import defaultdict | |
from dataclasses import field, dataclass | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
# Third-party library imports | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from pytorch3d.renderer.cameras import CamerasBase | |
from pytorch3d.transforms import ( | |
se3_exp_map, | |
se3_log_map, | |
Transform3d, | |
so3_relative_angle, | |
) | |
from util.camera_transform import pose_encoding_to_camera | |
import models | |
from hydra.utils import instantiate | |
from pytorch3d.renderer.cameras import PerspectiveCameras | |
logger = logging.getLogger(__name__) | |
class PoseDiffusionModel(nn.Module): | |
def __init__( | |
self, | |
pose_encoding_type: str, | |
IMAGE_FEATURE_EXTRACTOR: Dict, | |
DIFFUSER: Dict, | |
DENOISER: Dict, | |
): | |
"""Initializes a PoseDiffusion model. | |
Args: | |
pose_encoding_type (str): | |
Defines the encoding type for extrinsics and intrinsics | |
Currently, only `"absT_quaR_logFL"` is supported - | |
a concatenation of the translation vector, | |
rotation quaternion, and logarithm of focal length. | |
image_feature_extractor_cfg (Dict): | |
Configuration for the image feature extractor. | |
diffuser_cfg (Dict): | |
Configuration for the diffuser. | |
denoiser_cfg (Dict): | |
Configuration for the denoiser. | |
""" | |
super().__init__() | |
self.pose_encoding_type = pose_encoding_type | |
self.image_feature_extractor = instantiate( | |
IMAGE_FEATURE_EXTRACTOR, _recursive_=False | |
) | |
self.diffuser = instantiate(DIFFUSER, _recursive_=False) | |
denoiser = instantiate(DENOISER, _recursive_=False) | |
self.diffuser.model = denoiser | |
self.target_dim = denoiser.target_dim | |
def forward( | |
self, | |
image: torch.Tensor, | |
gt_cameras: Optional[CamerasBase] = None, | |
sequence_name: Optional[List[str]] = None, | |
cond_fn=None, | |
cond_start_step=0, | |
): | |
""" | |
Forward pass of the PoseDiffusionModel. | |
Args: | |
image (torch.Tensor): | |
Input image tensor, Bx3xHxW. | |
gt_cameras (Optional[CamerasBase], optional): | |
Camera object. Defaults to None. | |
sequence_name (Optional[List[str]], optional): | |
List of sequence names. Defaults to None. | |
cond_fn ([type], optional): | |
Conditional function. Wrapper for GGS or other functions. | |
cond_start_step (int, optional): | |
The sampling step to start using conditional function. | |
Returns: | |
PerspectiveCameras: PyTorch3D camera object. | |
""" | |
z = self.image_feature_extractor(image) | |
z = z.unsqueeze(0) | |
B, N, _ = z.shape | |
target_shape = [B, N, self.target_dim] | |
# sampling | |
pose_encoding, pose_encoding_diffusion_samples = self.diffuser.sample( | |
shape=target_shape, | |
z=z, | |
cond_fn=cond_fn, | |
cond_start_step=cond_start_step, | |
) | |
# convert the encoded representation to PyTorch3D cameras | |
pred_cameras = pose_encoding_to_camera( | |
pose_encoding, pose_encoding_type=self.pose_encoding_type | |
) | |
return pred_cameras | |