Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# pyre-unsafe | |
import json | |
import logging | |
from typing import List, Optional | |
import torch | |
from torch import nn | |
from detectron2.utils.file_io import PathManager | |
from densepose.structures.mesh import create_mesh | |
class MeshAlignmentEvaluator: | |
""" | |
Class for evaluation of 3D mesh alignment based on the learned vertex embeddings | |
""" | |
def __init__(self, embedder: nn.Module, mesh_names: Optional[List[str]]): | |
self.embedder = embedder | |
# use the provided mesh names if not None and not an empty list | |
self.mesh_names = mesh_names if mesh_names else embedder.mesh_names | |
self.logger = logging.getLogger(__name__) | |
with PathManager.open( | |
"https://dl.fbaipublicfiles.com/densepose/data/cse/mesh_keyvertices_v0.json", "r" | |
) as f: | |
self.mesh_keyvertices = json.load(f) | |
def evaluate(self): | |
ge_per_mesh = {} | |
gps_per_mesh = {} | |
for mesh_name_1 in self.mesh_names: | |
avg_errors = [] | |
avg_gps = [] | |
embeddings_1 = self.embedder(mesh_name_1) | |
keyvertices_1 = self.mesh_keyvertices[mesh_name_1] | |
keyvertex_names_1 = list(keyvertices_1.keys()) | |
keyvertex_indices_1 = [keyvertices_1[name] for name in keyvertex_names_1] | |
for mesh_name_2 in self.mesh_names: | |
if mesh_name_1 == mesh_name_2: | |
continue | |
embeddings_2 = self.embedder(mesh_name_2) | |
keyvertices_2 = self.mesh_keyvertices[mesh_name_2] | |
sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(embeddings_2.T) | |
vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(axis=1) | |
mesh_2 = create_mesh(mesh_name_2, embeddings_2.device) | |
geodists = mesh_2.geodists[ | |
vertices_2_matching_keyvertices_1, | |
[keyvertices_2[name] for name in keyvertex_names_1], | |
] | |
Current_Mean_Distances = 0.255 | |
gps = (-(geodists**2) / (2 * (Current_Mean_Distances**2))).exp() | |
avg_errors.append(geodists.mean().item()) | |
avg_gps.append(gps.mean().item()) | |
ge_mean = torch.as_tensor(avg_errors).mean().item() | |
gps_mean = torch.as_tensor(avg_gps).mean().item() | |
ge_per_mesh[mesh_name_1] = ge_mean | |
gps_per_mesh[mesh_name_1] = gps_mean | |
ge_mean_global = torch.as_tensor(list(ge_per_mesh.values())).mean().item() | |
gps_mean_global = torch.as_tensor(list(gps_per_mesh.values())).mean().item() | |
per_mesh_metrics = { | |
"GE": ge_per_mesh, | |
"GPS": gps_per_mesh, | |
} | |
return ge_mean_global, gps_mean_global, per_mesh_metrics | |