workaround for md5 variations (#533)
Browse files* workaround for md5 variations
* refactor the prepared hash too
- src/axolotl/utils/data.py +15 -13
- tests/test_data.py +64 -0
src/axolotl/utils/data.py
CHANGED
@@ -2,7 +2,6 @@
|
|
2 |
import functools
|
3 |
import hashlib
|
4 |
import logging
|
5 |
-
from hashlib import md5
|
6 |
from pathlib import Path
|
7 |
from typing import Tuple, Union
|
8 |
|
@@ -52,6 +51,13 @@ LOG = logging.getLogger("axolotl")
|
|
52 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def prepare_dataset(cfg, tokenizer):
|
56 |
if not cfg.pretraining_dataset:
|
57 |
with zero_first(is_main_process()):
|
@@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets(
|
|
88 |
) -> DatasetDict:
|
89 |
tokenizer_name = tokenizer.__class__.__name__
|
90 |
ds_hash = str(
|
91 |
-
md5(
|
92 |
(
|
93 |
str(cfg.sequence_len)
|
94 |
+ "@"
|
@@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets(
|
|
97 |
)
|
98 |
+ "|"
|
99 |
+ tokenizer_name
|
100 |
-
)
|
101 |
-
)
|
102 |
)
|
103 |
prepared_ds_path = (
|
104 |
Path(cfg.dataset_prepared_path) / ds_hash
|
@@ -374,7 +380,7 @@ def load_prepare_datasets(
|
|
374 |
# see if we can go ahead and load the stacked dataset
|
375 |
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
376 |
ds_hash = str(
|
377 |
-
md5(
|
378 |
(
|
379 |
str(cfg.sequence_len)
|
380 |
+ "@"
|
@@ -385,8 +391,8 @@ def load_prepare_datasets(
|
|
385 |
)
|
386 |
+ "|"
|
387 |
+ tokenizer_name
|
388 |
-
)
|
389 |
-
)
|
390 |
)
|
391 |
prepared_ds_path = (
|
392 |
Path(cfg.dataset_prepared_path) / ds_hash
|
@@ -500,12 +506,8 @@ def load_prepare_datasets(
|
|
500 |
+ "|"
|
501 |
+ str(cfg.seed or 42)
|
502 |
)
|
503 |
-
train_fingerprint =
|
504 |
-
|
505 |
-
).hexdigest()
|
506 |
-
test_fingerprint = hashlib.md5(
|
507 |
-
to_hash_test.encode(), usedforsecurity=False
|
508 |
-
).hexdigest()
|
509 |
|
510 |
with zero_first(is_main_process()):
|
511 |
dataset = dataset.train_test_split(
|
|
|
2 |
import functools
|
3 |
import hashlib
|
4 |
import logging
|
|
|
5 |
from pathlib import Path
|
6 |
from typing import Tuple, Union
|
7 |
|
|
|
51 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
52 |
|
53 |
|
54 |
+
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
55 |
+
try:
|
56 |
+
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
57 |
+
except TypeError:
|
58 |
+
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
59 |
+
|
60 |
+
|
61 |
def prepare_dataset(cfg, tokenizer):
|
62 |
if not cfg.pretraining_dataset:
|
63 |
with zero_first(is_main_process()):
|
|
|
94 |
) -> DatasetDict:
|
95 |
tokenizer_name = tokenizer.__class__.__name__
|
96 |
ds_hash = str(
|
97 |
+
md5(
|
98 |
(
|
99 |
str(cfg.sequence_len)
|
100 |
+ "@"
|
|
|
103 |
)
|
104 |
+ "|"
|
105 |
+ tokenizer_name
|
106 |
+
)
|
107 |
+
)
|
108 |
)
|
109 |
prepared_ds_path = (
|
110 |
Path(cfg.dataset_prepared_path) / ds_hash
|
|
|
380 |
# see if we can go ahead and load the stacked dataset
|
381 |
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
382 |
ds_hash = str(
|
383 |
+
md5(
|
384 |
(
|
385 |
str(cfg.sequence_len)
|
386 |
+ "@"
|
|
|
391 |
)
|
392 |
+ "|"
|
393 |
+ tokenizer_name
|
394 |
+
)
|
395 |
+
)
|
396 |
)
|
397 |
prepared_ds_path = (
|
398 |
Path(cfg.dataset_prepared_path) / ds_hash
|
|
|
506 |
+ "|"
|
507 |
+ str(cfg.seed or 42)
|
508 |
)
|
509 |
+
train_fingerprint = md5(to_hash_train)
|
510 |
+
test_fingerprint = md5(to_hash_test)
|
|
|
|
|
|
|
|
|
511 |
|
512 |
with zero_first(is_main_process()):
|
513 |
dataset = dataset.train_test_split(
|
tests/test_data.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
test module for the axolotl.utis.data module
|
3 |
+
"""
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
from transformers import LlamaTokenizer
|
7 |
+
|
8 |
+
from axolotl.utils.data import encode_pretraining, md5
|
9 |
+
|
10 |
+
|
11 |
+
class TestEncodePretraining(unittest.TestCase):
|
12 |
+
"""
|
13 |
+
test class for encode pretraining and md5 helper
|
14 |
+
"""
|
15 |
+
|
16 |
+
def setUp(self):
|
17 |
+
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
|
18 |
+
self.tokenizer.add_special_tokens(
|
19 |
+
{
|
20 |
+
"eos_token": "</s>",
|
21 |
+
"bos_token": "<s>",
|
22 |
+
"unk_token": "<unk>",
|
23 |
+
"pad_token": "<pad>",
|
24 |
+
}
|
25 |
+
)
|
26 |
+
self.max_tokens = 15 # set a small number for easy inspection
|
27 |
+
|
28 |
+
def test_encode_pretraining(self):
|
29 |
+
examples = {
|
30 |
+
"text": [
|
31 |
+
"Hello, world!",
|
32 |
+
"Nice to meet you.",
|
33 |
+
"lorem ipsum dolor sit amet.",
|
34 |
+
"Nice to meet you again!.",
|
35 |
+
"hello, hello",
|
36 |
+
]
|
37 |
+
}
|
38 |
+
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
39 |
+
|
40 |
+
self.assertEqual(len(result["input_ids"]), 3)
|
41 |
+
|
42 |
+
# Assert the length of input_ids and attention_mask is correct
|
43 |
+
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
|
44 |
+
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
|
45 |
+
|
46 |
+
# Assert EOS and PAD tokens are correctly added
|
47 |
+
# hello world! is 4 tokens
|
48 |
+
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
|
49 |
+
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
|
50 |
+
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
|
51 |
+
# second part, 5 tokens
|
52 |
+
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
|
53 |
+
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
|
54 |
+
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
|
55 |
+
|
56 |
+
def test_md5(self):
|
57 |
+
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
|
58 |
+
self.assertEqual(
|
59 |
+
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
unittest.main()
|