Spaces:
Runtime error
Runtime error
File size: 3,700 Bytes
3d3e4e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Standard library imports
import base64
import io
import logging
import math
import pickle
import warnings
from collections import defaultdict
from dataclasses import field, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
# Third-party library imports
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.transforms import (
se3_exp_map,
se3_log_map,
Transform3d,
so3_relative_angle,
)
from util.camera_transform import pose_encoding_to_camera
import models
from hydra.utils import instantiate
from pytorch3d.renderer.cameras import PerspectiveCameras
logger = logging.getLogger(__name__)
class PoseDiffusionModel(nn.Module):
def __init__(
self,
pose_encoding_type: str,
IMAGE_FEATURE_EXTRACTOR: Dict,
DIFFUSER: Dict,
DENOISER: Dict,
):
"""Initializes a PoseDiffusion model.
Args:
pose_encoding_type (str):
Defines the encoding type for extrinsics and intrinsics
Currently, only `"absT_quaR_logFL"` is supported -
a concatenation of the translation vector,
rotation quaternion, and logarithm of focal length.
image_feature_extractor_cfg (Dict):
Configuration for the image feature extractor.
diffuser_cfg (Dict):
Configuration for the diffuser.
denoiser_cfg (Dict):
Configuration for the denoiser.
"""
super().__init__()
self.pose_encoding_type = pose_encoding_type
self.image_feature_extractor = instantiate(
IMAGE_FEATURE_EXTRACTOR, _recursive_=False
)
self.diffuser = instantiate(DIFFUSER, _recursive_=False)
denoiser = instantiate(DENOISER, _recursive_=False)
self.diffuser.model = denoiser
self.target_dim = denoiser.target_dim
def forward(
self,
image: torch.Tensor,
gt_cameras: Optional[CamerasBase] = None,
sequence_name: Optional[List[str]] = None,
cond_fn=None,
cond_start_step=0,
):
"""
Forward pass of the PoseDiffusionModel.
Args:
image (torch.Tensor):
Input image tensor, Bx3xHxW.
gt_cameras (Optional[CamerasBase], optional):
Camera object. Defaults to None.
sequence_name (Optional[List[str]], optional):
List of sequence names. Defaults to None.
cond_fn ([type], optional):
Conditional function. Wrapper for GGS or other functions.
cond_start_step (int, optional):
The sampling step to start using conditional function.
Returns:
PerspectiveCameras: PyTorch3D camera object.
"""
z = self.image_feature_extractor(image)
z = z.unsqueeze(0)
B, N, _ = z.shape
target_shape = [B, N, self.target_dim]
# sampling
pose_encoding, pose_encoding_diffusion_samples = self.diffuser.sample(
shape=target_shape,
z=z,
cond_fn=cond_fn,
cond_start_step=cond_start_step,
)
# convert the encoded representation to PyTorch3D cameras
pred_cameras = pose_encoding_to_camera(
pose_encoding, pose_encoding_type=self.pose_encoding_type
)
return pred_cameras
|