Skip to content

Commit

Permalink
Set task execution options properly
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffxy committed Dec 20, 2024
1 parent e240dd0 commit 05c6a4c
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 51 deletions.
9 changes: 9 additions & 0 deletions proto/maestro.proto
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ message ExecuteTaskRequest {
// Note that unversioned dependencies (i.e., `run_command()` dependencies) are
// not included here.
repeated TaskDependency dep_versions = 4;

// The type of task to run.
ExecuteTaskType execute_task_type = 5;
}

enum ExecuteTaskType {
TT_UNSPECIFIED = 0;
TT_RUN_EXPERIMENT = 1;
TT_RUN_COMMAND = 2;
}

message TaskDependency {
Expand Down
13 changes: 11 additions & 2 deletions src/conductor/envs/maestro/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import conductor.envs.proto_gen.maestro_pb2 as pb
import conductor.envs.proto_gen.maestro_pb2_grpc as maestro_grpc
from conductor.envs.maestro.interface import ExecuteTaskResponse
from conductor.envs.maestro.interface import ExecuteTaskResponse, ExecuteTaskType
from conductor.task_identifier import TaskIdentifier
from conductor.errors import ConductorError
from conductor.errors import ConductorError, InternalError
from conductor.errors.generated import ERRORS_BY_CODE
from conductor.execution.version_index import Version

Expand Down Expand Up @@ -59,6 +59,7 @@ def execute_task(
workspace_rel_project_root: pathlib.Path,
task_identifier: TaskIdentifier,
dep_versions: Dict[TaskIdentifier, Version],
execute_task_type: ExecuteTaskType,
) -> ExecuteTaskResponse:
assert self._stub is not None
# pylint: disable-next=no-member
Expand All @@ -73,6 +74,14 @@ def execute_task(
dv.version.timestamp = version.timestamp
if version.commit_hash is not None:
dv.version.commit_hash = version.commit_hash
if execute_task_type == ExecuteTaskType.RunExperiment:
msg.execute_task_type = pb.TT_RUN_EXPERIMENT # pylint: disable=no-member
elif execute_task_type == ExecuteTaskType.RunCommand:
msg.execute_task_type = pb.TT_RUN_COMMAND # pylint: disable=no-member
else:
raise InternalError(
details=f"Unsupported execute task type {str(execute_task_type)}."
)
result = self._stub.ExecuteTask(msg)
if result.WhichOneof("result") == "error":
raise _pb_to_error(result.error)
Expand Down
62 changes: 42 additions & 20 deletions src/conductor/envs/maestro/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import logging
import pathlib
import time
from typing import Dict
from typing import Any, Dict

from conductor.config import MAESTRO_WORKSPACE_LOCATION, MAESTRO_WORKSPACE_NAME_FORMAT
from conductor.context import Context
from conductor.envs.maestro.interface import MaestroInterface, ExecuteTaskResponse
from conductor.envs.maestro.interface import (
MaestroInterface,
ExecuteTaskResponse,
ExecuteTaskType,
)
from conductor.errors import InternalError
from conductor.execution.executor import Executor
from conductor.execution.operation_state import OperationState
Expand Down Expand Up @@ -56,6 +60,7 @@ async def execute_task(
project_root: pathlib.Path,
task_identifier: TaskIdentifier,
dep_versions: Dict[TaskIdentifier, Version],
execute_task_type: ExecuteTaskType,
) -> ExecuteTaskResponse:
workspace_path = (
self._maestro_root / MAESTRO_WORKSPACE_LOCATION / workspace_name
Expand Down Expand Up @@ -97,24 +102,41 @@ async def execute_task(
deps_output_paths.append(output_path)

# 3. Create the task execution operation.
output_path = task_to_run.get_output_path(ctx)
assert output_path is not None
# NOTE: Set the options flags appropriately.
op = RunTaskExecutable(
initial_state=OperationState.QUEUED,
task=task_to_run,
identifier=task_identifier,
run=task_to_run.raw_run,
args=task_to_run.args,
options=task_to_run.options,
working_path=task_to_run.get_working_path(ctx),
output_path=output_path,
deps_output_paths=deps_output_paths,
record_output=True,
version_to_record=None,
serialize_args_options=True,
parallelizable=task_to_run.parallelizable,
)
kwargs: Dict[str, Any] = {
"initial_state": OperationState.QUEUED,
"task": task_to_run,
"identifier": task_identifier,
"run": task_to_run.raw_run,
"args": task_to_run.args,
"options": task_to_run.options,
"working_path": task_to_run.get_working_path(ctx),
"deps_output_paths": deps_output_paths,
"parallelizable": task_to_run.parallelizable,
}

if execute_task_type == ExecuteTaskType.RunExperiment:
assert isinstance(task_to_run, RunExperiment)
# We need it to be versioned.
exp_version = task_to_run.create_new_version(ctx)
output_path = task_to_run.get_output_path(ctx)
assert output_path is not None
kwargs["record_output"] = True
kwargs["version_to_record"] = exp_version
kwargs["serialize_args_options"] = True
kwargs["output_path"] = output_path
elif execute_task_type == ExecuteTaskType.RunCommand:
output_path = task_to_run.get_output_path(ctx)
assert output_path is not None
kwargs["record_output"] = False
kwargs["version_to_record"] = None
kwargs["serialize_args_options"] = False
kwargs["output_path"] = output_path
else:
raise InternalError(
details=f"Unsupported task type {str(execute_task_type)}."
)

op = RunTaskExecutable(**kwargs) # pylint: disable=missing-kwoa

# 4. Run the task.
plan = ExecutionPlan(
Expand Down
20 changes: 16 additions & 4 deletions src/conductor/envs/maestro/grpc_service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pathlib
import conductor.envs.proto_gen.maestro_pb2_grpc as rpc
import conductor.envs.proto_gen.maestro_pb2 as pb
from conductor.envs.maestro.interface import MaestroInterface
from conductor.errors import ConductorError
from conductor.task_identifier import TaskIdentifier
from conductor.envs.maestro.interface import MaestroInterface, ExecuteTaskType
from conductor.errors import ConductorError, InternalError
from conductor.execution.version_index import Version
from conductor.task_identifier import TaskIdentifier

# pylint: disable=no-member
# See https://github.com/protocolbuffers/protobuf/issues/10372
Expand Down Expand Up @@ -49,8 +49,20 @@ async def ExecuteTask(
has_uncommitted_changes=False,
)
dep_versions[dep_id] = version
if request.execute_task_type == pb.TT_RUN_EXPERIMENT:
execute_task_type = ExecuteTaskType.RunExperiment
elif request.execute_task_type == pb.TT_RUN_COMMAND:
execute_task_type = ExecuteTaskType.RunCommand
else:
raise InternalError(
details=f"Unsupported execute task type {str(request.execute_task_type)}."
)
response = await self._maestro.execute_task(
workspace_name, project_root, task_identifier, dep_versions
workspace_name,
project_root,
task_identifier,
dep_versions,
execute_task_type,
)
return pb.ExecuteTaskResult(
response=pb.ExecuteTaskResponse(
Expand Down
7 changes: 7 additions & 0 deletions src/conductor/envs/maestro/interface.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import pathlib
from typing import Dict, NamedTuple
from conductor.task_identifier import TaskIdentifier
Expand All @@ -9,6 +10,11 @@ class ExecuteTaskResponse(NamedTuple):
end_timestamp: int


class ExecuteTaskType(enum.Enum):
RunExperiment = "run_experiment"
RunCommand = "run_command"


class MaestroInterface:
"""
Captures the RPC interface for Maestro. We use this interface to separate
Expand All @@ -24,6 +30,7 @@ async def execute_task(
project_root: pathlib.Path,
task_identifier: TaskIdentifier,
dep_versions: Dict[TaskIdentifier, Version],
execute_task_type: ExecuteTaskType,
) -> ExecuteTaskResponse:
raise NotImplementedError

Expand Down
38 changes: 20 additions & 18 deletions src/conductor/envs/proto_gen/maestro_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions src/conductor/envs/proto_gen/maestro_pb2.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union

DESCRIPTOR: _descriptor.FileDescriptor
TT_RUN_COMMAND: ExecuteTaskType
TT_RUN_EXPERIMENT: ExecuteTaskType
TT_UNSPECIFIED: ExecuteTaskType

class ConductorError(_message.Message):
__slots__ = ["code", "extra_context", "file_context_line_number", "file_context_path", "kwargs"]
Expand All @@ -28,16 +32,18 @@ class ErrorKwarg(_message.Message):
def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ...

class ExecuteTaskRequest(_message.Message):
__slots__ = ["dep_versions", "project_root", "task_identifier", "workspace_name"]
__slots__ = ["dep_versions", "execute_task_type", "project_root", "task_identifier", "workspace_name"]
DEP_VERSIONS_FIELD_NUMBER: _ClassVar[int]
EXECUTE_TASK_TYPE_FIELD_NUMBER: _ClassVar[int]
PROJECT_ROOT_FIELD_NUMBER: _ClassVar[int]
TASK_IDENTIFIER_FIELD_NUMBER: _ClassVar[int]
WORKSPACE_NAME_FIELD_NUMBER: _ClassVar[int]
dep_versions: _containers.RepeatedCompositeFieldContainer[TaskDependency]
execute_task_type: ExecuteTaskType
project_root: str
task_identifier: str
workspace_name: str
def __init__(self, workspace_name: _Optional[str] = ..., project_root: _Optional[str] = ..., task_identifier: _Optional[str] = ..., dep_versions: _Optional[_Iterable[_Union[TaskDependency, _Mapping]]] = ...) -> None: ...
def __init__(self, workspace_name: _Optional[str] = ..., project_root: _Optional[str] = ..., task_identifier: _Optional[str] = ..., dep_versions: _Optional[_Iterable[_Union[TaskDependency, _Mapping]]] = ..., execute_task_type: _Optional[_Union[ExecuteTaskType, str]] = ...) -> None: ...

class ExecuteTaskResponse(_message.Message):
__slots__ = ["end_timestamp", "start_timestamp"]
Expand Down Expand Up @@ -110,3 +116,6 @@ class UnpackBundleResult(_message.Message):
error: ConductorError
response: UnpackBundleResponse
def __init__(self, response: _Optional[_Union[UnpackBundleResponse, _Mapping]] = ..., error: _Optional[_Union[ConductorError, _Mapping]] = ...) -> None: ...

class ExecuteTaskType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = [] # type: ignore
25 changes: 20 additions & 5 deletions src/conductor/execution/ops/run_remote_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from typing import Dict, Optional

from conductor.context import Context
from conductor.errors import MissingEnvSupport, EnvsRequireGit
from conductor.errors import MissingEnvSupport, EnvsRequireGit, InternalError
from conductor.execution.handle import OperationExecutionHandle
from conductor.execution.ops.operation import Operation
from conductor.execution.operation_state import OperationState
from conductor.task_identifier import TaskIdentifier
from conductor.execution.version_index import Version
from conductor.task_identifier import TaskIdentifier
from conductor.task_types.base import TaskType
from conductor.task_types.run import RunCommand, RunExperiment


class RunRemoteTask(Operation):
Expand All @@ -21,32 +23,45 @@ def __init__(
*,
env_name: str,
workspace_rel_project_root: pathlib.Path,
task_identifier: TaskIdentifier,
task: TaskType,
dep_versions: Dict[TaskIdentifier, Version],
) -> None:
super().__init__(initial_state)
self._env_name = env_name
self._task_identifier = task_identifier
self._task = task
self._dep_versions = dep_versions
self._project_root = workspace_rel_project_root

def start_execution(
self, ctx: Context, slot: Optional[int]
) -> OperationExecutionHandle:
# Import this here to avoid import errors for people who have not
# installed the [envs] extras.
from conductor.envs.maestro.interface import ExecuteTaskType

if ctx.envs is None:
raise MissingEnvSupport()
if not ctx.git.is_used():
raise EnvsRequireGit()

if isinstance(self._task, RunExperiment):
execute_task_type = ExecuteTaskType.RunExperiment
elif isinstance(self._task, RunCommand):
execute_task_type = ExecuteTaskType.RunCommand
else:
# Validation should be performed before this point.
raise InternalError(details=f"Unsupported task type: {type(self._task)}")

remote_env = ctx.envs.get_remote_env(self._env_name)
client = remote_env.client()
workspace_name = remote_env.workspace_name()
# NOTE: This can be made asynchronous if needed.
client.execute_task(
workspace_name,
self._project_root,
self._task_identifier,
self._task.identifier,
self._dep_versions,
execute_task_type,
)
return OperationExecutionHandle.from_sync_execution()

Expand Down

0 comments on commit 05c6a4c

Please sign in to comment.