diff --git a/src/ai/backend/manager/models/domain.py b/src/ai/backend/manager/models/domain.py index 2d18a447dd..18d9d56668 100644 --- a/src/ai/backend/manager/models/domain.py +++ b/src/ai/backend/manager/models/domain.py @@ -1,7 +1,20 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypedDict +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + NamedTuple, + Optional, + Sequence, + TypeAlias, + TypedDict, + cast, + override, +) import graphene import sqlalchemy as sa @@ -10,7 +23,8 @@ from sqlalchemy.engine.result import Result from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection -from sqlalchemy.orm import relationship +from sqlalchemy.ext.asyncio import AsyncSession as SASession +from sqlalchemy.orm import load_only, relationship from ai.backend.common import msgpack from ai.backend.common.types import ResourceSlot @@ -29,6 +43,17 @@ simple_db_mutate, simple_db_mutate_returning_item, ) +from .rbac import ( + AbstractPermissionContext, + AbstractPermissionContextBuilder, + DomainScope, + ProjectScope, + ScopeType, + UserScope, + get_predefined_roles_in_scope, +) +from .rbac.context import ClientContext +from .rbac.permission_defs import DomainPermission from .scaling_group import ScalingGroup from .user import UserRole @@ -415,3 +440,183 @@ def verify_dotfile_name(dotfile: str) -> bool: if dotfile in RESERVED_DOTFILES: return False return True + + +OWNER_PERMISSIONS: frozenset[DomainPermission] = frozenset([perm for perm in DomainPermission]) +ADMIN_PERMISSIONS: frozenset[DomainPermission] = frozenset([ + DomainPermission.READ_ATTRIBUTE, + DomainPermission.UPDATE_ATTRIBUTE, + DomainPermission.CREATE_USER, + DomainPermission.CREATE_PROJECT, +]) +MONITOR_PERMISSIONS: frozenset[DomainPermission] = frozenset([ + DomainPermission.READ_ATTRIBUTE, + DomainPermission.UPDATE_ATTRIBUTE, +]) +PRIVILEGED_MEMBER_PERMISSIONS: frozenset[DomainPermission] = frozenset() +MEMBER_PERMISSIONS: frozenset[DomainPermission] = frozenset() + +WhereClauseType: TypeAlias = ( + sa.sql.expression.BinaryExpression | sa.sql.expression.BooleanClauseList +) + + +@dataclass +class DomainPermissionContext(AbstractPermissionContext[DomainPermission, DomainRow, str]): + @property + def query_condition(self) -> WhereClauseType | None: + cond: WhereClauseType | None = None + + def _OR_coalesce( + base_cond: WhereClauseType | None, + _cond: sa.sql.expression.BinaryExpression, + ) -> WhereClauseType: + return base_cond | _cond if base_cond is not None else _cond + + if self.object_id_to_additional_permission_map: + cond = _OR_coalesce( + cond, DomainRow.name.in_(self.object_id_to_additional_permission_map.keys()) + ) + if self.object_id_to_overriding_permission_map: + cond = _OR_coalesce( + cond, DomainRow.name.in_(self.object_id_to_overriding_permission_map.keys()) + ) + return cond + + async def build_query(self) -> sa.sql.Select | None: + cond = self.query_condition + if cond is None: + return None + return sa.select(DomainRow).where(cond) + + async def calculate_final_permission(self, rbac_obj: DomainRow) -> frozenset[DomainPermission]: + domain_row = rbac_obj + domain_name = cast(str, domain_row.name) + permissions: frozenset[DomainPermission] = frozenset() + + if ( + overriding_perm := self.object_id_to_overriding_permission_map.get(domain_name) + ) is not None: + permissions = overriding_perm + else: + permissions |= self.object_id_to_additional_permission_map.get(domain_name, set()) + return permissions + + +class DomainPermissionContextBuilder( + AbstractPermissionContextBuilder[DomainPermission, DomainPermissionContext] +): + db_session: SASession + + def __init__(self, db_session: SASession) -> None: + self.db_session = db_session + + @override + async def calculate_permission( + self, + ctx: ClientContext, + target_scope: ScopeType, + ) -> frozenset[DomainPermission]: + roles = await get_predefined_roles_in_scope(ctx, target_scope, self.db_session) + permissions = await self._calculate_permission_by_predefined_roles(roles) + return permissions + + @override + async def build_ctx_in_system_scope( + self, + ctx: ClientContext, + ) -> DomainPermissionContext: + from .domain import DomainRow + + perm_ctx = DomainPermissionContext() + _domain_query_stmt = sa.select(DomainRow).options(load_only(DomainRow.name)) + for row in await self.db_session.scalars(_domain_query_stmt): + to_be_merged = await self.build_ctx_in_domain_scope(ctx, DomainScope(row.name)) + perm_ctx.merge(to_be_merged) + return perm_ctx + + @override + async def build_ctx_in_domain_scope( + self, + ctx: ClientContext, + scope: DomainScope, + ) -> DomainPermissionContext: + permissions = await self.calculate_permission(ctx, scope) + return DomainPermissionContext( + object_id_to_additional_permission_map={scope.domain_name: permissions} + ) + + @override + async def build_ctx_in_project_scope( + self, ctx: ClientContext, scope: ProjectScope + ) -> DomainPermissionContext: + return DomainPermissionContext() + + @override + async def build_ctx_in_user_scope( + self, ctx: ClientContext, scope: UserScope + ) -> DomainPermissionContext: + return DomainPermissionContext() + + @override + @classmethod + async def _permission_for_owner( + cls, + ) -> frozenset[DomainPermission]: + return OWNER_PERMISSIONS + + @override + @classmethod + async def _permission_for_admin( + cls, + ) -> frozenset[DomainPermission]: + return ADMIN_PERMISSIONS + + @override + @classmethod + async def _permission_for_monitor( + cls, + ) -> frozenset[DomainPermission]: + return MONITOR_PERMISSIONS + + @override + @classmethod + async def _permission_for_privileged_member( + cls, + ) -> frozenset[DomainPermission]: + return PRIVILEGED_MEMBER_PERMISSIONS + + @override + @classmethod + async def _permission_for_member( + cls, + ) -> frozenset[DomainPermission]: + return MEMBER_PERMISSIONS + + +class DomainWithPermissionSet(NamedTuple): + domain_row: DomainRow + permissions: frozenset[DomainPermission] + + +async def get_domains( + target_scope: ScopeType, + requested_permission: DomainPermission, + domain_name: Optional[str] = None, + *, + ctx: ClientContext, + db_conn: SAConnection, +) -> list[DomainWithPermissionSet]: + async with ctx.db.begin_readonly_session(db_conn) as db_session: + builder = DomainPermissionContextBuilder(db_session) + permission_ctx = await builder.build(ctx, target_scope, requested_permission) + query_stmt = await permission_ctx.build_query() + if query_stmt is None: + return [] + if domain_name is not None: + query_stmt = query_stmt.where(DomainRow.name == domain_name) + result: list[DomainWithPermissionSet] = [] + async for row in await db_session.stream_scalars(query_stmt): + permissions = await permission_ctx.calculate_final_permission(row) + result.append(DomainWithPermissionSet(row, permissions)) + return result diff --git a/src/ai/backend/manager/models/group.py b/src/ai/backend/manager/models/group.py index 8ca968ebb3..a9047fafc9 100644 --- a/src/ai/backend/manager/models/group.py +++ b/src/ai/backend/manager/models/group.py @@ -4,16 +4,21 @@ import enum import logging import uuid +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, Dict, Iterable, + NamedTuple, Optional, Sequence, + TypeAlias, TypedDict, Union, + cast, overload, + override, ) import aiotools @@ -24,7 +29,8 @@ from graphql import Undefined from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection -from sqlalchemy.orm import relationship +from sqlalchemy.ext.asyncio import AsyncSession as SASession +from sqlalchemy.orm import load_only, relationship from ai.backend.common import msgpack from ai.backend.common.types import ResourceSlot, VFolderID @@ -49,6 +55,17 @@ simple_db_mutate, simple_db_mutate_returning_item, ) +from .rbac import ( + AbstractPermissionContext, + AbstractPermissionContextBuilder, + DomainScope, + ProjectScope, + ScopeType, + UserScope, + get_predefined_roles_in_scope, +) +from .rbac.context import ClientContext +from .rbac.permission_defs import ProjectPermission from .user import ModifyUserInput, UserRole from .utils import ExtendedAsyncSAEngine, execute_with_retry @@ -840,3 +857,190 @@ def verify_dotfile_name(dotfile: str) -> bool: if dotfile in RESERVED_DOTFILES: return False return True + + +ALL_PROJECT_PERMISSIONS = frozenset([perm for perm in ProjectPermission]) +OWNER_PERMISSIONS: frozenset[ProjectPermission] = ALL_PROJECT_PERMISSIONS +ADMIN_PERMISSIONS: frozenset[ProjectPermission] = ALL_PROJECT_PERMISSIONS +MONITOR_PERMISSIONS: frozenset[ProjectPermission] = frozenset([ + ProjectPermission.READ_ATTRIBUTE, + ProjectPermission.UPDATE_ATTRIBUTE, +]) +PRIVILEGED_MEMBER_PERMISSIONS: frozenset[ProjectPermission] = frozenset() +MEMBER_PERMISSIONS: frozenset[ProjectPermission] = frozenset() + +WhereClauseType: TypeAlias = ( + sa.sql.expression.BinaryExpression | sa.sql.expression.BooleanClauseList +) + + +@dataclass +class ProjectPermissionContext(AbstractPermissionContext[ProjectPermission, GroupRow, uuid.UUID]): + @property + def query_condition(self) -> WhereClauseType | None: + cond: WhereClauseType | None = None + + def _OR_coalesce( + base_cond: WhereClauseType | None, + _cond: sa.sql.expression.BinaryExpression, + ) -> WhereClauseType: + return base_cond | _cond if base_cond is not None else _cond + + if self.domain_name_to_permission_map: + cond = _OR_coalesce( + cond, GroupRow.domain_name.in_(self.domain_name_to_permission_map.keys()) + ) + if self.object_id_to_additional_permission_map: + cond = _OR_coalesce( + cond, GroupRow.id.in_(self.object_id_to_additional_permission_map.keys()) + ) + if self.object_id_to_overriding_permission_map: + cond = _OR_coalesce( + cond, GroupRow.id.in_(self.object_id_to_overriding_permission_map.keys()) + ) + return cond + + async def build_query(self) -> sa.sql.Select | None: + cond = self.query_condition + if cond is None: + return None + return sa.select(GroupRow).where(cond) + + async def calculate_final_permission(self, rbac_obj: GroupRow) -> frozenset[ProjectPermission]: + project_row = rbac_obj + project_id = cast(uuid.UUID, project_row.id) + permissions: frozenset[ProjectPermission] = frozenset() + + if ( + overriding_perm := self.object_id_to_overriding_permission_map.get(project_id) + ) is not None: + permissions = overriding_perm + else: + permissions |= self.object_id_to_additional_permission_map.get(project_id, set()) + permissions |= self.domain_name_to_permission_map.get(project_row.domain_name, set()) + return permissions + + +class ProjectPermissionContextBuilder( + AbstractPermissionContextBuilder[ProjectPermission, ProjectPermissionContext] +): + db_session: SASession + + def __init__(self, db_session: SASession) -> None: + self.db_session = db_session + + @override + async def calculate_permission( + self, + ctx: ClientContext, + target_scope: ScopeType, + ) -> frozenset[ProjectPermission]: + roles = await get_predefined_roles_in_scope(ctx, target_scope, self.db_session) + permissions = await self._calculate_permission_by_predefined_roles(roles) + return permissions + + @override + async def build_ctx_in_system_scope( + self, + ctx: ClientContext, + ) -> ProjectPermissionContext: + from .domain import DomainRow + + perm_ctx = ProjectPermissionContext() + _domain_query_stmt = sa.select(DomainRow).options(load_only(DomainRow.name)) + for row in await self.db_session.scalars(_domain_query_stmt): + to_be_merged = await self.build_ctx_in_domain_scope(ctx, DomainScope(row.name)) + perm_ctx.merge(to_be_merged) + return perm_ctx + + @override + async def build_ctx_in_domain_scope( + self, + ctx: ClientContext, + scope: DomainScope, + ) -> ProjectPermissionContext: + permissions = await self.calculate_permission(ctx, scope) + return ProjectPermissionContext( + domain_name_to_permission_map={scope.domain_name: permissions} + ) + + @override + async def build_ctx_in_project_scope( + self, ctx: ClientContext, scope: ProjectScope + ) -> ProjectPermissionContext: + permissions = await self.calculate_permission(ctx, scope) + return ProjectPermissionContext( + object_id_to_additional_permission_map={scope.project_id: permissions} + ) + + @override + async def build_ctx_in_user_scope( + self, ctx: ClientContext, scope: UserScope + ) -> ProjectPermissionContext: + return ProjectPermissionContext() + + @override + @classmethod + async def _permission_for_owner( + cls, + ) -> frozenset[ProjectPermission]: + return OWNER_PERMISSIONS + + @override + @classmethod + async def _permission_for_admin( + cls, + ) -> frozenset[ProjectPermission]: + return ADMIN_PERMISSIONS + + @override + @classmethod + async def _permission_for_monitor( + cls, + ) -> frozenset[ProjectPermission]: + return MONITOR_PERMISSIONS + + @override + @classmethod + async def _permission_for_privileged_member( + cls, + ) -> frozenset[ProjectPermission]: + return PRIVILEGED_MEMBER_PERMISSIONS + + @override + @classmethod + async def _permission_for_member( + cls, + ) -> frozenset[ProjectPermission]: + return MEMBER_PERMISSIONS + + +class ProjectWithPermissionSet(NamedTuple): + project_row: GroupRow + permissions: frozenset[ProjectPermission] + + +async def get_projects( + target_scope: ScopeType, + requested_permission: ProjectPermission, + project_id: Optional[uuid.UUID] = None, + project_name: Optional[str] = None, + *, + ctx: ClientContext, + db_conn: SAConnection, +) -> list[ProjectWithPermissionSet]: + async with ctx.db.begin_readonly_session(db_conn) as db_session: + builder = ProjectPermissionContextBuilder(db_session) + permission_ctx = await builder.build(ctx, target_scope, requested_permission) + query_stmt = await permission_ctx.build_query() + if query_stmt is None: + return [] + if project_id is not None: + query_stmt = query_stmt.where(GroupRow.id == project_id) + if project_name is not None: + query_stmt = query_stmt.where(GroupRow.name == project_name) + result: list[ProjectWithPermissionSet] = [] + async for row in await db_session.stream_scalars(query_stmt): + permissions = await permission_ctx.calculate_final_permission(row) + result.append(ProjectWithPermissionSet(row, permissions)) + return result diff --git a/src/ai/backend/manager/models/rbac/permission_defs.py b/src/ai/backend/manager/models/rbac/permission_defs.py index 6c41a7b0d5..85f8214e0c 100644 --- a/src/ai/backend/manager/models/rbac/permission_defs.py +++ b/src/ai/backend/manager/models/rbac/permission_defs.py @@ -77,3 +77,21 @@ class AgentPermission(BasePermission): CREATE_COMPUTE_SESSION = enum.auto() CREATE_SERVICE = enum.auto() + + +class DomainPermission(BasePermission): + # These permissions limit actions taken directly to domains + READ_ATTRIBUTE = enum.auto() + UPDATE_ATTRIBUTE = enum.auto() + + CREATE_USER = enum.auto() + CREATE_PROJECT = enum.auto() + + +class ProjectPermission(BasePermission): + # These permissions limit actions taken directly to projects(groups) + READ_ATTRIBUTE = enum.auto() + UPDATE_ATTRIBUTE = enum.auto() + DELETE_PROJECT = enum.auto() + + ASSOCIATE_WITH_USER = enum.auto()