Spaces:
Sleeping
Sleeping
import sys | |
import os | |
import yaml | |
import random | |
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union | |
import numpy as np | |
from scipy import stats | |
import pandas as pd | |
import darts | |
from darts import models | |
from darts import metrics | |
from darts import TimeSeries | |
from pytorch_lightning.callbacks import Callback | |
from darts.logging import get_logger, raise_if_not | |
# for optuna callback | |
import warnings | |
import optuna | |
from optuna.storages._cached_storage import _CachedStorage | |
from optuna.storages._rdb.storage import RDBStorage | |
# Define key names of `Trial.system_attrs`. | |
_PRUNED_KEY = "ddp_pl:pruned" | |
_EPOCH_KEY = "ddp_pl:epoch" | |
with optuna._imports.try_import() as _imports: | |
import pytorch_lightning as pl | |
from pytorch_lightning import LightningModule | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import Callback | |
if not _imports.is_successful(): | |
Callback = object # type: ignore # NOQA | |
LightningModule = object # type: ignore # NOQA | |
Trainer = object # type: ignore # NOQA | |
def print_callback(study, trial, study_file=None): | |
# write output to a file | |
with open(study_file, "a") as f: | |
f.write(f"Current value: {trial.value}, Current params: {trial.params}\n") | |
f.write(f"Best value: {study.best_value}, Best params: {study.best_trial.params}\n") | |
def early_stopping_check(study, | |
trial, | |
study_file, | |
early_stopping_rounds=10): | |
""" | |
Early stopping callback for Optuna. | |
This function checks the current trial number and the best trial number. | |
""" | |
current_trial_number = trial.number | |
best_trial_number = study.best_trial.number | |
should_stop = (current_trial_number - best_trial_number) >= early_stopping_rounds | |
if should_stop: | |
with open(study_file, 'a') as f: | |
f.write('\nEarly stopping at trial {} (best trial: {})'.format(current_trial_number, best_trial_number)) | |
study.stop() | |
class LossLogger(Callback): | |
def __init__(self): | |
self.train_loss = [] | |
self.val_loss = [] | |
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
self.train_loss.append(float(trainer.callback_metrics["train_loss"])) | |
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
self.val_loss.append(float(trainer.callback_metrics["val_loss"])) | |
class PyTorchLightningPruningCallback(Callback): | |
"""PyTorch Lightning callback to prune unpromising trials. | |
See `the example <https://github.com/optuna/optuna-examples/blob/ | |
main/pytorch/pytorch_lightning_simple.py>`__ | |
if you want to add a pruning callback which observes accuracy. | |
Args: | |
trial: | |
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the | |
objective function. | |
monitor: | |
An evaluation metric for pruning, e.g., ``val_loss`` or | |
``val_acc``. The metrics are obtained from the returned dictionaries from e.g. | |
``pytorch_lightning.LightningModule.training_step`` or | |
``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on | |
how this dictionary is formatted. | |
""" | |
def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None: | |
super().__init__() | |
self._trial = trial | |
self.monitor = monitor | |
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: | |
# When the trainer calls `on_validation_end` for sanity check, | |
# do not call `trial.report` to avoid calling `trial.report` multiple times | |
# at epoch 0. The related page is | |
# https://github.com/PyTorchLightning/pytorch-lightning/issues/1391. | |
if trainer.sanity_checking: | |
return | |
epoch = pl_module.current_epoch | |
current_score = trainer.callback_metrics.get(self.monitor) | |
if current_score is None: | |
message = ( | |
"The metric '{}' is not in the evaluation logs for pruning. " | |
"Please make sure you set the correct metric name.".format(self.monitor) | |
) | |
warnings.warn(message) | |
return | |
self._trial.report(current_score, step=epoch) | |
if self._trial.should_prune(): | |
message = "Trial was pruned at epoch {}.".format(epoch) | |
raise optuna.TrialPruned(message) |