From 303ea3ed588a3daa76c64bb2970a21e238082b73 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Tue, 24 Dec 2024 14:55:07 -0500 Subject: [PATCH] [Envs][Execution] Add general task archival/restore utilities and refactor existing code (#124) This adds generic task result archival/restore utilities that will be used for results transfer to/from an environment. This also refactors the existing archive/restore tools to use these new utilities. This PR also: - Resolves #82 - Resolves #106 (we switch to `zstdmt` by default on Linux platforms) --- .gitignore | 1 + errors/errors.yml | 6 + pyproject.toml | 1 + src/conductor/cli/archive.py | 80 ++--- src/conductor/cli/restore.py | 84 +----- src/conductor/config.py | 2 +- src/conductor/envs/maestro/daemon.py | 2 + src/conductor/errors/generated.py | 16 + src/conductor/execution/version_index.py | 82 +++++- .../execution/version_index_queries.py | 27 ++ src/conductor/filename.py | 11 +- src/conductor/utils/output_archiving.py | 274 ++++++++++++++++++ tests/cond_archive_restore_test.py | 89 +++++- tests/conductor_runner.py | 9 +- tests/general_archiving_test.py | 56 ++++ website/docs/cli/restore.md | 7 + 16 files changed, 593 insertions(+), 154 deletions(-) create mode 100644 src/conductor/utils/output_archiving.py create mode 100644 tests/general_archiving_test.py diff --git a/.gitignore b/.gitignore index 0976761..cbc7a86 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ src/conductor_cli.egg-info .mypy_cache build dist +env cond-out diff --git a/errors/errors.yml b/errors/errors.yml index a6c68e3..4f57510 100644 --- a/errors/errors.yml +++ b/errors/errors.yml @@ -184,6 +184,12 @@ The provided archive contains task output(s) that already exist in the output directory '{output_dir}'. +4007: + name: UnsupportedArchiveType + message: >- + The provided archive file was compressed as {archive_type}, which is not + supported on your platform. + # General Conductor errors (error code 5xxx) 5001: diff --git a/pyproject.toml b/pyproject.toml index 776f7df..8ddea3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ exclude = ''' | src/conductor/errors/generated.py | src/conductor/envs/proto_gen/ | explorer/ + | env/ ) ''' diff --git a/src/conductor/cli/archive.py b/src/conductor/cli/archive.py index 7dff79e..e5d3b90 100644 --- a/src/conductor/cli/archive.py +++ b/src/conductor/cli/archive.py @@ -1,21 +1,22 @@ import pathlib import datetime -import subprocess from typing import List, Optional import conductor.filename as f -from conductor.config import ARCHIVE_VERSION_INDEX from conductor.context import Context from conductor.errors import ( - CreateArchiveFailed, OutputFileExists, OutputPathDoesNotExist, NoTaskOutputsToArchive, ) from conductor.task_identifier import TaskIdentifier from conductor.task_types.base import TaskType -from conductor.execution.version_index import VersionIndex from conductor.utils.user_code import cli_command +from conductor.utils.output_archiving import ( + create_archive, + platform_archive_type, + ArchiveType, +) def register_command(subparsers): @@ -49,16 +50,18 @@ def register_command(subparsers): parser.set_defaults(func=main) -def generate_archive_name() -> str: +def generate_archive_name(archive_type: ArchiveType) -> str: timestamp = datetime.datetime.now().strftime("%Y-%m-%d+%H-%M-%S") - return f.archive(timestamp=timestamp) + return f.archive(timestamp=timestamp, archive_type=archive_type) -def handle_output_path(ctx: Context, raw_output_path: Optional[str]) -> pathlib.Path: +def handle_output_path( + ctx: Context, raw_output_path: Optional[str], archive_type: ArchiveType +) -> pathlib.Path: if raw_output_path is None: output_path = pathlib.Path( ctx.output_path, - generate_archive_name(), + generate_archive_name(archive_type), ) return output_path @@ -70,7 +73,7 @@ def handle_output_path(ctx: Context, raw_output_path: Optional[str]) -> pathlib. if output_path.is_dir(): # Corresponds to the case where the user provides a path to a # directory where the archive should be stored - return output_path / generate_archive_name() + return output_path / generate_archive_name(archive_type) raise OutputFileExists() elif output_path.parent.exists() and output_path.parent.is_dir(): @@ -105,49 +108,11 @@ def append_if_archivable(task: TaskType): return relevant_tasks -def create_archive( - ctx: Context, - archive_index: VersionIndex, - output_archive_path: pathlib.Path, - archive_index_path: pathlib.Path, -) -> None: - output_dirs_str = [ - str( - pathlib.Path( - task_id.path, - f.task_output_dir(task_id, version), - ) - ) - for task_id, version in archive_index.get_all_versions() - ] - - try: - process = subprocess.Popen( - [ - "tar", - "czf", # Create a new archive and use gzip to compress - str(output_archive_path), - "-C", # Files to put in the archive are relative to `ctx.output_path` - str(ctx.output_path), - str(archive_index_path.relative_to(ctx.output_path)), - *output_dirs_str, - ], - shell=False, - ) - process.wait() - if process.returncode != 0: - raise CreateArchiveFailed().add_extra_context( - "The tar utility returned a non-zero error code." - ) - - except OSError as ex: - raise CreateArchiveFailed().add_extra_context(str(ex)) - - @cli_command def main(args): ctx = Context.from_cwd() - output_archive_path = handle_output_path(ctx, args.output) + archive_type = platform_archive_type() + output_archive_path = handle_output_path(ctx, args.output, archive_type) # If `None`, we should archive all tasks tasks_to_archive = compute_tasks_to_archive(ctx, args.task_identifier) @@ -155,17 +120,15 @@ def main(args): raise NoTaskOutputsToArchive() try: - archive_index_path = pathlib.Path(ctx.output_path, ARCHIVE_VERSION_INDEX) - archive_index_path.unlink(missing_ok=True) - archive_index = VersionIndex.create_or_load(archive_index_path) - total_entry_count = ctx.version_index.copy_entries_to( - dest=archive_index, tasks=tasks_to_archive, latest_only=args.latest + tasks_to_archive_with_versions = ctx.version_index.get_versioned_tasks( + tasks=tasks_to_archive, latest_only=args.latest ) - if total_entry_count == 0: + if len(tasks_to_archive_with_versions) == 0: raise NoTaskOutputsToArchive() - archive_index.commit_changes() - create_archive(ctx, archive_index, output_archive_path, archive_index_path) + create_archive( + ctx, tasks_to_archive_with_versions, output_archive_path, archive_type + ) # Compute a relative path to the current working directory, if possible try: @@ -177,6 +140,3 @@ def main(args): except: output_archive_path.unlink(missing_ok=True) raise - - finally: - archive_index_path.unlink(missing_ok=True) diff --git a/src/conductor/cli/restore.py b/src/conductor/cli/restore.py index e82eee7..ac499a2 100644 --- a/src/conductor/cli/restore.py +++ b/src/conductor/cli/restore.py @@ -1,14 +1,9 @@ import pathlib -import subprocess -import shutil -import sqlite3 -import conductor.filename as f -from conductor.config import ARCHIVE_STAGING, ARCHIVE_VERSION_INDEX from conductor.context import Context -from conductor.errors import ArchiveFileInvalid, DuplicateTaskOutput -from conductor.execution.version_index import VersionIndex +from conductor.errors import ArchiveFileInvalid from conductor.utils.user_code import cli_command +from conductor.utils.output_archiving import restore_archive def register_command(subparsers): @@ -21,25 +16,14 @@ def register_command(subparsers): type=str, help="Path to the archive file to restore.", ) + parser.add_argument( + "--strict", + action="store_true", + help="If set, the restore operation will fail if any task output is already present.", + ) parser.set_defaults(func=main) -def extract_archive(archive_file: pathlib.Path, staging_path: pathlib.Path): - try: - process = subprocess.Popen( - ["tar", "xzf", str(archive_file), "-C", str(staging_path)], - shell=False, - ) - process.wait() - if process.returncode != 0: - raise ArchiveFileInvalid().add_extra_context( - "The tar utility returned a non-zero error code." - ) - - except OSError as ex: - raise ArchiveFileInvalid().add_extra_context(str(ex)) - - @cli_command def main(args): ctx = Context.from_cwd() @@ -48,56 +32,4 @@ def main(args): if not archive_file.is_file(): raise ArchiveFileInvalid() - try: - archive_version_index = None - staging_path = ctx.output_path / ARCHIVE_STAGING - staging_path.mkdir(exist_ok=True) - extract_archive(archive_file, staging_path) - - archive_version_index_path = staging_path / ARCHIVE_VERSION_INDEX - if not archive_version_index_path.is_file(): - raise ArchiveFileInvalid().add_extra_context( - "Could not locate the archive version index." - ) - - archive_version_index = VersionIndex.create_or_load(archive_version_index_path) - try: - archive_version_index.copy_entries_to( - dest=ctx.version_index, tasks=None, latest_only=False - ) - except sqlite3.IntegrityError as ex: - raise DuplicateTaskOutput(output_dir=str(ctx.output_path)) from ex - - # Copy over all archived task outputs - for task_id, version in archive_version_index.get_all_versions(): - src_task_path = pathlib.Path( - staging_path, task_id.path, f.task_output_dir(task_id, version) - ) - if not src_task_path.is_dir(): - raise ArchiveFileInvalid().add_extra_context( - "Missing archived task output for '{}' at version {} in the " - "archive.".format(str(task_id), str(version)) - ) - - dest_task_path = pathlib.Path( - ctx.output_path, task_id.path, f.task_output_dir(task_id, version) - ) - dest_task_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copytree(src_task_path, dest_task_path) - if not dest_task_path.is_dir(): - raise ArchiveFileInvalid().add_extra_context( - "Missing copied archived task output for '{}' at version {}.".format( - str(task_id), str(version) - ) - ) - - # Everything was copied over and verified - safe to commit the index changes - ctx.version_index.commit_changes() - - except: - ctx.version_index.rollback_changes() - raise - - finally: - del archive_version_index - shutil.rmtree(staging_path, ignore_errors=True) + restore_archive(ctx, archive_file, expect_no_duplicates=args.strict) diff --git a/src/conductor/config.py b/src/conductor/config.py index 9ca12ee..870e016 100644 --- a/src/conductor/config.py +++ b/src/conductor/config.py @@ -42,7 +42,7 @@ SLOT_ENV_VARIABLE_NAME = "COND_SLOT" # A template for the default name of a Conductor archive. -ARCHIVE_FILE_NAME_TEMPLATE = "cond-archive+{timestamp}.tar.gz" +ARCHIVE_FILE_NAME_TEMPLATE = "cond-archive+{timestamp}.tar.{extension}" # The file name of the version index used in a Conductor archive. ARCHIVE_VERSION_INDEX = "version_index_archive.sqlite" diff --git a/src/conductor/envs/maestro/daemon.py b/src/conductor/envs/maestro/daemon.py index 0507a4a..a8ce2ad 100644 --- a/src/conductor/envs/maestro/daemon.py +++ b/src/conductor/envs/maestro/daemon.py @@ -148,6 +148,8 @@ async def execute_task( ) executor = Executor(execution_slots=1, silent=True) executor.run_plan(plan, ctx) + # Make sure any new versions are committed. + ctx.version_index.commit_changes() end_timestamp = int(time.time()) return ExecuteTaskResponse( diff --git a/src/conductor/errors/generated.py b/src/conductor/errors/generated.py index 2d45b3e..1537fde 100644 --- a/src/conductor/errors/generated.py +++ b/src/conductor/errors/generated.py @@ -554,6 +554,20 @@ def _message(self): ) +class UnsupportedArchiveType(ConductorError): + error_code = 4007 + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + self.archive_type = kwargs["archive_type"] + + def _message(self): + return "The provided archive file was compressed as {archive_type}, which is not supported on your platform.".format( + archive_type=self.archive_type, + ) + + class ConfigParseError(ConductorError): error_code = 5001 @@ -775,6 +789,7 @@ def _message(self): 4004: CreateArchiveFailed, 4005: ArchiveFileInvalid, 4006: DuplicateTaskOutput, + 4007: UnsupportedArchiveType, 5001: ConfigParseError, 5002: ConfigInvalidValue, 5003: UnsupportedPlatform, @@ -830,6 +845,7 @@ def _message(self): "CreateArchiveFailed", "ArchiveFileInvalid", "DuplicateTaskOutput", + "UnsupportedArchiveType", "ConfigParseError", "ConfigInvalidValue", "UnsupportedPlatform", diff --git a/src/conductor/execution/version_index.py b/src/conductor/execution/version_index.py index a912ec1..2dedbd3 100644 --- a/src/conductor/execution/version_index.py +++ b/src/conductor/execution/version_index.py @@ -151,11 +151,22 @@ def generate_new_output_version(self, commit: Optional[Git.Commit]) -> Version: timestamp, commit_hash, commit.has_changes if commit is not None else False ) - def insert_output_version(self, task_identifier: TaskIdentifier, version: Version): + def insert_output_version( + self, task_identifier: TaskIdentifier, version: Version, unchecked: bool = False + ) -> int: + """ + Insert a versioned task into the index and return the number of rows + inserted. This will return 0 rows inserted if the task version is + already in the index. + """ + if unchecked: + query = q.insert_new_version_unchecked + else: + query = q.insert_new_version cursor = self._conn.cursor() has_uncommitted_changes = 1 if version.has_uncommitted_changes else 0 cursor.execute( - q.insert_new_version, + query, ( str(task_identifier), version.timestamp, @@ -163,6 +174,7 @@ def insert_output_version(self, task_identifier: TaskIdentifier, version: Versio has_uncommitted_changes, ), ) + return cursor.rowcount def get_all_versions(self) -> List[Tuple[TaskIdentifier, Version]]: cursor = self._conn.cursor() @@ -172,6 +184,46 @@ def get_all_versions(self) -> List[Tuple[TaskIdentifier, Version]]: for row in cursor ] + def get_all_unversioned(self) -> List[TaskIdentifier]: + try: + cursor = self._conn.cursor() + cursor.execute(q.get_unversioned_tasks) + return [TaskIdentifier.from_str(row[0]) for row in cursor] + except sqlite3.OperationalError: + # The unversioned table does not exist, so there are no unversioned tasks. + # We create the unversioned table only when we add unversioned tasks. + return [] + + def get_versioned_tasks( + self, tasks: Optional[List[TaskIdentifier]], latest_only: bool + ) -> List[Tuple[TaskIdentifier, Version]]: + cursor = self._conn.cursor() + if tasks is None: + if latest_only: + cursor.execute(q.all_entries_latest) + else: + cursor.execute(q.all_entries) + return [ + (TaskIdentifier.from_str(row[0]), self._version_from_row(row[1:])) + for row in cursor + ] + + else: + results = [] + for task_id in tasks: + if latest_only: + cursor.execute(q.latest_entry_for_task, (str(task_id),)) + else: + cursor.execute(q.all_entries_for_task, (str(task_id),)) + for row in cursor: + results.append( + ( + TaskIdentifier.from_str(row[0]), + self._version_from_row(row[1:]), + ) + ) + return results + def copy_entries_to( self, dest: "VersionIndex", @@ -195,6 +247,22 @@ def copy_entries_to( insert_count += dest.bulk_load(cursor) return insert_count + @staticmethod + def copy_specific_entries_to( + dest: "VersionIndex", entries: List[Tuple[TaskIdentifier, Version]] + ) -> int: + values = [] + for task_id, version in entries: + values.append( + ( + str(task_id), + version.timestamp, + version.commit_hash, + 1 if version.has_uncommitted_changes else 0, + ) + ) + return dest.bulk_load(values) + def bulk_load(self, rows: Iterable) -> int: """ Load the rows into the index and return the number loaded @@ -203,6 +271,16 @@ def bulk_load(self, rows: Iterable) -> int: cursor.executemany(q.insert_new_version, rows) return cursor.rowcount + def bulk_load_unversioned(self, task_ids: Iterable[TaskIdentifier]) -> int: + """ + Load the unversioned task IDs into the index and return the number loaded. + """ + cursor = self._conn.cursor() + cursor.execute(q.create_unversioned_table) + rows = [(str(task_id),) for task_id in task_ids] + cursor.executemany(q.add_unversioned_task, rows) + return cursor.rowcount + def commit_changes(self): if not self._conn.in_transaction: return diff --git a/src/conductor/execution/version_index_queries.py b/src/conductor/execution/version_index_queries.py index 74c3e16..2b5f62c 100644 --- a/src/conductor/execution/version_index_queries.py +++ b/src/conductor/execution/version_index_queries.py @@ -8,6 +8,15 @@ ) """ +# Used when archiving non-versioned task outputs (usually for transport to/from +# a remote environment). +create_unversioned_table = """ + CREATE TABLE IF NOT EXISTS unversioned ( + task_identifier TEXT NOT NULL, + PRIMARY KEY (task_identifier) + ) +""" + set_format_version = "PRAGMA user_version = {version:d}" get_format_version = "PRAGMA user_version" @@ -24,6 +33,16 @@ VALUES (?, ?, ?, ?) """ +insert_new_version_unchecked = """ + INSERT OR IGNORE INTO version_index ( + task_identifier, + timestamp, + git_commit_hash, + has_uncommitted_changes + ) + VALUES (?, ?, ?, ?) +""" + latest_task_version = """ SELECT timestamp, @@ -101,6 +120,14 @@ version_index """ +add_unversioned_task = """ + INSERT INTO unversioned (task_identifier) VALUES (?) +""" + +get_unversioned_tasks = """ + SELECT task_identifier FROM unversioned +""" + # Queries used in format 1 (retained for testing purposes) diff --git a/src/conductor/filename.py b/src/conductor/filename.py index b7da280..141b6db 100644 --- a/src/conductor/filename.py +++ b/src/conductor/filename.py @@ -1,12 +1,17 @@ -from typing import Optional +from typing import Optional, TYPE_CHECKING from conductor.config import ARCHIVE_FILE_NAME_TEMPLATE, TASK_OUTPUT_DIR_SUFFIX from conductor.execution.version_index import Version from conductor.task_identifier import TaskIdentifier +if TYPE_CHECKING: + from conductor.utils.output_archiving import ArchiveType -def archive(timestamp: str) -> str: - return ARCHIVE_FILE_NAME_TEMPLATE.format(timestamp=timestamp) + +def archive(timestamp: str, archive_type: "ArchiveType") -> str: + return ARCHIVE_FILE_NAME_TEMPLATE.format( + timestamp=timestamp, extension=archive_type.extension() + ) def task_output_dir( diff --git a/src/conductor/utils/output_archiving.py b/src/conductor/utils/output_archiving.py new file mode 100644 index 0000000..bf83a1a --- /dev/null +++ b/src/conductor/utils/output_archiving.py @@ -0,0 +1,274 @@ +import enum +import pathlib +import platform +import subprocess +import shutil +from typing import List, Optional, Tuple, TYPE_CHECKING + +import conductor.filename as f +from conductor.config import ARCHIVE_VERSION_INDEX, ARCHIVE_STAGING +from conductor.errors import ( + InternalError, + CreateArchiveFailed, + ArchiveFileInvalid, + DuplicateTaskOutput, + UnsupportedPlatform, + UnsupportedArchiveType, +) +from conductor.execution.version_index import VersionIndex, Version +from conductor.task_identifier import TaskIdentifier + +if TYPE_CHECKING: + from conductor.context import Context + + +class ArchiveType(enum.Enum): + Gzip = "gzip" + Zstd = "zstdmt" + + def extension(self): + if self == ArchiveType.Gzip: + return "gz" + elif self == ArchiveType.Zstd: + return "zst" + else: + raise InternalError(details="Unknown archive type.") + + +def platform_archive_type() -> ArchiveType: + system = platform.system() + if system == "Linux": + return ArchiveType.Zstd + elif system == "Darwin": + return ArchiveType.Gzip + else: + # Windows is unsupported. We check for platform support at the beginning + # of all Conductor commands. + raise UnsupportedPlatform() + + +def create_archive( + ctx: "Context", + tasks_to_archive: List[Tuple[TaskIdentifier, Optional[Version]]], + output_archive_path: pathlib.Path, + archive_type: ArchiveType, +) -> None: + """ + This utility is used to create an archive of the output directories of the + given tasks for transport purposes (e.g., moving data to/from a remote + environment). + """ + + # Ensure versions are specified when they should be specified. + # Partition tasks into versioned and unversioned tasks. + versioned_tasks = [] + unversioned_tasks = [] + for task_id, version in tasks_to_archive: + task = ctx.task_index.get_task(task_id) + if task.archivable: + if version is None: + raise InternalError( + details=f"Did not provide a version for an archivable task {str(task_id)}." + ) + versioned_tasks.append((task_id, version)) + else: + unversioned_tasks.append(task_id) + + try: + archive_index_path = ctx.output_path / ARCHIVE_VERSION_INDEX + archive_index_path.unlink(missing_ok=True) + + # Store the versions of the tasks that are being archived. + archive_index = VersionIndex.create_or_load( + ctx.output_path / ARCHIVE_VERSION_INDEX + ) + VersionIndex.copy_specific_entries_to(archive_index, versioned_tasks) + archive_index.bulk_load_unversioned(unversioned_tasks) + archive_index.commit_changes() + + # Collect the output directories for the tasks to archive. + output_dirs_str = [] + for task_id, version in versioned_tasks: + output_dirs_str.append( + str( + pathlib.Path( + task_id.path, + f.task_output_dir(task_id, version), + ) + ) + ) + for task_id in unversioned_tasks: + output_dirs_str.append( + str(pathlib.Path(task_id.path, f.task_output_dir(task_id))) + ) + + # Create the archive. + process_args = [ + "tar", + "-cf", + str(output_archive_path), + "--use-compress-program", + archive_type.value, + "-C", # Files to put in the archive are relative to `ctx.output_path` + str(ctx.output_path), + str(archive_index_path.relative_to(ctx.output_path)), + *output_dirs_str, + ] + result = subprocess.run(process_args, check=False, capture_output=True) + if result.returncode != 0: + raise CreateArchiveFailed().add_extra_context( + "The tar utility returned a non-zero error code." + ) + + finally: + # Clean up the archive index file. + archive_index_path.unlink(missing_ok=True) + + +def restore_archive( + ctx: "Context", + archive_path: pathlib.Path, + archive_type: Optional[ArchiveType] = None, + expect_no_duplicates: bool = False, +) -> None: + """ + This utility is used to restore the output directories of tasks from an + archive created by `create_archive`. + """ + + try: + # Extract the archive to a staging location. + staging_path = ctx.output_path / ARCHIVE_STAGING + staging_path.mkdir(parents=True, exist_ok=True) + if archive_type is None: + archive_type = _infer_compress_program(archive_path) + if not _supports_compress_program(archive_type): + raise UnsupportedArchiveType(archive_type=archive_type.value) + _extract_archive(archive_path, staging_path, archive_type) + + # Load the archive version index. + archive_version_index_path = staging_path / ARCHIVE_VERSION_INDEX + archive_version_index = VersionIndex.create_or_load(archive_version_index_path) + + successful_restore_dirs: List[pathlib.Path] = [] + # Copy over versioned tasks, skipping the ones that already exist. + for task_id, version in archive_version_index.get_all_versions(): + insert_count = ctx.version_index.insert_output_version( + task_id, version, unchecked=True + ) + if insert_count == 0: + # Version already exists in the current version index. + if expect_no_duplicates: + # We should not have duplicates, so we raise an error. + # NOTE: We clear the successful restore directories before + # raising the error to keep the output directory clean. + for d in successful_restore_dirs: + shutil.rmtree(d, ignore_errors=True) + raise DuplicateTaskOutput(output_dir=str(ctx.output_path)) + else: + # We skip copying over this task. + continue + + src_task_path = pathlib.Path( + staging_path, task_id.path, f.task_output_dir(task_id, version) + ) + if not src_task_path.is_dir(): + raise ArchiveFileInvalid().add_extra_context( + "Missing archived task output for '{}' at version {} in the " + "archive.".format(str(task_id), str(version)) + ) + + dest_task_path = pathlib.Path( + ctx.output_path, task_id.path, f.task_output_dir(task_id, version) + ) + dest_task_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copytree(src_task_path, dest_task_path) + if not dest_task_path.is_dir(): + raise ArchiveFileInvalid().add_extra_context( + "Missing copied archived task output for '{}' at version {}.".format( + str(task_id), str(version) + ) + ) + successful_restore_dirs.append(dest_task_path) + + # Copy over unversioned tasks. We always blindly copy over these outputs. + for task_id in archive_version_index.get_all_unversioned(): + src_task_path = pathlib.Path( + staging_path, task_id.path, f.task_output_dir(task_id) + ) + if not src_task_path.is_dir(): + raise ArchiveFileInvalid().add_extra_context( + "Missing archived task output for '{}' in the " + "archive.".format(str(task_id)) + ) + + dest_task_path = pathlib.Path( + ctx.output_path, task_id.path, f.task_output_dir(task_id) + ) + dest_task_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copytree(src_task_path, dest_task_path, dirs_exist_ok=True) + if not dest_task_path.is_dir(): + raise ArchiveFileInvalid().add_extra_context( + "Missing copied archived task output for '{}'.".format(str(task_id)) + ) + + # Safe to commit now. + ctx.version_index.commit_changes() + + except: + # Something went wrong, so undo our changes. + ctx.version_index.rollback_changes() + raise + + finally: + # Clean up the staging directory. + shutil.rmtree(staging_path, ignore_errors=True) + + +def _extract_archive( + archive_file: pathlib.Path, staging_path: pathlib.Path, archive_type: ArchiveType +) -> None: + try: + process_args = [ + "tar", + "-xf", + str(archive_file), + ] + if archive_type == ArchiveType.Zstd: + process_args.append("--use-compress-program") + process_args.append(archive_type.value) + process_args.extend( + [ + "-C", + str(staging_path), + ] + ) + result = subprocess.run(process_args, check=False, capture_output=True) + if result.returncode != 0: + raise ArchiveFileInvalid().add_extra_context( + "The tar utility returned a non-zero error code." + ) + + except OSError as ex: + raise ArchiveFileInvalid().add_extra_context(str(ex)) + + +def _infer_compress_program(archive_file: pathlib.Path) -> ArchiveType: + if archive_file.suffix == ".gz": + # This is a heuristic we use to support legacy Conductor archives (which + # were gzip-compressed) or archives created on macOS (which does not + # have zstd installed by default). + return ArchiveType.Gzip + else: + # Conductor has switched to using zstd for compression. + return ArchiveType.Zstd + + +def _supports_compress_program(archive_type: ArchiveType) -> bool: + system = platform.system() + if archive_type == ArchiveType.Zstd: + return system == "Linux" + elif archive_type == ArchiveType.Gzip: + return system == "Linux" or system == "Darwin" + else: + return False diff --git a/tests/cond_archive_restore_test.py b/tests/cond_archive_restore_test.py index e5ca6ad..fa5d91b 100644 --- a/tests/cond_archive_restore_test.py +++ b/tests/cond_archive_restore_test.py @@ -1,7 +1,13 @@ import pathlib import shutil -from .conductor_runner import ConductorRunner, count_task_outputs, EXAMPLE_TEMPLATES +from .conductor_runner import ( + ConductorRunner, + count_task_outputs, + EXAMPLE_TEMPLATES, + FIXTURE_TEMPLATES, +) +from conductor.utils.output_archiving import platform_archive_type def test_archive_restore(tmp_path: pathlib.Path): @@ -16,12 +22,14 @@ def test_archive_restore(tmp_path: pathlib.Path): result = cond.archive("//figures:graph", output_path=None, latest=False) assert result.returncode == 0 + expected_archive_type = platform_archive_type() + # Make sure we found the archive found_archive = False archive_name = None orig_archive_path = None for file in cond.output_path.iterdir(): - if file.name.endswith(".tar.gz"): + if file.name.endswith(f".tar.{expected_archive_type.extension()}"): found_archive = True archive_name = file.name orig_archive_path = file @@ -32,12 +40,12 @@ def test_archive_restore(tmp_path: pathlib.Path): # Restoring the archive into an output directory that already contains the results # should fail - result = cond.restore(orig_archive_path) + result = cond.restore(orig_archive_path, strict=True) assert result.returncode != 0 # Remove the output directory and then try restoring shutil.rmtree(cond.output_path) - result = cond.restore(cond.project_root / archive_name) + result = cond.restore(cond.project_root / archive_name, strict=True) assert result.returncode == 0 # Only the run_experiment() task output should be restored @@ -48,7 +56,7 @@ def test_archive_restore(tmp_path: pathlib.Path): def test_restore_invalid(tmp_path: pathlib.Path): cond = ConductorRunner.from_template(tmp_path, EXAMPLE_TEMPLATES["hello_world"]) - result = cond.restore(cond.output_path / "non_existent.tar.gz") + result = cond.restore(cond.output_path / "non_existent.tar.zst", strict=True) assert result.returncode != 0 @@ -57,7 +65,9 @@ def test_archive_output(tmp_path: pathlib.Path): result = cond.run("//:hello_world") assert result.returncode == 0 - output_archive = cond.project_root / "custom.tar.gz" + expected_archive_type = platform_archive_type() + expected_extension = expected_archive_type.extension() + output_archive = cond.project_root / f"custom.tar.{expected_extension}" assert not output_archive.exists() result = cond.archive("//:hello_world", output_path=output_archive, latest=False) assert result.returncode == 0 @@ -73,11 +83,16 @@ def test_archive_output_dir(tmp_path: pathlib.Path): result = cond.archive("//:hello_world", output_path=cond.project_root, latest=False) assert result.returncode == 0 + expected_archive_type = platform_archive_type() + expected_extension = expected_archive_type.extension() + # Ensure the archive was saved in the correct output directory with a # Conductor-provided name archive_found = False for file in cond.project_root.iterdir(): - if file.name.startswith("cond-archive") and file.name.endswith(".tar.gz"): + if file.name.startswith("cond-archive") and file.name.endswith( + f".tar.{expected_extension}" + ): archive_found = True break assert archive_found @@ -108,20 +123,74 @@ def test_archive_restore_latest(tmp_path: pathlib.Path): assert result.returncode == 0 assert count_task_outputs(cond.output_path) == 2 + expected_archive_type = platform_archive_type() + expected_extension = expected_archive_type.extension() + # Archive the latest only - output_archive = cond.project_root / "latest.tar.gz" + output_archive = cond.project_root / f"latest.tar.{expected_extension}" assert not output_archive.exists() result = cond.archive("//:hello_world", output_path=output_archive, latest=True) assert result.returncode == 0 assert output_archive.exists() and output_archive.is_file() # Restoring into an existing experiment output directory should fail - result = cond.restore(output_archive) + result = cond.restore(output_archive, strict=True) assert result.returncode != 0 # Restore into an empty output path. Only one of the experiment outputs # should have been archived (and thus restored). shutil.rmtree(cond.output_path) - result = cond.restore(output_archive) + result = cond.restore(output_archive, strict=True) assert result.returncode == 0 assert count_task_outputs(cond.output_path) == 1 + + +def test_partial_archive_restore(tmp_path: pathlib.Path): + cond = ConductorRunner.from_template(tmp_path, FIXTURE_TEMPLATES["experiments"]) + result = cond.run("//sweep:threads-args-1") + assert result.returncode == 0 + + # Make an archive with just this first result. + expected_archive_type = platform_archive_type() + expected_extension = expected_archive_type.extension() + output_archive1 = cond.project_root / f"partial1.tar.{expected_extension}" + assert not output_archive1.exists() + result = cond.archive( + "//sweep:threads-args", output_path=output_archive1, latest=False + ) + assert result.returncode == 0 + assert output_archive1.exists() and output_archive1.is_file() + + # Run the second experiment. + result = cond.run("//sweep:threads-args-2") + assert result.returncode == 0 + + # Make an archive with both results. + output_archive2 = cond.project_root / f"partial2.tar.{expected_extension}" + assert not output_archive2.exists() + result = cond.archive( + "//sweep:threads-args", output_path=output_archive2, latest=False + ) + assert result.returncode == 0 + assert output_archive2.exists() and output_archive2.is_file() + + shutil.rmtree(cond.output_path) + expected_task_out = cond.output_path / "sweep" + + # Restore the first archive - it should succeed. + result = cond.restore(output_archive1, strict=False) + assert result.returncode == 0 + assert count_task_outputs(expected_task_out) == 1 + + # Restore the second archive - on strict mode it should fail. + result = cond.restore(output_archive2, strict=True) + assert result.returncode != 0 + assert count_task_outputs(expected_task_out) == 1 + + # Restore on non-strict mode. + result = cond.restore(output_archive2, strict=False) + if result.returncode != 0: + print(result.stdout.decode("utf-8")) + print(result.stderr.decode("utf-8")) + assert result.returncode == 0 + assert count_task_outputs(expected_task_out) == 2 diff --git a/tests/conductor_runner.py b/tests/conductor_runner.py index a4aee37..c2ad08d 100644 --- a/tests/conductor_runner.py +++ b/tests/conductor_runner.py @@ -88,8 +88,13 @@ def archive( cmd.append("--latest") return self._run_command(cmd) - def restore(self, archive_path: pathlib.Path) -> subprocess.CompletedProcess: - return self._run_command(["restore", str(archive_path)]) + def restore( + self, archive_path: pathlib.Path, strict: bool + ) -> subprocess.CompletedProcess: + cmd = ["restore", str(archive_path)] + if strict: + cmd.append("--strict") + return self._run_command(cmd) def gc( self, dry_run: bool = False, verbose: bool = False diff --git a/tests/general_archiving_test.py b/tests/general_archiving_test.py new file mode 100644 index 0000000..090fce5 --- /dev/null +++ b/tests/general_archiving_test.py @@ -0,0 +1,56 @@ +import pathlib +from conductor.config import VERSION_INDEX_NAME +from conductor.context import Context +from conductor.execution.version_index import VersionIndex +from conductor.task_identifier import TaskIdentifier +from conductor.utils.output_archiving import ( + create_archive, + restore_archive, + ArchiveType, +) +from .conductor_runner import ConductorRunner, EXAMPLE_TEMPLATES + + +def test_overall_archiving(tmp_path: pathlib.Path): + cond = ConductorRunner.from_template(tmp_path, EXAMPLE_TEMPLATES["dependencies"]) + result = cond.run("//figures:graph") + assert result.returncode == 0 + + run_benchmark_id = TaskIdentifier.from_str("//experiments:run_benchmark") + figures_id = TaskIdentifier.from_str("//figures:graph") + version_index = VersionIndex.create_or_load(cond.output_path / VERSION_INDEX_NAME) + versions = version_index.get_all_versions_for_task(run_benchmark_id) + assert len(versions) == 1 + + ctx = Context(cond.project_root) + ctx.task_index.load_transitive_closure(figures_id) + to_archive = [ + (run_benchmark_id, versions[0]), + (figures_id, None), + ] + archive_output_path = cond.project_root / "test_archive.tar.gz" + assert not archive_output_path.exists() + create_archive(ctx, to_archive, archive_output_path, archive_type=ArchiveType.Gzip) + assert archive_output_path.exists() + + # Clear the output directory and recreate the Context to clear out cached + # state. + result = cond.clean() + ctx = Context(cond.project_root) + assert result.returncode == 0 + + # Restore the archive. + restore_archive(ctx, archive_output_path, archive_type=ArchiveType.Gzip) + + # Check that the output directories for the relevant tasks were restored. + expt_out_dir = cond.find_task_output_dir(str(run_benchmark_id), is_experiment=True) + assert expt_out_dir is not None + assert expt_out_dir.exists() + assert expt_out_dir.is_dir() + assert (expt_out_dir / "results.csv").exists() + + figures_out_dir = cond.find_task_output_dir(str(figures_id), is_experiment=False) + assert figures_out_dir is not None + assert figures_out_dir.exists() + assert figures_out_dir.is_dir() + assert (figures_out_dir / "graph.csv").exists() diff --git a/website/docs/cli/restore.md b/website/docs/cli/restore.md index b293fc2..c63fe24 100644 --- a/website/docs/cli/restore.md +++ b/website/docs/cli/restore.md @@ -19,6 +19,13 @@ for more details about Conductor's archive and restore features. The path to the archive file to restore. +### `--strict` + +By default, Conductor will only restore result versions that do not exist in +your local results directory. If this flag is set, Conductor will abort the +restore if any result versions in the archive already exist in your local +results directory. + ## Optional Arguments ### `-h` or `--help`