Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Enable continuation of training #1605

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
62 changes: 50 additions & 12 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
optimizer: Union[str, Type[torch.optim.Optimizer]]
quantiles: List[float] = field(default_factory=list)
optimizer_args: dict = field(default_factory=dict)
scheduler: Optional[Type[torch.optim.lr_scheduler.OneCycleLR]] = None
scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None
scheduler_args: dict = field(default_factory=dict)
newer_samples_weight: float = 1.0
newer_samples_start: float = 0.0
Expand All @@ -104,6 +104,7 @@
n_data: int = field(init=False)
loss_func_name: str = field(init=False)
lr_finder_args: dict = field(default_factory=dict)
optimizer_state: dict = field(default_factory=dict)

def __post_init__(self):
# assert the uncertainty estimation params and then finalize the quantiles
Expand Down Expand Up @@ -189,19 +190,53 @@

def set_scheduler(self):
"""
Set the scheduler and scheduler args.
Set the scheduler and scheduler arg depending on the user selection.
The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet.
"""
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
self.scheduler_args.update(
{
"pct_start": 0.3,
"anneal_strategy": "cos",
"div_factor": 10.0,
"final_div_factor": 10.0,
"three_phase": True,
}
)
self.scheduler_args.clear()
if isinstance(self.scheduler, str):
if self.scheduler.lower() == "onecyclelr":
self.scheduler = torch.optim.lr_scheduler.OneCycleLR

Check failure on line 199 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "scheduler" for class "Train*"   Type "type[OneCycleLR]" is incompatible with type "type[_LRScheduler] | None"     "type[OneCycleLR]" is incompatible with "type[_LRScheduler]"     Type "type[OneCycleLR]" is incompatible with type "type[_LRScheduler]"     "type[type]" is incompatible with "type[None]" (reportAttributeAccessIssue)
self.scheduler_args.update(
{
"pct_start": 0.3,
"anneal_strategy": "cos",
"div_factor": 10.0,
"final_div_factor": 10.0,
"three_phase": True,
}
)
elif self.scheduler.lower() == "steplr":
self.scheduler = torch.optim.lr_scheduler.StepLR

Check failure on line 210 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "scheduler" for class "Train*"   Type "type[StepLR]" is incompatible with type "type[_LRScheduler] | None"     "type[StepLR]" is incompatible with "type[_LRScheduler]"     Type "type[StepLR]" is incompatible with type "type[_LRScheduler]"     "type[type]" is incompatible with "type[None]" (reportAttributeAccessIssue)
self.scheduler_args.update(
{
"step_size": 10,
"gamma": 0.1,
}
)
elif self.scheduler.lower() == "exponentiallr":
self.scheduler = torch.optim.lr_scheduler.ExponentialLR

Check failure on line 218 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "scheduler" for class "Train*"   Type "type[ExponentialLR]" is incompatible with type "type[_LRScheduler] | None"     "type[ExponentialLR]" is incompatible with "type[_LRScheduler]"     Type "type[ExponentialLR]" is incompatible with type "type[_LRScheduler]"     "type[type]" is incompatible with "type[None]" (reportAttributeAccessIssue)
self.scheduler_args.update(

Check warning on line 219 in neuralprophet/configure.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/configure.py#L217-L219

Added lines #L217 - L219 were not covered by tests
{
"gamma": 0.95,
}
)
elif self.scheduler.lower() == "cosineannealinglr":
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR

Check failure on line 225 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "scheduler" for class "Train*"   Type "type[CosineAnnealingLR]" is incompatible with type "type[_LRScheduler] | None"     "type[CosineAnnealingLR]" is incompatible with "type[_LRScheduler]"     Type "type[CosineAnnealingLR]" is incompatible with type "type[_LRScheduler]"     "type[type]" is incompatible with "type[None]" (reportAttributeAccessIssue)
self.scheduler_args.update(

Check warning on line 226 in neuralprophet/configure.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/configure.py#L224-L226

Added lines #L224 - L226 were not covered by tests
{
"T_max": 50,
}
)
else:
raise NotImplementedError(f"Scheduler {self.scheduler} is not supported.")

Check warning on line 232 in neuralprophet/configure.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/configure.py#L232

Added line #L232 was not covered by tests
elif self.scheduler is None:
self.scheduler = torch.optim.lr_scheduler.ExponentialLR

Check failure on line 234 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "scheduler" for class "Train*"   Type "type[ExponentialLR]" is incompatible with type "type[_LRScheduler] | None"     "type[ExponentialLR]" is incompatible with "type[_LRScheduler]"     Type "type[ExponentialLR]" is incompatible with type "type[_LRScheduler]"     "type[type]" is incompatible with "type[None]" (reportAttributeAccessIssue)
self.scheduler_args.update(
{
"gamma": 0.95,
}
)

def set_lr_finder_args(self, dataset_size, num_batches):
"""
Expand Down Expand Up @@ -239,6 +274,9 @@
delay_weight = 1
return delay_weight

def set_optimizer_state(self, optimizer_state: dict):
self.optimizer_state = optimizer_state


@dataclass
class Trend:
Expand Down Expand Up @@ -304,7 +342,7 @@
log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local))
self.trend_global_local = "global"

if self.trend_local_reg < 0:

Check failure on line 345 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg))
self.trend_local_reg = False

Expand Down Expand Up @@ -353,13 +391,13 @@
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local))
self.global_local = "global"

self.periods = OrderedDict(

Check failure on line 394 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

No overloads for "__init__" match the provided arguments (reportCallIssue)
{

Check failure on line 395 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "dict[str, Season]" cannot be assigned to parameter "iterable" of type "Iterable[list[bytes]]" in function "__init__" (reportArgumentType)
"yearly": Season(
resolution=6,
period=365.25,
arg=self.yearly_arg,
global_local=(

Check failure on line 400 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.yearly_global_local
if self.yearly_global_local in ["global", "local"]
else self.global_local
Expand Down
112 changes: 88 additions & 24 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@
>>> m = NeuralProphet(collect_metrics=["MSE", "MAE", "RMSE"])
>>> # use custorm torchmetrics names
>>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError",
scheduler : str, torch.optim.lr_scheduler._LRScheduler
Type of learning rate scheduler to use.

Options
* (default) ``OneCycleLR``: One Cycle Learning Rate scheduler
* ``StepLR``: Step Learning Rate scheduler
* ``ExponentialLR``: Exponential Learning Rate scheduler
* ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler

Examples
--------
>>> from neuralprophet import NeuralProphet
>>> # Step Learning Rate scheduler
>>> m = NeuralProphet(scheduler="StepLR")

COMMENT
Uncertainty Estimation
Expand Down Expand Up @@ -432,6 +446,7 @@
batch_size: Optional[int] = None,
loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] = "SmoothL1Loss",
optimizer: Union[str, Type[torch.optim.Optimizer]] = "AdamW",
scheduler: Optional[str] = "onecyclelr",
newer_samples_weight: float = 2,
newer_samples_start: float = 0.0,
quantiles: List[float] = [],
Expand Down Expand Up @@ -505,6 +520,7 @@
self.config_train = configure.Train(
quantiles=quantiles,
learning_rate=learning_rate,
scheduler=scheduler,
epochs=epochs,
batch_size=batch_size,
loss_func=loss_func,
Expand Down Expand Up @@ -916,6 +932,7 @@
continue_training: bool = False,
num_workers: int = 0,
deterministic: bool = False,
scheduler: Optional[str] = None,
):
"""Train, and potentially evaluate model.

Expand Down Expand Up @@ -967,14 +984,38 @@
Note: using multiple workers and therefore distributed training might significantly increase
the training time since each batch needs to be copied to each worker for each epoch. Keeping
all data on the main process might be faster for most datasets.
scheduler : str
Type of learning rate scheduler to use for continued training. If None, uses ExponentialLR as
default as specified in the model config.
Options
* ``StepLR``: Step Learning Rate scheduler
* ``ExponentialLR``: Exponential Learning Rate scheduler
* ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler

Returns
-------
pd.DataFrame
metrics with training and potentially evaluation metrics
"""
if self.fitted:
raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.")
if self.fitted and not continue_training:
Constantin343 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
"Model has been fitted already. If you want to continue training please set the flag continue_training."
)

if continue_training and epochs is None:
raise ValueError("Continued training requires setting the number of epochs to train for.")

Check warning on line 1006 in neuralprophet/forecaster.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/forecaster.py#L1006

Added line #L1006 was not covered by tests

if continue_training:
if scheduler is not None:
self.config_train.scheduler = scheduler
else:
self.config_train.scheduler = None
self.config_train.set_scheduler()

if scheduler is not None and not continue_training:
log.warning(

Check warning on line 1016 in neuralprophet/forecaster.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/forecaster.py#L1016

Added line #L1016 was not covered by tests
"Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model."
)

# Configuration
if epochs is not None:
Expand Down Expand Up @@ -1060,8 +1101,9 @@
or any(value != 1 for value in self.num_seasonalities_modelled_dict.values())
)

if self.fitted is True and not continue_training:
log.error("Model has already been fitted. Re-fitting may break or produce different results.")
if continue_training and self.metrics_logger.checkpoint_path is None:
log.error("Continued training requires checkpointing in model to continue from last epoch.")

Check warning on line 1105 in neuralprophet/forecaster.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/forecaster.py#L1105

Added line #L1105 was not covered by tests

self.max_lags = df_utils.get_max_num_lags(
n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors
)
Expand Down Expand Up @@ -2661,23 +2703,23 @@
torch DataLoader
"""
df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be avoided?
# if not self.fitted:
self.config_normalization.init_data_params(
df=df,
config_lagged_regressors=self.config_lagged_regressors,
config_regressors=self.config_regressors,
config_events=self.config_events,
config_seasonality=self.config_seasonality,
)
if not self.fitted:
self.config_normalization.init_data_params(
df=df,
config_lagged_regressors=self.config_lagged_regressors,
config_regressors=self.config_regressors,
config_events=self.config_events,
config_seasonality=self.config_seasonality,
)

df = _normalize(df=df, config_normalization=self.config_normalization)
# if not self.fitted:
if self.config_trend.changepoints is not None:
# scale user-specified changepoint times
df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)})
if not self.fitted:
if self.config_trend.changepoints is not None:
# scale user-specified changepoint times
df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)})

df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization)
self.config_trend.changepoints = df_normalized["t"].values # type: ignore
df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization)
self.config_trend.changepoints = df_normalized["t"].values # type: ignore

# df_merged, _ = df_utils.join_dataframes(df)
# df_merged = df_merged.sort_values("ds")
Expand Down Expand Up @@ -2765,12 +2807,24 @@
# Internal flag to check if validation is enabled
validation_enabled = df_val is not None

# Init the model, if not continue from checkpoint
# Load model and optimizer state from checkpoint if continue_training is True
if continue_training:
raise NotImplementedError(
"Continuing training from checkpoint is not implemented yet. This feature is planned for one of the \
upcoming releases."
)
checkpoint_path = self.metrics_logger.checkpoint_path
checkpoint = torch.load(checkpoint_path)

checkpoint_epoch = checkpoint["epoch"] if "epoch" in checkpoint else 0
previous_epoch = max(self.model.current_epoch, checkpoint_epoch)

# Set continue_training flag in model to update scheduler correctly
self.model.continue_training = True
self.model.start_epoch = previous_epoch

# Adjust epochs
new_total_epochs = previous_epoch + self.config_train.epochs
self.config_train.epochs = new_total_epochs

self.config_train.set_optimizer_state(checkpoint["optimizer_states"][0])

else:
self.model = self._init_model()

Expand Down Expand Up @@ -2852,8 +2906,18 @@

if not metrics_enabled:
return None

# Return metrics collected in logger as dataframe
metrics_df = pd.DataFrame(self.metrics_logger.history)
if self.metrics_logger.history is not None:
# avoid array mismatch when continuing training
history = self.metrics_logger.history
max_length = max(len(lst) for lst in history.values())
for key in history:
while len(history[key]) < max_length:
history[key].append(None)
metrics_df = pd.DataFrame(history)
else:
metrics_df = pd.DataFrame()

Check warning on line 2920 in neuralprophet/forecaster.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/forecaster.py#L2920

Added line #L2920 was not covered by tests
return metrics_df

def restore_trainer(self, accelerator: Optional[str] = None):
Expand Down
40 changes: 34 additions & 6 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
num_seasonalities_modelled: int = 1,
num_seasonalities_modelled_dict: dict = None,
meta_used_in_model: bool = False,
continue_training: bool = False,
start_epoch: int = 0,
):
"""
Parameters
Expand Down Expand Up @@ -312,6 +314,10 @@
else:
self.config_regressors.regressors = None

# Continued training
self.continue_training = continue_training
self.start_epoch = start_epoch

@property
def ar_weights(self) -> torch.Tensor:
"""sets property auto-regression weights for regularization. Update if AR is modelled differently"""
Expand Down Expand Up @@ -865,12 +871,34 @@
optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args)

# Scheduler
lr_scheduler = self._scheduler(
optimizer,
max_lr=self.learning_rate,
total_steps=self.trainer.estimated_stepping_batches,
**self.config_train.scheduler_args,
)
self._scheduler = self.config_train.scheduler

if self.continue_training:
Constantin343 marked this conversation as resolved.
Show resolved Hide resolved
optimizer.load_state_dict(self.config_train.optimizer_state)

# Update initial learning rate to the last learning rate for continued training
last_lr = float(optimizer.param_groups[0]["lr"]) # Ensure it's a float

for param_group in optimizer.param_groups:
param_group["initial_lr"] = (last_lr,)

if self._scheduler == torch.optim.lr_scheduler.OneCycleLR:
log.warning("OneCycleLR scheduler is not supported for continued training. Switching to ExponentialLR")
self._scheduler = torch.optim.lr_scheduler.ExponentialLR
self.config_train.scheduler_args = {"gamma": 0.95}

Check warning on line 888 in neuralprophet/time_net.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/time_net.py#L886-L888

Added lines #L886 - L888 were not covered by tests

if self._scheduler == torch.optim.lr_scheduler.OneCycleLR:
lr_scheduler = self._scheduler(
optimizer,
max_lr=self.learning_rate,
total_steps=self.trainer.estimated_stepping_batches,
**self.config_train.scheduler_args,
)
else:
lr_scheduler = self._scheduler(
optimizer,
**self.config_train.scheduler_args,
)

return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

Expand Down
Loading
Loading