# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de from typing import Optional, Dict, Union import os import os.path as osp import pickle import numpy as np import torch import torch.nn as nn from collections import namedtuple import logging logging.getLogger("smplx").setLevel(logging.ERROR) from .lbs import lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords from .vertex_ids import vertex_ids as VERTEX_IDS from .utils import ( Struct, to_np, to_tensor, Tensor, Array, SMPLOutput, SMPLHOutput, SMPLXOutput, MANOOutput, FLAMEOutput, find_joint_kin_chain, ) from .vertex_joint_selector import VertexJointSelector ModelOutput = namedtuple( "ModelOutput", [ "vertices", "joints", "full_pose", "betas", "global_orient", "body_pose", "expression", "left_hand_pose", "right_hand_pose", "jaw_pose", ], ) ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields) class SMPL(nn.Module): NUM_JOINTS = 23 NUM_BODY_JOINTS = 23 SHAPE_SPACE_DIM = 300 def __init__( self, model_path: str, kid_template_path: str = "", data_struct: Optional[Struct] = None, create_betas: bool = True, betas: Optional[Tensor] = None, num_betas: int = 10, create_global_orient: bool = True, global_orient: Optional[Tensor] = None, create_body_pose: bool = True, body_pose: Optional[Tensor] = None, create_transl: bool = True, transl: Optional[Tensor] = None, dtype=torch.float32, batch_size: int = 1, joint_mapper=None, gender: str = "neutral", age: str = "adult", vertex_ids: Dict[str, int] = None, v_template: Optional[Union[Tensor, Array]] = None, v_personal: Optional[Union[Tensor, Array]] = None, **kwargs, ) -> None: """SMPL model constructor Parameters ---------- model_path: str The path to the folder or to the file where the model parameters are stored data_struct: Strct A struct object. If given, then the parameters of the model are read from the object. Otherwise, the model tries to read the parameters from the given `model_path`. (default = None) create_global_orient: bool, optional Flag for creating a member variable for the global orientation of the body. (default = True) global_orient: torch.tensor, optional, Bx3 The default value for the global orientation variable. (default = None) create_body_pose: bool, optional Flag for creating a member variable for the pose of the body. (default = True) body_pose: torch.tensor, optional, Bx(Body Joints * 3) The default value for the body pose variable. (default = None) num_betas: int, optional Number of shape components to use (default = 10). create_betas: bool, optional Flag for creating a member variable for the shape space (default = True). betas: torch.tensor, optional, Bx10 The default value for the shape member variable. (default = None) create_transl: bool, optional Flag for creating a member variable for the translation of the body. (default = True) transl: torch.tensor, optional, Bx3 The default value for the transl variable. (default = None) dtype: torch.dtype, optional The data type for the created variables batch_size: int, optional The batch size used for creating the member variables joint_mapper: object, optional An object that re-maps the joints. Useful if one wants to re-order the SMPL joints to some other convention (e.g. MSCOCO) (default = None) gender: str, optional Which gender to load vertex_ids: dict, optional A dictionary containing the indices of the extra vertices that will be selected """ self.gender = gender self.age = age if data_struct is None: if osp.isdir(model_path): model_fn = "SMPL_{}.{ext}".format(gender.upper(), ext="pkl") smpl_path = os.path.join(model_path, model_fn) else: smpl_path = model_path assert osp.exists(smpl_path), "Path {} does not exist!".format(smpl_path) with open(smpl_path, "rb") as smpl_file: data_struct = Struct(**pickle.load(smpl_file, encoding="latin1")) super(SMPL, self).__init__() self.batch_size = batch_size shapedirs = data_struct.shapedirs if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM: # print(f'WARNING: You are using a {self.name()} model, with only' # ' 10 shape coefficients.') num_betas = min(num_betas, 10) else: num_betas = min(num_betas, self.SHAPE_SPACE_DIM) if self.age == "kid": v_template_smil = np.load(kid_template_path) v_template_smil -= np.mean(v_template_smil, axis=0) v_template_diff = np.expand_dims(v_template_smil - data_struct.v_template, axis=2) shapedirs = np.concatenate((shapedirs[:, :, :num_betas], v_template_diff), axis=2) num_betas = num_betas + 1 self._num_betas = num_betas shapedirs = shapedirs[:, :, :num_betas] # The shape components self.register_buffer("shapedirs", to_tensor(to_np(shapedirs), dtype=dtype)) if vertex_ids is None: # SMPL and SMPL-H share the same topology, so any extra joints can # be drawn from the same place vertex_ids = VERTEX_IDS["smplh"] self.dtype = dtype self.joint_mapper = joint_mapper self.vertex_joint_selector = VertexJointSelector(vertex_ids=vertex_ids, **kwargs) self.faces = data_struct.f self.register_buffer( "faces_tensor", to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long), ) if create_betas: if betas is None: default_betas = torch.zeros([batch_size, self.num_betas], dtype=dtype) else: if torch.is_tensor(betas): default_betas = betas.clone().detach() else: default_betas = torch.tensor(betas, dtype=dtype) self.register_parameter("betas", nn.Parameter(default_betas, requires_grad=True)) # The tensor that contains the global rotation of the model # It is separated from the pose of the joints in case we wish to # optimize only over one of them if create_global_orient: if global_orient is None: default_global_orient = torch.zeros([batch_size, 3], dtype=dtype) else: if torch.is_tensor(global_orient): default_global_orient = global_orient.clone().detach() else: default_global_orient = torch.tensor(global_orient, dtype=dtype) global_orient = nn.Parameter(default_global_orient, requires_grad=True) self.register_parameter("global_orient", global_orient) if create_body_pose: if body_pose is None: default_body_pose = torch.zeros([batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) else: if torch.is_tensor(body_pose): default_body_pose = body_pose.clone().detach() else: default_body_pose = torch.tensor(body_pose, dtype=dtype) self.register_parameter("body_pose", nn.Parameter(default_body_pose, requires_grad=True)) if create_transl: if transl is None: default_transl = torch.zeros([batch_size, 3], dtype=dtype, requires_grad=True) else: default_transl = torch.tensor(transl, dtype=dtype) self.register_parameter("transl", nn.Parameter(default_transl, requires_grad=True)) if v_template is None: v_template = data_struct.v_template if not torch.is_tensor(v_template): v_template = to_tensor(to_np(v_template), dtype=dtype) if v_personal is not None: v_personal = to_tensor(to_np(v_personal), dtype=dtype) v_template += v_personal # The vertices of the template model self.register_buffer("v_template", v_template) j_regressor = to_tensor(to_np(data_struct.J_regressor), dtype=dtype) self.register_buffer("J_regressor", j_regressor) # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207 num_pose_basis = data_struct.posedirs.shape[-1] # 207 x 20670 posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=dtype)) # indices of parents for each joints parents = to_tensor(to_np(data_struct.kintree_table[0])).long() parents[0] = -1 self.register_buffer("parents", parents) self.register_buffer("lbs_weights", to_tensor(to_np(data_struct.weights), dtype=dtype)) @property def num_betas(self): return self._num_betas @property def num_expression_coeffs(self): return 0 def create_mean_pose(self, data_struct) -> Tensor: pass def name(self) -> str: return "SMPL" @torch.no_grad() def reset_params(self, **params_dict) -> None: for param_name, param in self.named_parameters(): if param_name in params_dict: param[:] = torch.tensor(params_dict[param_name]) else: param.fill_(0) def get_num_verts(self) -> int: return self.v_template.shape[0] def get_num_faces(self) -> int: return self.faces.shape[0] def extra_repr(self) -> str: msg = [ f"Gender: {self.gender.upper()}", f"Number of joints: {self.J_regressor.shape[0]}", f"Betas: {self.num_betas}", ] return "\n".join(msg) def forward( self, betas: Optional[Tensor] = None, body_pose: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, transl: Optional[Tensor] = None, return_verts=True, return_full_pose: bool = False, pose2rot: bool = True, **kwargs, ) -> SMPLOutput: """Forward pass for the SMPL model Parameters ---------- global_orient: torch.tensor, optional, shape Bx3 If given, ignore the member variable and use it as the global rotation of the body. Useful if someone wishes to predicts this with an external model. (default=None) betas: torch.tensor, optional, shape BxN_b If given, ignore the member variable `betas` and use it instead. For example, it can used if shape parameters `betas` are predicted from some external model. (default=None) body_pose: torch.tensor, optional, shape Bx(J*3) If given, ignore the member variable `body_pose` and use it instead. For example, it can used if someone predicts the pose of the body joints are predicted from some external model. It should be a tensor that contains joint rotations in axis-angle format. (default=None) transl: torch.tensor, optional, shape Bx3 If given, ignore the member variable `transl` and use it instead. For example, it can used if the translation `transl` is predicted from some external model. (default=None) return_verts: bool, optional Return the vertices. (default=True) return_full_pose: bool, optional Returns the full axis-angle pose vector (default=False) Returns ------- """ # If no shape and pose parameters are passed along, then use the # ones from the module global_orient = (global_orient if global_orient is not None else self.global_orient) body_pose = body_pose if body_pose is not None else self.body_pose betas = betas if betas is not None else self.betas apply_trans = transl is not None or hasattr(self, "transl") if transl is None and hasattr(self, "transl"): transl = self.transl full_pose = torch.cat([global_orient, body_pose], dim=1) batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0]) if betas.shape[0] != batch_size: num_repeats = int(batch_size / betas.shape[0]) betas = betas.expand(num_repeats, -1) vertices, joints = lbs( betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot, ) joints = self.vertex_joint_selector(vertices, joints) # Map the joints to the current dataset if self.joint_mapper is not None: joints = self.joint_mapper(joints) if apply_trans: joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) output = SMPLOutput( vertices=vertices if return_verts else None, global_orient=global_orient, body_pose=body_pose, joints=joints, betas=betas, full_pose=full_pose if return_full_pose else None, ) return output class SMPLLayer(SMPL): def __init__(self, *args, **kwargs) -> None: # Just create a SMPL module without any member variables super(SMPLLayer, self).__init__( create_body_pose=False, create_betas=False, create_global_orient=False, create_transl=False, *args, **kwargs, ) def forward( self, betas: Optional[Tensor] = None, body_pose: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, transl: Optional[Tensor] = None, return_verts=True, return_full_pose: bool = False, pose2rot: bool = True, **kwargs, ) -> SMPLOutput: """Forward pass for the SMPL model Parameters ---------- global_orient: torch.tensor, optional, shape Bx3x3 Global rotation of the body. Useful if someone wishes to predicts this with an external model. It is expected to be in rotation matrix format. (default=None) betas: torch.tensor, optional, shape BxN_b Shape parameters. For example, it can used if shape parameters `betas` are predicted from some external model. (default=None) body_pose: torch.tensor, optional, shape BxJx3x3 Body pose. For example, it can used if someone predicts the pose of the body joints are predicted from some external model. It should be a tensor that contains joint rotations in rotation matrix format. (default=None) transl: torch.tensor, optional, shape Bx3 Translation vector of the body. For example, it can used if the translation `transl` is predicted from some external model. (default=None) return_verts: bool, optional Return the vertices. (default=True) return_full_pose: bool, optional Returns the full axis-angle pose vector (default=False) Returns ------- """ model_vars = [betas, global_orient, body_pose, transl] batch_size = 1 for var in model_vars: if var is None: continue batch_size = max(batch_size, len(var)) device, dtype = self.shapedirs.device, self.shapedirs.dtype if global_orient is None: global_orient = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) if body_pose is None: body_pose = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous()) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if transl is None: transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) full_pose = torch.cat( [ global_orient.reshape(-1, 1, 3, 3), body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), ], dim=1, ) vertices, joints = lbs( betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=False, ) joints = self.vertex_joint_selector(vertices, joints) # Map the joints to the current dataset if self.joint_mapper is not None: joints = self.joint_mapper(joints) if transl is not None: joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) output = SMPLOutput( vertices=vertices if return_verts else None, global_orient=global_orient, body_pose=body_pose, joints=joints, betas=betas, full_pose=full_pose if return_full_pose else None, ) return output class SMPLH(SMPL): # The hand joints are replaced by MANO NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2 NUM_HAND_JOINTS = 15 NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS def __init__( self, model_path, kid_template_path: str = "", data_struct: Optional[Struct] = None, create_left_hand_pose: bool = True, left_hand_pose: Optional[Tensor] = None, create_right_hand_pose: bool = True, right_hand_pose: Optional[Tensor] = None, use_pca: bool = True, num_pca_comps: int = 6, flat_hand_mean: bool = False, batch_size: int = 1, gender: str = "neutral", age: str = "adult", dtype=torch.float32, vertex_ids=None, use_compressed: bool = True, ext: str = "pkl", **kwargs, ) -> None: """SMPLH model constructor Parameters ---------- model_path: str The path to the folder or to the file where the model parameters are stored data_struct: Strct A struct object. If given, then the parameters of the model are read from the object. Otherwise, the model tries to read the parameters from the given `model_path`. (default = None) create_left_hand_pose: bool, optional Flag for creating a member variable for the pose of the left hand. (default = True) left_hand_pose: torch.tensor, optional, BxP The default value for the left hand pose member variable. (default = None) create_right_hand_pose: bool, optional Flag for creating a member variable for the pose of the right hand. (default = True) right_hand_pose: torch.tensor, optional, BxP The default value for the right hand pose member variable. (default = None) num_pca_comps: int, optional The number of PCA components to use for each hand. (default = 6) flat_hand_mean: bool, optional If False, then the pose of the hand is initialized to False. batch_size: int, optional The batch size used for creating the member variables gender: str, optional Which gender to load dtype: torch.dtype, optional The data type for the created variables vertex_ids: dict, optional A dictionary containing the indices of the extra vertices that will be selected """ self.num_pca_comps = num_pca_comps # If no data structure is passed, then load the data from the given # model folder if data_struct is None: # Load the model if osp.isdir(model_path): model_fn = "SMPLH_{}.{ext}".format(gender.upper(), ext=ext) smplh_path = os.path.join(model_path, model_fn) else: smplh_path = model_path assert osp.exists(smplh_path), "Path {} does not exist!".format(smplh_path) if ext == "pkl": with open(smplh_path, "rb") as smplh_file: model_data = pickle.load(smplh_file, encoding="latin1") elif ext == "npz": model_data = np.load(smplh_path, allow_pickle=True) else: raise ValueError("Unknown extension: {}".format(ext)) data_struct = Struct(**model_data) if vertex_ids is None: vertex_ids = VERTEX_IDS["smplh"] super(SMPLH, self).__init__( model_path=model_path, kid_template_path=kid_template_path, data_struct=data_struct, batch_size=batch_size, vertex_ids=vertex_ids, gender=gender, age=age, use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs, ) self.use_pca = use_pca self.num_pca_comps = num_pca_comps self.flat_hand_mean = flat_hand_mean left_hand_components = data_struct.hands_componentsl[:num_pca_comps] right_hand_components = data_struct.hands_componentsr[:num_pca_comps] self.np_left_hand_components = left_hand_components self.np_right_hand_components = right_hand_components if self.use_pca: self.register_buffer("left_hand_components", torch.tensor(left_hand_components, dtype=dtype)) self.register_buffer( "right_hand_components", torch.tensor(right_hand_components, dtype=dtype), ) if self.flat_hand_mean: left_hand_mean = np.zeros_like(data_struct.hands_meanl) else: left_hand_mean = data_struct.hands_meanl if self.flat_hand_mean: right_hand_mean = np.zeros_like(data_struct.hands_meanr) else: right_hand_mean = data_struct.hands_meanr self.register_buffer("left_hand_mean", to_tensor(left_hand_mean, dtype=self.dtype)) self.register_buffer("right_hand_mean", to_tensor(right_hand_mean, dtype=self.dtype)) # Create the buffers for the pose of the left hand hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS if create_left_hand_pose: if left_hand_pose is None: default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) else: default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype) left_hand_pose_param = nn.Parameter(default_lhand_pose, requires_grad=True) self.register_parameter("left_hand_pose", left_hand_pose_param) if create_right_hand_pose: if right_hand_pose is None: default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) else: default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype) right_hand_pose_param = nn.Parameter(default_rhand_pose, requires_grad=True) self.register_parameter("right_hand_pose", right_hand_pose_param) # Create the buffer for the mean pose. pose_mean_tensor = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean) if not torch.is_tensor(pose_mean_tensor): pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype) self.register_buffer("pose_mean", pose_mean_tensor) def create_mean_pose(self, data_struct, flat_hand_mean=False): # Create the array for the mean pose. If flat_hand is false, then use # the mean that is given by the data, rather than the flat open hand global_orient_mean = torch.zeros([3], dtype=self.dtype) body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype) pose_mean = torch.cat( [ global_orient_mean, body_pose_mean, self.left_hand_mean, self.right_hand_mean, ], dim=0, ) return pose_mean def name(self) -> str: return "SMPL+H" def extra_repr(self): msg = super(SMPLH, self).extra_repr() msg = [msg] if self.use_pca: msg.append(f"Number of PCA components: {self.num_pca_comps}") msg.append(f"Flat hand mean: {self.flat_hand_mean}") return "\n".join(msg) def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, body_pose: Optional[Tensor] = None, left_hand_pose: Optional[Tensor] = None, right_hand_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, pose2rot: bool = True, **kwargs, ) -> SMPLHOutput: """""" # If no shape and pose parameters are passed along, then use the # ones from the module global_orient = (global_orient if global_orient is not None else self.global_orient) body_pose = body_pose if body_pose is not None else self.body_pose betas = betas if betas is not None else self.betas left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose) right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose) apply_trans = transl is not None or hasattr(self, "transl") if transl is None: if hasattr(self, "transl"): transl = self.transl if self.use_pca: left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components]) right_hand_pose = torch.einsum("bi,ij->bj", [right_hand_pose, self.right_hand_components]) full_pose = torch.cat([global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1) full_pose += self.pose_mean vertices, joints = lbs( betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot, ) # Add any extra joints that might be needed joints = self.vertex_joint_selector(vertices, joints) if self.joint_mapper is not None: joints = self.joint_mapper(joints) if apply_trans: joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) output = SMPLHOutput( vertices=vertices if return_verts else None, joints=joints, betas=betas, global_orient=global_orient, body_pose=body_pose, left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, full_pose=full_pose if return_full_pose else None, ) return output class SMPLHLayer(SMPLH): def __init__(self, *args, **kwargs) -> None: """SMPL+H as a layer model constructor""" super(SMPLHLayer, self).__init__( create_global_orient=False, create_body_pose=False, create_left_hand_pose=False, create_right_hand_pose=False, create_betas=False, create_transl=False, *args, **kwargs, ) def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, body_pose: Optional[Tensor] = None, left_hand_pose: Optional[Tensor] = None, right_hand_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, pose2rot: bool = True, **kwargs, ) -> SMPLHOutput: """Forward pass for the SMPL+H model Parameters ---------- global_orient: torch.tensor, optional, shape Bx3x3 Global rotation of the body. Useful if someone wishes to predicts this with an external model. It is expected to be in rotation matrix format. (default=None) betas: torch.tensor, optional, shape BxN_b Shape parameters. For example, it can used if shape parameters `betas` are predicted from some external model. (default=None) body_pose: torch.tensor, optional, shape BxJx3x3 If given, ignore the member variable `body_pose` and use it instead. For example, it can used if someone predicts the pose of the body joints are predicted from some external model. It should be a tensor that contains joint rotations in rotation matrix format. (default=None) left_hand_pose: torch.tensor, optional, shape Bx15x3x3 If given, contains the pose of the left hand. It should be a tensor that contains joint rotations in rotation matrix format. (default=None) right_hand_pose: torch.tensor, optional, shape Bx15x3x3 If given, contains the pose of the right hand. It should be a tensor that contains joint rotations in rotation matrix format. (default=None) transl: torch.tensor, optional, shape Bx3 Translation vector of the body. For example, it can used if the translation `transl` is predicted from some external model. (default=None) return_verts: bool, optional Return the vertices. (default=True) return_full_pose: bool, optional Returns the full axis-angle pose vector (default=False) Returns ------- """ model_vars = [ betas, global_orient, body_pose, transl, left_hand_pose, right_hand_pose, ] batch_size = 1 for var in model_vars: if var is None: continue batch_size = max(batch_size, len(var)) device, dtype = self.shapedirs.device, self.shapedirs.dtype if global_orient is None: global_orient = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) if body_pose is None: body_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous()) if left_hand_pose is None: left_hand_pose = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) if right_hand_pose is None: right_hand_pose = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if transl is None: transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) # Concatenate all pose vectors full_pose = torch.cat( [ global_orient.reshape(-1, 1, 3, 3), body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), ], dim=1, ) vertices, joints = lbs( betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=False, ) # Add any extra joints that might be needed joints = self.vertex_joint_selector(vertices, joints) if self.joint_mapper is not None: joints = self.joint_mapper(joints) if transl is not None: joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) output = SMPLHOutput( vertices=vertices if return_verts else None, joints=joints, betas=betas, global_orient=global_orient, body_pose=body_pose, left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, full_pose=full_pose if return_full_pose else None, ) return output class SMPLX(SMPLH): """ SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters trained jointly for the face, hands and body. SMPL-X uses standard vertex based linear blend skinning with learned corrective blend shapes, has N=10475 vertices and K=54 joints, which includes joints for the neck, jaw, eyeballs and fingers. """ NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS # 21 NUM_HAND_JOINTS = 15 NUM_FACE_JOINTS = 3 NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS EXPRESSION_SPACE_DIM = 100 NECK_IDX = 12 def __init__( self, model_path: str, kid_template_path: str = "", num_expression_coeffs: int = 10, create_expression: bool = True, expression: Optional[Tensor] = None, create_jaw_pose: bool = True, jaw_pose: Optional[Tensor] = None, create_leye_pose: bool = True, leye_pose: Optional[Tensor] = None, create_reye_pose=True, reye_pose: Optional[Tensor] = None, use_face_contour: bool = False, batch_size: int = 1, gender: str = "neutral", age: str = "adult", dtype=torch.float32, ext: str = "npz", **kwargs, ) -> None: """SMPLX model constructor Parameters ---------- model_path: str The path to the folder or to the file where the model parameters are stored num_expression_coeffs: int, optional Number of expression components to use (default = 10). create_expression: bool, optional Flag for creating a member variable for the expression space (default = True). expression: torch.tensor, optional, Bx10 The default value for the expression member variable. (default = None) create_jaw_pose: bool, optional Flag for creating a member variable for the jaw pose. (default = False) jaw_pose: torch.tensor, optional, Bx3 The default value for the jaw pose variable. (default = None) create_leye_pose: bool, optional Flag for creating a member variable for the left eye pose. (default = False) leye_pose: torch.tensor, optional, Bx10 The default value for the left eye pose variable. (default = None) create_reye_pose: bool, optional Flag for creating a member variable for the right eye pose. (default = False) reye_pose: torch.tensor, optional, Bx10 The default value for the right eye pose variable. (default = None) use_face_contour: bool, optional Whether to compute the keypoints that form the facial contour batch_size: int, optional The batch size used for creating the member variables gender: str, optional Which gender to load dtype: torch.dtype The data type for the created variables """ # Load the model if osp.isdir(model_path): model_fn = "SMPLX_{}.{ext}".format(gender.upper(), ext=ext) smplx_path = os.path.join(model_path, model_fn) else: smplx_path = model_path assert osp.exists(smplx_path), "Path {} does not exist!".format(smplx_path) if ext == "pkl": with open(smplx_path, "rb") as smplx_file: model_data = pickle.load(smplx_file, encoding="latin1") elif ext == "npz": model_data = np.load(smplx_path, allow_pickle=True) else: raise ValueError("Unknown extension: {}".format(ext)) data_struct = Struct(**model_data) super(SMPLX, self).__init__( model_path=model_path, kid_template_path=kid_template_path, data_struct=data_struct, dtype=dtype, batch_size=batch_size, vertex_ids=VERTEX_IDS["smplx"], gender=gender, age=age, ext=ext, **kwargs, ) lmk_faces_idx = data_struct.lmk_faces_idx self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long)) lmk_bary_coords = data_struct.lmk_bary_coords self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype)) self.use_face_contour = use_face_contour if self.use_face_contour: dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long) self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx) dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords dynamic_lmk_bary_coords = torch.tensor(dynamic_lmk_bary_coords, dtype=dtype) self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_bary_coords) neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long)) if create_jaw_pose: if jaw_pose is None: default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) self.register_parameter("jaw_pose", jaw_pose_param) if create_leye_pose: if leye_pose is None: default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_leye_pose = torch.tensor(leye_pose, dtype=dtype) leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True) self.register_parameter("leye_pose", leye_pose_param) if create_reye_pose: if reye_pose is None: default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_reye_pose = torch.tensor(reye_pose, dtype=dtype) reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True) self.register_parameter("reye_pose", reye_pose_param) shapedirs = data_struct.shapedirs if len(shapedirs.shape) < 3: shapedirs = shapedirs[:, :, None] if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM: # print(f'WARNING: You are using a {self.name()} model, with only' # ' 10 shape and 10 expression coefficients.') expr_start_idx = 10 expr_end_idx = 20 num_expression_coeffs = min(num_expression_coeffs, 10) else: expr_start_idx = self.SHAPE_SPACE_DIM expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM) self._num_expression_coeffs = num_expression_coeffs expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype)) if create_expression: if expression is None: default_expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype) else: default_expression = torch.tensor(expression, dtype=dtype) expression_param = nn.Parameter(default_expression, requires_grad=True) self.register_parameter("expression", expression_param) def name(self) -> str: return "SMPL-X" @property def num_expression_coeffs(self): return self._num_expression_coeffs def create_mean_pose(self, data_struct, flat_hand_mean=False): # Create the array for the mean pose. If flat_hand is false, then use # the mean that is given by the data, rather than the flat open hand global_orient_mean = torch.zeros([3], dtype=self.dtype) body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype) jaw_pose_mean = torch.zeros([3], dtype=self.dtype) leye_pose_mean = torch.zeros([3], dtype=self.dtype) reye_pose_mean = torch.zeros([3], dtype=self.dtype) pose_mean = np.concatenate( [ global_orient_mean, body_pose_mean, jaw_pose_mean, leye_pose_mean, reye_pose_mean, self.left_hand_mean, self.right_hand_mean, ], axis=0, ) return pose_mean def extra_repr(self): msg = super(SMPLX, self).extra_repr() msg = [msg, f"Number of Expression Coefficients: {self.num_expression_coeffs}"] return "\n".join(msg) def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, body_pose: Optional[Tensor] = None, left_hand_pose: Optional[Tensor] = None, right_hand_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, expression: Optional[Tensor] = None, jaw_pose: Optional[Tensor] = None, leye_pose: Optional[Tensor] = None, reye_pose: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, pose2rot: bool = True, return_joint_transformation: bool = False, return_vertex_transformation: bool = False, pose_type: str = 'posed', **kwargs, ) -> SMPLXOutput: """ Forward pass for the SMPLX model Parameters ---------- global_orient: torch.tensor, optional, shape Bx3 If given, ignore the member variable and use it as the global rotation of the body. Useful if someone wishes to predicts this with an external model. (default=None) betas: torch.tensor, optional, shape BxN_b If given, ignore the member variable `betas` and use it instead. For example, it can used if shape parameters `betas` are predicted from some external model. (default=None) expression: torch.tensor, optional, shape BxN_e If given, ignore the member variable `expression` and use it instead. For example, it can used if expression parameters `expression` are predicted from some external model. body_pose: torch.tensor, optional, shape Bx(J*3) If given, ignore the member variable `body_pose` and use it instead. For example, it can used if someone predicts the pose of the body joints are predicted from some external model. It should be a tensor that contains joint rotations in axis-angle format. (default=None) left_hand_pose: torch.tensor, optional, shape BxP If given, ignore the member variable `left_hand_pose` and use this instead. It should either contain PCA coefficients or joint rotations in axis-angle format. right_hand_pose: torch.tensor, optional, shape BxP If given, ignore the member variable `right_hand_pose` and use this instead. It should either contain PCA coefficients or joint rotations in axis-angle format. jaw_pose: torch.tensor, optional, shape Bx3 If given, ignore the member variable `jaw_pose` and use this instead. It should either joint rotations in axis-angle format. transl: torch.tensor, optional, shape Bx3 If given, ignore the member variable `transl` and use it instead. For example, it can used if the translation `transl` is predicted from some external model. (default=None) return_verts: bool, optional Return the vertices. (default=True) return_full_pose: bool, optional Returns the full axis-angle pose vector (default=False) Returns ------- output: ModelOutput A named tuple of type `ModelOutput` """ # If no shape and pose parameters are passed along, then use the # ones from the module global_orient = (global_orient if global_orient is not None else self.global_orient) body_pose = body_pose if body_pose is not None else self.body_pose betas = betas if betas is not None else self.betas left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose) right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose) jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose leye_pose = leye_pose if leye_pose is not None else self.leye_pose reye_pose = reye_pose if reye_pose is not None else self.reye_pose expression = expression if expression is not None else self.expression apply_trans = transl is not None or hasattr(self, "transl") if transl is None: if hasattr(self, "transl"): transl = self.transl if self.use_pca: left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components]) right_hand_pose = torch.einsum("bi,ij->bj", [right_hand_pose, self.right_hand_components]) full_pose = torch.cat( [ global_orient, body_pose, jaw_pose, leye_pose, reye_pose, left_hand_pose, right_hand_pose, ], dim=1, ) if pose_type == "t-pose": full_pose *= 0.0 elif pose_type == "da-pose": body_pose = torch.zeros_like(body_pose).view(body_pose.shape[0], -1, 3) body_pose[:, 0] = torch.tensor([0., 0., 30 * np.pi / 180.]) body_pose[:, 1] = torch.tensor([0., 0., -30 * np.pi / 180.]) body_pose = body_pose.view(body_pose.shape[0], -1) full_pose = torch.cat( [ global_orient * 0., body_pose, jaw_pose * 0., leye_pose * 0., reye_pose * 0., left_hand_pose * 0., right_hand_pose * 0., ], dim=1, ) # Add the mean pose of the model. Does not affect the body, only the # hands when flat_hand_mean == False # full_pose += self.pose_mean batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0]) # Concatenate the shape and expression coefficients scale = int(batch_size / betas.shape[0]) if scale > 1: betas = betas.expand(scale, -1) shape_components = torch.cat([betas, expression], dim=-1) shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) if return_joint_transformation or return_vertex_transformation: vertices, joints, joint_transformation, vertex_transformation = lbs( shape_components, full_pose, self.v_template, shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot, return_transformation=True, ) else: vertices, joints = lbs( shape_components, full_pose, self.v_template, shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot, ) lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) if self.use_face_contour: lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( vertices, full_pose, self.dynamic_lmk_faces_idx, self.dynamic_lmk_bary_coords, self.neck_kin_chain, pose2rot=True, ) dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) # Add any extra joints that might be needed joints = self.vertex_joint_selector(vertices, joints) # Add the landmarks to the joints joints = torch.cat([joints, landmarks], dim=1) # Map the joints to the current dataset if self.joint_mapper is not None: joints = self.joint_mapper(joints=joints, vertices=vertices) if apply_trans: joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) output = SMPLXOutput( vertices=vertices if return_verts else None, joints=joints, betas=betas, expression=expression, global_orient=global_orient, body_pose=body_pose, left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, full_pose=full_pose if return_full_pose else None, joint_transformation=joint_transformation if return_joint_transformation else None, vertex_transformation=vertex_transformation if return_vertex_transformation else None, ) return output class SMPLXLayer(SMPLX): def __init__(self, *args, **kwargs) -> None: # Just create a SMPLX module without any member variables super(SMPLXLayer, self).__init__( create_global_orient=False, create_body_pose=False, create_left_hand_pose=False, create_right_hand_pose=False, create_jaw_pose=False, create_leye_pose=False, create_reye_pose=False, create_betas=False, create_expression=False, create_transl=False, *args, **kwargs, ) def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, body_pose: Optional[Tensor] = None, left_hand_pose: Optional[Tensor] = None, right_hand_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, expression: Optional[Tensor] = None, jaw_pose: Optional[Tensor] = None, leye_pose: Optional[Tensor] = None, reye_pose: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, **kwargs, ) -> SMPLXOutput: """ Forward pass for the SMPLX model Parameters ---------- global_orient: torch.tensor, optional, shape Bx3x3 If given, ignore the member variable and use it as the global rotation of the body. Useful if someone wishes to predicts this with an external model. It is expected to be in rotation matrix format. (default=None) betas: torch.tensor, optional, shape BxN_b If given, ignore the member variable `betas` and use it instead. For example, it can used if shape parameters `betas` are predicted from some external model. (default=None) expression: torch.tensor, optional, shape BxN_e Expression coefficients. For example, it can used if expression parameters `expression` are predicted from some external model. body_pose: torch.tensor, optional, shape BxJx3x3 If given, ignore the member variable `body_pose` and use it instead. For example, it can used if someone predicts the pose of the body joints are predicted from some external model. It should be a tensor that contains joint rotations in rotation matrix format. (default=None) left_hand_pose: torch.tensor, optional, shape Bx15x3x3 If given, contains the pose of the left hand. It should be a tensor that contains joint rotations in rotation matrix format. (default=None) right_hand_pose: torch.tensor, optional, shape Bx15x3x3 If given, contains the pose of the right hand. It should be a tensor that contains joint rotations in rotation matrix format. (default=None) jaw_pose: torch.tensor, optional, shape Bx3x3 Jaw pose. It should either joint rotations in rotation matrix format. transl: torch.tensor, optional, shape Bx3 Translation vector of the body. For example, it can used if the translation `transl` is predicted from some external model. (default=None) return_verts: bool, optional Return the vertices. (default=True) return_full_pose: bool, optional Returns the full pose vector (default=False) Returns ------- output: ModelOutput A data class that contains the posed vertices and joints """ device, dtype = self.shapedirs.device, self.shapedirs.dtype model_vars = [ betas, global_orient, body_pose, transl, expression, left_hand_pose, right_hand_pose, jaw_pose, ] batch_size = 1 for var in model_vars: if var is None: continue batch_size = max(batch_size, len(var)) if global_orient is None: global_orient = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) if body_pose is None: body_pose = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous()) if left_hand_pose is None: left_hand_pose = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) if right_hand_pose is None: right_hand_pose = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) if jaw_pose is None: jaw_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) if leye_pose is None: leye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) if reye_pose is None: reye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) if expression is None: expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype, device=device) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if transl is None: transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) # Concatenate all pose vectors full_pose = torch.cat( [ global_orient.reshape(-1, 1, 3, 3), body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), jaw_pose.reshape(-1, 1, 3, 3), leye_pose.reshape(-1, 1, 3, 3), reye_pose.reshape(-1, 1, 3, 3), left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), ], dim=1, ) shape_components = torch.cat([betas, expression], dim=-1) shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) vertices, joints = lbs( shape_components, full_pose, self.v_template, shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=False, ) lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(batch_size, 1, 1) if self.use_face_contour: lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( vertices, full_pose, self.dynamic_lmk_faces_idx, self.dynamic_lmk_bary_coords, self.neck_kin_chain, pose2rot=False, ) dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) # Add any extra joints that might be needed joints = self.vertex_joint_selector(vertices, joints) # Add the landmarks to the joints joints = torch.cat([joints, landmarks], dim=1) # Map the joints to the current dataset if self.joint_mapper is not None: joints = self.joint_mapper(joints=joints, vertices=vertices) if transl is not None: joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) output = SMPLXOutput( vertices=vertices if return_verts else None, joints=joints, betas=betas, expression=expression, global_orient=global_orient, body_pose=body_pose, left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, transl=transl, full_pose=full_pose if return_full_pose else None, ) return output class MANO(SMPL): # The hand joints are replaced by MANO NUM_BODY_JOINTS = 1 NUM_HAND_JOINTS = 15 NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS def __init__( self, model_path: str, is_rhand: bool = True, data_struct: Optional[Struct] = None, create_hand_pose: bool = True, hand_pose: Optional[Tensor] = None, use_pca: bool = True, num_pca_comps: int = 6, flat_hand_mean: bool = False, batch_size: int = 1, dtype=torch.float32, vertex_ids=None, use_compressed: bool = True, ext: str = "pkl", **kwargs, ) -> None: """MANO model constructor Parameters ---------- model_path: str The path to the folder or to the file where the model parameters are stored data_struct: Strct A struct object. If given, then the parameters of the model are read from the object. Otherwise, the model tries to read the parameters from the given `model_path`. (default = None) create_hand_pose: bool, optional Flag for creating a member variable for the pose of the right hand. (default = True) hand_pose: torch.tensor, optional, BxP The default value for the right hand pose member variable. (default = None) num_pca_comps: int, optional The number of PCA components to use for each hand. (default = 6) flat_hand_mean: bool, optional If False, then the pose of the hand is initialized to False. batch_size: int, optional The batch size used for creating the member variables dtype: torch.dtype, optional The data type for the created variables vertex_ids: dict, optional A dictionary containing the indices of the extra vertices that will be selected """ self.num_pca_comps = num_pca_comps self.is_rhand = is_rhand # If no data structure is passed, then load the data from the given # model folder if data_struct is None: # Load the model if osp.isdir(model_path): model_fn = "MANO_{}.{ext}".format("RIGHT" if is_rhand else "LEFT", ext=ext) mano_path = os.path.join(model_path, model_fn) else: mano_path = model_path self.is_rhand = (True if "RIGHT" in os.path.basename(model_path) else False) assert osp.exists(mano_path), "Path {} does not exist!".format(mano_path) if ext == "pkl": with open(mano_path, "rb") as mano_file: model_data = pickle.load(mano_file, encoding="latin1") elif ext == "npz": model_data = np.load(mano_path, allow_pickle=True) else: raise ValueError("Unknown extension: {}".format(ext)) data_struct = Struct(**model_data) if vertex_ids is None: vertex_ids = VERTEX_IDS["smplh"] super(MANO, self).__init__( model_path=model_path, data_struct=data_struct, batch_size=batch_size, vertex_ids=vertex_ids, use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs, ) # add only MANO tips to the extra joints self.vertex_joint_selector.extra_joints_idxs = to_tensor(list(VERTEX_IDS["mano"].values()), dtype=torch.long) self.use_pca = use_pca self.num_pca_comps = num_pca_comps if self.num_pca_comps == 45: self.use_pca = False self.flat_hand_mean = flat_hand_mean hand_components = data_struct.hands_components[:num_pca_comps] self.np_hand_components = hand_components if self.use_pca: self.register_buffer("hand_components", torch.tensor(hand_components, dtype=dtype)) if self.flat_hand_mean: hand_mean = np.zeros_like(data_struct.hands_mean) else: hand_mean = data_struct.hands_mean self.register_buffer("hand_mean", to_tensor(hand_mean, dtype=self.dtype)) # Create the buffers for the pose of the left hand hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS if create_hand_pose: if hand_pose is None: default_hand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) else: default_hand_pose = torch.tensor(hand_pose, dtype=dtype) hand_pose_param = nn.Parameter(default_hand_pose, requires_grad=True) self.register_parameter("hand_pose", hand_pose_param) # Create the buffer for the mean pose. pose_mean = self.create_mean_pose(data_struct, flat_hand_mean=flat_hand_mean) pose_mean_tensor = pose_mean.clone().to(dtype) # pose_mean_tensor = torch.tensor(pose_mean, dtype=dtype) self.register_buffer("pose_mean", pose_mean_tensor) def name(self) -> str: return "MANO" def create_mean_pose(self, data_struct, flat_hand_mean=False): # Create the array for the mean pose. If flat_hand is false, then use # the mean that is given by the data, rather than the flat open hand global_orient_mean = torch.zeros([3], dtype=self.dtype) pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0) return pose_mean def extra_repr(self): msg = [super(MANO, self).extra_repr()] if self.use_pca: msg.append(f"Number of PCA components: {self.num_pca_comps}") msg.append(f"Flat hand mean: {self.flat_hand_mean}") return "\n".join(msg) def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, hand_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, **kwargs, ) -> MANOOutput: """Forward pass for the MANO model""" # If no shape and pose parameters are passed along, then use the # ones from the module global_orient = (global_orient if global_orient is not None else self.global_orient) betas = betas if betas is not None else self.betas hand_pose = hand_pose if hand_pose is not None else self.hand_pose apply_trans = transl is not None or hasattr(self, "transl") if transl is None: if hasattr(self, "transl"): transl = self.transl if self.use_pca: hand_pose = torch.einsum("bi,ij->bj", [hand_pose, self.hand_components]) full_pose = torch.cat([global_orient, hand_pose], dim=1) full_pose += self.pose_mean vertices, joints = lbs( betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=True, ) # # Add pre-selected extra joints that might be needed # joints = self.vertex_joint_selector(vertices, joints) if self.joint_mapper is not None: joints = self.joint_mapper(joints) if apply_trans: joints = joints + transl.unsqueeze(dim=1) vertices = vertices + transl.unsqueeze(dim=1) output = MANOOutput( vertices=vertices if return_verts else None, joints=joints if return_verts else None, betas=betas, global_orient=global_orient, hand_pose=hand_pose, full_pose=full_pose if return_full_pose else None, ) return output class MANOLayer(MANO): def __init__(self, *args, **kwargs) -> None: """MANO as a layer model constructor""" super(MANOLayer, self).__init__( create_global_orient=False, create_hand_pose=False, create_betas=False, create_transl=False, *args, **kwargs, ) def name(self) -> str: return "MANO" def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, hand_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, **kwargs, ) -> MANOOutput: """Forward pass for the MANO model""" device, dtype = self.shapedirs.device, self.shapedirs.dtype if global_orient is None: batch_size = 1 global_orient = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) else: batch_size = global_orient.shape[0] if hand_pose is None: hand_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if transl is None: transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) full_pose = torch.cat([global_orient, hand_pose], dim=1) vertices, joints = lbs( betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=False, ) if self.joint_mapper is not None: joints = self.joint_mapper(joints) if transl is not None: joints = joints + transl.unsqueeze(dim=1) vertices = vertices + transl.unsqueeze(dim=1) output = MANOOutput( vertices=vertices if return_verts else None, joints=joints if return_verts else None, betas=betas, global_orient=global_orient, hand_pose=hand_pose, full_pose=full_pose if return_full_pose else None, ) return output class FLAME(SMPL): NUM_JOINTS = 5 SHAPE_SPACE_DIM = 300 EXPRESSION_SPACE_DIM = 100 NECK_IDX = 0 def __init__( self, model_path: str, data_struct=None, num_expression_coeffs=10, create_expression: bool = True, expression: Optional[Tensor] = None, create_neck_pose: bool = True, neck_pose: Optional[Tensor] = None, create_jaw_pose: bool = True, jaw_pose: Optional[Tensor] = None, create_leye_pose: bool = True, leye_pose: Optional[Tensor] = None, create_reye_pose=True, reye_pose: Optional[Tensor] = None, use_face_contour=False, batch_size: int = 1, gender: str = "neutral", dtype: torch.dtype = torch.float32, ext="pkl", **kwargs, ) -> None: """FLAME model constructor Parameters ---------- model_path: str The path to the folder or to the file where the model parameters are stored num_expression_coeffs: int, optional Number of expression components to use (default = 10). create_expression: bool, optional Flag for creating a member variable for the expression space (default = True). expression: torch.tensor, optional, Bx10 The default value for the expression member variable. (default = None) create_neck_pose: bool, optional Flag for creating a member variable for the neck pose. (default = False) neck_pose: torch.tensor, optional, Bx3 The default value for the neck pose variable. (default = None) create_jaw_pose: bool, optional Flag for creating a member variable for the jaw pose. (default = False) jaw_pose: torch.tensor, optional, Bx3 The default value for the jaw pose variable. (default = None) create_leye_pose: bool, optional Flag for creating a member variable for the left eye pose. (default = False) leye_pose: torch.tensor, optional, Bx10 The default value for the left eye pose variable. (default = None) create_reye_pose: bool, optional Flag for creating a member variable for the right eye pose. (default = False) reye_pose: torch.tensor, optional, Bx10 The default value for the right eye pose variable. (default = None) use_face_contour: bool, optional Whether to compute the keypoints that form the facial contour batch_size: int, optional The batch size used for creating the member variables gender: str, optional Which gender to load dtype: torch.dtype The data type for the created variables """ model_fn = f"FLAME_{gender.upper()}.{ext}" flame_path = os.path.join(model_path, model_fn) assert osp.exists(flame_path), "Path {} does not exist!".format(flame_path) if ext == "npz": file_data = np.load(flame_path, allow_pickle=True) elif ext == "pkl": with open(flame_path, "rb") as smpl_file: file_data = pickle.load(smpl_file, encoding="latin1") else: raise ValueError("Unknown extension: {}".format(ext)) data_struct = Struct(**file_data) super(FLAME, self).__init__( model_path=model_path, data_struct=data_struct, dtype=dtype, batch_size=batch_size, gender=gender, ext=ext, **kwargs, ) self.use_face_contour = use_face_contour self.vertex_joint_selector.extra_joints_idxs = to_tensor([], dtype=torch.long) if create_neck_pose: if neck_pose is None: default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_neck_pose = torch.tensor(neck_pose, dtype=dtype) neck_pose_param = nn.Parameter(default_neck_pose, requires_grad=True) self.register_parameter("neck_pose", neck_pose_param) if create_jaw_pose: if jaw_pose is None: default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) self.register_parameter("jaw_pose", jaw_pose_param) if create_leye_pose: if leye_pose is None: default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_leye_pose = torch.tensor(leye_pose, dtype=dtype) leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True) self.register_parameter("leye_pose", leye_pose_param) if create_reye_pose: if reye_pose is None: default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_reye_pose = torch.tensor(reye_pose, dtype=dtype) reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True) self.register_parameter("reye_pose", reye_pose_param) shapedirs = data_struct.shapedirs if len(shapedirs.shape) < 3: shapedirs = shapedirs[:, :, None] if shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM: # print(f'WARNING: You are using a {self.name()} model, with only' # ' 10 shape and 10 expression coefficients.') expr_start_idx = 10 expr_end_idx = 20 num_expression_coeffs = min(num_expression_coeffs, 10) else: expr_start_idx = self.SHAPE_SPACE_DIM expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs num_expression_coeffs = min(num_expression_coeffs, self.EXPRESSION_SPACE_DIM) self._num_expression_coeffs = num_expression_coeffs expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] self.register_buffer("expr_dirs", to_tensor(to_np(expr_dirs), dtype=dtype)) if create_expression: if expression is None: default_expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype) else: default_expression = torch.tensor(expression, dtype=dtype) expression_param = nn.Parameter(default_expression, requires_grad=True) self.register_parameter("expression", expression_param) # The pickle file that contains the barycentric coordinates for # regressing the landmarks landmark_bcoord_filename = osp.join(model_path, "flame_static_embedding.pkl") with open(landmark_bcoord_filename, "rb") as fp: landmarks_data = pickle.load(fp, encoding="latin1") lmk_faces_idx = landmarks_data["lmk_face_idx"].astype(np.int64) self.register_buffer("lmk_faces_idx", torch.tensor(lmk_faces_idx, dtype=torch.long)) lmk_bary_coords = landmarks_data["lmk_b_coords"] self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype)) if self.use_face_contour: face_contour_path = os.path.join(model_path, "flame_dynamic_embedding.npy") contour_embeddings = np.load(face_contour_path, allow_pickle=True, encoding="latin1")[()] dynamic_lmk_faces_idx = np.array(contour_embeddings["lmk_face_idx"], dtype=np.int64) dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long) self.register_buffer("dynamic_lmk_faces_idx", dynamic_lmk_faces_idx) dynamic_lmk_b_coords = torch.tensor(contour_embeddings["lmk_b_coords"], dtype=dtype) self.register_buffer("dynamic_lmk_bary_coords", dynamic_lmk_b_coords) neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) self.register_buffer("neck_kin_chain", torch.tensor(neck_kin_chain, dtype=torch.long)) @property def num_expression_coeffs(self): return self._num_expression_coeffs def name(self) -> str: return "FLAME" def extra_repr(self): msg = [ super(FLAME, self).extra_repr(), f"Number of Expression Coefficients: {self.num_expression_coeffs}", f"Use face contour: {self.use_face_contour}", ] return "\n".join(msg) def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, neck_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, expression: Optional[Tensor] = None, jaw_pose: Optional[Tensor] = None, leye_pose: Optional[Tensor] = None, reye_pose: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, pose2rot: bool = True, **kwargs, ) -> FLAMEOutput: """ Forward pass for the SMPLX model Parameters ---------- global_orient: torch.tensor, optional, shape Bx3 If given, ignore the member variable and use it as the global rotation of the body. Useful if someone wishes to predicts this with an external model. (default=None) betas: torch.tensor, optional, shape Bx10 If given, ignore the member variable `betas` and use it instead. For example, it can used if shape parameters `betas` are predicted from some external model. (default=None) expression: torch.tensor, optional, shape Bx10 If given, ignore the member variable `expression` and use it instead. For example, it can used if expression parameters `expression` are predicted from some external model. jaw_pose: torch.tensor, optional, shape Bx3 If given, ignore the member variable `jaw_pose` and use this instead. It should either joint rotations in axis-angle format. jaw_pose: torch.tensor, optional, shape Bx3 If given, ignore the member variable `jaw_pose` and use this instead. It should either joint rotations in axis-angle format. transl: torch.tensor, optional, shape Bx3 If given, ignore the member variable `transl` and use it instead. For example, it can used if the translation `transl` is predicted from some external model. (default=None) return_verts: bool, optional Return the vertices. (default=True) return_full_pose: bool, optional Returns the full axis-angle pose vector (default=False) Returns ------- output: ModelOutput A named tuple of type `ModelOutput` """ # If no shape and pose parameters are passed along, then use the # ones from the module global_orient = (global_orient if global_orient is not None else self.global_orient) jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose neck_pose = neck_pose if neck_pose is not None else self.neck_pose leye_pose = leye_pose if leye_pose is not None else self.leye_pose reye_pose = reye_pose if reye_pose is not None else self.reye_pose betas = betas if betas is not None else self.betas expression = expression if expression is not None else self.expression apply_trans = transl is not None or hasattr(self, "transl") if transl is None: if hasattr(self, "transl"): transl = self.transl full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) batch_size = max(betas.shape[0], global_orient.shape[0], jaw_pose.shape[0]) # Concatenate the shape and expression coefficients scale = int(batch_size / betas.shape[0]) if scale > 1: betas = betas.expand(scale, -1) shape_components = torch.cat([betas, expression], dim=-1) shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) vertices, joints = lbs( shape_components, full_pose, self.v_template, shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot, ) lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) if self.use_face_contour: lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( vertices, full_pose, self.dynamic_lmk_faces_idx, self.dynamic_lmk_bary_coords, self.neck_kin_chain, pose2rot=True, ) dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) # Add any extra joints that might be needed joints = self.vertex_joint_selector(vertices, joints) # Add the landmarks to the joints joints = torch.cat([joints, landmarks], dim=1) # Map the joints to the current dataset if self.joint_mapper is not None: joints = self.joint_mapper(joints=joints, vertices=vertices) if apply_trans: joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) output = FLAMEOutput( vertices=vertices if return_verts else None, joints=joints, betas=betas, expression=expression, global_orient=global_orient, neck_pose=neck_pose, jaw_pose=jaw_pose, full_pose=full_pose if return_full_pose else None, ) return output class FLAMELayer(FLAME): def __init__(self, *args, **kwargs) -> None: """ FLAME as a layer model constructor """ super(FLAMELayer, self).__init__( create_betas=False, create_expression=False, create_global_orient=False, create_neck_pose=False, create_jaw_pose=False, create_leye_pose=False, create_reye_pose=False, *args, **kwargs, ) def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, neck_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, expression: Optional[Tensor] = None, jaw_pose: Optional[Tensor] = None, leye_pose: Optional[Tensor] = None, reye_pose: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, pose2rot: bool = True, **kwargs, ) -> FLAMEOutput: """ Forward pass for the SMPLX model Parameters ---------- global_orient: torch.tensor, optional, shape Bx3x3 Global rotation of the body. Useful if someone wishes to predicts this with an external model. It is expected to be in rotation matrix format. (default=None) betas: torch.tensor, optional, shape BxN_b Shape parameters. For example, it can used if shape parameters `betas` are predicted from some external model. (default=None) expression: torch.tensor, optional, shape BxN_e If given, ignore the member variable `expression` and use it instead. For example, it can used if expression parameters `expression` are predicted from some external model. jaw_pose: torch.tensor, optional, shape Bx3x3 Jaw pose. It should either joint rotations in rotation matrix format. transl: torch.tensor, optional, shape Bx3 Translation vector of the body. For example, it can used if the translation `transl` is predicted from some external model. (default=None) return_verts: bool, optional Return the vertices. (default=True) return_full_pose: bool, optional Returns the full axis-angle pose vector (default=False) Returns ------- output: ModelOutput A named tuple of type `ModelOutput` """ device, dtype = self.shapedirs.device, self.shapedirs.dtype if global_orient is None: batch_size = 1 global_orient = ( torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) else: batch_size = global_orient.shape[0] if neck_pose is None: neck_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous()) if jaw_pose is None: jaw_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) if leye_pose is None: leye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) if reye_pose is None: reye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if expression is None: expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype, device=device) if transl is None: transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) full_pose = torch.cat([global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) shape_components = torch.cat([betas, expression], dim=-1) shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) vertices, joints = lbs( shape_components, full_pose, self.v_template, shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=False, ) lmk_faces_idx = (self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous()) lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) if self.use_face_contour: lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( vertices, full_pose, self.dynamic_lmk_faces_idx, self.dynamic_lmk_bary_coords, self.neck_kin_chain, pose2rot=False, ) dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) # Add any extra joints that might be needed joints = self.vertex_joint_selector(vertices, joints) # Add the landmarks to the joints joints = torch.cat([joints, landmarks], dim=1) # Map the joints to the current dataset if self.joint_mapper is not None: joints = self.joint_mapper(joints=joints, vertices=vertices) joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) output = FLAMEOutput( vertices=vertices if return_verts else None, joints=joints, betas=betas, expression=expression, global_orient=global_orient, neck_pose=neck_pose, jaw_pose=jaw_pose, full_pose=full_pose if return_full_pose else None, ) return output def build_layer(model_path: str, model_type: str = "smpl", **kwargs) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]: """Method for creating a model from a path and a model type Parameters ---------- model_path: str Either the path to the model you wish to load or a folder, where each subfolder contains the differents types, i.e.: model_path: | |-- smpl |-- SMPL_FEMALE |-- SMPL_NEUTRAL |-- SMPL_MALE |-- smplh |-- SMPLH_FEMALE |-- SMPLH_MALE |-- smplx |-- SMPLX_FEMALE |-- SMPLX_NEUTRAL |-- SMPLX_MALE |-- mano |-- MANO RIGHT |-- MANO LEFT |-- flame |-- FLAME_FEMALE |-- FLAME_MALE |-- FLAME_NEUTRAL model_type: str, optional When model_path is a folder, then this parameter specifies the type of model to be loaded **kwargs: dict Keyword arguments Returns ------- body_model: nn.Module The PyTorch module that implements the corresponding body model Raises ------ ValueError: In case the model type is not one of SMPL, SMPLH, SMPLX, MANO or FLAME """ if osp.isdir(model_path): model_path = os.path.join(model_path, model_type) else: model_type = osp.basename(model_path).split("_")[0].lower() if model_type.lower() == "smpl": return SMPLLayer(model_path, **kwargs) elif model_type.lower() == "smplh": return SMPLHLayer(model_path, **kwargs) elif model_type.lower() == "smplx": return SMPLXLayer(model_path, **kwargs) elif "mano" in model_type.lower(): return MANOLayer(model_path, **kwargs) elif "flame" in model_type.lower(): return FLAMELayer(model_path, **kwargs) else: raise ValueError(f"Unknown model type {model_type}, exiting!") def create(model_path: str, model_type: str = "smpl", **kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: """Method for creating a model from a path and a model type Parameters ---------- model_path: str Either the path to the model you wish to load or a folder, where each subfolder contains the differents types, i.e.: model_path: | |-- smpl |-- SMPL_FEMALE |-- SMPL_NEUTRAL |-- SMPL_MALE |-- smplh |-- SMPLH_FEMALE |-- SMPLH_MALE |-- smplx |-- SMPLX_FEMALE |-- SMPLX_NEUTRAL |-- SMPLX_MALE |-- mano |-- MANO RIGHT |-- MANO LEFT model_type: str, optional When model_path is a folder, then this parameter specifies the type of model to be loaded **kwargs: dict Keyword arguments Returns ------- body_model: nn.Module The PyTorch module that implements the corresponding body model Raises ------ ValueError: In case the model type is not one of SMPL, SMPLH, SMPLX, MANO or FLAME """ # If it's a folder, assume if osp.isdir(model_path): model_path = os.path.join(model_path, model_type) else: model_type = osp.basename(model_path).split("_")[0].lower() if model_type.lower() == "smpl": return SMPL(model_path, **kwargs) elif model_type.lower() == "smplh": return SMPLH(model_path, **kwargs) elif model_type.lower() == "smplx": return SMPLX(model_path, **kwargs) elif "mano" in model_type.lower(): return MANO(model_path, **kwargs) elif "flame" in model_type.lower(): return FLAME(model_path, **kwargs) else: raise ValueError(f"Unknown model type {model_type}, exiting!")