Hi,
I’d like to know why is group_texts
soooo slow.
Initially it’s super fast but when it reaches around 92%, I see only 2 threads (that’s 16/8) working from top
command.
The dataset has over 80 million examples so it talks almost forever.
I’m wondering why dataset.map()
uses only 1/8 threads after 92%.
Is there any better solution for this preprocess task?
You can copy&&paste the code piece below and just run it. I copied it from HF official tutorial.
Cheers!
Aiden
from datasets import concatenate_datasets, load_dataset
from transformers import BertTokenizerFast
def preprocess_function(examples):
tokenizer = BertTokenizerFast.from_pretrained('prajjwal1/bert-tiny')
return tokenizer([" ".join(x) for x in examples["text"]])
def group_texts(examples):
block_size = 384
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of block_size.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
# cpu cores: 16
num_proc_1 = 16
num_proc_2 = 16
# num_proc_2 = 16 * 8
bookcorpus = load_dataset("bookcorpus", split="train")
wiki = load_dataset("wikipedia", "20220301.en", split="train")
wiki = wiki.remove_columns([col for col in wiki.column_names if col != "text"]) # only keep the 'text' column
assert bookcorpus.features.type == wiki.features.type
raw_datasets_bert = concatenate_datasets([bookcorpus, wiki])
tokenized_datasets_bert_1 = raw_datasets_bert.map(
preprocess_function,
batched=True,
batch_size=20000,
writer_batch_size=20000,
remove_columns=["text"],
num_proc=num_proc_1
)
print("start tokenized_datasets_bert_2")
tokenized_datasets_bert_2 = tokenized_datasets_bert_1.map(
group_texts,
batched=True,
batch_size=2_000,
writer_batch_size=2_000,
num_proc=num_proc_2
)