This branch hasnāt been merged, but I want to use optuna in my workflow. Although I have tried it, I want to confirm the usage. @sgugger (firstly thanks for the PR) could you please provide instructions on what changes do I need to make to make it work (like defining the search space and then getting results on them, and finding the best hyperparams). I want to confirm if Iām using it in the right manner. Also is the implementation complete ?
Hi there!
This is a work in progress so Iād hold on a tiny bit before starting using it (Iāll actually make some changes today). Iāll add an example in the PR once Iām done (hopefully by end of day) so you (and others) can start playing with it and give us potential feedback, but be prepared for some slight changes in the API as we polish it (we want to support other hp-search platforms such as Ray)
Thanks for the reply. Iāll look forward to the example and using it. Iāll hopefully try to contribute if I come across some rough edges. Trainer
changes a lot, my inherited trainer code breaks most of the time after each update, so Iām prepared for it .
Ok, done for today and prepared the road to support ray as well (not working right now though). There is an example on a regression problem in the README cause I didnāt want to launch my GPU setup. Will add a real example soon, but it should be enough to get you going.
Could you please tell where that README is ? I checked your recent commits on both trainer_optuna
branch and master
, didnāt see it.
Sorry, not README, I meant the PR first post.
I put a real example now.
What are the pros/cons of optuna VS ray?
Both work with the API. I havenāt used either long enough to have a strong opinion, but basically ray would be better if you have multiple GPUs and optuna might be better with just one, from what I understood.
FYI, this has been merged in master. Here is an example of use:
from nlp import load_dataset, load_metric
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
dataset = load_dataset('glue', 'mrpc')
metric = load_metric('glue', 'mrpc')
def encode(examples):
outputs = tokenizer(examples['sentence1'], examples['sentence2'], truncation=True)
return outputs
encoded_dataset = dataset.map(encode, batched=True)
# Won't be necessary when this PR is merged with master since the Trainer will do it automatically
encoded_dataset.set_format(columns=['attention_mask', 'input_ids', 'token_type_ids', 'label'])
def model_init():
return AutoModelForSequenceClassification.from_pretrained('bert-base-cased', return_dict=True)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = predictions.argmax(axis=-1)
return metric.compute(predictions=predictions, references=labels)
# Evaluate during training and a bit more often than the default to be able to prune bad trials early.
# Disabling tqdm is a matter of preference.
training_args = TrainingArguments("test", evaluate_during_training=True, eval_steps=500, disable_tqdm=True)
trainer = Trainer(
args=training_args,
data_collator=DataCollatorWithPadding(tokenizer),
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["validation"],
model_init=model_init,
compute_metrics=compute_metrics,
)
# Defaut objective is the sum of all metrics when metrics are provided, so we have to maximize it.
trainer.hyperparameter_search(direction="maximize")
This will use optuna or Ray Tune, depending on which you have installed. If you have both, it will use optuna by default, but you can pass backend="ray"
to use Ray Tune. Note that you need an installation from source of nlp to make the example work.
To customize the hyperparameter search space, you can pass a function hp_space
to this call. Here is an example if you want to search higher learning rates than the default with optuna:
def my_hp_space(trial):
return {
"learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True),
"num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
"seed": trial.suggest_int("seed", 1, 40),
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
}
trainer.hyperparameter_search(direction="maximize", hp_space=my_hp_space)
and ray:
def my_hp_space_ray(trial):
from ray import tune
return {
"learning_rate": tune.loguniform(1e-4, 1e-2),
"num_train_epochs": tune.choice(range(1, 6)),
"seed": tune.choice(range(1, 41)),
"per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
}
trainer.hyperparameter_search(direction="maximize", hp_space=my_hp_space)
If you want to customize the objective to minimize/maximize, pass along a function to compute_objective
:
def my_objective(metrics):
# Your elaborate computation here
return result_to_optimize
trainer.hyperparameter_search(direction="maximize", compute_objective=my_objective)
Thanks. I was following this PR. I wanted to know which type of hyperparams can be tuned with this approach? Does it work with Default ones only (training_args) ? What if we have custom param that we want to tune (for instance a lambda in an objective function) ?
The hyperparams you can tune must be in the TrainingArguments
you passed to your Trainer. If you have custom ones that are not in TrainingArguments
, just subclass TrainingArguments
and add them in your subclass.
The hp_space
function indicates the hyperparameter search space (see the code of the default for optuna or Ray in training_utils.py
and adapt it to your needs) and the compute_objective
function should return the objective to minize/maximize.
Thank you so much! But I have a problem when defining the Trainer. It said, āinit() got an unexpected keyword argument āmodel_initāā. Is the Trainer doesnāt recognize the āmodel_initā argument?
I think this error affect next error when I want to call the āhyperparameter_searchā method. It said, āāTrainerā object has no attribute āhyperparameter_searchāā.
What should I do? Very sorry for the very newbie question and Thankyou before.
This is new so you need an installation from source to use it. It will be in the next release coming soon otherwise.
Alright, Iām waiting for it!
FYI, You can pip install now to use this feature. No need to build from source.
Oh yeah thank you, It seems developed. But Iām still getting problem in hyperparameter_search
method. I defined my backend parameter to āoptunaā but the error said: You picked the optuna backend, but it is not installed. Use pip install optuna.
, though Iāve already pip-installed it before the hyperparameter_search
code line. The case was same when I defined the backend parameter into ārayā. Have I make a mistake? I run my code in Google Colab by the way.
It means that it is not installed in your current environment. If you are using notebooks, you have to restart the kernel. Python needs to reload the libraries to see which ones are available.
Oh yeah it has already worked. Thank you so much!
I wonder if Sylvain or others might have advice on how to make the hyperparameters search more efficient or manageable, time and resource-wise.
Iāve tried slimming down the dataset (500K rows to 90K rows), reducing the number of parameters to tune (to just 1, number of epochs) and changing the ādirectionā to āminimizeā instead of āmaximizeā.
Is there something else I can do, aside from further cutting down the size of the dataset? Iām running trials on Colab Pro with GPU/high-RAM enabled, and current version looks like itāll take about 7 hours (perfectly fine for others Iām sure).
I donāt suppose thereās an equivalent of RandomizedSearchCV for trainer?