Skip to content

Commit

Permalink
Introduce raftify and RaftContext
Browse files Browse the repository at this point in the history
Merge

Add fragment

Merge

Merge

Merge

Merge

Merge

Merge

Update test

Merge

Merge

Merge

Handle RedisLock correctly

Remove mistakenly placed function

Fix wrong logging

Fix wrong type

Fix CI

Remove useless Raft RedisConnection

Improve leader not found error handling

Merge

Update dependencies

Update log_dir

Rename LogLevel

WIP

RaftCluster -> RaftFacade

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP

WIP
  • Loading branch information
jopemachine committed Nov 9, 2023
1 parent b31efcb commit 3ec9b9a
Show file tree
Hide file tree
Showing 23 changed files with 1,214 additions and 366 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
raft-*.mdb
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 1 addition & 0 deletions changes/1506.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Raft-based leader election process to manager group in HA condition in order to make their states consistent.
24 changes: 23 additions & 1 deletion configs/manager/halfstack.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ password = "develove"


[manager]
num-proc = 4
num-proc = 3
service-addr = { host = "0.0.0.0", port = 8081 }
#user = "nobody"
#group = "nobody"
Expand All @@ -33,6 +33,28 @@ hide-agents = true
# The order of agent selection.
agent-selection-resource-priority = ["cuda", "rocm", "tpu", "cpu", "mem"]

[raft]
heartbeat-tick = 3
election-tick = 10
log-dir = "./logs"
slog-level = "debug"
log-level = "debug"

[[raft.peers]]
host = "127.0.0.1"
port = 60151
node-id = 1

[[raft.peers]]
host = "127.0.0.1"
port = 60152
node-id = 2

[[raft.peers]]
host = "127.0.0.1"
port = 60153
node-id = 3

[docker-registry]
ssl-verify = false

Expand Down
509 changes: 328 additions & 181 deletions python.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,6 @@ types-tabulate

backend.ai-krunner-alpine==5.1.0
backend.ai-krunner-static-gnu==4.1.0

rraft-py==0.2.21
raftify==0.0.49
33 changes: 33 additions & 0 deletions run-backend-ai.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/sh

# Use current Shell as API Mode
tmux rename-window API-Mode

# Session Mode
tmux new-window
tmux rename-window Session-Mode

# Manager
tmux new-window
tmux rename-window manager

# Agent
tmux new-window
tmux rename-window agent

# Storage
tmux new-window
tmux rename-window storage

# Web UI
tmux new-window
tmux rename-window web

sleep 2

tmux send-keys -t manager './backend.ai mgr start-server --debug' Enter
tmux send-keys -t agent './backend.ai ag start-server --debug' Enter
tmux send-keys -t storage './py -m ai.backend.storage.server' Enter
tmux send-keys -t web './py -m ai.backend.web.server' Enter
tmux send-keys -t API-Mode 'source env-local-admin-api.sh' Enter
tmux send-keys -t Session-Mode 'source env-local-admin-session.sh' Enter
45 changes: 45 additions & 0 deletions scripts/print-raft-log-entries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import sys
import lmdb
import rraft
from raftify.deserializer import init_rraft_py_deserializer
from raftify.storage.lmdb import SNAPSHOT_KEY, LAST_INDEX_KEY, HARD_STATE_KEY, CONF_STATE_KEY


def main(argv):
init_rraft_py_deserializer()
idx = argv[1]
assert idx.isdigit(), "idx must be a number"

env = lmdb.open(f"{os.getcwd()}/logs/node-{idx}", max_dbs=2)

entries_db = env.open_db(b"entries")

print('---- Entries ----')
with env.begin(db=entries_db) as txn:
cursor = txn.cursor()
for key, value in cursor:
print(f"Key: {int(key.decode())}, Value: {rraft.Entry.decode(value)}")

metadata_db = env.open_db(b"meta")

print('---- Metadata ----')
with env.begin(db=metadata_db) as txn:
cursor = txn.cursor()
for key, value in cursor:
if key == SNAPSHOT_KEY:
print(f'Key: "snapshot", Value: "{rraft.Snapshot.decode(value)}"')
elif key == LAST_INDEX_KEY:
print(f'Key: "last_index", Value: "{int(value.decode())}"')
elif key == HARD_STATE_KEY:
print(f'Key: "hard_state", Value: "{rraft.HardState.decode(value)}"')
elif key == CONF_STATE_KEY:
print(f'Key: "conf_state", Value: "{rraft.ConfState.decode(value)}"')
else:
assert False, f"Unknown key: {key}"

env.close()


if __name__ == "__main__":
main(sys.argv)
74 changes: 73 additions & 1 deletion src/ai/backend/common/distributed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import abc
import asyncio
import logging
from typing import TYPE_CHECKING, Callable, Final

from aiomonitor.task import preserve_termination_log
from raftify import RaftNode

from .logging import BraceStyleAdapter

Expand All @@ -16,7 +18,77 @@
log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]


class GlobalTimer:
class AbstractGlobalTimer(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def generate_tick(self) -> None:
raise NotImplementedError

@abc.abstractmethod
async def join(self) -> None:
raise NotImplementedError

@abc.abstractmethod
async def leave(self) -> None:
raise NotImplementedError


class RaftGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
"""

_event_producer: Final[EventProducer]

def __init__(
self,
raft_node: RaftNode,
event_producer: EventProducer,
event_factory: Callable[[], AbstractEvent],
interval: float = 10.0,
initial_delay: float = 0.0,
) -> None:
self._event_producer = event_producer
self._event_factory = event_factory
self._stopped = False
self.interval = interval
self.initial_delay = initial_delay
self.raft_node = raft_node

async def generate_tick(self) -> None:
try:
await asyncio.sleep(self.initial_delay)
if self._stopped:
return
while True:
try:
if self._stopped:
return
if self.raft_node.is_leader():
await self._event_producer.produce_event(self._event_factory())
if self._stopped:
return
await asyncio.sleep(self.interval)
except asyncio.TimeoutError: # timeout raised from etcd lock
log.warn("timeout raised while trying to acquire lock. retrying...")
except asyncio.CancelledError:
pass

async def join(self) -> None:
self._tick_task = asyncio.create_task(self.generate_tick())

async def leave(self) -> None:
self._stopped = True
await asyncio.sleep(0)
if not self._tick_task.done():
try:
self._tick_task.cancel()
await self._tick_task
except asyncio.CancelledError:
pass


class DistributedLockGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
Expand Down
29 changes: 28 additions & 1 deletion src/ai/backend/manager/api/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, cast

import attrs
from raftify import RaftFacade, RaftNode

if TYPE_CHECKING:
from ai.backend.common.bgtask import BackgroundTaskManager
Expand All @@ -26,6 +27,31 @@ class BaseContext:
pass


class RaftClusterContext:
_cluster: Optional[RaftFacade] = None
bootstrap_done: bool
node_id_start: int

def __init__(self, bootstrap_done: bool = False, node_id_start: int = 1) -> None:
self.bootstrap_done = bootstrap_done
self.node_id_start = node_id_start

def use_raft(self) -> bool:
return self._cluster is not None

@property
def cluster(self) -> RaftFacade:
return cast(RaftFacade, self._cluster)

@cluster.setter
def cluster(self, rhs: RaftFacade) -> None:
self._cluster = rhs

@property
def raft_node(self) -> RaftNode:
return cast(RaftNode, cast(RaftFacade, self._cluster).raft_node)


@attrs.define(slots=True, auto_attribs=True, init=False)
class RootContext(BaseContext):
pidx: int
Expand Down Expand Up @@ -53,3 +79,4 @@ class RootContext(BaseContext):
error_monitor: ErrorPluginContext
stats_monitor: StatsPluginContext
background_task_manager: BackgroundTaskManager
raft_ctx: RaftClusterContext
6 changes: 3 additions & 3 deletions src/ai/backend/manager/api/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dateutil.relativedelta import relativedelta

from ai.backend.common import validators as tx
from ai.backend.common.distributed import GlobalTimer
from ai.backend.common.distributed import DistributedLockGlobalTimer
from ai.backend.common.events import AbstractEvent, EmptyEventArgs, EventHandler
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import AgentId, LogSeverity
Expand Down Expand Up @@ -238,7 +238,7 @@ async def log_cleanup_task(app: web.Application, src: AgentId, event: DoLogClean

@attrs.define(slots=True, auto_attribs=True, init=False)
class PrivateContext:
log_cleanup_timer: GlobalTimer
log_cleanup_timer: DistributedLockGlobalTimer
log_cleanup_timer_evh: EventHandler[web.Application, DoLogCleanupEvent]


Expand All @@ -250,7 +250,7 @@ async def init(app: web.Application) -> None:
app,
log_cleanup_task,
)
app_ctx.log_cleanup_timer = GlobalTimer(
app_ctx.log_cleanup_timer = DistributedLockGlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
Expand Down
50 changes: 50 additions & 0 deletions src/ai/backend/manager/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import json
import logging
import pathlib
import subprocess
Expand All @@ -10,7 +11,11 @@

import click
from more_itertools import chunked
from raftify.peers import Peer, Peers
from raftify.raft_client import RaftClient
from raftify.utils import SocketAddr
from setproctitle import setproctitle
from tabulate import tabulate

from ai.backend.cli.types import ExitCode
from ai.backend.common import redis_helper as redis_helper
Expand Down Expand Up @@ -325,6 +330,51 @@ async def _clear_terminated_sessions():
asyncio.run(_clear_terminated_sessions())


async def inspect_node_status(cli_ctx: CLIContext) -> None:
# TODO: Add handling methods for nodes joined after initial_peers
raft_configs = cli_ctx.local_config["raft"]
table = []
headers = ["ENDPOINT", "NODE ID", "IS LEADER", "RAFT TERM", "RAFT APPLIED INDEX"]

if raft_configs is not None:
peers = Peers(
{
int(peer_config["node-id"]): Peer(
addr=SocketAddr(peer_config["host"], peer_config["port"])
)
for peer_config in raft_configs.pop("peers")
}
)

for peer in peers:
raft_client = RaftClient(peer.addr)
node_info = json.loads(await raft_client.debug_node(timeout=5.0))
is_leader = node_info["node_id"] == node_info["current_leader_id"]
table.append(
[
peer.addr,
node_info["node_id"],
is_leader,
node_info["raft"]["term"],
node_info["raft_log"]["applied"],
]
)

table = [headers, *sorted(table, key=lambda x: str(x[0]))]
print(
tabulate(table, headers="firstrow", tablefmt="grid", stralign="center", numalign="center")
)


@main.command()
@click.pass_obj
def status(cli_ctx: CLIContext) -> None:
"""
Collect and print each manager process's status.
"""
asyncio.run(inspect_node_status(cli_ctx))


@main.group(cls=LazyGroup, import_name="ai.backend.manager.cli.dbschema:cli")
def schema():
"""Command set for managing the database schema."""
Expand Down
Loading

0 comments on commit 3ec9b9a

Please sign in to comment.