How to operate on columns of a dataset

I am playing with LLMs and have the usual two columns “input_ids” and “labels” in a Dataset object. The dataset has been created from a DataFrame object.

dataset = Dataset.from_pandas(df)

Then I encode two columns of this dataframe using a Tokenizer.

I would like to compute the maximum length of the encodings.

First way:

max_source_len = max(len(x) for x in train_dataset["input_ids"])
max_target_len = max(len(x) for x in train_dataset["labels"])
print(max_source_len)
print(max_target_len)

This takes approx. 17 seconds.

I have tested also variations of the previous code, but the problem seems to be in the transformation of the columns to list objects. I have also tried to force the dataset to stay in memory, but, frankly, I did not understand very well the documentation of this (it seems possible only when a dataset is loaded with load_dataset).

If I transform the dataset back into a pandas DataFrame:

df2 =  dataset.to_pandas()
print(df2.input_ids.apply(len).max())
print(df2.labels.apply(len).max())

the two values are printed in less than 2secs.

How am I supposed to apply operations on the columns of a dataset?

2 Likes

Be warn, I am no expert in this field, but in my experience working directly with columns in a HF dataset is slower compared to other methods. When I want to access columns in the dataset, I use the dataset.set_format() function to set the columns you want to a specific format. For your example, when I load my dataset and then use this function as such going from this:

train_dataset = dataset_dict['train']
max_source_len = max(len(x) for x in train_dataset['input_ids'])
print(f"Max source len: {max_source_len}")

to this:

train_dataset = dataset_dict['train']
train_dataset.set_format('numpy')
max_source_len = max(len(x) for x in train_dataset['input_ids'])
print(f"Max source len: {max_source_len}")

The execution takes approximately 4 times less time and still keeps the dataset intact.

If you want to work further with your columns and reformate them somehow the dataset.map function helps with this, the function also accepts the batched argument, which speeds up the process, but in this example it is not needed and is even slower.

Hope this helps a little.

1 Like