Correctly handle splits for datasets.arrow_dataset.Dataset objects (#1504)
Browse files* Correctly handle splits for datasets.arrow_dataset.Dataset objects
The `load_tokenized_prepared_datasets` function currently has logic for loading a dataset from local path that always checks if a split is in the dataset. The problem is, if the dataset is loaded using `load_from_disk` and it is an Arrow-based dataset, *there is no* split information. Instead what happens is, by calling `split in ds`, it presumably searches through all the rows and columns of the arrow dataset object to find e.g., 'train' assuming `split == 'train'`. This causes the program to hang.
See https://chat.openai.com/share/0d567dbd-d60b-4079-9040-e1de58a4dff3 for context.
* chore: lint
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/utils/data/sft.py
CHANGED
@@ -379,14 +379,15 @@ def load_tokenized_prepared_datasets(
|
|
379 |
d_base_type = d_type_split[0]
|
380 |
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
381 |
|
382 |
-
if
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
|
|
390 |
|
391 |
# support for using a subset of the data
|
392 |
if config_dataset.shards:
|
|
|
379 |
d_base_type = d_type_split[0]
|
380 |
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
381 |
|
382 |
+
if isinstance(ds, DatasetDict):
|
383 |
+
if config_dataset.split and config_dataset.split in ds:
|
384 |
+
ds = ds[config_dataset.split]
|
385 |
+
elif split in ds:
|
386 |
+
ds = ds[split]
|
387 |
+
else:
|
388 |
+
raise ValueError(
|
389 |
+
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
|
390 |
+
)
|
391 |
|
392 |
# support for using a subset of the data
|
393 |
if config_dataset.shards:
|