winglian commited on
Commit
ead34c5
·
unverified ·
1 Parent(s): ec02b7c

swap the data collator for evals if not using sample packing (#1076)

Browse files

* swap the data collator for evals if not using sample packing

* drop last from dataloader to help with issues with evals

Files changed (1) hide show
  1. src/axolotl/core/trainer_builder.py +42 -4
src/axolotl/core/trainer_builder.py CHANGED
@@ -1,3 +1,4 @@
 
1
  """
2
  Builder for the training args and trainer
3
  """
@@ -137,10 +138,19 @@ class AxolotlTrainer(Trainer):
137
  args = None # type: AxolotlTrainingArguments
138
  tag_names = ["axolotl"]
139
 
140
- def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
 
 
 
 
 
 
 
141
  self.num_epochs = num_epochs
142
  self.bench_data_collator = bench_data_collator
143
- super().__init__(*args, **kwargs)
 
 
144
 
145
  def create_scheduler(
146
  self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
@@ -239,6 +249,16 @@ class AxolotlTrainer(Trainer):
239
  return super().get_train_dataloader()
240
 
241
  def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
 
 
 
 
 
 
 
 
 
 
242
  if self.args.sample_packing and self.args.eval_sample_packing is not False:
243
  eval_dataset = (
244
  eval_dataset if eval_dataset is not None else self.eval_dataset
@@ -269,6 +289,7 @@ class AxolotlTrainer(Trainer):
269
  return self.accelerator.prepare_data_loader(
270
  DataLoader(eval_dataset, **dataloader_params)
271
  )
 
272
  return super().get_eval_dataloader(eval_dataset)
273
 
274
  def _get_bench_sampler(
@@ -651,6 +672,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
651
  training_arguments_kwargs[
652
  "dataloader_prefetch_factor"
653
  ] = self.cfg.dataloader_prefetch_factor
 
 
 
 
 
 
654
 
655
  if self.cfg.val_set_size == 0:
656
  # no eval set, so don't eval
@@ -831,6 +858,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
831
  eval_dataset=self.eval_dataset,
832
  args=training_args,
833
  data_collator=self.build_collator(training_args, **data_collator_kwargs),
 
 
 
834
  bench_data_collator=transformers.DataCollatorForSeq2Seq(
835
  self.tokenizer,
836
  return_tensors="pt",
@@ -851,14 +881,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
851
 
852
  return trainer
853
 
854
- def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
 
 
855
  if training_args.pretraining:
856
  return None
857
 
858
  if self.cfg.model_config_type == "mamba":
859
  return MambaDataCollator(tokenizer=self.tokenizer)
860
 
861
- if training_args.sample_packing:
 
 
 
 
 
 
862
  return BatchSamplerDataCollatorForSeq2Seq(
863
  self.tokenizer,
864
  return_tensors="pt",
 
1
+ # pylint: disable=too-many-lines
2
  """
3
  Builder for the training args and trainer
4
  """
 
138
  args = None # type: AxolotlTrainingArguments
139
  tag_names = ["axolotl"]
140
 
141
+ def __init__(
142
+ self,
143
+ *_args,
144
+ num_epochs=1,
145
+ bench_data_collator=None,
146
+ eval_data_collator=None,
147
+ **kwargs
148
+ ):
149
  self.num_epochs = num_epochs
150
  self.bench_data_collator = bench_data_collator
151
+ self.eval_data_collator = eval_data_collator
152
+ super().__init__(*_args, **kwargs)
153
+ self.train_data_collator = self.data_collator
154
 
155
  def create_scheduler(
156
  self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
 
249
  return super().get_train_dataloader()
250
 
251
  def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
252
+ if self.args.sample_packing and self.args.eval_sample_packing is False:
253
+ self.data_collator = ( # pylint: disable=attribute-defined-outside-init
254
+ self.eval_data_collator
255
+ )
256
+ dataloader = super().get_eval_dataloader(eval_dataset)
257
+ self.data_collator = ( # pylint: disable=attribute-defined-outside-init
258
+ self.train_data_collator
259
+ )
260
+ return dataloader
261
+
262
  if self.args.sample_packing and self.args.eval_sample_packing is not False:
263
  eval_dataset = (
264
  eval_dataset if eval_dataset is not None else self.eval_dataset
 
289
  return self.accelerator.prepare_data_loader(
290
  DataLoader(eval_dataset, **dataloader_params)
291
  )
292
+
293
  return super().get_eval_dataloader(eval_dataset)
294
 
295
  def _get_bench_sampler(
 
672
  training_arguments_kwargs[
673
  "dataloader_prefetch_factor"
674
  ] = self.cfg.dataloader_prefetch_factor
675
+ if self.cfg.dataloader_drop_last is not None:
676
+ training_arguments_kwargs[
677
+ "dataloader_drop_last"
678
+ ] = self.cfg.dataloader_drop_last
679
+ elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
680
+ training_arguments_kwargs["dataloader_drop_last"] = True
681
 
682
  if self.cfg.val_set_size == 0:
683
  # no eval set, so don't eval
 
858
  eval_dataset=self.eval_dataset,
859
  args=training_args,
860
  data_collator=self.build_collator(training_args, **data_collator_kwargs),
861
+ eval_data_collator=self.build_collator(
862
+ training_args, is_eval=True, **data_collator_kwargs
863
+ ),
864
  bench_data_collator=transformers.DataCollatorForSeq2Seq(
865
  self.tokenizer,
866
  return_tensors="pt",
 
881
 
882
  return trainer
883
 
884
+ def build_collator(
885
+ self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
886
+ ):
887
  if training_args.pretraining:
888
  return None
889
 
890
  if self.cfg.model_config_type == "mamba":
891
  return MambaDataCollator(tokenizer=self.tokenizer)
892
 
893
+ use_batch_sampler_collator = False
894
+ if is_eval is False and training_args.sample_packing:
895
+ use_batch_sampler_collator = True
896
+ if is_eval and training_args.eval_sample_packing:
897
+ use_batch_sampler_collator = True
898
+
899
+ if use_batch_sampler_collator:
900
  return BatchSamplerDataCollatorForSeq2Seq(
901
  self.tokenizer,
902
  return_tensors="pt",