Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Enable continuation of training #1605

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
75 changes: 54 additions & 21 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ def fit(
pd.DataFrame
metrics with training and potentially evaluation metrics
"""
if self.fitted:
if self.fitted and not continue_training:
Constantin343 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.")

# Configuration
Expand Down Expand Up @@ -1067,6 +1067,10 @@ def fit(

if self.fitted is True and not continue_training:
log.error("Model has already been fitted. Re-fitting may break or produce different results.")

if continue_training and self.metrics_logger.checkpoint_path is None:
log.error("Continued training requires checkpointing in model.")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain what necessitates this (for my understanding)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was that it makes sense to continue from the checkpoint, but probably it's not necessary. All the necessary parameters should still be available in the model itself. I will adapt it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some testing and it seems like checkpoint is indeed necessary to correctly continue training with the pytroch-lighting trainer. I can get it to run without checkpointing but fitting again always leads to a complete restart of the training.

Maybe there is some workaround, but I would suggest keeping it like this as continued training always goes hand in hand with checkpointing in pytroch-lighting.


self.max_lags = df_utils.get_max_num_lags(
n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors
)
Expand Down Expand Up @@ -2666,23 +2670,24 @@ def _init_train_loader(self, df, num_workers=0):
torch DataLoader
"""
df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be avoided?
# if not self.fitted:
self.config_normalization.init_data_params(
df=df,
config_lagged_regressors=self.config_lagged_regressors,
config_regressors=self.config_regressors,
config_events=self.config_events,
config_seasonality=self.config_seasonality,
)
if not self.fitted:
self.config_normalization.init_data_params(
df=df,
config_lagged_regressors=self.config_lagged_regressors,
config_regressors=self.config_regressors,
config_events=self.config_events,
config_seasonality=self.config_seasonality,
)

print("Changepoints:", self.config_trend.changepoints)
df = _normalize(df=df, config_normalization=self.config_normalization)
# if not self.fitted:
if self.config_trend.changepoints is not None:
# scale user-specified changepoint times
df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)})
if not self.fitted:
if self.config_trend.changepoints is not None:
# scale user-specified changepoint times
df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)})

df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization)
self.config_trend.changepoints = df_normalized["t"].values # type: ignore
df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization)
self.config_trend.changepoints = df_normalized["t"].values # type: ignore

# df_merged, _ = df_utils.join_dataframes(df)
# df_merged = df_merged.sort_values("ds")
Expand Down Expand Up @@ -2770,12 +2775,36 @@ def _train(
# Internal flag to check if validation is enabled
validation_enabled = df_val is not None

# Init the model, if not continue from checkpoint
# Load model and optimizer state from checkpoint if continue_training is True
if continue_training:
raise NotImplementedError(
"Continuing training from checkpoint is not implemented yet. This feature is planned for one of the \
upcoming releases."
)
checkpoint_path = self.metrics_logger.checkpoint_path
checkpoint = torch.load(checkpoint_path)

# Load model state
self.model.load_state_dict(checkpoint["state_dict"])

# Set continue_training flag in model to update scheduler correctly
self.model.continue_training = True

previous_epoch = checkpoint["epoch"]
# Adjust epochs
if self.config_train.epochs:
additional_epochs = self.config_train.epochs
else:
additional_epochs = previous_epoch
Constantin343 marked this conversation as resolved.
Show resolved Hide resolved
# Get the number of epochs already trained
new_total_epochs = previous_epoch + additional_epochs
self.config_train.epochs = new_total_epochs

# Reinitialize optimizer with loaded model parameters
optimizer = torch.optim.AdamW(self.model.parameters())
Constantin343 marked this conversation as resolved.
Show resolved Hide resolved

# Load optimizer state
if "optimizer_states" in checkpoint and checkpoint["optimizer_states"]:
optimizer.load_state_dict(checkpoint["optimizer_states"][0])

self.config_train.optimizer = optimizer

else:
self.model = self._init_model()

Expand Down Expand Up @@ -2859,8 +2888,12 @@ def _train(

if not metrics_enabled:
return None

# Return metrics collected in logger as dataframe
metrics_df = pd.DataFrame(self.metrics_logger.history)
if self.metrics_logger.history is not None:
metrics_df = pd.DataFrame(self.metrics_logger.history)
else:
metrics_df = pd.DataFrame()
return metrics_df

def restore_trainer(self, accelerator: Optional[str] = None):
Expand Down
23 changes: 17 additions & 6 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
num_seasonalities_modelled: int = 1,
num_seasonalities_modelled_dict: dict = None,
meta_used_in_model: bool = False,
continue_training: bool = False,
):
"""
Parameters
Expand Down Expand Up @@ -306,6 +307,9 @@ def __init__(
else:
self.config_regressors.regressors = None

# Continued training
self.continue_training = continue_training

@property
def ar_weights(self) -> torch.Tensor:
"""sets property auto-regression weights for regularization. Update if AR is modelled differently"""
Expand Down Expand Up @@ -865,12 +869,19 @@ def configure_optimizers(self):
optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args)

# Scheduler
lr_scheduler = self._scheduler(
optimizer,
max_lr=self.learning_rate,
total_steps=self.trainer.estimated_stepping_batches,
**self.config_train.scheduler_args,
)
if self.continue_training:
Constantin343 marked this conversation as resolved.
Show resolved Hide resolved
# Update initial learning rate to the last learning rate for continued training
last_lr = optimizer.param_groups[0]["lr"]
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
for param_group in optimizer.param_groups:
param_group["initial_lr"] = last_lr
else:
lr_scheduler = self._scheduler(
optimizer,
max_lr=self.learning_rate,
total_steps=self.trainer.estimated_stepping_batches,
**self.config_train.scheduler_args,
)

return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

Expand Down
27 changes: 13 additions & 14 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,16 @@ def test_save_load_io():
pd.testing.assert_frame_equal(forecast, forecast3)


# TODO: add functionality to continue training
# def test_continue_training():
# df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
# m = NeuralProphet(
# epochs=EPOCHS,
# batch_size=BATCH_SIZE,
# learning_rate=LR,
# n_lags=6,
# n_forecasts=3,
# n_changepoints=0,
# )
# metrics = m.fit(df, freq="D")
# metrics2 = m.fit(df, freq="D", continue_training=True)
# assert metrics1["Loss"].sum() >= metrics2["Loss"].sum()
def test_continue_training():
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
n_lags=6,
n_forecasts=3,
n_changepoints=0,
)
metrics = m.fit(df, checkpointing=True, freq="D")
metrics2 = m.fit(df, freq="D", continue_training=True)
assert metrics["Loss"].min() >= metrics2["Loss"].min()
Loading