Skip to content

Commit

Permalink
feat: use gokart worker (#402)
Browse files Browse the repository at this point in the history
* feat: copied worker from luigi

* chore: remove unused components and add type annotations

* feat: add option to use custom worker on build
  • Loading branch information
hiro-o918 authored Nov 2, 2024
1 parent 17f7064 commit 61fd81b
Show file tree
Hide file tree
Showing 4 changed files with 1,134 additions and 3 deletions.
2 changes: 1 addition & 1 deletion gokart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gokart.build import build # noqa:F401
from gokart.build import WorkerSchedulerFactory, build # noqa:F401
from gokart.info import make_tree_info, tree_info # noqa:F401
from gokart.pandas_type_config import PandasTypeConfig # noqa:F401
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter # noqa:F401
Expand Down
38 changes: 36 additions & 2 deletions gokart/build.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from functools import partial
from logging import getLogger
from typing import Literal, Optional, TypeVar, cast, overload
from typing import Literal, Optional, Protocol, TypeVar, cast, overload

import backoff
import luigi
from luigi import rpc, scheduler

import gokart
from gokart import worker
from gokart.conflict_prevention_lock.task_lock import TaskLockException
from gokart.target import TargetOnKart
from gokart.task import TaskOnKart
Expand Down Expand Up @@ -43,6 +45,31 @@ def __init__(self):
self.flag: bool = False


class WorkerProtocol(Protocol):
"""Protocol for Worker.
This protocol is determined by luigi.worker.Worker.
"""

def add(self, task: TaskOnKart) -> bool: ...

def run(self) -> bool: ...

def __enter__(self) -> 'WorkerProtocol': ...

def __exit__(self, type, value, traceback) -> Literal[False]: ...


class WorkerSchedulerFactory:
def create_local_scheduler(self) -> scheduler.Scheduler:
return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False)

def create_remote_scheduler(self, url) -> rpc.RemoteScheduler:
return rpc.RemoteScheduler(url)

def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant=False) -> WorkerProtocol:
return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant)


def _get_output(task: TaskOnKart[T]) -> T:
output = task.output()
# FIXME: currently, nested output is not supported
Expand Down Expand Up @@ -106,6 +133,7 @@ def build(
"""
if reset_register:
_reset_register()

with LoggerConfig(level=log_level):
task_lock_exception_raised = TaskLockExceptionRaisedFlag()

Expand All @@ -119,7 +147,13 @@ def when_failure(task, exception):
)
def _build_task():
task_lock_exception_raised.flag = False
result = luigi.build([task], local_scheduler=True, detailed_summary=True, log_level=logging.getLevelName(log_level), **env_params)
result = luigi.build(
[task],
local_scheduler=True,
detailed_summary=True,
log_level=logging.getLevelName(log_level),
**env_params,
)
if task_lock_exception_raised.flag:
raise HasLockedTaskException()
if result.status == luigi.LuigiStatusCode.FAILED:
Expand Down
Loading

0 comments on commit 61fd81b

Please sign in to comment.