Nanobit commited on
Commit
cb7cd34
·
1 Parent(s): d57ba56

Fix data.py lint

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/data.py +18 -15
src/axolotl/utils/data.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import logging
2
  from hashlib import md5
3
  from pathlib import Path
@@ -46,12 +48,12 @@ def load_tokenized_prepared_datasets(
46
  md5(
47
  (
48
  str(cfg.sequence_len)
49
- + "@" # noqa: W503
50
- + "|".join( # noqa: W503
51
  sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
52
  )
53
- + "|" # noqa: W503
54
- + tokenizer_name # noqa: W503
55
  ).encode("utf-8")
56
  ).hexdigest()
57
  )
@@ -81,6 +83,7 @@ def load_tokenized_prepared_datasets(
81
  logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
82
  logging.info("Loading raw datasets...")
83
  datasets = []
 
84
  for d in cfg.datasets:
85
  ds: Union[Dataset, DatasetDict] = None
86
  ds_from_hub = False
@@ -229,7 +232,7 @@ def load_tokenized_prepared_datasets(
229
 
230
  samples = []
231
  for d in datasets:
232
- samples = samples + [i for i in d]
233
  dataset = Dataset.from_list(samples).shuffle(seed=42)
234
  if cfg.local_rank == 0:
235
  logging.info(
@@ -265,14 +268,14 @@ def load_prepare_datasets(
265
  md5(
266
  (
267
  str(cfg.sequence_len)
268
- + "@" # noqa: W503
269
- + str(max_packed_sequence_len) # noqa: W503
270
- + seed # noqa: W503
271
- + "|".join( # noqa: W503
272
  sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
273
  )
274
- + "|" # noqa: W503
275
- + tokenizer_name # noqa: W503
276
  ).encode("utf-8")
277
  ).hexdigest()
278
  )
@@ -327,7 +330,7 @@ def load_prepare_datasets(
327
  logging.info(
328
  f"packing master dataset to len: {cfg.max_packed_sequence_len}"
329
  )
330
- dataset = Dataset.from_list([_ for _ in constant_len_dataset])
331
 
332
  # filter out bad data
333
  dataset = Dataset.from_list(
@@ -335,9 +338,9 @@ def load_prepare_datasets(
335
  d
336
  for d in dataset
337
  if len(d["input_ids"]) < cfg.sequence_len
338
- and len(d["input_ids"]) > 0 # noqa: W503
339
- and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
340
- and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
341
  ]
342
  )
343
 
 
1
+ """Module containing data utilities for Axolotl"""
2
+
3
  import logging
4
  from hashlib import md5
5
  from pathlib import Path
 
48
  md5(
49
  (
50
  str(cfg.sequence_len)
51
+ + "@"
52
+ + "|".join(
53
  sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
54
  )
55
+ + "|"
56
+ + tokenizer_name
57
  ).encode("utf-8")
58
  ).hexdigest()
59
  )
 
83
  logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
84
  logging.info("Loading raw datasets...")
85
  datasets = []
86
+ # pylint: disable=invalid-name
87
  for d in cfg.datasets:
88
  ds: Union[Dataset, DatasetDict] = None
89
  ds_from_hub = False
 
232
 
233
  samples = []
234
  for d in datasets:
235
+ samples = samples + list(d)
236
  dataset = Dataset.from_list(samples).shuffle(seed=42)
237
  if cfg.local_rank == 0:
238
  logging.info(
 
268
  md5(
269
  (
270
  str(cfg.sequence_len)
271
+ + "@"
272
+ + str(max_packed_sequence_len)
273
+ + seed
274
+ + "|".join(
275
  sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
276
  )
277
+ + "|"
278
+ + tokenizer_name
279
  ).encode("utf-8")
280
  ).hexdigest()
281
  )
 
330
  logging.info(
331
  f"packing master dataset to len: {cfg.max_packed_sequence_len}"
332
  )
333
+ dataset = Dataset.from_list(list(constant_len_dataset))
334
 
335
  # filter out bad data
336
  dataset = Dataset.from_list(
 
338
  d
339
  for d in dataset
340
  if len(d["input_ids"]) < cfg.sequence_len
341
+ and len(d["input_ids"]) > 0
342
+ and len(d["input_ids"]) == len(d["attention_mask"])
343
+ and len(d["input_ids"]) == len(d["labels"])
344
  ]
345
  )
346