File size: 6,384 Bytes
6eb1d7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# pyre-unsafe

import pickle
from functools import lru_cache
from typing import Dict, Optional, Tuple
import torch

from detectron2.utils.file_io import PathManager

from densepose.data.meshes.catalog import MeshCatalog, MeshInfo


def _maybe_copy_to_device(
    attribute: Optional[torch.Tensor], device: torch.device
) -> Optional[torch.Tensor]:
    if attribute is None:
        return None
    return attribute.to(device)


class Mesh:
    def __init__(
        self,
        vertices: Optional[torch.Tensor] = None,
        faces: Optional[torch.Tensor] = None,
        geodists: Optional[torch.Tensor] = None,
        symmetry: Optional[Dict[str, torch.Tensor]] = None,
        texcoords: Optional[torch.Tensor] = None,
        mesh_info: Optional[MeshInfo] = None,
        device: Optional[torch.device] = None,
    ):
        """
        Args:
            vertices (tensor [N, 3] of float32): vertex coordinates in 3D
            faces (tensor [M, 3] of long): triangular face represented as 3
                vertex indices
            geodists (tensor [N, N] of float32): geodesic distances from
                vertex `i` to vertex `j` (optional, default: None)
            symmetry (dict: str -> tensor): various mesh symmetry data:
                - "vertex_transforms": vertex mapping under horizontal flip,
                  tensor of size [N] of type long; vertex `i` is mapped to
                  vertex `tensor[i]` (optional, default: None)
            texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global
                and normalized mesh UVs (optional, default: None)
            mesh_info (MeshInfo type): necessary to load the attributes on-the-go,
                can be used instead of passing all the variables one by one
            device (torch.device): device of the Mesh. If not provided, will use
                the device of the vertices
        """
        self._vertices = vertices
        self._faces = faces
        self._geodists = geodists
        self._symmetry = symmetry
        self._texcoords = texcoords
        self.mesh_info = mesh_info
        self.device = device

        assert self._vertices is not None or self.mesh_info is not None

        all_fields = [self._vertices, self._faces, self._geodists, self._texcoords]

        if self.device is None:
            for field in all_fields:
                if field is not None:
                    self.device = field.device
                    break
            if self.device is None and symmetry is not None:
                for key in symmetry:
                    self.device = symmetry[key].device
                    break
            self.device = torch.device("cpu") if self.device is None else self.device

        assert all([var.device == self.device for var in all_fields if var is not None])
        if symmetry:
            assert all(symmetry[key].device == self.device for key in symmetry)
        if texcoords and vertices:
            assert len(vertices) == len(texcoords)

    def to(self, device: torch.device):
        device_symmetry = self._symmetry
        if device_symmetry:
            device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()}
        return Mesh(
            _maybe_copy_to_device(self._vertices, device),
            _maybe_copy_to_device(self._faces, device),
            _maybe_copy_to_device(self._geodists, device),
            device_symmetry,
            _maybe_copy_to_device(self._texcoords, device),
            self.mesh_info,
            device,
        )

    @property
    def vertices(self):
        if self._vertices is None and self.mesh_info is not None:
            self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device)
        return self._vertices

    @property
    def faces(self):
        if self._faces is None and self.mesh_info is not None:
            self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device)
        return self._faces

    @property
    def geodists(self):
        if self._geodists is None and self.mesh_info is not None:
            self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device)
        return self._geodists

    @property
    def symmetry(self):
        if self._symmetry is None and self.mesh_info is not None:
            self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device)
        return self._symmetry

    @property
    def texcoords(self):
        if self._texcoords is None and self.mesh_info is not None:
            self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device)
        return self._texcoords

    def get_geodists(self):
        if self.geodists is None:
            self.geodists = self._compute_geodists()
        return self.geodists

    def _compute_geodists(self):
        # TODO: compute using Laplace-Beltrami
        geodists = None
        return geodists


def load_mesh_data(
    mesh_fpath: str, field: str, device: Optional[torch.device] = None
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
    with PathManager.open(mesh_fpath, "rb") as hFile:
        # pyre-fixme[7]: Expected `Tuple[Optional[Tensor], Optional[Tensor]]` but
        #  got `Tensor`.
        return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to(device)
    return None


def load_mesh_auxiliary_data(
    fpath: str, device: Optional[torch.device] = None
) -> Optional[torch.Tensor]:
    fpath_local = PathManager.get_local_path(fpath)
    with PathManager.open(fpath_local, "rb") as hFile:
        return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device)
    return None


@lru_cache()
def load_mesh_symmetry(
    symmetry_fpath: str, device: Optional[torch.device] = None
) -> Optional[Dict[str, torch.Tensor]]:
    with PathManager.open(symmetry_fpath, "rb") as hFile:
        symmetry_loaded = pickle.load(hFile)
        symmetry = {
            "vertex_transforms": torch.as_tensor(
                symmetry_loaded["vertex_transforms"], dtype=torch.long
            ).to(device),
        }
        return symmetry
    return None


@lru_cache()
def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh:
    return Mesh(mesh_info=MeshCatalog[mesh_name], device=device)