diff --git a/changes/2934.feature.md b/changes/2934.feature.md new file mode 100644 index 0000000000..7c7c3995a5 --- /dev/null +++ b/changes/2934.feature.md @@ -0,0 +1 @@ +Add GQL Relay domain query schema and resolver diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index aa9c2d6e01..d541126aa0 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -14,6 +14,12 @@ type Queries { agents(scaling_group: String, status: String): [Agent] agent_summary(agent_id: String!): AgentSummary agent_summary_list(limit: Int!, offset: Int!, filter: String, order: String, scaling_group: String, status: String): AgentSummaryList + + """Added in 24.12.0.""" + domain_node(id: GlobalIDField!, permission: DomainPermissionValueField = "read_attribute"): DomainNode + + """Added in 24.12.0.""" + domain_nodes(filter: String, order: String, permission: DomainPermissionValueField = "read_attribute", offset: Int, before: String, after: String, first: Int, last: Int): DomainConnection domain(name: String): Domain domains(is_active: Boolean): [Domain] @@ -360,6 +366,115 @@ type AgentSummaryList implements PaginatedList { total_count: Int! } +"""Added in 24.12.0.""" +type DomainNode implements Node { + """The ID of the object""" + id: ID! + name: String + description: String + is_active: Boolean + created_at: DateTime + modified_at: DateTime + total_resource_slots: JSONString + allowed_vfolder_hosts: JSONString + allowed_docker_registries: [String] + dotfiles: Bytes + integration_id: String + scaling_groups(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ScalinGroupConnection +} + +"""Added in 24.09.1.""" +scalar Bytes + +"""Added in 24.12.0.""" +type ScalinGroupConnection { + """Pagination data for this connection.""" + pageInfo: PageInfo! + + """Contains the nodes in this connection.""" + edges: [ScalinGroupEdge]! + + """Total count of the GQL nodes of the query.""" + count: Int +} + +""" +The Relay compliant `PageInfo` type, containing data necessary to paginate this connection. +""" +type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String +} + +""" +Added in 24.12.0. A Relay edge containing a `ScalinGroup` and its cursor. +""" +type ScalinGroupEdge { + """The item at the end of the edge""" + node: ScalingGroupNode + + """A cursor for use in pagination""" + cursor: String! +} + +"""Added in 24.12.0.""" +type ScalingGroupNode implements Node { + """The ID of the object""" + id: ID! + name: String + description: String + is_active: Boolean + is_public: Boolean + created_at: DateTime + wsproxy_addr: String + wsproxy_api_token: String + driver: String + driver_opts: JSONString + scheduler: String + scheduler_opts: JSONString + use_host_network: Boolean +} + +""" +Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of ":". UUID or string type values are also allowed. +""" +scalar GlobalIDField + +""" +Added in 24.12.0. One of ['read_attribute', 'read_sensitive_attribute', 'update_attribute', 'create_user', 'create_project']. +""" +scalar DomainPermissionValueField + +"""Added in 24.12.0""" +type DomainConnection { + """Pagination data for this connection.""" + pageInfo: PageInfo! + + """Contains the nodes in this connection.""" + edges: [DomainEdge]! + + """Total count of the GQL nodes of the query.""" + count: Int +} + +"""Added in 24.12.0 A Relay edge containing a `Domain` and its cursor.""" +type DomainEdge { + """The item at the end of the edge""" + node: DomainNode + + """A cursor for use in pagination""" + cursor: String! +} + type Domain { name: String description: String @@ -411,23 +526,6 @@ type UserConnection { count: Int } -""" -The Relay compliant `PageInfo` type, containing data necessary to paginate this connection. -""" -type PageInfo { - """When paginating forwards, are there more items?""" - hasNextPage: Boolean! - - """When paginating backwards, are there more items?""" - hasPreviousPage: Boolean! - - """When paginating backwards, the cursor to continue.""" - startCursor: String - - """When paginating forwards, the cursor to continue.""" - endCursor: String -} - """Added in 24.03.0 A Relay edge containing a `User` and its cursor.""" type UserEdge { """The item at the end of the edge""" @@ -1041,11 +1139,6 @@ type ComputeSessionEdge { cursor: String! } -""" -Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of ":". UUID or string type values are also allowed. -""" -scalar GlobalIDField - type ComputeSessionList implements PaginatedList { items: [ComputeSession]! total_count: Int! @@ -1404,6 +1497,12 @@ type Mutations { To purge domain, there should be no users and groups in the target domain. """ purge_domain(name: String!): PurgeDomain + + """Added in 24.12.0.""" + create_domain_node(input: CreateDomainNodeInput!): CreateDomainNode + + """Added in 24.12.0.""" + modify_domain_node(input: ModifyDomainNodeInput!): ModifyDomainNode create_group(name: String!, props: GroupInput!): CreateGroup modify_group(gid: UUID!, props: ModifyGroupInput!): ModifyGroup @@ -1637,6 +1736,47 @@ type PurgeDomain { msg: String } +"""Added in 24.12.0.""" +type CreateDomainNode { + ok: Boolean + msg: String + item: DomainNode +} + +"""Added in 24.12.0.""" +input CreateDomainNodeInput { + name: String! + description: String + is_active: Boolean = true + total_resource_slots: JSONString = "{}" + allowed_vfolder_hosts: JSONString = "{}" + allowed_docker_registries: [String] = [] + integration_id: String = null + dotfiles: Bytes = "90" + scaling_groups: [String] +} + +"""Added in 24.12.0.""" +type ModifyDomainNode { + item: DomainNode + client_mutation_id: String +} + +"""Added in 24.12.0.""" +input ModifyDomainNodeInput { + id: GlobalIDField! + description: String + is_active: Boolean + total_resource_slots: JSONString + allowed_vfolder_hosts: JSONString + allowed_docker_registries: [String] + integration_id: String + dotfiles: Bytes + sgroups_to_add: [String] + sgroups_to_remove: [String] + client_mutation_id: String +} + type CreateGroup { ok: Boolean msg: String diff --git a/src/ai/backend/manager/models/base.py b/src/ai/backend/manager/models/base.py index 9ce78df1ac..a30117cd6b 100644 --- a/src/ai/backend/manager/models/base.py +++ b/src/ai/backend/manager/models/base.py @@ -938,6 +938,29 @@ async def batch_multiresult_in_session( return [*objs_per_key.values()] +async def batch_multiresult_in_scalar_stream( + graph_ctx: GraphQueryContext, + db_sess: SASession, + query: sa.sql.Select, + obj_type: type[T_SQLBasedGQLObject], + key_list: Iterable[T_Key], + key_getter: Callable[[Row], T_Key], +) -> Sequence[Sequence[T_SQLBasedGQLObject]]: + """ + A batched query adaptor for (key -> [item]) resolving patterns. + stream the result in async session. + """ + objs_per_key: dict[T_Key, list[T_SQLBasedGQLObject]] + objs_per_key = dict() + for key in key_list: + objs_per_key[key] = list() + async for row in await db_sess.stream_scalars(query): + objs_per_key[key_getter(row)].append( + obj_type.from_row(graph_ctx, row), + ) + return [*objs_per_key.values()] + + def privileged_query(required_role: UserRole): def wrap(func): @functools.wraps(func) diff --git a/src/ai/backend/manager/models/domain.py b/src/ai/backend/manager/models/domain.py index 10f3f7a9c3..4c7b9fcec9 100644 --- a/src/ai/backend/manager/models/domain.py +++ b/src/ai/backend/manager/models/domain.py @@ -29,7 +29,7 @@ from sqlalchemy.orm import load_only, relationship from ai.backend.common import msgpack -from ai.backend.common.types import ResourceSlot +from ai.backend.common.types import ResourceSlot, VFolderHostPermissionMap from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.models.group import ProjectType @@ -135,11 +135,12 @@ class DomainModel(RBACModel[DomainPermission]): modified_at: datetime _total_resource_slots: Optional[dict] - _allowed_vfolder_hosts: dict + _allowed_vfolder_hosts: VFolderHostPermissionMap _allowed_docker_registries: list[str] _integration_id: Optional[str] _dotfiles: str + orm_obj: DomainRow _permissions: frozenset[DomainPermission] = field(default_factory=frozenset) @property @@ -153,7 +154,7 @@ def total_resource_slots(self) -> Optional[dict]: @property @required_permission(DomainPermission.READ_SENSITIVE_ATTRIBUTE) - def allowed_vfolder_hosts(self) -> dict: + def allowed_vfolder_hosts(self) -> VFolderHostPermissionMap: return self._allowed_vfolder_hosts @property @@ -185,6 +186,7 @@ def from_row(cls, row: DomainRow, permissions: Iterable[DomainPermission]) -> Se _integration_id=row.integration_id, _dotfiles=row.dotfiles, _permissions=frozenset(permissions), + orm_obj=row, ) @@ -658,24 +660,37 @@ async def _permission_for_member( return MEMBER_PERMISSIONS +async def get_permission_ctx( + target_scope: ScopeType, + requested_permission: DomainPermission, + *, + ctx: ClientContext, + db_session: SASession, +) -> DomainPermissionContext: + builder = DomainPermissionContextBuilder(db_session) + permission_ctx = await builder.build(ctx, target_scope, requested_permission) + return permission_ctx + + async def get_domains( target_scope: ScopeType, requested_permission: DomainPermission, - domain_name: Optional[str] = None, + domain_names: Optional[Iterable[str]] = None, *, ctx: ClientContext, - db_conn: SAConnection, + db_session: SASession, ) -> list[DomainModel]: - async with ctx.db.begin_readonly_session(db_conn) as db_session: - builder = DomainPermissionContextBuilder(db_session) - permission_ctx = await builder.build(ctx, target_scope, requested_permission) - query_stmt = await permission_ctx.build_query() - if query_stmt is None: - return [] - if domain_name is not None: - query_stmt = query_stmt.where(DomainRow.name == domain_name) - result: list[DomainModel] = [] - async for row in await db_session.stream_scalars(query_stmt): - permissions = await permission_ctx.calculate_final_permission(row) - result.append(DomainModel.from_row(row, permissions)) - return result + ret: list[DomainModel] = [] + permission_ctx = await get_permission_ctx( + target_scope, requested_permission, ctx=ctx, db_session=db_session + ) + cond = permission_ctx.query_condition + if cond is None: + return ret + query_stmt = sa.select(DomainRow).where(cond) + if domain_names is not None: + query_stmt = query_stmt.where(DomainRow.name.in_(domain_names)) + async for row in await db_session.stream_scalars(query_stmt): + permissions = await permission_ctx.calculate_final_permission(row) + ret.append(DomainModel.from_row(row, permissions)) + return ret diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 0c6eaf225b..c6298ca585 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -65,6 +65,13 @@ from .base import DataLoaderManager, PaginatedConnectionField, privileged_query, scoped_query from .domain import CreateDomain, DeleteDomain, Domain, ModifyDomain, PurgeDomain from .endpoint import Endpoint, EndpointList, EndpointToken, EndpointTokenList, ModifyEndpoint +from .gql_models.domain import ( + CreateDomainNode, + DomainConnection, + DomainNode, + DomainPermissionValueField, + ModifyDomainNode, +) from .gql_models.group import GroupConnection, GroupNode from .gql_models.image import ( AliasImage, @@ -113,7 +120,8 @@ LegacyComputeSessionList, ) from .keypair import CreateKeyPair, DeleteKeyPair, KeyPair, KeyPairList, ModifyKeyPair -from .rbac.permission_defs import ComputeSessionPermission +from .rbac import SystemScope +from .rbac.permission_defs import ComputeSessionPermission, DomainPermission from .rbac.permission_defs import VFolderPermission as VFolderRBACPermission from .resource_policy import ( CreateKeyPairResourcePolicy, @@ -215,6 +223,9 @@ class Mutations(graphene.ObjectType): delete_domain = DeleteDomain.Field() purge_domain = PurgeDomain.Field() + create_domain_node = CreateDomainNode.Field(description="Added in 24.12.0.") + modify_domain_node = ModifyDomainNode.Field(description="Added in 24.12.0.") + # admin only create_group = CreateGroup.Field() modify_group = ModifyGroup.Field() @@ -367,6 +378,27 @@ class Queries(graphene.ObjectType): status=graphene.String(), ) + domain_node = graphene.Field( + DomainNode, + description="Added in 24.12.0.", + id=GlobalIDField(required=True), + permission=DomainPermissionValueField( + required=False, + default_value=DomainPermission.READ_ATTRIBUTE, + ), + ) + + domain_nodes = PaginatedConnectionField( + DomainConnection, + description="Added in 24.12.0.", + filter=graphene.String(), + order=graphene.String(), + permission=DomainPermissionValueField( + required=False, + default_value=DomainPermission.READ_ATTRIBUTE, + ), + ) + domain = graphene.Field( Domain, name=graphene.String(), @@ -943,6 +975,41 @@ async def resolve_agent_summary_list( ) return AgentSummaryList(agent_list, total_count) + @staticmethod + async def resolve_domain_node( + root: Any, + info: graphene.ResolveInfo, + *, + id: str, + permission: DomainPermission, + ) -> Optional[DomainNode]: + return await DomainNode.get_node(info, id, permission) + + @staticmethod + async def resolve_domain_nodes( + root: Any, + info: graphene.ResolveInfo, + *, + permission: DomainPermission, + filter: Optional[str] = None, + order: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + before: Optional[str] = None, + last: Optional[int] = None, + ) -> ConnectionResolverResult[DomainNode]: + return await DomainNode.get_connection( + info, + SystemScope(), + permission, + filter_expr=filter, + order_expr=order, + after=after, + first=first, + before=before, + last=last, + ) + @staticmethod async def resolve_domain( root: Any, diff --git a/src/ai/backend/manager/models/gql_models/base.py b/src/ai/backend/manager/models/gql_models/base.py index 6ed427ae79..82e2456a54 100644 --- a/src/ai/backend/manager/models/gql_models/base.py +++ b/src/ai/backend/manager/models/gql_models/base.py @@ -1,6 +1,9 @@ from __future__ import annotations +from typing import Any, Optional + import graphene +import graphql from graphene.types import Scalar from graphene.types.scalars import MAX_INT, MIN_INT from graphql.language.ast import IntValueNode @@ -62,6 +65,26 @@ def parse_literal(node): return num +class Bytes(Scalar): + class Meta: + description = "Added in 24.09.1." + + @staticmethod + def serialize(val: bytes) -> str: + return val.hex() + + @staticmethod + def parse_literal(node: Any, _variables=None) -> Optional[bytes]: + if isinstance(node, graphql.language.ast.StringValueNode): + assert isinstance(node, str) + return bytes.fromhex(node) + return None + + @staticmethod + def parse_value(value: str) -> bytes: + return bytes.fromhex(value) + + class ImageRefType(graphene.InputObjectType): name = graphene.String(required=True) registry = graphene.String() diff --git a/src/ai/backend/manager/models/gql_models/domain.py b/src/ai/backend/manager/models/gql_models/domain.py new file mode 100644 index 0000000000..6f692ac229 --- /dev/null +++ b/src/ai/backend/manager/models/gql_models/domain.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Self, + cast, +) + +import graphene +import graphql +import sqlalchemy as sa +from dateutil.parser import parse as dtparse +from graphene.types.datetime import DateTime as GQLDateTime +from sqlalchemy.ext.asyncio import AsyncSession + +from ..base import ( + FilterExprArg, + OrderExprArg, + PaginatedConnectionField, + generate_sql_info_for_gql_connection, +) +from ..domain import DomainRow, get_domains, get_permission_ctx +from ..gql_relay import ( + AsyncNode, + Connection, + ConnectionResolverResult, + GlobalIDField, + ResolvedGlobalID, +) +from ..minilang.ordering import OrderSpecItem, QueryOrderParser +from ..minilang.queryfilter import FieldSpecItem, QueryFilterParser +from ..rbac import ( + ClientContext, + ScopeType, + SystemScope, +) +from ..rbac.permission_defs import DomainPermission, ScalingGroupPermission +from ..scaling_group import ScalingGroupForDomainRow, get_scaling_groups +from ..user import UserRole +from ..utils import execute_with_txn_retry +from .base import Bytes +from .scaling_group import ScalinGroupConnection + +if TYPE_CHECKING: + from ..domain import DomainModel + from ..gql import GraphQueryContext + from .scaling_group import ScalingGroupNode + + +class DomainPermissionValueField(graphene.Scalar): + class Meta: + description = f"Added in 24.12.0. One of {[val.value for val in DomainPermission]}." + + @staticmethod + def serialize(val: DomainPermission) -> str: + return val.value + + @staticmethod + def parse_literal(node: Any, _variables=None): + if isinstance(node, graphql.language.ast.StringValueNode): + return DomainPermission(node.value) + + @staticmethod + def parse_value(value: str) -> DomainPermission: + return DomainPermission(value) + + +_queryfilter_fieldspec: Mapping[str, FieldSpecItem] = { + "id": ("id", None), + "row_id": ("id", None), + "name": ("name", None), + "is_active": ("is_active", None), + "created_at": ("created_at", dtparse), + "modified_at": ("modified_at", dtparse), +} + +_queryorder_colmap: Mapping[str, OrderSpecItem] = { + "id": ("id", None), + "row_id": ("id", None), + "name": ("name", None), + "is_active": ("is_active", None), + "created_at": ("created_at", None), + "modified_at": ("modified_at", None), +} + + +class DomainNode(graphene.ObjectType): + class Meta: + interfaces = (AsyncNode,) + description = "Added in 24.12.0." + + name = graphene.String() + description = graphene.String() + is_active = graphene.Boolean() + created_at = GQLDateTime() + modified_at = GQLDateTime() + total_resource_slots = graphene.JSONString() + allowed_vfolder_hosts = graphene.JSONString() + allowed_docker_registries = graphene.List(lambda: graphene.String) + dotfiles = Bytes() + integration_id = graphene.String() + + # Dynamic fields. + scaling_groups = PaginatedConnectionField(ScalinGroupConnection) + + @classmethod + def from_rbac_model( + cls, + graph_ctx: GraphQueryContext, + obj: DomainModel, + ) -> Self: + return cls( + id=obj.name, + name=obj.name, + description=obj.description, + is_active=obj.is_active, + created_at=obj.created_at, + modified_at=obj.modified_at, + total_resource_slots=obj.total_resource_slots, + allowed_vfolder_hosts=obj.allowed_vfolder_hosts.to_json(), + allowed_docker_registries=obj.allowed_docker_registries, + dotfiles=obj.dotfiles, + integration_id=obj.integration_id, + ) + + @classmethod + def from_orm_model( + cls, + graph_ctx: GraphQueryContext, + obj: DomainRow, + ) -> Self: + return cls( + id=obj.name, + name=obj.name, + description=obj.description, + is_active=obj.is_active, + created_at=obj.created_at, + modified_at=obj.modified_at, + total_resource_slots=obj.total_resource_slots, + allowed_vfolder_hosts=obj.allowed_vfolder_hosts.to_json(), + allowed_docker_registries=obj.allowed_docker_registries, + dotfiles=obj.dotfiles, + integration_id=obj.integration_id, + ) + + async def resolve_scaling_groups( + self, info: graphene.ResolveInfo + ) -> ConnectionResolverResult[ScalingGroupNode]: + from .scaling_group import ScalingGroupNode + + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader_by_func( + graph_ctx, ScalingGroupNode.batch_load_by_domain + ) + + sgroups = await loader.load(self.name) + return ConnectionResolverResult(sgroups, None, None, None, total_count=len(sgroups)) + + @classmethod + async def get_node( + cls, + info: graphene.ResolveInfo, + id: str, + permission: DomainPermission = DomainPermission.READ_ATTRIBUTE, + ) -> Optional[Self]: + from ..domain import DomainModel + + graph_ctx: GraphQueryContext = info.context + _, domain_name = AsyncNode.resolve_global_id(info, id) + user = graph_ctx.user + client_ctx = ClientContext(graph_ctx.db, user["domain_name"], user["uuid"], user["role"]) + async with graph_ctx.db.begin_readonly_session() as db_session: + permission_ctx = await get_permission_ctx( + SystemScope(), permission, ctx=client_ctx, db_session=db_session + ) + cond = permission_ctx.query_condition + if cond is None: + return None + row = await db_session.scalar(sa.select(DomainRow).where(DomainRow.name == domain_name)) + permissions = await permission_ctx.calculate_final_permission(row) + + return cls.from_rbac_model(graph_ctx, DomainModel.from_row(row, permissions)) + + @classmethod + async def get_connection( + cls, + info: graphene.ResolveInfo, + scope: ScopeType, + permission: DomainPermission, + filter_expr: Optional[str] = None, + order_expr: Optional[str] = None, + offset: Optional[int] = None, + after: Optional[str] = None, + first: Optional[int] = None, + before: Optional[str] = None, + last: Optional[int] = None, + ) -> ConnectionResolverResult[Self]: + from ..domain import DomainModel + + graph_ctx: GraphQueryContext = info.context + _filter_arg = ( + FilterExprArg(filter_expr, QueryFilterParser(_queryfilter_fieldspec)) + if filter_expr is not None + else None + ) + _order_expr = ( + OrderExprArg(order_expr, QueryOrderParser(_queryorder_colmap)) + if order_expr is not None + else None + ) + ( + query, + cnt_query, + _, + cursor, + pagination_order, + page_size, + ) = generate_sql_info_for_gql_connection( + info, + DomainRow, + DomainRow.name, + _filter_arg, + _order_expr, + offset, + after=after, + first=first, + before=before, + last=last, + ) + user = graph_ctx.user + client_ctx = ClientContext(graph_ctx.db, user["domain_name"], user["uuid"], user["role"]) + result: list[Self] = [] + async with graph_ctx.db.begin_readonly_session() as db_session: + permission_ctx = await get_permission_ctx( + scope, + permission, + db_session=db_session, + ctx=client_ctx, + ) + cond = permission_ctx.query_condition + if cond is None: + return ConnectionResolverResult([], cursor, pagination_order, page_size, 0) + + query = query.where(cond) + cnt_query = cnt_query.where(cond) + total_cnt = await db_session.scalar(cnt_query) + async for row in await db_session.stream_scalars(query): + row = cast(DomainRow, row) + permissions = await permission_ctx.calculate_final_permission(row) + result.append( + cls.from_rbac_model(graph_ctx, DomainModel.from_row(row, permissions)) + ) + return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) + + +class DomainConnection(Connection): + class Meta: + node = DomainNode + description = "Added in 24.12.0" + + +async def _ensure_sgroup_permission( + graph_ctx: GraphQueryContext, sgroup_names: Iterable[str], *, db_session: AsyncSession +) -> None: + user = graph_ctx.user + client_ctx = ClientContext(graph_ctx.db, user["domain_name"], user["uuid"], user["role"]) + sgroup_models = await get_scaling_groups( + SystemScope(), + ScalingGroupPermission.ASSOCIATE_WITH_SCOPES, + sgroup_names, + db_session=db_session, + ctx=client_ctx, + ) + not_allowed_sgroups = set(sgroup_names) - set([sg.name for sg in sgroup_models]) + if not_allowed_sgroups: + raise ValueError( + f"Not allowed to associate the domain with given scaling groups(s:{not_allowed_sgroups})" + ) + + +class CreateDomainNodeInput(graphene.InputObjectType): + class Meta: + description = "Added in 24.12.0." + + name = graphene.String(required=True) + description = graphene.String(required=False) + is_active = graphene.Boolean(required=False, default_value=True) + total_resource_slots = graphene.JSONString(required=False, default_value={}) + allowed_vfolder_hosts = graphene.JSONString(required=False, default_value={}) + allowed_docker_registries = graphene.List( + lambda: graphene.String, required=False, default_value=[] + ) + integration_id = graphene.String(required=False, default_value=None) + dotfiles = Bytes(required=False, default_value=b"\x90") + + scaling_groups = graphene.List(lambda: graphene.String, required=False) + + +class CreateDomainNode(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN,) + + class Meta: + description = "Added in 24.12.0." + + class Arguments: + input = CreateDomainNodeInput(required=True) + + # Output fields + ok = graphene.Boolean() + msg = graphene.String() + item = graphene.Field(lambda: DomainNode, required=False) + + @classmethod + async def mutate( + cls, + root: Any, + info: graphene.ResolveInfo, + input: CreateDomainNodeInput, + ) -> CreateDomainNode: + graph_ctx: GraphQueryContext = info.context + + if (raw_scaling_groups := input.pop("scaling_groups")) is not None: + scaling_groups = cast(list[str], raw_scaling_groups) + else: + scaling_groups = None + + async def _insert(db_session: AsyncSession) -> DomainRow: + if scaling_groups is not None: + await _ensure_sgroup_permission(graph_ctx, scaling_groups, db_session=db_session) + _insert_and_returning = sa.select(DomainRow).from_statement( + sa.insert(DomainRow).values(**input).returning(DomainRow) + ) + domain_row = await db_session.scalar(_insert_and_returning) + if scaling_groups is not None: + await db_session.execute( + sa.insert(ScalingGroupForDomainRow), + [ + {"scaling_group": sgroup_name, "domain": input.name} + for sgroup_name in scaling_groups + ], + ) + return domain_row + + async with graph_ctx.db.connect() as db_conn: + try: + domain_row = await execute_with_txn_retry( + _insert, graph_ctx.db.begin_session, db_conn + ) + except sa.exc.IntegrityError as e: + raise ValueError( + f"Cannot create the domain with given arguments. (arg:{input}, e:{str(e)})" + ) + return CreateDomainNode(True, "", DomainNode.from_orm_model(graph_ctx, domain_row)) + + +class ModifyDomainNodeInput(graphene.InputObjectType): + class Meta: + description = "Added in 24.12.0." + + id = GlobalIDField(required=True) + description = graphene.String(required=False) + is_active = graphene.Boolean(required=False) + total_resource_slots = graphene.JSONString(required=False) + allowed_vfolder_hosts = graphene.JSONString(required=False) + allowed_docker_registries = graphene.List(lambda: graphene.String, required=False) + integration_id = graphene.String(required=False) + dotfiles = Bytes(required=False) + sgroups_to_add = graphene.List(lambda: graphene.String, required=False) + sgroups_to_remove = graphene.List(lambda: graphene.String, required=False) + client_mutation_id = graphene.String(required=False) + + +class ModifyDomainNode(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN, UserRole.ADMIN) + + class Meta: + description = "Added in 24.12.0." + + class Arguments: + input = ModifyDomainNodeInput(required=True) + + # Output fields + item = graphene.Field(DomainNode) + client_mutation_id = graphene.String() # Relay output + + @classmethod + async def mutate( + cls, + root: Any, + info: graphene.ResolveInfo, + input: ModifyDomainNodeInput, + ) -> ModifyDomainNode: + graph_ctx: GraphQueryContext = info.context + _, domain_name = cast(ResolvedGlobalID, input["id"]) + + if (raw_sgroups_to_add := input.pop("sgroups_to_add")) is not None: + sgroups_to_add = set(raw_sgroups_to_add) + else: + sgroups_to_add = None + if (raw_sgroups_to_remove := input.pop("sgroups_to_remove")) is not None: + sgroups_to_remove = set(raw_sgroups_to_remove) + else: + sgroups_to_remove = None + + if sgroups_to_add is not None and sgroups_to_remove is not None: + if union := sgroups_to_add | sgroups_to_remove: + raise ValueError( + "Should be no scaling group names included in both `sgroups_to_add` and `sgroups_to_remove` " + f"(sg:{union})." + ) + + async def _update(db_session: AsyncSession) -> Optional[DomainRow]: + user = graph_ctx.user + client_ctx = ClientContext( + graph_ctx.db, user["domain_name"], user["uuid"], user["role"] + ) + domain_models = await get_domains( + SystemScope(), + DomainPermission.UPDATE_ATTRIBUTE, + [domain_name], + ctx=client_ctx, + db_session=db_session, + ) + if not domain_models: + raise ValueError(f"Not allowed to update domain (id:{domain_name})") + + if sgroups_to_add is not None: + await _ensure_sgroup_permission(graph_ctx, sgroups_to_add, db_session=db_session) + await db_session.execute( + sa.insert(ScalingGroupForDomainRow), + [ + {"scaling_group": sgroup_name, "domain": domain_name} + for sgroup_name in sgroups_to_add + ], + ) + if sgroups_to_remove is not None: + await _ensure_sgroup_permission(graph_ctx, sgroups_to_remove, db_session=db_session) + await db_session.execute( + sa.delete(ScalingGroupForDomainRow).where( + (ScalingGroupForDomainRow.domain == domain_name) + & (ScalingGroupForDomainRow.scaling_group.in_(sgroups_to_remove)) + ), + ) + _update_stmt = ( + sa.update(DomainRow) + .where(DomainRow.name == domain_name) + .values(input) + .returning(DomainRow) + ) + _stmt = sa.select(DomainRow).from_statement(_update_stmt) + + return await db_session.scalar(_stmt) + + async with graph_ctx.db.connect() as db_conn: + try: + domain_row = await execute_with_txn_retry( + _update, graph_ctx.db.begin_session, db_conn + ) + except sa.exc.IntegrityError as e: + raise ValueError( + f"Cannot modify the domain with given arguments. (arg:{input}, e:{str(e)})" + ) + if domain_row is None: + raise ValueError(f"Domain not found (id:{domain_name})") + return ModifyDomainNode( + DomainNode.from_orm_model(graph_ctx, domain_row), + input.get("client_mutation_id"), + ) diff --git a/src/ai/backend/manager/models/gql_models/scaling_group.py b/src/ai/backend/manager/models/gql_models/scaling_group.py new file mode 100644 index 0000000000..eade789c76 --- /dev/null +++ b/src/ai/backend/manager/models/gql_models/scaling_group.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import uuid +from collections.abc import Sequence +from typing import ( + TYPE_CHECKING, + Self, +) + +import graphene +import sqlalchemy as sa +from graphene.types.datetime import DateTime as GQLDateTime + +from ai.backend.common.types import AccessKey + +from ..base import ( + batch_multiresult_in_scalar_stream, +) +from ..gql_relay import ( + AsyncNode, + Connection, +) +from ..scaling_group import ( + ScalingGroupForDomainRow, + ScalingGroupForKeypairsRow, + ScalingGroupForProjectRow, + ScalingGroupRow, +) + +if TYPE_CHECKING: + from ..gql import GraphQueryContext + + +class ScalingGroupNode(graphene.ObjectType): + class Meta: + interfaces = (AsyncNode,) + description = "Added in 24.12.0." + + name = graphene.String() + description = graphene.String() + is_active = graphene.Boolean() + is_public = graphene.Boolean() + created_at = GQLDateTime() + wsproxy_addr = graphene.String() + wsproxy_api_token = graphene.String() + driver = graphene.String() + driver_opts = graphene.JSONString() + scheduler = graphene.String() + scheduler_opts = graphene.JSONString() + use_host_network = graphene.Boolean() + + @classmethod + def from_row( + cls, + ctx: GraphQueryContext, + row: ScalingGroupRow, + ) -> Self: + return cls( + name=row.name, + description=row.description, + is_active=row.is_active, + is_public=row.is_public, + created_at=row.created_at, + wsproxy_addr=row.wsproxy_addr, + wsproxy_api_token=row.wsproxy_api_token, + driver=row.driver, + driver_opts=row.driver_opts, + scheduler=row.scheduler, + scheduler_opts=row.scheduler_opts, + use_host_network=row.use_host_network, + ) + + @classmethod + async def batch_load_by_group( + cls, + ctx: GraphQueryContext, + group_ids: Sequence[uuid.UUID], + ) -> Sequence[Sequence[ScalingGroupNode]]: + j = sa.join( + ScalingGroupRow, + ScalingGroupForProjectRow, + ScalingGroupRow.name == ScalingGroupForProjectRow.scaling_group, + ) + _stmt = ( + sa.select(ScalingGroupRow) + .select_from(j) + .where(ScalingGroupForProjectRow.group.in_(group_ids)) + ) + async with ctx.db.begin_readonly_session() as db_session: + return await batch_multiresult_in_scalar_stream( + ctx, + db_session, + _stmt, + cls, + group_ids, + lambda row: row.name, + ) + + @classmethod + async def batch_load_by_domain( + cls, + ctx: GraphQueryContext, + domain_names: Sequence[str], + ) -> Sequence[Sequence[ScalingGroupNode]]: + j = sa.join( + ScalingGroupRow, + ScalingGroupForDomainRow, + ScalingGroupRow.name == ScalingGroupForDomainRow.scaling_group, + ) + _stmt = ( + sa.select(ScalingGroupRow) + .select_from(j) + .where(ScalingGroupForDomainRow.domain.in_(domain_names)) + ) + async with ctx.db.begin_readonly_session() as db_session: + return await batch_multiresult_in_scalar_stream( + ctx, + db_session, + _stmt, + cls, + domain_names, + lambda row: row.name, + ) + + @classmethod + async def batch_load_by_keypair( + cls, + ctx: GraphQueryContext, + access_keys: Sequence[AccessKey], + ) -> Sequence[Sequence[ScalingGroupNode]]: + j = sa.join( + ScalingGroupRow, + ScalingGroupForKeypairsRow, + ScalingGroupRow.name == ScalingGroupForKeypairsRow.scaling_group, + ) + _stmt = ( + sa.select(ScalingGroupRow) + .select_from(j) + .where(ScalingGroupForKeypairsRow.access_key.in_(access_keys)) + ) + async with ctx.db.begin_readonly_session() as db_session: + return await batch_multiresult_in_scalar_stream( + ctx, + db_session, + _stmt, + cls, + access_keys, + lambda row: row.name, + ) + + +class ScalinGroupConnection(Connection): + class Meta: + node = ScalingGroupNode + description = "Added in 24.12.0." diff --git a/src/ai/backend/manager/models/rbac/__init__.py b/src/ai/backend/manager/models/rbac/__init__.py index f649b32e6d..337682b1b3 100644 --- a/src/ai/backend/manager/models/rbac/__init__.py +++ b/src/ai/backend/manager/models/rbac/__init__.py @@ -7,6 +7,8 @@ from dataclasses import dataclass, field from typing import Any, Callable, Generic, Optional, Self, TypeAlias, TypeVar, cast +import graphene +import graphql import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import load_only, selectinload, with_loader_criteria @@ -370,6 +372,44 @@ def deserialize(cls, val: str) -> Self: def deserialize_scope(val: str) -> ScopeType: + """ + Deserialize a string value in the format ':'. + + This function takes a string input representing a scope and deserializes it into `Scope` object + containing the scope type, scope ID (if applicable), and an optional additional scope ID. + + The input should adhere to one of the following formats: + 1. '' (for system scope) + 2. ':' + + Scope types and their corresponding ID formats: + - system: No ID (covers the whole system) + - domain: String ID + - project: UUID + - user: UUID + + Args: + value (str): The input string to deserialize, in one of the specified formats. + + Returns: + One of [SystemScope, DomainScope, ProjectScope, UserScope] object. + + Raises: + rbac.exceptions.InvalidScope: + If the input string does not conform to the expected formats or if the scope type is invalid. + ValueError: + If the scope ID format doesn't match the expected type for the given scope. + + Examples: + >>> deserialize_scope("system") + SystemScope() + >>> deserialize_scope("domain:default") + DomainScope("default") + >>> deserialize_scope("project:123e4567-e89b-12d3-a456-426614174000") + ProjectScope(UUID('123e4567-e89b-12d3-a456-426614174000')) + >>> deserialize_scope("user:123e4567-e89b-12d3-a456-426614174000") + UserScope(UUID('123e4567-e89b-12d3-a456-426614174000')) + """ for scope in (SystemScope, DomainScope, ProjectScope, UserScope): try: return scope.deserialize(val) @@ -379,6 +419,30 @@ def deserialize_scope(val: str) -> ScopeType: raise InvalidScope(f"Invalid scope (s: {scope})") +class ScopeField(graphene.Scalar): + class Meta: + description = ( + "Added in 24.12.0. A string value in the format ':'. " + " should be one of [system, domain, project, user]. " + " should be the ID value of the scope. " + "e.g. `domain:default`, `user:123e4567-e89b-12d3-a456-426614174000`." + ) + + @staticmethod + def serialize(val: ScopeType) -> str: + return val.serialize() + + @staticmethod + def parse_literal(node: Any, _variables=None) -> Optional[ScopeType]: + if isinstance(node, graphql.language.ast.StringValueNode): + return deserialize_scope(node.value) + return None + + @staticmethod + def parse_value(value: str) -> ScopeType: + return deserialize_scope(value) + + # Extra scope is to address some scopes that contain specific object types # such as registries for images, scaling groups for agents, storage hosts for vfolders etc. class ExtraScope: diff --git a/src/ai/backend/manager/models/scaling_group.py b/src/ai/backend/manager/models/scaling_group.py index 2aecba2156..1910d6a310 100644 --- a/src/ai/backend/manager/models/scaling_group.py +++ b/src/ai/backend/manager/models/scaling_group.py @@ -1,17 +1,17 @@ from __future__ import annotations import uuid -from dataclasses import dataclass -from datetime import timedelta +from collections.abc import Container, Iterable, Mapping, Sequence +from dataclasses import dataclass, field +from datetime import datetime, timedelta from typing import ( TYPE_CHECKING, Any, Dict, - Iterable, - Mapping, Optional, - Sequence, + Self, Set, + TypeAlias, cast, overload, override, @@ -55,6 +55,7 @@ AbstractPermissionContextBuilder, DomainScope, ProjectScope, + RBACModel, ScopeType, UserScope, get_predefined_roles_in_scope, @@ -282,6 +283,49 @@ class ScalingGroupRow(Base): scaling_groups = ScalingGroupRow.__table__ +@dataclass +class ScalingGroupModel(RBACModel[ScalingGroupPermission]): + name: str + description: Optional[str] + is_active: bool + is_public: bool + created_at: datetime + + wsproxy_addr: Optional[str] + wsproxy_api_token: Optional[str] + driver: str + driver_opts: dict + scheduler: str + use_host_network: bool + scheduler_opts: ScalingGroupOpts + + orm_obj: ScalingGroupRow + _permissions: frozenset[ScalingGroupPermission] = field(default_factory=frozenset) + + @property + def permissions(self) -> Container[ScalingGroupPermission]: + return self._permissions + + @classmethod + def from_row(cls, row: ScalingGroupRow, permissions: Iterable[ScalingGroupPermission]) -> Self: + return cls( + name=row.name, + description=row.description, + is_active=row.is_active, + is_public=row.is_public, + created_at=row.created_at, + wsproxy_addr=row.wsproxy_addr, + wsproxy_api_token=row.wsproxy_api_token, + driver=row.driver, + driver_opts=row.driver_opts, + scheduler=row.scheduler, + use_host_network=row.use_host_network, + scheduler_opts=row.scheduler_opts, + _permissions=frozenset(permissions), + orm_obj=row, + ) + + @overload async def query_allowed_sgroups( db_conn: SAConnection, @@ -1111,6 +1155,10 @@ async def mutate( ScalingGroupToPermissionMap = Mapping[str, frozenset[ScalingGroupPermission]] +WhereClauseType: TypeAlias = ( + sa.sql.expression.BinaryExpression | sa.sql.expression.BooleanClauseList +) + @dataclass class ScalingGroupPermissionContext(AbstractPermissionContext[ScalingGroupPermission, str, str]): @@ -1118,8 +1166,31 @@ class ScalingGroupPermissionContext(AbstractPermissionContext[ScalingGroupPermis def sgroup_to_permissions_map(self) -> ScalingGroupToPermissionMap: return self.object_id_to_additional_permission_map - async def build_query(self) -> sa.sql.Select | None: - return None + @property + def query_condition(self) -> Optional[WhereClauseType]: + cond: Optional[WhereClauseType] = None + + def _OR_coalesce( + base_cond: Optional[WhereClauseType], + _cond: sa.sql.expression.BinaryExpression, + ) -> WhereClauseType: + return base_cond | _cond if base_cond is not None else _cond + + if self.object_id_to_additional_permission_map: + cond = _OR_coalesce( + cond, ScalingGroupRow.name.in_(self.object_id_to_additional_permission_map.keys()) + ) + if self.object_id_to_overriding_permission_map: + cond = _OR_coalesce( + cond, ScalingGroupRow.name.in_(self.object_id_to_overriding_permission_map.keys()) + ) + return cond + + async def build_query(self) -> Optional[sa.sql.Select]: + cond = self.query_condition + if cond is None: + return None + return sa.select(ScalingGroupRow).where(cond) async def calculate_final_permission(self, rbac_obj: str) -> frozenset[ScalingGroupPermission]: host_name = rbac_obj @@ -1300,3 +1371,26 @@ async def _permission_for_member( cls, ) -> frozenset[ScalingGroupPermission]: return MEMBER_PERMISSIONS + + +async def get_scaling_groups( + target_scope: ScopeType, + requested_permission: ScalingGroupPermission, + sgroup_names: Optional[Iterable[str]] = None, + *, + ctx: ClientContext, + db_session: SASession, +) -> list[ScalingGroupModel]: + ret: list[ScalingGroupModel] = [] + builder = ScalingGroupPermissionContextBuilder(db_session) + permission_ctx = await builder.build(ctx, target_scope, requested_permission) + cond = permission_ctx.query_condition + if cond is None: + return ret + _stmt = sa.select(ScalingGroupRow).where(cond) + if sgroup_names is not None: + _stmt = _stmt.where(ScalingGroupRow.name.in_(sgroup_names)) + async for row in await db_session.stream_scalars(_stmt): + permissions = await permission_ctx.calculate_final_permission(row) + ret.append(ScalingGroupModel.from_row(row, permissions)) + return ret