import os import sys import torch def load_state_dicts(folder_path): state_dicts = {} for filename in os.listdir(folder_path): if filename.endswith(".pth"): print('Processing {}'.format(filename)) file_path = os.path.join(folder_path, filename) state_dict = torch.load(file_path) new_state_dict = {"state_dict": {}, "optimizer": state_dict['optimizer'], "meta": state_dict['meta'], } for key in state_dict['state_dict'].keys(): if 'spatial_pos_encoder' in key or 'skeleton_head.MLP' in key or 'skeleton_head.adj_output_mlp' in key: continue new_key = key.replace("keypoint_head.", "keypoint_head_module.").replace('bias_function_prior_weight', 'markov_structural_mlp') new_state_dict['state_dict'][new_key] = state_dict['state_dict'][key] new_file_path = os.path.join(folder_path, f'{filename}') print(f'Saving to {new_file_path}') torch.save(new_state_dict, new_file_path) return state_dicts if __name__ == "__main__": folder_path = sys.argv[1] load_state_dicts(folder_path)