Skip to content

Commit

Permalink
Introduce raftify and RaftContext
Browse files Browse the repository at this point in the history
Merge

Add fragment

Merge

Merge

Merge

Merge

Merge

Merge

Update test

Merge

Merge

Merge

Handle RedisLock correctly

Remove mistakenly placed function

Fix wrong logging

Fix wrong type

Fix CI

Remove useless Raft RedisConnection

Improve leader not found error handling

Merge

Update dependencies

Update log_dir

Rename LogLevel

WIP

RaftCluster -> RaftFacade

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP

Try to fix CI

Apply RaftGlobalTimer to logs.py's init

Delete run-backend-ai.sh

Update gitignore

Delete scripts/print-raft-log-entries.py

Improve inspect_node_status

WIP
  • Loading branch information
jopemachine committed Nov 29, 2023
1 parent b31efcb commit 7d4d108
Show file tree
Hide file tree
Showing 21 changed files with 1,321 additions and 475 deletions.
1 change: 1 addition & 0 deletions changes/1506.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Raft-based leader election process to manager group in HA condition in order to make their states consistent.
42 changes: 42 additions & 0 deletions client_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
import pickle

from rraft import ConfChange, ConfChangeType

from raftify.log_entry.set_command import SetCommand
from raftify.raft_client import RaftClient
from raftify.utils import SocketAddr


async def main() -> None:
"""
A simple set of commands to test and show usage of RaftClient.
Please bootstrap the Raft cluster before running this script.
"""

print("---Message propose---")
await RaftClient("192.168.0.37:60151").propose(SetCommand("1", "B").encode())

# print("---Message propose rerouting---")
# await RaftClient("127.0.0.1:60062").propose(SetCommand("2", "A").encode())

# print("---Debug node result---", await RaftClient("127.0.0.1:60061").debug_node())

# print(
# "---Debug peers---",
# pickle.loads((await RaftClient("127.0.0.1:60061").get_peers()).peers),
# )

# print("---Make Confchange manually---")
# addr = SocketAddr.from_str("127.0.0.1:60062")

# conf_change = ConfChange.default()
# conf_change.set_node_id(2)
# conf_change.set_context(pickle.dumps([addr]))
# conf_change.set_change_type(ConfChangeType.RemoveNode)

# await RaftClient(addr).change_config(conf_change)


if __name__ == "__main__":
asyncio.run(main())
24 changes: 23 additions & 1 deletion configs/manager/halfstack.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ password = "develove"


[manager]
num-proc = 4
num-proc = 3
service-addr = { host = "0.0.0.0", port = 8081 }
#user = "nobody"
#group = "nobody"
Expand All @@ -33,6 +33,28 @@ hide-agents = true
# The order of agent selection.
agent-selection-resource-priority = ["cuda", "rocm", "tpu", "cpu", "mem"]

[raft]
heartbeat-tick = 3
election-tick = 10
log-dir = "./logs"
slog-level = "debug"
log-level = "debug"

[[raft.peers]]
host = "127.0.0.1"
port = 60151
node-id = 1

[[raft.peers]]
host = "127.0.0.1"
port = 60152
node-id = 2

[[raft.peers]]
host = "127.0.0.1"
port = 60153
node-id = 3

[docker-registry]
ssl-verify = false

Expand Down
717 changes: 434 additions & 283 deletions python.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,6 @@ types-tabulate

backend.ai-krunner-alpine==5.1.0
backend.ai-krunner-static-gnu==4.1.0

rraft-py==0.2.27
raftify==0.0.62
74 changes: 73 additions & 1 deletion src/ai/backend/common/distributed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import abc
import asyncio
import logging
from typing import TYPE_CHECKING, Callable, Final

from aiomonitor.task import preserve_termination_log
from raftify import RaftNode

from .logging import BraceStyleAdapter

Expand All @@ -16,7 +18,77 @@
log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]


class GlobalTimer:
class AbstractGlobalTimer(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def generate_tick(self) -> None:
raise NotImplementedError

@abc.abstractmethod
async def join(self) -> None:
raise NotImplementedError

@abc.abstractmethod
async def leave(self) -> None:
raise NotImplementedError


class RaftGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
"""

_event_producer: Final[EventProducer]

def __init__(
self,
raft_node: RaftNode,
event_producer: EventProducer,
event_factory: Callable[[], AbstractEvent],
interval: float = 10.0,
initial_delay: float = 0.0,
) -> None:
self._event_producer = event_producer
self._event_factory = event_factory
self._stopped = False
self.interval = interval
self.initial_delay = initial_delay
self.raft_node = raft_node

async def generate_tick(self) -> None:
try:
await asyncio.sleep(self.initial_delay)
if self._stopped:
return
while True:
try:
if self._stopped:
return
if self.raft_node.is_leader():
await self._event_producer.produce_event(self._event_factory())
if self._stopped:
return
await asyncio.sleep(self.interval)
except asyncio.TimeoutError: # timeout raised from etcd lock
log.warn("timeout raised while trying to acquire lock. retrying...")
except asyncio.CancelledError:
pass

async def join(self) -> None:
self._tick_task = asyncio.create_task(self.generate_tick())

async def leave(self) -> None:
self._stopped = True
await asyncio.sleep(0)
if not self._tick_task.done():
try:
self._tick_task.cancel()
await self._tick_task
except asyncio.CancelledError:
pass


class DistributedLockGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
Expand Down
29 changes: 28 additions & 1 deletion src/ai/backend/manager/api/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, cast

import attrs
from raftify import RaftFacade, RaftNode

if TYPE_CHECKING:
from ai.backend.common.bgtask import BackgroundTaskManager
Expand All @@ -26,6 +27,31 @@ class BaseContext:
pass


class RaftClusterContext:
_cluster: Optional[RaftFacade] = None
bootstrap_done: bool
node_id_start: int

def __init__(self, bootstrap_done: bool = False, node_id_start: int = 1) -> None:
self.bootstrap_done = bootstrap_done
self.node_id_start = node_id_start

def use_raft(self) -> bool:
return self._cluster is not None

@property
def cluster(self) -> RaftFacade:
return cast(RaftFacade, self._cluster)

@cluster.setter
def cluster(self, rhs: RaftFacade) -> None:
self._cluster = rhs

@property
def raft_node(self) -> RaftNode:
return cast(RaftNode, cast(RaftFacade, self._cluster).raft_node)


@attrs.define(slots=True, auto_attribs=True, init=False)
class RootContext(BaseContext):
pidx: int
Expand Down Expand Up @@ -53,3 +79,4 @@ class RootContext(BaseContext):
error_monitor: ErrorPluginContext
stats_monitor: StatsPluginContext
background_task_manager: BackgroundTaskManager
raft_ctx: RaftClusterContext
34 changes: 24 additions & 10 deletions src/ai/backend/manager/api/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from dateutil.relativedelta import relativedelta

from ai.backend.common import validators as tx
from ai.backend.common.distributed import GlobalTimer
from ai.backend.common.distributed import (
AbstractGlobalTimer,
DistributedLockGlobalTimer,
RaftGlobalTimer,
)
from ai.backend.common.events import AbstractEvent, EmptyEventArgs, EventHandler
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import AgentId, LogSeverity
Expand Down Expand Up @@ -238,7 +242,7 @@ async def log_cleanup_task(app: web.Application, src: AgentId, event: DoLogClean

@attrs.define(slots=True, auto_attribs=True, init=False)
class PrivateContext:
log_cleanup_timer: GlobalTimer
log_cleanup_timer: AbstractGlobalTimer
log_cleanup_timer_evh: EventHandler[web.Application, DoLogCleanupEvent]


Expand All @@ -250,14 +254,24 @@ async def init(app: web.Application) -> None:
app,
log_cleanup_task,
)
app_ctx.log_cleanup_timer = GlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
task_name="log_cleanup_task",
)

if root_ctx.raft_ctx.use_raft():
app_ctx.log_cleanup_timer = RaftGlobalTimer(
root_ctx.raft_ctx.raft_node,
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
)
else:
app_ctx.log_cleanup_timer = DistributedLockGlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
task_name="log_cleanup_task",
)
await app_ctx.log_cleanup_timer.join()


Expand Down
66 changes: 66 additions & 0 deletions src/ai/backend/manager/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from __future__ import annotations

import asyncio
import json
import logging
import pathlib
import pickle
import subprocess
import sys
from datetime import datetime
from functools import partial

import click
from more_itertools import chunked
from raftify.peers import Peer, Peers
from raftify.raft_client import RaftClient
from raftify.utils import SocketAddr
from setproctitle import setproctitle
from tabulate import tabulate

from ai.backend.cli.types import ExitCode
from ai.backend.common import redis_helper as redis_helper
Expand Down Expand Up @@ -325,6 +331,66 @@ async def _clear_terminated_sessions():
asyncio.run(_clear_terminated_sessions())


async def inspect_node_status(cli_ctx: CLIContext) -> None:
raft_configs = cli_ctx.local_config["raft"]
table = []
headers = ["ENDPOINT", "NODE ID", "IS LEADER", "RAFT TERM", "RAFT APPLIED INDEX"]

if raft_configs is not None:
initial_peers = Peers(
{
int(peer_config["node-id"]): Peer(
addr=SocketAddr(peer_config["host"], peer_config["port"])
)
for peer_config in raft_configs.pop("peers")
}
)

peers: Peers | None = None
for peer in initial_peers.values():
raft_client = RaftClient(peer.addr)
try:
resp = await raft_client.get_peers()
peers = pickle.loads(resp.peers)
except Exception:
continue

if not peers:
print("No peers are available!")
return

for peer in peers.values():
raft_client = RaftClient(peer.addr)

resp = await raft_client.debug_node()

node_info = json.loads(resp.result)
is_leader = node_info["node_id"] == node_info["current_leader_id"]
table.append(
[
peer.addr,
node_info["node_id"],
is_leader,
node_info["raft"]["term"],
node_info["raft_log"]["applied"],
]
)

table = [headers, *sorted(table, key=lambda x: str(x[0]))]
print(
tabulate(table, headers="firstrow", tablefmt="grid", stralign="center", numalign="center")
)


@main.command()
@click.pass_obj
def status(cli_ctx: CLIContext) -> None:
"""
Collect and print each manager process's status.
"""
asyncio.run(inspect_node_status(cli_ctx))


@main.group(cls=LazyGroup, import_name="ai.backend.manager.cli.dbschema:cli")
def schema():
"""Command set for managing the database schema."""
Expand Down
Loading

0 comments on commit 7d4d108

Please sign in to comment.