File size: 1,071 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import time
from builtins import print
import argparse

import torch
# os.environ["CUDA_VISIBLE_DEVICES"] = '3'


def get_time_str():
    return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())


def main():
    total_parser = argparse.ArgumentParser("Pretrain Unsupervise.")
    total_parser.add_argument('--ckpt_path', default=None, type=str)
    total_parser.add_argument('--bin_path', default=None, type=str)
    total_parser.add_argument('--rm_prefix', default=None, type=str)
    # * Args for base model
    args = total_parser.parse_args()
    print('Argument parse success.')
    state_dict = torch.load(args.ckpt_path)['module']
    new_state_dict = {}

    if args.rm_prefix is not None:
        prefix_len = len(args.rm_prefix)
        for k, v in state_dict.items():
            if k[:prefix_len] == args.rm_prefix:
                new_state_dict[k[prefix_len:]] = v
            else:
                new_state_dict[k] = v
    else:
        new_state_dict = state_dict
    torch.save(new_state_dict, args.bin_path)


if __name__ == '__main__':
    main()