diff --git a/nmdc_automation/workflow_automation/watch_nmdc.py b/nmdc_automation/workflow_automation/watch_nmdc.py index 852a470..cd5b630 100644 --- a/nmdc_automation/workflow_automation/watch_nmdc.py +++ b/nmdc_automation/workflow_automation/watch_nmdc.py @@ -87,7 +87,9 @@ def write_metadata_if_not_exists(self, job: WorkflowJob)->Path: class JobManager: - def __init__(self, config, file_handler, init_cache: bool = True): + """ JobManager class for managing WorkflowJob objects """ + def __init__(self, config: SiteConfig, file_handler: FileHandler, init_cache: bool = True): + """ Initialize the JobManager with a Config object and a FileHandler object """ self.config = config self.file_handler = file_handler self._job_cache = [] @@ -98,18 +100,22 @@ def __init__(self, config, file_handler, init_cache: bool = True): @property def job_cache(self)-> List[WorkflowJob]: + """ Get the job cache """ return self._job_cache @job_cache.setter - def job_cache(self, value): + def job_cache(self, value) -> None: + """ Set the job cache """ self._job_cache = value - def job_checkpoint(self): + def job_checkpoint(self) -> Dict[str, Any]: + """ Get the state data for all jobs """ jobs = [wfjob.workflow.state for wfjob in self.job_cache] data = {"jobs": jobs} return data - def save_checkpoint(self): + def save_checkpoint(self) -> None: + """ Save jobs to state data """ data = self.job_checkpoint() self.file_handler.write_state(data) @@ -131,12 +137,13 @@ def get_workflow_jobs_from_state(self)-> List[WorkflowJob]: return wf_job_list - def find_job_by_opid(self, opid): + def find_job_by_opid(self, opid) -> Optional[WorkflowJob]: + """ Find a job by operation id """ return next((job for job in self.job_cache if job.opid == opid), None) def prepare_and_cache_new_job(self, new_job: WorkflowJob, opid: str, force=False)-> Optional[WorkflowJob]: - + """ Prepare and cache a new job """ if "object_id_latest" in new_job.workflow.config: logger.warning("Old record. Skipping.") return @@ -153,6 +160,7 @@ def prepare_and_cache_new_job(self, new_job: WorkflowJob, opid: str, force=False def get_finished_jobs(self)->Tuple[List[WorkflowJob], List[WorkflowJob]]: + """ Get finished jobs """ successful_jobs = [] failed_jobs = [] for job in self.job_cache: @@ -166,6 +174,7 @@ def get_finished_jobs(self)->Tuple[List[WorkflowJob], List[WorkflowJob]]: def process_successful_job(self, job: WorkflowJob) -> Database: + """ Process a successful job """ logger.info(f"Running post for op {job.opid}") output_path = self.file_handler.get_output_path(job) @@ -183,15 +192,13 @@ def process_successful_job(self, job: WorkflowJob) -> Database: return database - def process_failed_job(self, job): + def process_failed_job(self, job) -> None: + """ Process a failed job """ if job.failed_count < self._MAX_FAILS: job.failed_count += 1 job.cromwell_submit() - - - class RuntimeApiHandler: def __init__(self, config): self.runtime_api = NmdcRuntimeApi(config)