Skip to content

Commit

Permalink
refactor: Merge kernels.role into sessions.session_type (#1587) (#…
Browse files Browse the repository at this point in the history
…2854)

Co-authored-by: Sanghun Lee <sanghun@lablup.com>
Co-authored-by: Joongi Kim <joongi@lablup.com>
  • Loading branch information
3 people authored Sep 21, 2024
1 parent 29afd1c commit 5728169
Show file tree
Hide file tree
Showing 9 changed files with 385 additions and 100 deletions.
1 change: 1 addition & 0 deletions changes/1587.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Merge `kernels.role` into `sessions.session_type` and check the image compatibility based on comparison with the `ai.backend.role` label
1 change: 1 addition & 0 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ class SessionTypes(enum.StrEnum):
INTERACTIVE = "interactive"
BATCH = "batch"
INFERENCE = "inference"
SYSTEM = "system"


class SessionResult(enum.StrEnum):
Expand Down
9 changes: 5 additions & 4 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
DEAD_SESSION_STATUSES,
ImageRow,
KernelLoadingStrategy,
KernelRole,
SessionDependencyRow,
SessionRow,
SessionStatus,
Expand All @@ -100,6 +99,7 @@
session_templates,
vfolders,
)
from ..models.session import PRIVATE_SESSION_TYPES
from ..types import UserScope
from ..utils import query_userinfo as _query_userinfo
from .auth import auth_required
Expand Down Expand Up @@ -1491,9 +1491,9 @@ async def get_direct_access_info(request: web.Request) -> web.Response:
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
kernel_role: KernelRole = sess.main_kernel.role
resp = {}
if kernel_role == KernelRole.SYSTEM:
sess_type = cast(SessionTypes, sess.session_type)
if sess_type in PRIVATE_SESSION_TYPES:
public_host = sess.main_kernel.agent_row.public_host
found_ports: dict[str, list[str]] = {}
for sport in sess.main_kernel.service_ports:
Expand All @@ -1502,7 +1502,8 @@ async def get_direct_access_info(request: web.Request) -> web.Response:
elif sport["name"] == "sftpd":
found_ports["sftpd"] = sport["host_ports"]
resp = {
"kernel_role": kernel_role.name,
"kernel_role": sess_type.name, # legacy
"session_type": sess_type.name,
"public_host": public_host,
"sshd_ports": found_ports.get("sftpd") or found_ports["sshd"],
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
"""replace_kernelrole_to_sessiontypes
Revision ID: 3596bc12ec09
Revises: 59a622c31820
Create Date: 2023-10-04 16:43:46.281383
"""

import enum
import uuid
from typing import cast

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import text

from ai.backend.manager.models.base import GUID, mapper_registry

# revision identifiers, used by Alembic.
revision = "3596bc12ec09"
down_revision = "59a622c31820"
branch_labels = None
depends_on = None

ENUM_CLS = "sessiontypes"

PAGE_SIZE = 100


class KernelRole(enum.Enum):
INFERENCE = "INFERENCE"
COMPUTE = "COMPUTE"
SYSTEM = "SYSTEM"


kernelrole_choices = list(map(lambda v: v.name, KernelRole))
kernelrole = postgresql.ENUM(*kernelrole_choices, name="kernelrole")


class OldSessionTypes(enum.StrEnum):
INTERACTIVE = "interactive"
BATCH = "batch"
INFERENCE = "inference"


def upgrade():
connection = op.get_bind()

# Relax the sessions.session_type from enum to varchar(64).
connection.execute(
text(
"ALTER TABLE sessions ALTER COLUMN session_type TYPE varchar(64) USING session_type::text;"
)
)
connection.execute(
text("ALTER TABLE sessions ALTER COLUMN session_type SET DEFAULT 'INTERACTIVE';")
)

# Relax the kernels.session_type from enum to varchar(64).
connection.execute(
text(
"ALTER TABLE kernels ALTER COLUMN session_type TYPE varchar(64) USING session_type::text;"
)
)
connection.execute(
text("ALTER TABLE kernels ALTER COLUMN session_type SET DEFAULT 'INTERACTIVE';")
)
# Relax the kernels.role from enum to varchar(64).
connection.execute(
text("ALTER TABLE kernels ALTER COLUMN role TYPE varchar(64) USING role::text;")
)
connection.execute(text("ALTER TABLE kernels ALTER COLUMN role SET DEFAULT 'COMPUTE';"))

# Drop enum types
connection.execute(text("DROP TYPE IF EXISTS sessiontypes;"))
connection.execute(text("DROP TYPE IF EXISTS kernelrole;"))

# Update `sessions.session_type` column
kernels = sa.Table(
"kernels",
mapper_registry.metadata,
sa.Column("id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")),
sa.Column("session_id", GUID),
sa.Column(
"role",
postgresql.ENUM("INFERENCE", "COMPUTE", "SYSTEM", name="kernelrole"),
default=KernelRole.COMPUTE,
server_default=KernelRole.COMPUTE.name,
),
sa.Column(
"session_type",
sa.VARCHAR,
index=True,
nullable=False,
default="interactive",
server_default="interactive",
),
extend_existing=True,
)

sessions = sa.Table(
"sessions",
mapper_registry.metadata,
sa.Column("id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")),
sa.Column(
"session_type",
sa.VARCHAR,
index=True,
nullable=False,
default="interactive",
server_default="interactive",
),
extend_existing=True,
)
while True:
_sid_query = (
sa.select(kernels.c.session_id).where(kernels.c.role == "SYSTEM").limit(PAGE_SIZE)
)
session_ids = cast(list[uuid.UUID], connection.scalars(_sid_query).all())
_session_query = (
sa.update(sessions)
.values({"session_type": "SYSTEM"})
.where(sessions.c.id.in_(session_ids))
)
_kernel_query = (
sa.update(kernels)
.values({"session_type": "SYSTEM", "role": "COMPUTE"})
.where(kernels.c.session_id.in_(session_ids))
)
connection.execute(_session_query)
result = connection.execute(_kernel_query)

if result.rowcount == 0:
break

op.drop_column("kernels", "role")


def downgrade():
connection = op.get_bind()

kernel_role_values = ["INFERENCE", "COMPUTE", "SYSTEM"]
KernelRoleType = postgresql.ENUM(*kernel_role_values, name="kernelrole")
KernelRoleType.create(connection)
op.add_column(
"kernels",
sa.Column(
"role",
KernelRoleType,
autoincrement=False,
nullable=True,
server_default=KernelRole.COMPUTE.name,
),
)

kernels = sa.Table(
"kernels",
mapper_registry.metadata,
sa.Column("id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")),
sa.Column(
"role",
KernelRoleType,
default=KernelRole.COMPUTE,
server_default=KernelRole.COMPUTE.name,
),
sa.Column(
"session_type",
sa.VARCHAR,
index=True,
nullable=False,
default="interactive",
server_default="interactive",
),
extend_existing=True,
)

sessions = sa.Table(
"sessions",
mapper_registry.metadata,
sa.Column("id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")),
sa.Column(
"session_type",
sa.VARCHAR,
index=True,
nullable=False,
default="interactive",
server_default="interactive",
),
extend_existing=True,
)

# replace session_type.system role
while True:
_sid_query = (
sa.select(kernels.c.session_id)
.where(kernels.c.session_type == "SYSTEM")
.limit(PAGE_SIZE)
)
session_ids = cast(list[uuid.UUID], connection.scalars(_sid_query).all())
_session_query = (
sa.update(sessions)
.values({"session_type": "INTERACTIVE"})
.where(sessions.c.id.in_(session_ids))
)
_kernel_query = (
sa.update(kernels)
.values({"session_type": "INTERACTIVE", "role": "SYSTEM"})
.where(kernels.c.session_id.in_(session_ids))
)
connection.execute(_session_query)
result = connection.execute(_kernel_query)
if result.rowcount == 0:
break

op.alter_column("kernels", column_name="role", nullable=False)

connection.execute(
text(
"CREATE TYPE sessiontypes AS ENUM (%s)"
% (",".join(f"'{choice.name}'" for choice in OldSessionTypes))
)
)
# Revert sessions.session_type to enum
connection.execute(text("ALTER TABLE sessions ALTER COLUMN session_type DROP DEFAULT;"))
connection.execute(
text(
"ALTER TABLE sessions ALTER COLUMN session_type TYPE sessiontypes "
"USING session_type::sessiontypes;"
)
)
connection.execute(
text("ALTER TABLE sessions ALTER COLUMN session_type SET DEFAULT 'INTERACTIVE';")
)

# Revert kernels.session_type to enum
connection.execute(text("ALTER TABLE kernels ALTER COLUMN session_type DROP DEFAULT;"))
connection.execute(
text(
"ALTER TABLE kernels ALTER COLUMN session_type TYPE sessiontypes "
"USING session_type::sessiontypes;"
)
)
connection.execute(
text("ALTER TABLE kernels ALTER COLUMN session_type SET DEFAULT 'INTERACTIVE';")
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
"""

import enum

import sqlalchemy as sa
from alembic import op
from sqlalchemy import cast
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql.functions import coalesce

from ai.backend.manager.models import ImageRow, KernelRole, kernels
from ai.backend.manager.models import ImageRow, kernels
from ai.backend.manager.models.base import EnumType

# revision identifiers, used by Alembic.
Expand All @@ -21,6 +23,13 @@
branch_labels = None
depends_on = None


class KernelRole(enum.Enum):
INFERENCE = "INFERENCE"
COMPUTE = "COMPUTE"
SYSTEM = "SYSTEM"


images = ImageRow.__table__
kernelrole_choices = list(map(lambda v: v.name, KernelRole))
kernelrole = postgresql.ENUM(*kernelrole_choices, name="kernelrole")
Expand Down
25 changes: 18 additions & 7 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,31 +214,42 @@ class StrEnumType(TypeDecorator, Generic[T_StrEnum]):
impl = sa.VARCHAR
cache_ok = True

def __init__(self, enum_cls: type[T_StrEnum], **opts) -> None:
def __init__(self, enum_cls: type[T_StrEnum], use_name: bool = False, **opts) -> None:
self._opts = opts
super().__init__(length=64, **opts)
self._use_name = use_name
self._enum_cls = enum_cls

def process_bind_param(
self,
value: Optional[T_StrEnum],
dialect: Dialect,
) -> Optional[str]:
return value.value if value is not None else None
if value is None:
return None
if self._use_name:
return value.name
else:
return value.value

def process_result_value(
self,
value: str,
value: Optional[str],
dialect: Dialect,
) -> Optional[T_StrEnum]:
return self._enum_cls(value) if value is not None else None
if value is None:
return None
if self._use_name:
return self._enum_cls[value]
else:
return self._enum_cls(value)

def copy(self, **kw) -> type[Self]:
return StrEnumType(self._enum_cls, **self._opts)
return StrEnumType(self._enum_cls, self._use_name, **self._opts)

@property
def python_type(self) -> T_StrEnum:
return self._enum_class
def python_type(self) -> type[T_StrEnum]:
return self._enum_cls


class CurvePublicKeyColumn(TypeDecorator):
Expand Down
Loading

0 comments on commit 5728169

Please sign in to comment.