ddd commited on
Commit
853fd97
·
1 Parent(s): 40e984c

fix hparam

Browse files
Files changed (1) hide show
  1. utils/hparams.py +44 -37
utils/hparams.py CHANGED
@@ -21,30 +21,35 @@ def override_config(old_config: dict, new_config: dict):
21
 
22
 
23
  def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
24
- if config == '':
25
- parser = argparse.ArgumentParser(description='neural music')
26
  parser.add_argument('--config', type=str, default='',
27
  help='location of the data corpus')
28
  parser.add_argument('--exp_name', type=str, default='', help='exp_name')
29
- parser.add_argument('--hparams', type=str, default='',
30
  help='location of the data corpus')
31
  parser.add_argument('--infer', action='store_true', help='infer')
32
  parser.add_argument('--validate', action='store_true', help='validate')
33
  parser.add_argument('--reset', action='store_true', help='reset hparams')
 
34
  parser.add_argument('--debug', action='store_true', help='debug')
35
  args, unknown = parser.parse_known_args()
 
36
  else:
37
  args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
38
- infer=False, validate=False, reset=False, debug=False)
39
- args_work_dir = ''
40
- if args.exp_name != '':
41
- args.work_dir = args.exp_name
42
- args_work_dir = f'checkpoints/{args.work_dir}'
43
 
44
  config_chains = []
45
  loaded_config = set()
46
 
47
- def load_config(config_fn): # deep first
 
 
 
48
  with open(config_fn) as f:
49
  hparams_ = yaml.safe_load(f)
50
  loaded_config.add(config_fn)
@@ -53,10 +58,10 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
53
  if not isinstance(hparams_['base_config'], list):
54
  hparams_['base_config'] = [hparams_['base_config']]
55
  for c in hparams_['base_config']:
 
 
 
56
  if c not in loaded_config:
57
- if c.startswith('.'):
58
- c = f'{os.path.dirname(config_fn)}/{c}'
59
- c = os.path.normpath(c)
60
  override_config(ret_hparams, load_config(c))
61
  override_config(ret_hparams, hparams_)
62
  else:
@@ -64,36 +69,43 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
64
  config_chains.append(config_fn)
65
  return ret_hparams
66
 
67
- global hparams
68
- assert args.config != '' or args_work_dir != ''
69
  saved_hparams = {}
70
- if args_work_dir != 'checkpoints/':
 
 
71
  ckpt_config_path = f'{args_work_dir}/config.yaml'
72
  if os.path.exists(ckpt_config_path):
73
- try:
74
- with open(ckpt_config_path) as f:
75
- saved_hparams.update(yaml.safe_load(f))
76
- except:
77
- pass
78
- if args.config == '':
79
- args.config = ckpt_config_path
80
-
81
  hparams_ = {}
82
-
83
- hparams_.update(load_config(args.config))
84
-
85
  if not args.reset:
86
  hparams_.update(saved_hparams)
87
  hparams_['work_dir'] = args_work_dir
88
 
 
 
89
  if args.hparams != "":
90
  for new_hparam in args.hparams.split(","):
91
  k, v = new_hparam.split("=")
92
- if v in ['True', 'False'] or type(hparams_[k]) == bool:
93
- hparams_[k] = eval(v)
 
 
 
 
 
 
 
94
  else:
95
- hparams_[k] = type(hparams_[k])(v)
96
-
 
 
 
97
  if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
98
  os.makedirs(hparams_['work_dir'], exist_ok=True)
99
  with open(ckpt_config_path, 'w') as f:
@@ -102,11 +114,11 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
102
  hparams_['infer'] = args.infer
103
  hparams_['debug'] = args.debug
104
  hparams_['validate'] = args.validate
 
105
  global global_print_hparams
106
  if global_hparams:
107
  hparams.clear()
108
  hparams.update(hparams_)
109
-
110
  if print_hparams and global_print_hparams and global_hparams:
111
  print('| Hparams chains: ', config_chains)
112
  print('| Hparams: ')
@@ -114,9 +126,4 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
114
  print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
115
  print("")
116
  global_print_hparams = False
117
- # print(hparams_.keys())
118
- if hparams.get('exp_name') is None:
119
- hparams['exp_name'] = args.exp_name
120
- if hparams_.get('exp_name') is None:
121
- hparams_['exp_name'] = args.exp_name
122
- return hparams_
 
21
 
22
 
23
  def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
24
+ if config == '' and exp_name == '':
25
+ parser = argparse.ArgumentParser(description='')
26
  parser.add_argument('--config', type=str, default='',
27
  help='location of the data corpus')
28
  parser.add_argument('--exp_name', type=str, default='', help='exp_name')
29
+ parser.add_argument('-hp', '--hparams', type=str, default='',
30
  help='location of the data corpus')
31
  parser.add_argument('--infer', action='store_true', help='infer')
32
  parser.add_argument('--validate', action='store_true', help='validate')
33
  parser.add_argument('--reset', action='store_true', help='reset hparams')
34
+ parser.add_argument('--remove', action='store_true', help='remove old ckpt')
35
  parser.add_argument('--debug', action='store_true', help='debug')
36
  args, unknown = parser.parse_known_args()
37
+ print("| Unknow hparams: ", unknown)
38
  else:
39
  args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
40
+ infer=False, validate=False, reset=False, debug=False, remove=False)
41
+ global hparams
42
+ assert args.config != '' or args.exp_name != ''
43
+ if args.config != '':
44
+ assert os.path.exists(args.config)
45
 
46
  config_chains = []
47
  loaded_config = set()
48
 
49
+ def load_config(config_fn):
50
+ # deep first inheritance and avoid the second visit of one node
51
+ if not os.path.exists(config_fn):
52
+ return {}
53
  with open(config_fn) as f:
54
  hparams_ = yaml.safe_load(f)
55
  loaded_config.add(config_fn)
 
58
  if not isinstance(hparams_['base_config'], list):
59
  hparams_['base_config'] = [hparams_['base_config']]
60
  for c in hparams_['base_config']:
61
+ if c.startswith('.'):
62
+ c = f'{os.path.dirname(config_fn)}/{c}'
63
+ c = os.path.normpath(c)
64
  if c not in loaded_config:
 
 
 
65
  override_config(ret_hparams, load_config(c))
66
  override_config(ret_hparams, hparams_)
67
  else:
 
69
  config_chains.append(config_fn)
70
  return ret_hparams
71
 
 
 
72
  saved_hparams = {}
73
+ args_work_dir = ''
74
+ if args.exp_name != '':
75
+ args_work_dir = f'checkpoints/{args.exp_name}'
76
  ckpt_config_path = f'{args_work_dir}/config.yaml'
77
  if os.path.exists(ckpt_config_path):
78
+ with open(ckpt_config_path) as f:
79
+ saved_hparams_ = yaml.safe_load(f)
80
+ if saved_hparams_ is not None:
81
+ saved_hparams.update(saved_hparams_)
 
 
 
 
82
  hparams_ = {}
83
+ if args.config != '':
84
+ hparams_.update(load_config(args.config))
 
85
  if not args.reset:
86
  hparams_.update(saved_hparams)
87
  hparams_['work_dir'] = args_work_dir
88
 
89
+ # Support config overriding in command line. Support list type config overriding.
90
+ # Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
91
  if args.hparams != "":
92
  for new_hparam in args.hparams.split(","):
93
  k, v = new_hparam.split("=")
94
+ v = v.strip("\'\" ")
95
+ config_node = hparams_
96
+ for k_ in k.split(".")[:-1]:
97
+ config_node = config_node[k_]
98
+ k = k.split(".")[-1]
99
+ if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
100
+ if type(config_node[k]) == list:
101
+ v = v.replace(" ", ",")
102
+ config_node[k] = eval(v)
103
  else:
104
+ config_node[k] = type(config_node[k])(v)
105
+ if args_work_dir != '' and args.remove:
106
+ answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
107
+ if answer.lower() == "y":
108
+ remove_file(args_work_dir)
109
  if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
110
  os.makedirs(hparams_['work_dir'], exist_ok=True)
111
  with open(ckpt_config_path, 'w') as f:
 
114
  hparams_['infer'] = args.infer
115
  hparams_['debug'] = args.debug
116
  hparams_['validate'] = args.validate
117
+ hparams_['exp_name'] = args.exp_name
118
  global global_print_hparams
119
  if global_hparams:
120
  hparams.clear()
121
  hparams.update(hparams_)
 
122
  if print_hparams and global_print_hparams and global_hparams:
123
  print('| Hparams chains: ', config_chains)
124
  print('| Hparams: ')
 
126
  print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
127
  print("")
128
  global_print_hparams = False
129
+ return hparams_