winglian commited on
Commit
0b4cf5b
·
unverified ·
1 Parent(s): 78ee2cd

workaround for md5 variations (#533)

Browse files

* workaround for md5 variations

* refactor the prepared hash too

Files changed (2) hide show
  1. src/axolotl/utils/data.py +15 -13
  2. tests/test_data.py +64 -0
src/axolotl/utils/data.py CHANGED
@@ -2,7 +2,6 @@
2
  import functools
3
  import hashlib
4
  import logging
5
- from hashlib import md5
6
  from pathlib import Path
7
  from typing import Tuple, Union
8
 
@@ -52,6 +51,13 @@ LOG = logging.getLogger("axolotl")
52
  DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
53
 
54
 
 
 
 
 
 
 
 
55
  def prepare_dataset(cfg, tokenizer):
56
  if not cfg.pretraining_dataset:
57
  with zero_first(is_main_process()):
@@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets(
88
  ) -> DatasetDict:
89
  tokenizer_name = tokenizer.__class__.__name__
90
  ds_hash = str(
91
- md5( # nosec
92
  (
93
  str(cfg.sequence_len)
94
  + "@"
@@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets(
97
  )
98
  + "|"
99
  + tokenizer_name
100
- ).encode("utf-8")
101
- ).hexdigest()
102
  )
103
  prepared_ds_path = (
104
  Path(cfg.dataset_prepared_path) / ds_hash
@@ -374,7 +380,7 @@ def load_prepare_datasets(
374
  # see if we can go ahead and load the stacked dataset
375
  seed = f"@{str(cfg.seed)}" if cfg.seed else ""
376
  ds_hash = str(
377
- md5( # nosec
378
  (
379
  str(cfg.sequence_len)
380
  + "@"
@@ -385,8 +391,8 @@ def load_prepare_datasets(
385
  )
386
  + "|"
387
  + tokenizer_name
388
- ).encode("utf-8")
389
- ).hexdigest()
390
  )
391
  prepared_ds_path = (
392
  Path(cfg.dataset_prepared_path) / ds_hash
@@ -500,12 +506,8 @@ def load_prepare_datasets(
500
  + "|"
501
  + str(cfg.seed or 42)
502
  )
503
- train_fingerprint = hashlib.md5(
504
- to_hash_train.encode(), usedforsecurity=False
505
- ).hexdigest()
506
- test_fingerprint = hashlib.md5(
507
- to_hash_test.encode(), usedforsecurity=False
508
- ).hexdigest()
509
 
510
  with zero_first(is_main_process()):
511
  dataset = dataset.train_test_split(
 
2
  import functools
3
  import hashlib
4
  import logging
 
5
  from pathlib import Path
6
  from typing import Tuple, Union
7
 
 
51
  DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
52
 
53
 
54
+ def md5(to_hash: str, encoding: str = "utf-8") -> str:
55
+ try:
56
+ return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
57
+ except TypeError:
58
+ return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
59
+
60
+
61
  def prepare_dataset(cfg, tokenizer):
62
  if not cfg.pretraining_dataset:
63
  with zero_first(is_main_process()):
 
94
  ) -> DatasetDict:
95
  tokenizer_name = tokenizer.__class__.__name__
96
  ds_hash = str(
97
+ md5(
98
  (
99
  str(cfg.sequence_len)
100
  + "@"
 
103
  )
104
  + "|"
105
  + tokenizer_name
106
+ )
107
+ )
108
  )
109
  prepared_ds_path = (
110
  Path(cfg.dataset_prepared_path) / ds_hash
 
380
  # see if we can go ahead and load the stacked dataset
381
  seed = f"@{str(cfg.seed)}" if cfg.seed else ""
382
  ds_hash = str(
383
+ md5(
384
  (
385
  str(cfg.sequence_len)
386
  + "@"
 
391
  )
392
  + "|"
393
  + tokenizer_name
394
+ )
395
+ )
396
  )
397
  prepared_ds_path = (
398
  Path(cfg.dataset_prepared_path) / ds_hash
 
506
  + "|"
507
  + str(cfg.seed or 42)
508
  )
509
+ train_fingerprint = md5(to_hash_train)
510
+ test_fingerprint = md5(to_hash_test)
 
 
 
 
511
 
512
  with zero_first(is_main_process()):
513
  dataset = dataset.train_test_split(
tests/test_data.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ test module for the axolotl.utis.data module
3
+ """
4
+ import unittest
5
+
6
+ from transformers import LlamaTokenizer
7
+
8
+ from axolotl.utils.data import encode_pretraining, md5
9
+
10
+
11
+ class TestEncodePretraining(unittest.TestCase):
12
+ """
13
+ test class for encode pretraining and md5 helper
14
+ """
15
+
16
+ def setUp(self):
17
+ self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
18
+ self.tokenizer.add_special_tokens(
19
+ {
20
+ "eos_token": "</s>",
21
+ "bos_token": "<s>",
22
+ "unk_token": "<unk>",
23
+ "pad_token": "<pad>",
24
+ }
25
+ )
26
+ self.max_tokens = 15 # set a small number for easy inspection
27
+
28
+ def test_encode_pretraining(self):
29
+ examples = {
30
+ "text": [
31
+ "Hello, world!",
32
+ "Nice to meet you.",
33
+ "lorem ipsum dolor sit amet.",
34
+ "Nice to meet you again!.",
35
+ "hello, hello",
36
+ ]
37
+ }
38
+ result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
39
+
40
+ self.assertEqual(len(result["input_ids"]), 3)
41
+
42
+ # Assert the length of input_ids and attention_mask is correct
43
+ self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
44
+ self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
45
+
46
+ # Assert EOS and PAD tokens are correctly added
47
+ # hello world! is 4 tokens
48
+ self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
49
+ self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
50
+ self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
51
+ # second part, 5 tokens
52
+ self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
53
+ self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
54
+ self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
55
+
56
+ def test_md5(self):
57
+ self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
58
+ self.assertEqual(
59
+ md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
60
+ )
61
+
62
+
63
+ if __name__ == "__main__":
64
+ unittest.main()