From 0dc6d0c952c698df5335171099e8c210041e6749 Mon Sep 17 00:00:00 2001 From: Kyujin Cho Date: Tue, 25 Jul 2023 00:22:35 +0900 Subject: [PATCH] feat: refactor `SessionRow`'s ORM resolvers, advertise additional informations to WSProxy (#1396) Co-authored-by: Jonghyun Park Co-authored-by: Joongi Kim --- changes/1396.feature.md | 1 + src/ai/backend/manager/api/session.py | 160 ++++++++++++------ src/ai/backend/manager/api/stream.py | 30 ++-- src/ai/backend/manager/models/session.py | 198 +++++++++-------------- src/ai/backend/manager/registry.py | 83 ++++++++-- 5 files changed, 283 insertions(+), 189 deletions(-) create mode 100644 changes/1396.feature.md diff --git a/changes/1396.feature.md b/changes/1396.feature.md new file mode 100644 index 0000000000..a7c47efa3f --- /dev/null +++ b/changes/1396.feature.md @@ -0,0 +1 @@ +Refactor `SessionRow` ORM queries by introducing `KernelLoadingStrategy` to generalize and reuse `SessionRow.get_session()` diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index d6a2209825..38812cd260 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -40,6 +40,7 @@ from aiohttp import hdrs, web from dateutil.tz import tzutc from redis.asyncio import Redis +from sqlalchemy.orm import noload, selectinload from sqlalchemy.sql.expression import null, true if TYPE_CHECKING: @@ -66,6 +67,7 @@ from ..models import ( AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, DEAD_SESSION_STATUSES, + KernelLoadingStrategy, KernelRole, SessionRow, SessionStatus, @@ -763,8 +765,14 @@ async def start_service(request: web.Request, params: Mapping[str, Any]) -> web. async with root_ctx.db.begin_readonly_session() as db_sess: session = await asyncio.shield( app_ctx.database_ptask_group.create_task( - SessionRow.get_session_with_main_kernel( - session_name, access_key, db_session=db_sess + SessionRow.get_session( + db_sess, + session_name, + access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, + eager_loading_op=[ + selectinload(SessionRow.routing).options(noload("*")), + ], ), ) ) @@ -839,14 +847,27 @@ async def start_service(request: web.Request, params: Mapping[str, Any]) -> web. if result["status"] == "failed": raise InternalServerError("Failed to launch the app service", extra_data=result["error"]) - async with aiohttp.ClientSession() as session: - async with session.post( + body = { + "login_session_token": params["login_session_token"], + "kernel_host": kernel_host, + "kernel_port": host_port, + "session": { + "id": str(session.id), + "user_uuid": str(session.user_uuid), + "group_id": str(session.group_id), + "access_key": session.access_key, + "domain_name": session.domain_name, + }, + } + if session.routing: + body["endpoint"] = { + "id": str(session.routing.endpoint), + } + + async with aiohttp.ClientSession() as req: + async with req.post( f"{wsproxy_addr}/v2/conf", - json={ - "login_session_token": params["login_session_token"], - "kernel_host": kernel_host, - "kernel_port": host_port, - }, + json=body, ) as resp: token_json = await resp.json() return web.json_response( @@ -880,8 +901,11 @@ async def get_commit_status(request: web.Request, params: Mapping[str, Any]) -> ) try: async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) status_info = await root_ctx.registry.get_commit_status(session) except BackendError: @@ -911,8 +935,11 @@ async def get_abusing_report(request: web.Request, params: Mapping[str, Any]) -> ) try: async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) kernel = session.main_kernel report = await root_ctx.registry.get_abusing_report(kernel.id) @@ -974,8 +1001,11 @@ async def commit_session(request: web.Request, params: Mapping[str, Any]) -> web ) try: async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) resp: Mapping[str, Any] = await asyncio.shield( @@ -1081,11 +1111,11 @@ async def rename_session(request: web.Request, params: Any) -> web.Response: ) async with root_ctx.db.begin_session() as db_sess: compute_session = await SessionRow.get_session( + db_sess, session_name, owner_access_key, allow_stale=True, for_update=True, - db_session=db_sess, ) if compute_session.status != SessionStatus.RUNNING: raise InvalidAPIParameters("Can't change name of not running session") @@ -1141,8 +1171,11 @@ async def destroy(request: web.Request, params: Any) -> web.Response: session_name, ] sessions = [ - await SessionRow.get_session_with_kernels( - name_or_id, owner_access_key, db_session=db_sess + await SessionRow.get_session( + db_sess, + name_or_id, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS, ) for name_or_id in target_session_references ] @@ -1164,8 +1197,11 @@ async def destroy(request: web.Request, params: Any) -> web.Response: return web.json_response(last_stats, status=200) else: async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_kernels( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS, ) last_stat = await root_ctx.registry.destroy_session( session, @@ -1231,8 +1267,11 @@ async def get_direct_access_info(request: web.Request) -> web.Response: _, owner_access_key = await get_access_key_scopes(request) async with root_ctx.db.begin_session() as db_sess: - sess = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + sess = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) kernel_role: KernelRole = sess.main_kernel.role resp = {} @@ -1263,8 +1302,11 @@ async def get_info(request: web.Request) -> web.Response: log.info("GET_INFO (ak:{0}/{1}, s:{2})", requester_access_key, owner_access_key, session_name) try: async with root_ctx.db.begin_session() as db_sess: - sess = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + sess = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await root_ctx.registry.increment_session_usage(sess) resp["domainName"] = sess.domain_name @@ -1318,8 +1360,11 @@ async def restart(request: web.Request) -> web.Response: requester_access_key, owner_access_key = await get_access_key_scopes(request) log.info("RESTART (ak:{0}/{1}, s:{2})", requester_access_key, owner_access_key, session_name) async with root_ctx.db.begin_session() as db_sess: - session = await SessionRow.get_session_with_kernels( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS, ) try: await root_ctx.registry.increment_session_usage(session) @@ -1348,8 +1393,11 @@ async def execute(request: web.Request) -> web.Response: log.warning("EXECUTE: invalid/missing parameters") raise InvalidAPIParameters async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: await root_ctx.registry.increment_session_usage(session) @@ -1440,8 +1488,11 @@ async def interrupt(request: web.Request) -> web.Response: requester_access_key, owner_access_key = await get_access_key_scopes(request) log.info("INTERRUPT(ak:{0}/{1}, s:{2})", requester_access_key, owner_access_key, session_name) async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: await root_ctx.registry.increment_session_usage(session) @@ -1472,8 +1523,11 @@ async def complete(request: web.Request) -> web.Response: except json.decoder.JSONDecodeError: raise InvalidAPIParameters async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: code = params.get("code", "") @@ -1509,8 +1563,11 @@ async def shutdown_service(request: web.Request, params: Any) -> web.Response: ) service_name = params.get("service_name") async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: await root_ctx.registry.shutdown_service(session, service_name) @@ -1542,7 +1599,7 @@ async def _find_dependent_sessions(session_id: uuid.UUID) -> Set[uuid.UUID]: return dependent_sessions root_session = await SessionRow.get_session( - root_session_name_or_id, access_key=access_key, db_session=db_session + db_session, root_session_name_or_id, access_key=access_key ) return await _find_dependent_sessions(cast(uuid.UUID, root_session.id)) @@ -1559,8 +1616,11 @@ async def upload_files(request: web.Request) -> web.Response: "UPLOAD_FILE (ak:{0}/{1}, s:{2})", requester_access_key, owner_access_key, session_name ) async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: await root_ctx.registry.increment_session_usage(session) @@ -1615,8 +1675,11 @@ async def download_files(request: web.Request, params: Any) -> web.Response: files[0], ) async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: assert len(files) <= 5, "Too many files" @@ -1674,8 +1737,11 @@ async def download_single(request: web.Request, params: Any) -> web.Response: ) try: async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await root_ctx.registry.increment_session_usage(session) result = await root_ctx.registry.download_single(session, owner_access_key, file) @@ -1710,8 +1776,11 @@ async def list_files(request: web.Request) -> web.Response: path, ) async with root_ctx.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - session_name, owner_access_key, db_session=db_sess + session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) except (asyncio.TimeoutError, AssertionError, json.decoder.JSONDecodeError) as e: log.warning("LIST_FILES: invalid/missing parameters, {0!r}", e) @@ -1752,11 +1821,12 @@ async def get_container_logs(request: web.Request, params: Any) -> web.Response: ) resp = {"result": {"logs": ""}} async with root_ctx.db.begin_readonly_session() as db_sess: - compute_session = await SessionRow.get_session_with_main_kernel( + compute_session = await SessionRow.get_session( + db_sess, session_name, owner_access_key, allow_stale=True, - db_session=db_sess, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) if ( compute_session.status in DEAD_SESSION_STATUSES diff --git a/src/ai/backend/manager/api/stream.py b/src/ai/backend/manager/api/stream.py index 704d53d1f0..7f4bfb81d2 100644 --- a/src/ai/backend/manager/api/stream.py +++ b/src/ai/backend/manager/api/stream.py @@ -50,7 +50,7 @@ from ai.backend.manager.idle import AppStreamingStatus from ..defs import DEFAULT_ROLE -from ..models import KernelRow, SessionRow +from ..models import KernelLoadingStrategy, KernelRow, SessionRow from .auth import auth_required from .exceptions import ( AppNotFound, @@ -86,8 +86,11 @@ async def stream_pty(defer, request: web.Request) -> web.StreamResponse: async with root_ctx.db.begin_readonly_session() as db_sess: session = await asyncio.shield( database_ptask_group.create_task( - SessionRow.get_session_with_kernels( - session_name, access_key, only_main_kern=True, db_session=db_sess + SessionRow.get_session( + db_sess, + session_name, + access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) ), ) @@ -301,8 +304,11 @@ async def stream_execute(defer, request: web.Request) -> web.StreamResponse: async with root_ctx.db.begin_readonly_session() as db_sess: session: SessionRow = await asyncio.shield( database_ptask_group.create_task( - SessionRow.get_session_with_kernels( - session_name, access_key, only_main_kern=True, db_session=db_sess + SessionRow.get_session( + db_sess, + session_name, + access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ), # noqa ), ) @@ -448,8 +454,11 @@ async def stream_proxy( async with root_ctx.db.begin_readonly_session() as db_sess: session = await asyncio.shield( database_ptask_group.create_task( - SessionRow.get_session_with_kernels( - session_name, access_key, only_main_kern=True, db_session=db_sess + SessionRow.get_session( + db_sess, + session_name, + access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ), ) ) @@ -632,8 +641,11 @@ async def get_stream_apps(request: web.Request) -> web.Response: access_key = request["keypair"]["access_key"] root_ctx: RootContext = request.app["_root.context"] async with root_ctx.db.begin_readonly_session() as db_sess: - compute_session = await SessionRow.get_session_with_kernels( - session_name, access_key, only_main_kern=True, db_session=db_sess + compute_session = await SessionRow.get_session( + db_sess, + session_name, + access_key, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) service_ports = compute_session.main_kernel.service_ports if service_ports is None: diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index 575a70656d..c78f9c1ccb 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -97,6 +97,7 @@ "ComputeSessionList", "InferenceSession", "InferenceSessionList", + "KernelLoadingStrategy", ) @@ -534,6 +535,12 @@ class SessionOp(str, enum.Enum): GET_AGENT_LOGS = "get_logs_from_agent" +class KernelLoadingStrategy(str, enum.Enum): + ALL_KERNELS = "all" + MAIN_KERNEL_ONLY = "main" + NONE = "none" + + class SessionRow(Base): __tablename__ = "sessions" id = SessionIDColumn() @@ -827,7 +834,9 @@ async def _update() -> None: await execute_with_retry(_update) + @classmethod async def set_session_result( + cls, db: ExtendedAsyncSAEngine, session_id: SessionId, success: bool, @@ -881,9 +890,6 @@ async def match_sessions( ] try: session_id = UUID(str(session_reference)) - except ValueError: - pass - else: # Fetch id-based query first query_list = [ aiotools.apartial( @@ -902,6 +908,8 @@ async def match_sessions( ), *query_list, ] + except ValueError: + pass for fetch_func in query_list: rows = await fetch_func( @@ -920,31 +928,60 @@ async def match_sessions( @classmethod async def get_session( cls, + db_session: SASession, session_name_or_id: Union[str, UUID], access_key: Optional[AccessKey] = None, *, allow_stale: bool = False, for_update: bool = False, - db_session: SASession, + kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.NONE, + eager_loading_op: list[Any] = [], ) -> SessionRow: """ Retrieve the session information by session's UUID, or session's name paired with access_key. This will return the information of the session and the sibling kernel(s). - :param session_name_or_id: session's id or session's name. + :param db_session: Database connection to use when fetching row. + :param session_name_or_id: Name or ID (UUID) of session to look up. :param access_key: Access key used to create session. - :param allow_stale: If True, filter "inactive" sessions as well as "active" ones. - If False, filter "active" sessions only. - :param for_update: Apply for_update during select query. - :param db_session: Database connection for reuse. + :param allow_stale: If set to True, filter "inactive" sessions as well as "active" ones. + Otherwise filter "active" sessions only. + :param for_update: Apply for_update during executing select query. + :param kernel_loading_strategy: Determines JOIN strategy of `kernels` relation when fetching session rows. + :param eager_loading_op: Extra loading operators to be passed directly to `match_sessions()` API. """ + match kernel_loading_strategy: + case KernelLoadingStrategy.ALL_KERNELS: + eager_loading_op.extend( + [ + noload("*"), + selectinload(SessionRow.kernels).options( + noload("*"), + selectinload(KernelRow.agent_row).noload("*"), + ), + ] + ) + case KernelLoadingStrategy.MAIN_KERNEL_ONLY: + kernel_rel = SessionRow.kernels + kernel_rel.and_(KernelRow.cluster_role == DEFAULT_ROLE) + eager_loading_op.extend( + [ + noload("*"), + selectinload(kernel_rel).options( + noload("*"), + selectinload(KernelRow.agent_row).noload("*"), + ), + ] + ) + session_list = await cls.match_sessions( db_session, session_name_or_id, access_key, allow_stale=allow_stale, for_update=for_update, + eager_loading_op=eager_loading_op, ) if not session_list: raise SessionNotFound(f"Session (id={session_name_or_id}) does not exist.") @@ -962,90 +999,54 @@ async def get_session( return session_list[0] @classmethod - async def get_session_with_kernels( + async def list_sessions( cls, - session_name_or_id: str | UUID, - access_key: Optional[AccessKey] = None, - *, - allow_stale: bool = False, - for_update: bool = False, - only_main_kern: bool = False, db_session: SASession, - ) -> SessionRow: - kernel_rel = SessionRow.kernels - if only_main_kern: - kernel_rel.and_(KernelRow.cluster_role == DEFAULT_ROLE) - kernel_loading_op = ( - noload("*"), - selectinload(kernel_rel).options( - noload("*"), - selectinload(KernelRow.agent_row).noload("*"), - ), - ) - session_list = await cls.match_sessions( - db_session, - session_name_or_id, - access_key, - allow_stale=allow_stale, - for_update=for_update, - eager_loading_op=kernel_loading_op, - ) - try: - return session_list[0] - except IndexError: - raise SessionNotFound(f"Session (id={session_name_or_id}) does not exist.") - - @classmethod - async def list_sessions_with_main_kernels( - cls, session_ids: list[UUID], access_key: Optional[AccessKey] = None, *, allow_stale: bool = False, for_update: bool = False, - db_session: SASession, + kernel_loading_strategy=KernelLoadingStrategy.NONE, + eager_loading_op: list[Any] = [], ) -> Iterable[SessionRow]: - kernel_rel = SessionRow.kernels - kernel_rel.and_(KernelRow.cluster_role == DEFAULT_ROLE) - kernel_loading_op = ( - noload("*"), - selectinload(kernel_rel).options( - noload("*"), - selectinload(KernelRow.agent_row).noload("*"), - ), - ) + match kernel_loading_strategy: + case KernelLoadingStrategy.ALL_KERNELS: + eager_loading_op.extend( + [ + noload("*"), + selectinload(SessionRow.kernels).options( + noload("*"), + selectinload(KernelRow.agent_row).noload("*"), + ), + ] + ) + case KernelLoadingStrategy.MAIN_KERNEL_ONLY: + kernel_rel = SessionRow.kernels + kernel_rel.and_(KernelRow.cluster_role == DEFAULT_ROLE) + eager_loading_op.extend( + [ + noload("*"), + selectinload(kernel_rel).options( + noload("*"), + selectinload(KernelRow.agent_row).noload("*"), + ), + ] + ) + session_list = await cls.match_sessions( db_session, session_ids, access_key, allow_stale=allow_stale, for_update=for_update, - eager_loading_op=kernel_loading_op, + eager_loading_op=eager_loading_op, ) try: return session_list except IndexError: raise SessionNotFound(f"Session (ids={session_ids}) does not exist.") - @classmethod - async def get_session_with_main_kernel( - cls, - session_name_or_id: str | UUID, - access_key: Optional[AccessKey] = None, - *, - allow_stale: bool = False, - for_update: bool = False, - db_session: SASession, - ) -> SessionRow: - return await cls.get_session_with_kernels( - session_name_or_id, - access_key, - allow_stale=allow_stale, - for_update=for_update, - only_main_kern=True, - db_session=db_session, - ) - @classmethod async def get_session_by_id( cls, @@ -1097,51 +1098,6 @@ async def get_sgroup_managed_sessions( result = await db_sess.execute(query) return result.scalars().all() - @classmethod - async def get_session_to_destroy( - cls, db: ExtendedAsyncSAEngine, session_id: SessionId - ) -> SessionRow: - query = ( - sa.select(SessionRow) - .where(SessionRow.id == session_id) - .options( - noload("*"), - load_only(SessionRow.creation_id, SessionRow.status), - selectinload(SessionRow.kernels).options( - noload("*"), - load_only( - KernelRow.id, - KernelRow.role, - KernelRow.access_key, - KernelRow.status, - KernelRow.container_id, - KernelRow.cluster_role, - KernelRow.agent, - KernelRow.agent_addr, - ), - ), - ) - ) - async with db.begin_readonly_session() as db_session: - return (await db_session.scalars(query)).first() - - @classmethod - async def get_session_to_produce_event( - cls, db: ExtendedAsyncSAEngine, session_id: SessionId - ) -> SessionRow: - query = ( - sa.select(SessionRow) - .where(SessionRow.id == session_id) - .options( - noload("*"), - load_only( - SessionRow.id, SessionRow.name, SessionRow.creation_id, SessionRow.access_key - ), - ) - ) - async with db.begin_readonly_session() as db_session: - return (await db_session.scalars(query)).first() - class SessionDependencyRow(Base): __tablename__ = "session_dependencies" @@ -1393,8 +1349,10 @@ async def resolve_dependencies( async def resolve_commit_status(self, info: graphene.ResolveInfo) -> str: graph_ctx: GraphQueryContext = info.context async with graph_ctx.db.begin_readonly_session() as db_sess: - session: SessionRow = await SessionRow.get_session_with_main_kernel( - self.id, db_session=db_sess + session: SessionRow = await SessionRow.get_session( + db_sess, + self.id, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) commit_status = await graph_ctx.registry.get_commit_status(session) return commit_status["status"] diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 0764a9b997..3f92514e78 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -52,6 +52,7 @@ from redis.asyncio import Redis from sqlalchemy.exc import DBAPIError from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import load_only, noload, selectinload from yarl import URL from ai.backend.common import msgpack, redis_helper @@ -142,6 +143,7 @@ AgentStatus, EndpointRow, ImageRow, + KernelLoadingStrategy, KernelRole, KernelRow, KernelStatus, @@ -544,10 +546,11 @@ async def create_session( # NOTE: We can reuse the session IDs of TERMINATED sessions only. # NOTE: Reusing a session in the PENDING status returns an empty value in service_ports. async with self.db.begin_readonly_session() as db_sess: - sess = await SessionRow.get_session_with_main_kernel( + sess = await SessionRow.get_session( + db_sess, session_name, owner_access_key, - db_session=db_sess, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) running_image_ref = ImageRef( sess.main_kernel.image, [sess.main_kernel.registry], sess.main_kernel.architecture @@ -719,9 +722,9 @@ async def create_cluster( # NOTE: Reusing a session in the PENDING status returns an empty value in service_ports. async with self.db.begin_readonly_session() as db_sess: await SessionRow.get_session( + db_sess, session_name, owner_access_key, - db_session=db_sess, ) except SessionNotFound: pass @@ -1310,9 +1313,9 @@ async def _post_enqueue() -> None: for dependency_id in dependency_sessions: try: match_info = await SessionRow.get_session( + db_sess, dependency_id, access_key, - db_session=db_sess, ) except SessionNotFound: raise InvalidAPIParameters( @@ -1645,8 +1648,21 @@ async def finalize_running( new_session_status = await SessionRow.transit_session_status(self.db, session_id) if new_session_status is None or new_session_status != SessionStatus.RUNNING: return - - updated_session = await SessionRow.get_session_to_produce_event(self.db, session_id) + query = ( + sa.select(SessionRow) + .where(SessionRow.id == session_id) + .options( + noload("*"), + load_only( + SessionRow.id, + SessionRow.name, + SessionRow.creation_id, + SessionRow.access_key, + ), + ) + ) + async with self.db.begin_readonly_session() as db_session: + updated_session = (await db_session.scalars(query)).first() log.debug( "Producing SessionStartedEvent({}, {})", @@ -2153,7 +2169,29 @@ async def destroy_session( session_id, set_error=True, ): - target_session = await SessionRow.get_session_to_destroy(self.db, session_id) + query = ( + sa.select(SessionRow) + .where(SessionRow.id == session_id) + .options( + noload("*"), + load_only(SessionRow.creation_id, SessionRow.status), + selectinload(SessionRow.kernels).options( + noload("*"), + load_only( + KernelRow.id, + KernelRow.role, + KernelRow.access_key, + KernelRow.status, + KernelRow.container_id, + KernelRow.cluster_role, + KernelRow.agent, + KernelRow.agent_addr, + ), + ), + ) + ) + async with self.db.begin_readonly_session() as db_session: + target_session = (await db_session.scalars(query)).first() match target_session.status: case SessionStatus.PENDING: @@ -3272,9 +3310,10 @@ async def update_appproxy_endpoint_routes( ) -> None: active_routes = [r for r in endpoint.routings if r.status == RouteStatus.HEALTHY] - target_sessions = await SessionRow.list_sessions_with_main_kernels( + target_sessions = await SessionRow.list_sessions( + db_sess, [r.session for r in active_routes], - db_session=db_sess, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) query = ( sa.select([scaling_groups.c.wsproxy_addr, scaling_groups.c.wsproxy_api_token]) @@ -3312,6 +3351,16 @@ async def update_appproxy_endpoint_routes( f"{wsproxy_addr}/v2/endpoints/{endpoint.id}", json={ "service_name": endpoint.name, + "tags": { + "session": { + "user_uuid": str(endpoint.session_owner), + "group_id": str(endpoint.project), + "domain_name": endpoint.domain, + }, + "endpoint": { + "id": endpoint.id, + }, + }, "apps": inference_apps, "open_to_public": endpoint.open_to_public, }, # TODO: support for multiple inference apps @@ -3320,7 +3369,6 @@ async def update_appproxy_endpoint_routes( }, ) as resp: endpoint_json = await resp.json() - log.debug("resp: {}", endpoint_json) async with self.db.begin_session() as db_sess: query = ( sa.update(EndpointRow) @@ -3456,7 +3504,9 @@ async def handle_destroy_session( event: DoTerminateSessionEvent, ) -> None: async with context.db.begin_session() as db_sess: - session = await SessionRow.get_session_with_kernels(event.session_id, db_session=db_sess) + session = await SessionRow.get_session( + db_sess, event.session_id, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS + ) await context.destroy_session( session, forced=False, @@ -3481,8 +3531,11 @@ async def invoke_session_callback( try: allow_stale = isinstance(event, (SessionCancelledEvent, SessionTerminatedEvent)) async with context.db.begin_readonly_session() as db_sess: - session = await SessionRow.get_session_with_main_kernel( - event.session_id, db_session=db_sess, allow_stale=allow_stale + session = await SessionRow.get_session( + db_sess, + event.session_id, + allow_stale=allow_stale, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) except SessionNotFound: return @@ -3568,8 +3621,8 @@ async def handle_batch_result( await SessionRow.set_session_result(context.db, event.session_id, False, event.exit_code) async with context.db.begin_session() as db_sess: try: - session = await SessionRow.get_session_with_kernels( - event.session_id, db_session=db_sess + session = await SessionRow.get_session( + db_sess, event.session_id, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS ) except SessionNotFound: return