diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4365a5b2e5..2c4603ea87 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -180,6 +180,7 @@ def __init__( mlflow_logging_config: Optional[dict] = None, flatten_imports: Sequence[str] = ('llmfoundry',), final_register_only: bool = False, + register_wait_seconds: int = 7200, ): _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite @@ -193,6 +194,7 @@ def __init__( self.using_peft = False self.final_register_only = final_register_only + self.register_wait_seconds = register_wait_seconds self.mlflow_registered_model_name = mlflow_registered_model_name if self.final_register_only and self.mlflow_registered_model_name is None: @@ -325,7 +327,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.using_peft = composer_model.using_peft elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. - timeout = 3600 + timeout = self.register_wait_seconds wait_start = time.time() while not self._all_register_processes_done(state.device): wait_time = time.time() - wait_start