Skip to content

Commit

Permalink
Add operation for remote task execution
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffxy committed Dec 20, 2024
1 parent 6c3940e commit cc6543a
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/conductor/envs/maestro/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import grpc
import pathlib
from typing import Optional
from typing import Dict, Optional

import conductor.envs.proto_gen.maestro_pb2 as pb
import conductor.envs.proto_gen.maestro_pb2_grpc as maestro_grpc
from conductor.envs.maestro.interface import ExecuteTaskResponse
from conductor.task_identifier import TaskIdentifier
from conductor.errors import ConductorError
from conductor.errors.generated import ERRORS_BY_CODE
from conductor.execution.version_index import Version


class MaestroGrpcClient:
Expand Down Expand Up @@ -55,16 +56,23 @@ def unpack_bundle(self, bundle_path: pathlib.Path) -> str:
def execute_task(
self,
workspace_name: str,
project_root: pathlib.Path,
workspace_rel_project_root: pathlib.Path,
task_identifier: TaskIdentifier,
dep_versions: Dict[TaskIdentifier, Version],
) -> ExecuteTaskResponse:
assert self._stub is not None
# pylint: disable-next=no-member
msg = pb.ExecuteTaskRequest(
workspace_name=workspace_name,
project_root=str(project_root),
project_root=str(workspace_rel_project_root),
task_identifier=str(task_identifier),
)
for task_id, version in dep_versions.items():
dv = msg.dep_versions.add()
dv.task_identifier = str(task_id)
dv.version.timestamp = version.timestamp
if version.commit_hash is not None:
dv.version.commit_hash = version.commit_hash
result = self._stub.ExecuteTask(msg)
if result.WhichOneof("result") == "error":
raise _pb_to_error(result.error)
Expand Down
54 changes: 54 additions & 0 deletions src/conductor/execution/ops/run_remote_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pathlib
from typing import Dict, Optional

from conductor.context import Context
from conductor.errors import MissingEnvSupport, EnvsRequireGit
from conductor.execution.handle import OperationExecutionHandle
from conductor.execution.ops.operation import Operation
from conductor.execution.operation_state import OperationState
from conductor.task_identifier import TaskIdentifier
from conductor.execution.version_index import Version


class RunRemoteTask(Operation):
"""
Runs a task (given by identifier) in a remote environment.
"""

def __init__(
self,
initial_state: OperationState,
*,
env_name: str,
workspace_rel_project_root: pathlib.Path,
task_identifier: TaskIdentifier,
dep_versions: Dict[TaskIdentifier, Version],
) -> None:
super().__init__(initial_state)
self._env_name = env_name
self._task_identifier = task_identifier
self._dep_versions = dep_versions
self._project_root = workspace_rel_project_root

def start_execution(
self, ctx: Context, slot: Optional[int]
) -> OperationExecutionHandle:
if ctx.envs is None:
raise MissingEnvSupport()
if not ctx.git.is_used():
raise EnvsRequireGit()

remote_env = ctx.envs.get_remote_env(self._env_name)
client = remote_env.client()
workspace_name = remote_env.workspace_name()
# NOTE: This can be made asynchronous if needed.
client.execute_task(
workspace_name,
self._project_root,
self._task_identifier,
self._dep_versions,
)
return OperationExecutionHandle.from_sync_execution()

def finish_execution(self, handle: OperationExecutionHandle, ctx: Context) -> None:
pass

0 comments on commit cc6543a

Please sign in to comment.