scottifer8 winglian commited on
Commit
8fa0785
·
unverified ·
1 Parent(s): 4313b1a

Correctly handle splits for datasets.arrow_dataset.Dataset objects (#1504)

Browse files

* Correctly handle splits for datasets.arrow_dataset.Dataset objects

The `load_tokenized_prepared_datasets` function currently has logic for loading a dataset from local path that always checks if a split is in the dataset. The problem is, if the dataset is loaded using `load_from_disk` and it is an Arrow-based dataset, *there is no* split information. Instead what happens is, by calling `split in ds`, it presumably searches through all the rows and columns of the arrow dataset object to find e.g., 'train' assuming `split == 'train'`. This causes the program to hang.

See https://chat.openai.com/share/0d567dbd-d60b-4079-9040-e1de58a4dff3 for context.

* chore: lint

---------

Co-authored-by: Wing Lian <[email protected]>

Files changed (1) hide show
  1. src/axolotl/utils/data/sft.py +9 -8
src/axolotl/utils/data/sft.py CHANGED
@@ -379,14 +379,15 @@ def load_tokenized_prepared_datasets(
379
  d_base_type = d_type_split[0]
380
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
381
 
382
- if config_dataset.split and config_dataset.split in ds:
383
- ds = ds[config_dataset.split]
384
- elif split in ds:
385
- ds = ds[split]
386
- elif isinstance(ds, DatasetDict):
387
- raise ValueError(
388
- f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
389
- )
 
390
 
391
  # support for using a subset of the data
392
  if config_dataset.shards:
 
379
  d_base_type = d_type_split[0]
380
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
381
 
382
+ if isinstance(ds, DatasetDict):
383
+ if config_dataset.split and config_dataset.split in ds:
384
+ ds = ds[config_dataset.split]
385
+ elif split in ds:
386
+ ds = ds[split]
387
+ else:
388
+ raise ValueError(
389
+ f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
390
+ )
391
 
392
  # support for using a subset of the data
393
  if config_dataset.shards: