add optimization for group-by-len (#563)
Browse files- src/axolotl/utils/trainer.py +10 -0
src/axolotl/utils/trainer.py
CHANGED
@@ -358,7 +358,14 @@ class ReLoRATrainer(AxolotlTrainer):
|
|
358 |
|
359 |
|
360 |
def add_position_ids(sample):
|
|
|
361 |
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
return sample
|
363 |
|
364 |
|
@@ -382,6 +389,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|
382 |
if eval_dataset:
|
383 |
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
384 |
|
|
|
|
|
|
|
385 |
if cfg.sample_packing:
|
386 |
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
387 |
if eval_dataset:
|
|
|
358 |
|
359 |
|
360 |
def add_position_ids(sample):
|
361 |
+
sample_len = len(sample["input_ids"])
|
362 |
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
363 |
+
sample["length"] = sample_len
|
364 |
+
return sample
|
365 |
+
|
366 |
+
|
367 |
+
def add_length(sample):
|
368 |
+
sample["length"] = len(sample["input_ids"])
|
369 |
return sample
|
370 |
|
371 |
|
|
|
389 |
if eval_dataset:
|
390 |
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
391 |
|
392 |
+
if cfg.group_by_length:
|
393 |
+
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
|
394 |
+
|
395 |
if cfg.sample_packing:
|
396 |
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
397 |
if eval_dataset:
|