winglian commited on
Commit
ca84cca
·
unverified ·
1 Parent(s): 32eeeb5

convert exponential notation lr to floats (#771)

Browse files
src/axolotl/utils/config.py CHANGED
@@ -119,6 +119,9 @@ def normalize_config(cfg):
119
  or (cfg.model_type and "mistral" in cfg.model_type.lower())
120
  )
121
 
 
 
 
122
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
123
 
124
 
 
119
  or (cfg.model_type and "mistral" in cfg.model_type.lower())
120
  )
121
 
122
+ if isinstance(cfg.learning_rate, str):
123
+ cfg.learning_rate = float(cfg.learning_rate)
124
+
125
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
126
 
127
 
tests/test_normalize_config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test classes for checking functionality of the cfg normalization
3
+ """
4
+ import unittest
5
+
6
+ from axolotl.utils.config import normalize_config
7
+ from axolotl.utils.dict import DictDefault
8
+
9
+
10
+ class NormalizeConfigTestCase(unittest.TestCase):
11
+ """
12
+ test class for normalize_config checks
13
+ """
14
+
15
+ def _get_base_cfg(self):
16
+ return DictDefault(
17
+ {
18
+ "base_model": "JackFram/llama-68m",
19
+ "base_model_config": "JackFram/llama-68m",
20
+ "tokenizer_type": "LlamaTokenizer",
21
+ "num_epochs": 1,
22
+ "micro_batch_size": 1,
23
+ "gradient_accumulation_steps": 1,
24
+ }
25
+ )
26
+
27
+ def test_lr_as_float(self):
28
+ cfg = (
29
+ self._get_base_cfg()
30
+ | DictDefault( # pylint: disable=unsupported-binary-operation
31
+ {
32
+ "learning_rate": "5e-5",
33
+ }
34
+ )
35
+ )
36
+
37
+ normalize_config(cfg)
38
+
39
+ assert cfg.learning_rate == 0.00005