Fix data.py lint
Browse files- src/axolotl/utils/data.py +18 -15
src/axolotl/utils/data.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import logging
|
2 |
from hashlib import md5
|
3 |
from pathlib import Path
|
@@ -46,12 +48,12 @@ def load_tokenized_prepared_datasets(
|
|
46 |
md5(
|
47 |
(
|
48 |
str(cfg.sequence_len)
|
49 |
-
+ "@"
|
50 |
-
+ "|".join(
|
51 |
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
52 |
)
|
53 |
-
+ "|"
|
54 |
-
+ tokenizer_name
|
55 |
).encode("utf-8")
|
56 |
).hexdigest()
|
57 |
)
|
@@ -81,6 +83,7 @@ def load_tokenized_prepared_datasets(
|
|
81 |
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
82 |
logging.info("Loading raw datasets...")
|
83 |
datasets = []
|
|
|
84 |
for d in cfg.datasets:
|
85 |
ds: Union[Dataset, DatasetDict] = None
|
86 |
ds_from_hub = False
|
@@ -229,7 +232,7 @@ def load_tokenized_prepared_datasets(
|
|
229 |
|
230 |
samples = []
|
231 |
for d in datasets:
|
232 |
-
samples = samples +
|
233 |
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
234 |
if cfg.local_rank == 0:
|
235 |
logging.info(
|
@@ -265,14 +268,14 @@ def load_prepare_datasets(
|
|
265 |
md5(
|
266 |
(
|
267 |
str(cfg.sequence_len)
|
268 |
-
+ "@"
|
269 |
-
+ str(max_packed_sequence_len)
|
270 |
-
+ seed
|
271 |
-
+ "|".join(
|
272 |
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
273 |
)
|
274 |
-
+ "|"
|
275 |
-
+ tokenizer_name
|
276 |
).encode("utf-8")
|
277 |
).hexdigest()
|
278 |
)
|
@@ -327,7 +330,7 @@ def load_prepare_datasets(
|
|
327 |
logging.info(
|
328 |
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
329 |
)
|
330 |
-
dataset = Dataset.from_list(
|
331 |
|
332 |
# filter out bad data
|
333 |
dataset = Dataset.from_list(
|
@@ -335,9 +338,9 @@ def load_prepare_datasets(
|
|
335 |
d
|
336 |
for d in dataset
|
337 |
if len(d["input_ids"]) < cfg.sequence_len
|
338 |
-
and len(d["input_ids"]) > 0
|
339 |
-
and len(d["input_ids"]) == len(d["attention_mask"])
|
340 |
-
and len(d["input_ids"]) == len(d["labels"])
|
341 |
]
|
342 |
)
|
343 |
|
|
|
1 |
+
"""Module containing data utilities for Axolotl"""
|
2 |
+
|
3 |
import logging
|
4 |
from hashlib import md5
|
5 |
from pathlib import Path
|
|
|
48 |
md5(
|
49 |
(
|
50 |
str(cfg.sequence_len)
|
51 |
+
+ "@"
|
52 |
+
+ "|".join(
|
53 |
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
54 |
)
|
55 |
+
+ "|"
|
56 |
+
+ tokenizer_name
|
57 |
).encode("utf-8")
|
58 |
).hexdigest()
|
59 |
)
|
|
|
83 |
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
84 |
logging.info("Loading raw datasets...")
|
85 |
datasets = []
|
86 |
+
# pylint: disable=invalid-name
|
87 |
for d in cfg.datasets:
|
88 |
ds: Union[Dataset, DatasetDict] = None
|
89 |
ds_from_hub = False
|
|
|
232 |
|
233 |
samples = []
|
234 |
for d in datasets:
|
235 |
+
samples = samples + list(d)
|
236 |
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
237 |
if cfg.local_rank == 0:
|
238 |
logging.info(
|
|
|
268 |
md5(
|
269 |
(
|
270 |
str(cfg.sequence_len)
|
271 |
+
+ "@"
|
272 |
+
+ str(max_packed_sequence_len)
|
273 |
+
+ seed
|
274 |
+
+ "|".join(
|
275 |
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
276 |
)
|
277 |
+
+ "|"
|
278 |
+
+ tokenizer_name
|
279 |
).encode("utf-8")
|
280 |
).hexdigest()
|
281 |
)
|
|
|
330 |
logging.info(
|
331 |
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
332 |
)
|
333 |
+
dataset = Dataset.from_list(list(constant_len_dataset))
|
334 |
|
335 |
# filter out bad data
|
336 |
dataset = Dataset.from_list(
|
|
|
338 |
d
|
339 |
for d in dataset
|
340 |
if len(d["input_ids"]) < cfg.sequence_len
|
341 |
+
and len(d["input_ids"]) > 0
|
342 |
+
and len(d["input_ids"]) == len(d["attention_mask"])
|
343 |
+
and len(d["input_ids"]) == len(d["labels"])
|
344 |
]
|
345 |
)
|
346 |
|