swap the data collator for evals if not using sample packing (#1076)
Browse files* swap the data collator for evals if not using sample packing
* drop last from dataloader to help with issues with evals
src/axolotl/core/trainer_builder.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
"""
|
2 |
Builder for the training args and trainer
|
3 |
"""
|
@@ -137,10 +138,19 @@ class AxolotlTrainer(Trainer):
|
|
137 |
args = None # type: AxolotlTrainingArguments
|
138 |
tag_names = ["axolotl"]
|
139 |
|
140 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
self.num_epochs = num_epochs
|
142 |
self.bench_data_collator = bench_data_collator
|
143 |
-
|
|
|
|
|
144 |
|
145 |
def create_scheduler(
|
146 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
@@ -239,6 +249,16 @@ class AxolotlTrainer(Trainer):
|
|
239 |
return super().get_train_dataloader()
|
240 |
|
241 |
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
243 |
eval_dataset = (
|
244 |
eval_dataset if eval_dataset is not None else self.eval_dataset
|
@@ -269,6 +289,7 @@ class AxolotlTrainer(Trainer):
|
|
269 |
return self.accelerator.prepare_data_loader(
|
270 |
DataLoader(eval_dataset, **dataloader_params)
|
271 |
)
|
|
|
272 |
return super().get_eval_dataloader(eval_dataset)
|
273 |
|
274 |
def _get_bench_sampler(
|
@@ -651,6 +672,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
651 |
training_arguments_kwargs[
|
652 |
"dataloader_prefetch_factor"
|
653 |
] = self.cfg.dataloader_prefetch_factor
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
|
655 |
if self.cfg.val_set_size == 0:
|
656 |
# no eval set, so don't eval
|
@@ -831,6 +858,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
831 |
eval_dataset=self.eval_dataset,
|
832 |
args=training_args,
|
833 |
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
|
|
|
|
|
|
834 |
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
835 |
self.tokenizer,
|
836 |
return_tensors="pt",
|
@@ -851,14 +881,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
851 |
|
852 |
return trainer
|
853 |
|
854 |
-
def build_collator(
|
|
|
|
|
855 |
if training_args.pretraining:
|
856 |
return None
|
857 |
|
858 |
if self.cfg.model_config_type == "mamba":
|
859 |
return MambaDataCollator(tokenizer=self.tokenizer)
|
860 |
|
861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
862 |
return BatchSamplerDataCollatorForSeq2Seq(
|
863 |
self.tokenizer,
|
864 |
return_tensors="pt",
|
|
|
1 |
+
# pylint: disable=too-many-lines
|
2 |
"""
|
3 |
Builder for the training args and trainer
|
4 |
"""
|
|
|
138 |
args = None # type: AxolotlTrainingArguments
|
139 |
tag_names = ["axolotl"]
|
140 |
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
*_args,
|
144 |
+
num_epochs=1,
|
145 |
+
bench_data_collator=None,
|
146 |
+
eval_data_collator=None,
|
147 |
+
**kwargs
|
148 |
+
):
|
149 |
self.num_epochs = num_epochs
|
150 |
self.bench_data_collator = bench_data_collator
|
151 |
+
self.eval_data_collator = eval_data_collator
|
152 |
+
super().__init__(*_args, **kwargs)
|
153 |
+
self.train_data_collator = self.data_collator
|
154 |
|
155 |
def create_scheduler(
|
156 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
|
249 |
return super().get_train_dataloader()
|
250 |
|
251 |
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
252 |
+
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
253 |
+
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
254 |
+
self.eval_data_collator
|
255 |
+
)
|
256 |
+
dataloader = super().get_eval_dataloader(eval_dataset)
|
257 |
+
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
258 |
+
self.train_data_collator
|
259 |
+
)
|
260 |
+
return dataloader
|
261 |
+
|
262 |
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
263 |
eval_dataset = (
|
264 |
eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
|
289 |
return self.accelerator.prepare_data_loader(
|
290 |
DataLoader(eval_dataset, **dataloader_params)
|
291 |
)
|
292 |
+
|
293 |
return super().get_eval_dataloader(eval_dataset)
|
294 |
|
295 |
def _get_bench_sampler(
|
|
|
672 |
training_arguments_kwargs[
|
673 |
"dataloader_prefetch_factor"
|
674 |
] = self.cfg.dataloader_prefetch_factor
|
675 |
+
if self.cfg.dataloader_drop_last is not None:
|
676 |
+
training_arguments_kwargs[
|
677 |
+
"dataloader_drop_last"
|
678 |
+
] = self.cfg.dataloader_drop_last
|
679 |
+
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
680 |
+
training_arguments_kwargs["dataloader_drop_last"] = True
|
681 |
|
682 |
if self.cfg.val_set_size == 0:
|
683 |
# no eval set, so don't eval
|
|
|
858 |
eval_dataset=self.eval_dataset,
|
859 |
args=training_args,
|
860 |
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
861 |
+
eval_data_collator=self.build_collator(
|
862 |
+
training_args, is_eval=True, **data_collator_kwargs
|
863 |
+
),
|
864 |
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
865 |
self.tokenizer,
|
866 |
return_tensors="pt",
|
|
|
881 |
|
882 |
return trainer
|
883 |
|
884 |
+
def build_collator(
|
885 |
+
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
886 |
+
):
|
887 |
if training_args.pretraining:
|
888 |
return None
|
889 |
|
890 |
if self.cfg.model_config_type == "mamba":
|
891 |
return MambaDataCollator(tokenizer=self.tokenizer)
|
892 |
|
893 |
+
use_batch_sampler_collator = False
|
894 |
+
if is_eval is False and training_args.sample_packing:
|
895 |
+
use_batch_sampler_collator = True
|
896 |
+
if is_eval and training_args.eval_sample_packing:
|
897 |
+
use_batch_sampler_collator = True
|
898 |
+
|
899 |
+
if use_batch_sampler_collator:
|
900 |
return BatchSamplerDataCollatorForSeq2Seq(
|
901 |
self.tokenizer,
|
902 |
return_tensors="pt",
|