diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index bc2b004fc..d44d6af81 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -94,7 +94,7 @@ class Train: optimizer: Union[str, Type[torch.optim.Optimizer]] quantiles: List[float] = field(default_factory=list) optimizer_args: dict = field(default_factory=dict) - scheduler: Optional[Type[torch.optim.lr_scheduler.OneCycleLR]] = None + scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None scheduler_args: dict = field(default_factory=dict) newer_samples_weight: float = 1.0 newer_samples_start: float = 0.0 @@ -104,6 +104,7 @@ class Train: n_data: int = field(init=False) loss_func_name: str = field(init=False) lr_finder_args: dict = field(default_factory=dict) + optimizer_state: dict = field(default_factory=dict) def __post_init__(self): # assert the uncertainty estimation params and then finalize the quantiles @@ -189,19 +190,53 @@ def set_optimizer(self): def set_scheduler(self): """ - Set the scheduler and scheduler args. + Set the scheduler and scheduler arg depending on the user selection. The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet. """ - self.scheduler = torch.optim.lr_scheduler.OneCycleLR - self.scheduler_args.update( - { - "pct_start": 0.3, - "anneal_strategy": "cos", - "div_factor": 10.0, - "final_div_factor": 10.0, - "three_phase": True, - } - ) + self.scheduler_args.clear() + if isinstance(self.scheduler, str): + if self.scheduler.lower() == "onecyclelr": + self.scheduler = torch.optim.lr_scheduler.OneCycleLR + self.scheduler_args.update( + { + "pct_start": 0.3, + "anneal_strategy": "cos", + "div_factor": 10.0, + "final_div_factor": 10.0, + "three_phase": True, + } + ) + elif self.scheduler.lower() == "steplr": + self.scheduler = torch.optim.lr_scheduler.StepLR + self.scheduler_args.update( + { + "step_size": 10, + "gamma": 0.1, + } + ) + elif self.scheduler.lower() == "exponentiallr": + self.scheduler = torch.optim.lr_scheduler.ExponentialLR + self.scheduler_args.update( + { + "gamma": 0.95, + } + ) + elif self.scheduler.lower() == "cosineannealinglr": + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR + self.scheduler_args.update( + { + "T_max": 50, + } + ) + else: + raise NotImplementedError(f"Scheduler {self.scheduler} is not supported.") + elif self.scheduler is None: + self.scheduler = torch.optim.lr_scheduler.ExponentialLR + self.scheduler_args.update( + { + "gamma": 0.95, + } + ) def set_lr_finder_args(self, dataset_size, num_batches): """ @@ -239,6 +274,9 @@ def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, re delay_weight = 1 return delay_weight + def set_optimizer_state(self, optimizer_state: dict): + self.optimizer_state = optimizer_state + @dataclass class Trend: diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 85939955e..3cc386dc2 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -298,6 +298,20 @@ class NeuralProphet: >>> m = NeuralProphet(collect_metrics=["MSE", "MAE", "RMSE"]) >>> # use custorm torchmetrics names >>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError", + scheduler : str, torch.optim.lr_scheduler._LRScheduler + Type of learning rate scheduler to use. + + Options + * (default) ``OneCycleLR``: One Cycle Learning Rate scheduler + * ``StepLR``: Step Learning Rate scheduler + * ``ExponentialLR``: Exponential Learning Rate scheduler + * ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler + + Examples + -------- + >>> from neuralprophet import NeuralProphet + >>> # Step Learning Rate scheduler + >>> m = NeuralProphet(scheduler="StepLR") COMMENT Uncertainty Estimation @@ -432,6 +446,7 @@ def __init__( batch_size: Optional[int] = None, loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] = "SmoothL1Loss", optimizer: Union[str, Type[torch.optim.Optimizer]] = "AdamW", + scheduler: Optional[str] = "onecyclelr", newer_samples_weight: float = 2, newer_samples_start: float = 0.0, quantiles: List[float] = [], @@ -505,6 +520,7 @@ def __init__( self.config_train = configure.Train( quantiles=quantiles, learning_rate=learning_rate, + scheduler=scheduler, epochs=epochs, batch_size=batch_size, loss_func=loss_func, @@ -916,6 +932,7 @@ def fit( continue_training: bool = False, num_workers: int = 0, deterministic: bool = False, + scheduler: Optional[str] = None, ): """Train, and potentially evaluate model. @@ -967,14 +984,38 @@ def fit( Note: using multiple workers and therefore distributed training might significantly increase the training time since each batch needs to be copied to each worker for each epoch. Keeping all data on the main process might be faster for most datasets. + scheduler : str + Type of learning rate scheduler to use for continued training. If None, uses ExponentialLR as + default as specified in the model config. + Options + * ``StepLR``: Step Learning Rate scheduler + * ``ExponentialLR``: Exponential Learning Rate scheduler + * ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler Returns ------- pd.DataFrame metrics with training and potentially evaluation metrics """ - if self.fitted: - raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.") + if self.fitted and not continue_training: + raise RuntimeError( + "Model has been fitted already. If you want to continue training please set the flag continue_training." + ) + + if continue_training and epochs is None: + raise ValueError("Continued training requires setting the number of epochs to train for.") + + if continue_training: + if scheduler is not None: + self.config_train.scheduler = scheduler + else: + self.config_train.scheduler = None + self.config_train.set_scheduler() + + if scheduler is not None and not continue_training: + log.warning( + "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model." + ) # Configuration if epochs is not None: @@ -1060,8 +1101,9 @@ def fit( or any(value != 1 for value in self.num_seasonalities_modelled_dict.values()) ) - if self.fitted is True and not continue_training: - log.error("Model has already been fitted. Re-fitting may break or produce different results.") + if continue_training and self.metrics_logger.checkpoint_path is None: + log.error("Continued training requires checkpointing in model to continue from last epoch.") + self.max_lags = df_utils.get_max_num_lags( n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors ) @@ -2661,23 +2703,23 @@ def _init_train_loader(self, df, num_workers=0): torch DataLoader """ df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be avoided? - # if not self.fitted: - self.config_normalization.init_data_params( - df=df, - config_lagged_regressors=self.config_lagged_regressors, - config_regressors=self.config_regressors, - config_events=self.config_events, - config_seasonality=self.config_seasonality, - ) + if not self.fitted: + self.config_normalization.init_data_params( + df=df, + config_lagged_regressors=self.config_lagged_regressors, + config_regressors=self.config_regressors, + config_events=self.config_events, + config_seasonality=self.config_seasonality, + ) df = _normalize(df=df, config_normalization=self.config_normalization) - # if not self.fitted: - if self.config_trend.changepoints is not None: - # scale user-specified changepoint times - df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) + if not self.fitted: + if self.config_trend.changepoints is not None: + # scale user-specified changepoint times + df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) - df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) - self.config_trend.changepoints = df_normalized["t"].values # type: ignore + df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) + self.config_trend.changepoints = df_normalized["t"].values # type: ignore # df_merged, _ = df_utils.join_dataframes(df) # df_merged = df_merged.sort_values("ds") @@ -2765,12 +2807,24 @@ def _train( # Internal flag to check if validation is enabled validation_enabled = df_val is not None - # Init the model, if not continue from checkpoint + # Load model and optimizer state from checkpoint if continue_training is True if continue_training: - raise NotImplementedError( - "Continuing training from checkpoint is not implemented yet. This feature is planned for one of the \ - upcoming releases." - ) + checkpoint_path = self.metrics_logger.checkpoint_path + checkpoint = torch.load(checkpoint_path) + + checkpoint_epoch = checkpoint["epoch"] if "epoch" in checkpoint else 0 + previous_epoch = max(self.model.current_epoch, checkpoint_epoch) + + # Set continue_training flag in model to update scheduler correctly + self.model.continue_training = True + self.model.start_epoch = previous_epoch + + # Adjust epochs + new_total_epochs = previous_epoch + self.config_train.epochs + self.config_train.epochs = new_total_epochs + + self.config_train.set_optimizer_state(checkpoint["optimizer_states"][0]) + else: self.model = self._init_model() @@ -2852,8 +2906,18 @@ def _train( if not metrics_enabled: return None + # Return metrics collected in logger as dataframe - metrics_df = pd.DataFrame(self.metrics_logger.history) + if self.metrics_logger.history is not None: + # avoid array mismatch when continuing training + history = self.metrics_logger.history + max_length = max(len(lst) for lst in history.values()) + for key in history: + while len(history[key]) < max_length: + history[key].append(None) + metrics_df = pd.DataFrame(history) + else: + metrics_df = pd.DataFrame() return metrics_df def restore_trainer(self, accelerator: Optional[str] = None): diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index a4fbfee3a..c30594e29 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -63,6 +63,8 @@ def __init__( num_seasonalities_modelled: int = 1, num_seasonalities_modelled_dict: dict = None, meta_used_in_model: bool = False, + continue_training: bool = False, + start_epoch: int = 0, ): """ Parameters @@ -312,6 +314,10 @@ def __init__( else: self.config_regressors.regressors = None + # Continued training + self.continue_training = continue_training + self.start_epoch = start_epoch + @property def ar_weights(self) -> torch.Tensor: """sets property auto-regression weights for regularization. Update if AR is modelled differently""" @@ -865,12 +871,34 @@ def configure_optimizers(self): optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args) # Scheduler - lr_scheduler = self._scheduler( - optimizer, - max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, - **self.config_train.scheduler_args, - ) + self._scheduler = self.config_train.scheduler + + if self.continue_training: + optimizer.load_state_dict(self.config_train.optimizer_state) + + # Update initial learning rate to the last learning rate for continued training + last_lr = float(optimizer.param_groups[0]["lr"]) # Ensure it's a float + + for param_group in optimizer.param_groups: + param_group["initial_lr"] = (last_lr,) + + if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: + log.warning("OneCycleLR scheduler is not supported for continued training. Switching to ExponentialLR") + self._scheduler = torch.optim.lr_scheduler.ExponentialLR + self.config_train.scheduler_args = {"gamma": 0.95} + + if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: + lr_scheduler = self._scheduler( + optimizer, + max_lr=self.learning_rate, + total_steps=self.trainer.estimated_stepping_batches, + **self.config_train.scheduler_args, + ) + else: + lr_scheduler = self._scheduler( + optimizer, + **self.config_train.scheduler_args, + ) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} diff --git a/tests/test_utils.py b/tests/test_utils.py index a327f3122..3b93721bf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,6 +21,7 @@ YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") NROWS = 512 EPOCHS = 10 +ADDITIONAL_EPOCHS = 5 LR = 1.0 BATCH_SIZE = 64 @@ -101,17 +102,47 @@ def test_save_load_io(): pd.testing.assert_frame_equal(forecast, forecast3) -# TODO: add functionality to continue training -# def test_continue_training(): -# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) -# m = NeuralProphet( -# epochs=EPOCHS, -# batch_size=BATCH_SIZE, -# learning_rate=LR, -# n_lags=6, -# n_forecasts=3, -# n_changepoints=0, -# ) -# metrics = m.fit(df, freq="D") -# metrics2 = m.fit(df, freq="D", continue_training=True) -# assert metrics1["Loss"].sum() >= metrics2["Loss"].sum() +def test_continue_training(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) + assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +def test_continue_training_with_scheduler_selection(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + # Continue training with StepLR + metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") + assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +def test_save_load_continue_training(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + save(m, "test_model.pt") + m2 = load("test_model.pt") + metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") + assert metrics["Loss"].min() >= metrics2["Loss"].min()