Trainer
ORTTrainer
class optimum.onnxruntime.ORTTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None feature: str = 'default' args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None model_init: typing.Callable[[], transformers.modeling_utils.PreTrainedModel] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None onnx_model_path: typing.Union[str, os.PathLike] = None )
How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior.
evaluate
< source >( eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' inference_with_ort: bool = False )
Run evaluation within ONNX Runtime or PyTorch backend and returns metrics.(Overriden from Trainer.evaluate()
)
evaluation_loop_ort
< source >( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )
Prediction/evaluation loop, shared by ORTTrainer.evaluate()
and ORTTrainer.predict()
.
Works both with or without labels.
predict
< source >( test_dataset: Dataset ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'test' inference_with_ort: bool = False )
Run prediction within ONNX Runtime or PyTorch backend and returns predictions and potential metrics.
(Overriden from Trainer.predict()
)
prediction_loop_ort
< source >( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )
Prediction/evaluation loop, shared by Trainer.evaluate()
and Trainer.predict()
.
Works both with or without labels.
train
< source >( resume_from_checkpoint: typing.Union[bool, str, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None ignore_keys_for_eval: typing.Optional[typing.List[str]] = None **kwargs )
Parameters
-
resume_from_checkpoint (
str
orbool
, optional) — If astr
, local path to a saved checkpoint as saved by a previous instance ofTrainer
. If abool
and equalsTrue
, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer
. If present, training will resume from the model/optimizer/scheduler states loaded here. -
trial (
optuna.Trial
orDict[str, Any]
, optional) — The trial run or the hyperparameter dictionary for hyperparameter search. -
ignore_keys_for_eval (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. kwargs — Additional keyword arguments used to hide deprecated arguments
Main onnxruntime training entry point.
ORTSeq2SeqTrainer
class optimum.onnxruntime.ORTSeq2SeqTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None feature: str = 'default' args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None model_init: typing.Callable[[], transformers.modeling_utils.PreTrainedModel] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None onnx_model_path: typing.Union[str, os.PathLike] = None )
evaluate
< source >( eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' max_length: typing.Optional[int] = None num_beams: typing.Optional[int] = None inference_with_ort: bool = False )
Parameters
-
eval_dataset (
Dataset
, optional) — Pass a dataset if you wish to overrideself.eval_dataset
. If it is andatasets.Dataset
, columns not accepted by themodel.forward()
method are automatically removed. It must implement the__len__
method. -
ignore_keys (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. -
metric_key_prefix (
str
, optional, defaults to"eval"
) — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is"eval"
(default) -
max_length (
int
, optional) — The maximum target length to use when predicting with the generate method. -
num_beams (
int
, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. -
inference_with_ort (
bool
, optional) — Whether enable inference within ONNX Runtime backend. The inference will be done within PyTorch by default.
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init compute_metrics
argument).
You can also subclass and override this method to inject custom behavior.
predict
< source >( test_dataset: Dataset ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' max_length: typing.Optional[int] = None num_beams: typing.Optional[int] = None inference_with_ort: bool = False )
Parameters
-
test_dataset (
Dataset
) — Dataset to run the predictions on. If it is andatasets.Dataset
, columns not accepted by themodel.forward()
method are automatically removed. Has to implement the method__len__
-
ignore_keys (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. -
metric_key_prefix (
str
, optional, defaults to"eval"
) — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is"eval"
(default) -
max_length (
int
, optional) — The maximum target length to use when predicting with the generate method. -
num_beams (
int
, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. -
inference_with_ort (
bool
, optional) — Whether enable inference within ONNX Runtime backend. The inference will be done within PyTorch by default.
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in evaluate()
.
If your predictions or labels have different sequence lengths (for instance because you’re doing dynamic padding in a token classification task) the predictions will be padded (on the right) to allow for concatenation into one array. The padding index is -100.
Returns: NamedTuple A namedtuple with the following keys:
- predictions (
np.ndarray
): The predictions ontest_dataset
. - label_ids (
np.ndarray
, optional): The labels (if the dataset contained some). - metrics (
Dict[str, float]
, optional): The potential dictionary of metrics (if the dataset contained labels).