Skip to content

Commit

Permalink
feat: refactor SessionRow's ORM resolvers, advertise additional inf…
Browse files Browse the repository at this point in the history
…ormations to WSProxy (#1396)

Co-authored-by: Jonghyun Park <jpark@lablup.com>
Co-authored-by: Joongi Kim <joongi@lablup.com>
  • Loading branch information
3 people authored Jul 24, 2023
1 parent 9c3b37e commit 0dc6d0c
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 189 deletions.
1 change: 1 addition & 0 deletions changes/1396.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `SessionRow` ORM queries by introducing `KernelLoadingStrategy` to generalize and reuse `SessionRow.get_session()`
160 changes: 115 additions & 45 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from aiohttp import hdrs, web
from dateutil.tz import tzutc
from redis.asyncio import Redis
from sqlalchemy.orm import noload, selectinload
from sqlalchemy.sql.expression import null, true

if TYPE_CHECKING:
Expand All @@ -66,6 +67,7 @@
from ..models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
DEAD_SESSION_STATUSES,
KernelLoadingStrategy,
KernelRole,
SessionRow,
SessionStatus,
Expand Down Expand Up @@ -763,8 +765,14 @@ async def start_service(request: web.Request, params: Mapping[str, Any]) -> web.
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await asyncio.shield(
app_ctx.database_ptask_group.create_task(
SessionRow.get_session_with_main_kernel(
session_name, access_key, db_session=db_sess
SessionRow.get_session(
db_sess,
session_name,
access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
eager_loading_op=[
selectinload(SessionRow.routing).options(noload("*")),
],
),
)
)
Expand Down Expand Up @@ -839,14 +847,27 @@ async def start_service(request: web.Request, params: Mapping[str, Any]) -> web.
if result["status"] == "failed":
raise InternalServerError("Failed to launch the app service", extra_data=result["error"])

async with aiohttp.ClientSession() as session:
async with session.post(
body = {
"login_session_token": params["login_session_token"],
"kernel_host": kernel_host,
"kernel_port": host_port,
"session": {
"id": str(session.id),
"user_uuid": str(session.user_uuid),
"group_id": str(session.group_id),
"access_key": session.access_key,
"domain_name": session.domain_name,
},
}
if session.routing:
body["endpoint"] = {
"id": str(session.routing.endpoint),
}

async with aiohttp.ClientSession() as req:
async with req.post(
f"{wsproxy_addr}/v2/conf",
json={
"login_session_token": params["login_session_token"],
"kernel_host": kernel_host,
"kernel_port": host_port,
},
json=body,
) as resp:
token_json = await resp.json()
return web.json_response(
Expand Down Expand Up @@ -880,8 +901,11 @@ async def get_commit_status(request: web.Request, params: Mapping[str, Any]) ->
)
try:
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
status_info = await root_ctx.registry.get_commit_status(session)
except BackendError:
Expand Down Expand Up @@ -911,8 +935,11 @@ async def get_abusing_report(request: web.Request, params: Mapping[str, Any]) ->
)
try:
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
kernel = session.main_kernel
report = await root_ctx.registry.get_abusing_report(kernel.id)
Expand Down Expand Up @@ -974,8 +1001,11 @@ async def commit_session(request: web.Request, params: Mapping[str, Any]) -> web
)
try:
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)

resp: Mapping[str, Any] = await asyncio.shield(
Expand Down Expand Up @@ -1081,11 +1111,11 @@ async def rename_session(request: web.Request, params: Any) -> web.Response:
)
async with root_ctx.db.begin_session() as db_sess:
compute_session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
allow_stale=True,
for_update=True,
db_session=db_sess,
)
if compute_session.status != SessionStatus.RUNNING:
raise InvalidAPIParameters("Can't change name of not running session")
Expand Down Expand Up @@ -1141,8 +1171,11 @@ async def destroy(request: web.Request, params: Any) -> web.Response:
session_name,
]
sessions = [
await SessionRow.get_session_with_kernels(
name_or_id, owner_access_key, db_session=db_sess
await SessionRow.get_session(
db_sess,
name_or_id,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS,
)
for name_or_id in target_session_references
]
Expand All @@ -1164,8 +1197,11 @@ async def destroy(request: web.Request, params: Any) -> web.Response:
return web.json_response(last_stats, status=200)
else:
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_kernels(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS,
)
last_stat = await root_ctx.registry.destroy_session(
session,
Expand Down Expand Up @@ -1231,8 +1267,11 @@ async def get_direct_access_info(request: web.Request) -> web.Response:
_, owner_access_key = await get_access_key_scopes(request)

async with root_ctx.db.begin_session() as db_sess:
sess = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
sess = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
kernel_role: KernelRole = sess.main_kernel.role
resp = {}
Expand Down Expand Up @@ -1263,8 +1302,11 @@ async def get_info(request: web.Request) -> web.Response:
log.info("GET_INFO (ak:{0}/{1}, s:{2})", requester_access_key, owner_access_key, session_name)
try:
async with root_ctx.db.begin_session() as db_sess:
sess = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
sess = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
await root_ctx.registry.increment_session_usage(sess)
resp["domainName"] = sess.domain_name
Expand Down Expand Up @@ -1318,8 +1360,11 @@ async def restart(request: web.Request) -> web.Response:
requester_access_key, owner_access_key = await get_access_key_scopes(request)
log.info("RESTART (ak:{0}/{1}, s:{2})", requester_access_key, owner_access_key, session_name)
async with root_ctx.db.begin_session() as db_sess:
session = await SessionRow.get_session_with_kernels(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS,
)
try:
await root_ctx.registry.increment_session_usage(session)
Expand Down Expand Up @@ -1348,8 +1393,11 @@ async def execute(request: web.Request) -> web.Response:
log.warning("EXECUTE: invalid/missing parameters")
raise InvalidAPIParameters
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
try:
await root_ctx.registry.increment_session_usage(session)
Expand Down Expand Up @@ -1440,8 +1488,11 @@ async def interrupt(request: web.Request) -> web.Response:
requester_access_key, owner_access_key = await get_access_key_scopes(request)
log.info("INTERRUPT(ak:{0}/{1}, s:{2})", requester_access_key, owner_access_key, session_name)
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
try:
await root_ctx.registry.increment_session_usage(session)
Expand Down Expand Up @@ -1472,8 +1523,11 @@ async def complete(request: web.Request) -> web.Response:
except json.decoder.JSONDecodeError:
raise InvalidAPIParameters
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
try:
code = params.get("code", "")
Expand Down Expand Up @@ -1509,8 +1563,11 @@ async def shutdown_service(request: web.Request, params: Any) -> web.Response:
)
service_name = params.get("service_name")
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
try:
await root_ctx.registry.shutdown_service(session, service_name)
Expand Down Expand Up @@ -1542,7 +1599,7 @@ async def _find_dependent_sessions(session_id: uuid.UUID) -> Set[uuid.UUID]:
return dependent_sessions

root_session = await SessionRow.get_session(
root_session_name_or_id, access_key=access_key, db_session=db_session
db_session, root_session_name_or_id, access_key=access_key
)
return await _find_dependent_sessions(cast(uuid.UUID, root_session.id))

Expand All @@ -1559,8 +1616,11 @@ async def upload_files(request: web.Request) -> web.Response:
"UPLOAD_FILE (ak:{0}/{1}, s:{2})", requester_access_key, owner_access_key, session_name
)
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
try:
await root_ctx.registry.increment_session_usage(session)
Expand Down Expand Up @@ -1615,8 +1675,11 @@ async def download_files(request: web.Request, params: Any) -> web.Response:
files[0],
)
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
try:
assert len(files) <= 5, "Too many files"
Expand Down Expand Up @@ -1674,8 +1737,11 @@ async def download_single(request: web.Request, params: Any) -> web.Response:
)
try:
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
await root_ctx.registry.increment_session_usage(session)
result = await root_ctx.registry.download_single(session, owner_access_key, file)
Expand Down Expand Up @@ -1710,8 +1776,11 @@ async def list_files(request: web.Request) -> web.Response:
path,
)
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await SessionRow.get_session_with_main_kernel(
session_name, owner_access_key, db_session=db_sess
session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
except (asyncio.TimeoutError, AssertionError, json.decoder.JSONDecodeError) as e:
log.warning("LIST_FILES: invalid/missing parameters, {0!r}", e)
Expand Down Expand Up @@ -1752,11 +1821,12 @@ async def get_container_logs(request: web.Request, params: Any) -> web.Response:
)
resp = {"result": {"logs": ""}}
async with root_ctx.db.begin_readonly_session() as db_sess:
compute_session = await SessionRow.get_session_with_main_kernel(
compute_session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
allow_stale=True,
db_session=db_sess,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
if (
compute_session.status in DEAD_SESSION_STATUSES
Expand Down
30 changes: 21 additions & 9 deletions src/ai/backend/manager/api/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ai.backend.manager.idle import AppStreamingStatus

from ..defs import DEFAULT_ROLE
from ..models import KernelRow, SessionRow
from ..models import KernelLoadingStrategy, KernelRow, SessionRow
from .auth import auth_required
from .exceptions import (
AppNotFound,
Expand Down Expand Up @@ -86,8 +86,11 @@ async def stream_pty(defer, request: web.Request) -> web.StreamResponse:
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await asyncio.shield(
database_ptask_group.create_task(
SessionRow.get_session_with_kernels(
session_name, access_key, only_main_kern=True, db_session=db_sess
SessionRow.get_session(
db_sess,
session_name,
access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
),
)
Expand Down Expand Up @@ -301,8 +304,11 @@ async def stream_execute(defer, request: web.Request) -> web.StreamResponse:
async with root_ctx.db.begin_readonly_session() as db_sess:
session: SessionRow = await asyncio.shield(
database_ptask_group.create_task(
SessionRow.get_session_with_kernels(
session_name, access_key, only_main_kern=True, db_session=db_sess
SessionRow.get_session(
db_sess,
session_name,
access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
), # noqa
),
)
Expand Down Expand Up @@ -448,8 +454,11 @@ async def stream_proxy(
async with root_ctx.db.begin_readonly_session() as db_sess:
session = await asyncio.shield(
database_ptask_group.create_task(
SessionRow.get_session_with_kernels(
session_name, access_key, only_main_kern=True, db_session=db_sess
SessionRow.get_session(
db_sess,
session_name,
access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
),
)
)
Expand Down Expand Up @@ -632,8 +641,11 @@ async def get_stream_apps(request: web.Request) -> web.Response:
access_key = request["keypair"]["access_key"]
root_ctx: RootContext = request.app["_root.context"]
async with root_ctx.db.begin_readonly_session() as db_sess:
compute_session = await SessionRow.get_session_with_kernels(
session_name, access_key, only_main_kern=True, db_session=db_sess
compute_session = await SessionRow.get_session(
db_sess,
session_name,
access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
service_ports = compute_session.main_kernel.service_ports
if service_ports is None:
Expand Down
Loading

0 comments on commit 0dc6d0c

Please sign in to comment.