Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add function loads batch idle check report (#2715) #2718

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 88 additions & 17 deletions src/ai/backend/manager/idle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import math
from abc import ABCMeta, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import (
Mapping,
MutableMapping,
Sequence,
)
from datetime import datetime, timedelta
from decimal import Decimal
from typing import (
Expand All @@ -15,13 +20,11 @@
DefaultDict,
Final,
List,
Mapping,
MutableMapping,
NamedTuple,
Optional,
Sequence,
Set,
Type,
TypedDict,
Union,
cast,
)
Expand All @@ -30,6 +33,7 @@
import sqlalchemy as sa
import trafaret as t
from aiotools import TaskGroupError
from redis.asyncio import Redis
from sqlalchemy.engine import Row

import ai.backend.common.validators as tx
Expand Down Expand Up @@ -168,6 +172,12 @@ class RemainingTimeType(enum.StrEnum):
EXPIRE_AFTER = "expire_after"


class ReportInfo(TypedDict):
remaining: float | None
remaining_time_type: str
extra: dict[str, Any] | None


class IdleCheckerHost:
check_interval: ClassVar[float] = DEFAULT_CHECK_INTERVAL

Expand Down Expand Up @@ -343,11 +353,55 @@ async def get_idle_check_report(
checker.name: {
"remaining": await checker.get_checker_result(self._redis_live, session_id),
"remaining_time_type": checker.remaining_time_type.value,
"extra": await checker.get_extra_info(session_id),
"extra": await checker.get_extra_info(self._redis_live, session_id),
}
for checker in self._checkers
}

async def get_batch_idle_check_report(
self,
session_ids: Sequence[SessionId],
) -> dict[SessionId, dict[str, ReportInfo]]:
class _ReportDataType(enum.StrEnum):
REMAINING_TIME = "remaining"
EXTRA_INFO = "extra"

key_session_report_map: dict[str, tuple[SessionId, BaseIdleChecker, _ReportDataType]] = {}
for sid in session_ids:
for checker in self._checkers:
_report_key = checker.get_report_key(sid)
key_session_report_map[_report_key] = (sid, checker, _ReportDataType.REMAINING_TIME)
if (_extra_key := checker.get_extra_info_key(sid)) is not None:
key_session_report_map[_extra_key] = (sid, checker, _ReportDataType.EXTRA_INFO)

key_list = list(key_session_report_map.keys())

async def _pipe_builder(r: Redis):
pipe = r.pipeline()
for key in key_list:
await pipe.get(key)
return pipe

ret: dict[SessionId, dict[str, ReportInfo]] = {}
for key, report in zip(
key_list, await redis_helper.execute(self._redis_live, _pipe_builder)
):
session_id, checker, report_type = key_session_report_map[key]
if session_id not in ret:
ret[session_id] = {}
if checker.name not in ret[session_id]:
ret[session_id][checker.name] = ReportInfo(
remaining=None,
remaining_time_type=checker.remaining_time_type.value,
extra=None,
)
raw_report = cast(bytes | None, report)
if raw_report is None:
continue

ret[session_id][checker.name][report_type.value] = msgpack.unpackb(raw_report)
return ret


class AbstractIdleCheckReporter(metaclass=ABCMeta):
remaining_time_type: RemainingTimeType
Expand Down Expand Up @@ -383,8 +437,14 @@ async def update_app_streaming_status(
def get_report_key(cls, session_id: SessionId) -> str:
return f"session.{session_id}.{cls.name}.report"

@classmethod
def get_extra_info_key(cls, session_id: SessionId) -> str | None:
return None

@abstractmethod
async def get_extra_info(self, session_id: SessionId) -> Optional[dict[str, Any]]:
async def get_extra_info(
self, redis_obj: RedisConnectionInfo, session_id: SessionId
) -> Optional[dict[str, Any]]:
return None

@abstractmethod
Expand Down Expand Up @@ -456,7 +516,9 @@ async def populate_config(self, raw_config: Mapping[str, Any]) -> None:
f"NewUserGracePeriodChecker: default period = {_grace_period} seconds",
)

async def get_extra_info(self, session_id: SessionId) -> Optional[dict[str, Any]]:
async def get_extra_info(
self, redis_obj: RedisConnectionInfo, session_id: SessionId
) -> Optional[dict[str, Any]]:
return None

async def del_remaining_time_report(
Expand Down Expand Up @@ -614,7 +676,9 @@ async def _execution_exited_cb(
) -> None:
await self._update_timeout(event.session_id)

async def get_extra_info(self, session_id: SessionId) -> Optional[dict[str, Any]]:
async def get_extra_info(
self, redis_obj: RedisConnectionInfo, session_id: SessionId
) -> Optional[dict[str, Any]]:
return None

async def check_idleness(
Expand Down Expand Up @@ -694,7 +758,9 @@ class SessionLifetimeChecker(BaseIdleChecker):
async def populate_config(self, raw_config: Mapping[str, Any]) -> None:
pass

async def get_extra_info(self, session_id: SessionId) -> Optional[dict[str, Any]]:
async def get_extra_info(
self, redis_obj: RedisConnectionInfo, session_id: SessionId
) -> Optional[dict[str, Any]]:
return None

async def check_idleness(
Expand Down Expand Up @@ -805,15 +871,18 @@ async def populate_config(self, raw_config: Mapping[str, Any]) -> None:
f"time-window({self.time_window.total_seconds()}s)"
)

def get_extra_info_key(self, session_id: SessionId) -> str:
return f"session.{session_id}.{self.extra_info_key}"
@classmethod
def get_extra_info_key(cls, session_id: SessionId) -> str | None:
return f"session.{session_id}.{cls.extra_info_key}"

async def get_extra_info(self, session_id: SessionId) -> Optional[dict[str, Any]]:
async def get_extra_info(
self, redis_obj: RedisConnectionInfo, session_id: SessionId
) -> Optional[dict[str, Any]]:
key = self.get_extra_info_key(session_id)
assert key is not None
data = await redis_helper.execute(
self._redis_live,
lambda r: r.get(
self.get_extra_info_key(session_id),
),
redis_obj,
lambda r: r.get(key),
)
return msgpack.unpackb(data) if data is not None else None

Expand Down Expand Up @@ -995,10 +1064,12 @@ def _avg(util_list: list[float]) -> float:
"thresholds_check_operator": self.thresholds_check_operator.value,
"resources": util_avg_thresholds.to_dict(),
}
_key = self.get_extra_info_key(session_id)
assert _key is not None
await redis_helper.execute(
self._redis_live,
redis_obj,
lambda r: r.set(
self.get_extra_info_key(session_id),
_key,
msgpack.packb(report),
ex=int(DEFAULT_CHECK_INTERVAL) * 10,
),
Expand Down
12 changes: 9 additions & 3 deletions tests/manager/test_idle_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,9 @@ async def utilization_idle_checker(
remaining = await utilization_idle_checker.get_checker_result(
checker_host._redis_live, session_id
)
util_info = await utilization_idle_checker.get_extra_info(session_id)
util_info = await utilization_idle_checker.get_extra_info(
checker_host._redis_live, session_id
)
finally:
await checker_host.shutdown()

Expand Down Expand Up @@ -862,7 +864,9 @@ async def utilization_idle_checker(
remaining = await utilization_idle_checker.get_checker_result(
checker_host._redis_live, session_id
)
util_info = await utilization_idle_checker.get_extra_info(session_id)
util_info = await utilization_idle_checker.get_extra_info(
checker_host._redis_live, session_id
)
finally:
await checker_host.shutdown()

Expand Down Expand Up @@ -944,7 +948,9 @@ async def utilization_idle_checker(
remaining = await utilization_idle_checker.get_checker_result(
checker_host._redis_live, session_id
)
util_info = await utilization_idle_checker.get_extra_info(session_id)
util_info = await utilization_idle_checker.get_extra_info(
checker_host._redis_live, session_id
)
finally:
await checker_host.shutdown()

Expand Down
Loading