From 54945cf0fbde2eff00c89852628014a099b8b389 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Tue, 22 Oct 2024 19:14:56 +0900 Subject: [PATCH] chore: Relocate GQL schema codes related to Image (#2938) Backported-from: main (24.12) Backported-to: 24.09 Backport-of: 2938 --- src/ai/backend/manager/models/endpoint.py | 4 +- src/ai/backend/manager/models/gql.py | 26 +- .../backend/manager/models/gql_models/base.py | 68 ++ .../manager/models/gql_models/image.py | 823 ++++++++++++++++++ .../manager/models/gql_models/kernel.py | 2 +- src/ai/backend/manager/models/image.py | 794 +---------------- src/ai/backend/manager/models/kernel.py | 3 +- 7 files changed, 915 insertions(+), 805 deletions(-) create mode 100644 src/ai/backend/manager/models/gql_models/base.py create mode 100644 src/ai/backend/manager/models/gql_models/image.py diff --git a/src/ai/backend/manager/models/endpoint.py b/src/ai/backend/manager/models/endpoint.py index 9ebccb1dd6..2bfd3bd021 100644 --- a/src/ai/backend/manager/models/endpoint.py +++ b/src/ai/backend/manager/models/endpoint.py @@ -62,8 +62,10 @@ URLColumn, gql_mutation_wrapper, ) +from .gql_models.base import ImageRefType +from .gql_models.image import ImageNode from .gql_models.vfolder import VirtualFolderNode -from .image import ImageIdentifier, ImageNode, ImageRefType, ImageRow +from .image import ImageIdentifier, ImageRow from .minilang import EnumFieldItem from .minilang.ordering import OrderSpecItem, QueryOrderParser from .minilang.queryfilter import FieldSpecItem, QueryFilterParser diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 1f00aba97d..0c6eaf225b 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -66,6 +66,20 @@ from .domain import CreateDomain, DeleteDomain, Domain, ModifyDomain, PurgeDomain from .endpoint import Endpoint, EndpointList, EndpointToken, EndpointTokenList, ModifyEndpoint from .gql_models.group import GroupConnection, GroupNode +from .gql_models.image import ( + AliasImage, + ClearImages, + DealiasImage, + ForgetImage, + ForgetImageById, + Image, + ImageNode, + ModifyImage, + PreloadImage, + RescanImages, + UnloadImage, + UntagImageFromRegistry, +) from .gql_models.session import ( ComputeSessionConnection, ComputeSessionNode, @@ -89,20 +103,8 @@ PurgeGroup, ) from .image import ( - AliasImage, - ClearImages, - DealiasImage, - ForgetImage, - ForgetImageById, - Image, ImageLoadFilter, - ImageNode, - ModifyImage, - PreloadImage, PublicImageLoadFilter, - RescanImages, - UnloadImage, - UntagImageFromRegistry, ) from .kernel import ( ComputeContainer, diff --git a/src/ai/backend/manager/models/gql_models/base.py b/src/ai/backend/manager/models/gql_models/base.py new file mode 100644 index 0000000000..6ed427ae79 --- /dev/null +++ b/src/ai/backend/manager/models/gql_models/base.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import graphene +from graphene.types import Scalar +from graphene.types.scalars import MAX_INT, MIN_INT +from graphql.language.ast import IntValueNode + +SAFE_MIN_INT = -9007199254740991 +SAFE_MAX_INT = 9007199254740991 + + +class ResourceLimit(graphene.ObjectType): + key = graphene.String() + min = graphene.String() + max = graphene.String() + + +class KVPair(graphene.ObjectType): + key = graphene.String() + value = graphene.String() + + +class ResourceLimitInput(graphene.InputObjectType): + key = graphene.String() + min = graphene.String() + max = graphene.String() + + +class KVPairInput(graphene.InputObjectType): + key = graphene.String() + value = graphene.String() + + +class BigInt(Scalar): + """ + BigInt is an extension of the regular graphene.Int scalar type + to support integers outside the range of a signed 32-bit integer. + """ + + @staticmethod + def coerce_bigint(value): + num = int(value) + if not (SAFE_MIN_INT <= num <= SAFE_MAX_INT): + raise ValueError("Cannot serialize integer out of the safe range.") + if not (MIN_INT <= num <= MAX_INT): + # treat as float + return float(int(num)) + return num + + serialize = coerce_bigint + parse_value = coerce_bigint + + @staticmethod + def parse_literal(node): + if isinstance(node, IntValueNode): + num = int(node.value) + if not (SAFE_MIN_INT <= num <= SAFE_MAX_INT): + raise ValueError("Cannot parse integer out of the safe range.") + if not (MIN_INT <= num <= MAX_INT): + # treat as float + return float(int(num)) + return num + + +class ImageRefType(graphene.InputObjectType): + name = graphene.String(required=True) + registry = graphene.String() + architecture = graphene.String() diff --git a/src/ai/backend/manager/models/gql_models/image.py b/src/ai/backend/manager/models/gql_models/image.py new file mode 100644 index 0000000000..7dcd9d7b9a --- /dev/null +++ b/src/ai/backend/manager/models/gql_models/image.py @@ -0,0 +1,823 @@ +from __future__ import annotations + +import logging +from collections.abc import MutableMapping, Sequence +from decimal import Decimal +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + List, + Optional, + overload, +) +from uuid import UUID + +import graphene +import sqlalchemy as sa +from graphql import Undefined +from redis.asyncio import Redis +from redis.asyncio.client import Pipeline +from sqlalchemy.orm import load_only, selectinload + +from ai.backend.common import redis_helper +from ai.backend.common.docker import ImageRef +from ai.backend.common.exception import UnknownImageReference +from ai.backend.common.types import ( + ImageAlias, +) +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType + +from ...api.exceptions import ImageNotFound, ObjectNotFound +from ...defs import DEFAULT_IMAGE_ARCH +from ..base import set_if_set +from ..gql_relay import AsyncNode +from ..image import ( + ImageAliasRow, + ImageIdentifier, + ImageLoadFilter, + ImageRow, + rescan_images, +) +from ..user import UserRole +from .base import ( + BigInt, + KVPair, + KVPairInput, + ResourceLimit, + ResourceLimitInput, +) + +if TYPE_CHECKING: + from ai.backend.common.bgtask import ProgressReporter + + from ..gql import GraphQueryContext + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + +__all__ = ( + "Image", + "ImageNode", + "PreloadImage", + "RescanImages", + "ForgetImage", + "ForgetImageById", + "UntagImageFromRegistry", + "ModifyImage", + "AliasImage", + "DealiasImage", + "ClearImages", +) + + +class Image(graphene.ObjectType): + id = graphene.UUID() + name = graphene.String() + project = graphene.String(description="Added in 24.03.10.") + humanized_name = graphene.String() + tag = graphene.String() + registry = graphene.String() + architecture = graphene.String() + is_local = graphene.Boolean() + digest = graphene.String() + labels = graphene.List(KVPair) + aliases = graphene.List(graphene.String) + size_bytes = BigInt() + resource_limits = graphene.List(ResourceLimit) + supported_accelerators = graphene.List(graphene.String) + installed = graphene.Boolean() + installed_agents = graphene.List(graphene.String) + # legacy field + hash = graphene.String() + + # internal attributes + raw_labels: dict[str, Any] + + @classmethod + def populate_row( + cls, + ctx: GraphQueryContext, + row: ImageRow, + installed_agents: List[str], + ) -> Image: + is_superadmin = ctx.user["role"] == UserRole.SUPERADMIN + hide_agents = False if is_superadmin else ctx.local_config["manager"]["hide-agents"] + ret = cls( + id=row.id, + name=row.image, + project=row.project, + humanized_name=row.image, + tag=row.tag, + registry=row.registry, + architecture=row.architecture, + is_local=row.is_local, + digest=row.trimmed_digest or None, + labels=[KVPair(key=k, value=v) for k, v in row.labels.items()], + aliases=[alias_row.alias for alias_row in row.aliases], + size_bytes=row.size_bytes, + resource_limits=[ + ResourceLimit( + key=k, + min=v.get("min", Decimal(0)), + max=v.get("max", Decimal("Infinity")), + ) + for k, v in row.resources.items() + ], + supported_accelerators=(row.accelerators or "").split(","), + installed=len(installed_agents) > 0, + installed_agents=installed_agents if not hide_agents else None, + # legacy + hash=row.trimmed_digest or None, + ) + ret.raw_labels = row.labels + return ret + + @classmethod + async def from_row( + cls, + ctx: GraphQueryContext, + row: ImageRow, + ) -> Image: + # TODO: add architecture + _installed_agents = await redis_helper.execute( + ctx.redis_image, + lambda r: r.smembers(row.name), + ) + installed_agents: List[str] = [] + for agent_id in _installed_agents: + if isinstance(agent_id, bytes): + installed_agents.append(agent_id.decode()) + else: + installed_agents.append(agent_id) + return cls.populate_row(ctx, row, installed_agents) + + @classmethod + async def bulk_load( + cls, + ctx: GraphQueryContext, + rows: List[ImageRow], + ) -> AsyncIterator[Image]: + async def _pipe(r: Redis) -> Pipeline: + pipe = r.pipeline() + for row in rows: + await pipe.smembers(row.name) + return pipe + + results = await redis_helper.execute(ctx.redis_image, _pipe) + for idx, row in enumerate(rows): + installed_agents: List[str] = [] + _installed_agents = results[idx] + for agent_id in _installed_agents: + if isinstance(agent_id, bytes): + installed_agents.append(agent_id.decode()) + else: + installed_agents.append(agent_id) + yield cls.populate_row(ctx, row, installed_agents) + + @classmethod + async def batch_load_by_canonical( + cls, + graph_ctx: GraphQueryContext, + image_names: Sequence[str], + ) -> Sequence[Optional[Image]]: + query = ( + sa.select(ImageRow) + .where(ImageRow.name.in_(image_names)) + .options(selectinload(ImageRow.aliases)) + ) + async with graph_ctx.db.begin_readonly_session() as session: + result = await session.execute(query) + return [await Image.from_row(graph_ctx, row) for row in result.scalars().all()] + + @classmethod + async def batch_load_by_image_ref( + cls, + graph_ctx: GraphQueryContext, + image_refs: Sequence[ImageRef], + ) -> Sequence[Optional[Image]]: + image_names = [x.canonical for x in image_refs] + return await cls.batch_load_by_canonical(graph_ctx, image_names) + + @classmethod + async def load_item_by_id( + cls, + ctx: GraphQueryContext, + id: UUID, + ) -> Image: + async with ctx.db.begin_readonly_session() as session: + row = await ImageRow.get(session, id, load_aliases=True) + if not row: + raise ImageNotFound + + return await cls.from_row(ctx, row) + + @classmethod + async def load_item( + cls, + ctx: GraphQueryContext, + reference: str, + architecture: str, + ) -> Image: + try: + async with ctx.db.begin_readonly_session() as session: + image_row = await ImageRow.resolve( + session, + [ + ImageIdentifier(reference, architecture), + ImageAlias(reference), + ], + ) + except UnknownImageReference: + raise ImageNotFound + return await cls.from_row(ctx, image_row) + + @classmethod + async def load_all( + cls, + ctx: GraphQueryContext, + *, + types: set[ImageLoadFilter] = set(), + ) -> Sequence[Image]: + async with ctx.db.begin_readonly_session() as session: + rows = await ImageRow.list(session, load_aliases=True) + items: list[Image] = [ + item async for item in cls.bulk_load(ctx, rows) if item.matches_filter(ctx, types) + ] + + return items + + @staticmethod + async def filter_allowed( + ctx: GraphQueryContext, + items: Sequence[Image], + domain_name: str, + ) -> Sequence[Image]: + from ..domain import domains + + async with ctx.db.begin() as conn: + query = ( + sa.select([domains.c.allowed_docker_registries]) + .select_from(domains) + .where(domains.c.name == domain_name) + ) + result = await conn.execute(query) + allowed_docker_registries = result.scalar() + + filtered_items: list[Image] = [ + item for item in items if item.registry in allowed_docker_registries + ] + + return filtered_items + + def matches_filter( + self, + ctx: GraphQueryContext, + load_filters: set[ImageLoadFilter], + ) -> bool: + """ + Determine if the image is filtered according to the `load_filters` parameter. + """ + user_role = ctx.user["role"] + + # If the image filtered by any of its labels, return False early. + # If the image is not filtered and is determiend to be valid by any of its labels, `is_valid = True`. + is_valid = ImageLoadFilter.GENERAL in load_filters + for label in self.labels: + match label.key: + case "ai.backend.features" if "operation" in label.value: + if ImageLoadFilter.OPERATIONAL in load_filters: + is_valid = True + else: + return False + case "ai.backend.customized-image.owner": + if ( + ImageLoadFilter.CUSTOMIZED not in load_filters + and ImageLoadFilter.CUSTOMIZED_GLOBAL not in load_filters + ): + return False + if ImageLoadFilter.CUSTOMIZED in load_filters: + if label.value == f"user:{ctx.user["uuid"]}": + is_valid = True + else: + return False + if ImageLoadFilter.CUSTOMIZED_GLOBAL in load_filters: + if user_role == UserRole.SUPERADMIN: + is_valid = True + else: + return False + return is_valid + + +class ImageNode(graphene.ObjectType): + class Meta: + interfaces = (AsyncNode,) + + row_id = graphene.UUID(description="Added in 24.03.4. The undecoded id value stored in DB.") + name = graphene.String() + project = graphene.String(description="Added in 24.03.10.") + humanized_name = graphene.String() + tag = graphene.String() + registry = graphene.String() + architecture = graphene.String() + is_local = graphene.Boolean() + digest = graphene.String() + labels = graphene.List(KVPair) + size_bytes = BigInt() + resource_limits = graphene.List(ResourceLimit) + supported_accelerators = graphene.List(graphene.String) + aliases = graphene.List( + graphene.String, description="Added in 24.03.4. The array of image aliases." + ) + + @overload + @classmethod + def from_row(cls, row: ImageRow) -> ImageNode: ... + + @overload + @classmethod + def from_row(cls, row: None) -> None: ... + + @classmethod + def from_row(cls, row: ImageRow | None) -> ImageNode | None: + if row is None: + return None + return cls( + id=row.id, + row_id=row.id, + name=row.image, + project=row.project, + humanized_name=row.image, + tag=row.tag, + registry=row.registry, + architecture=row.architecture, + is_local=row.is_local, + digest=row.trimmed_digest or None, + labels=[KVPair(key=k, value=v) for k, v in row.labels.items()], + size_bytes=row.size_bytes, + resource_limits=[ + ResourceLimit( + key=k, + min=v.get("min", Decimal(0)), + max=v.get("max", Decimal("Infinity")), + ) + for k, v in row.resources.items() + ], + supported_accelerators=(row.accelerators or "").split(","), + aliases=[alias_row.alias for alias_row in row.aliases], + ) + + @classmethod + def from_legacy_image(cls, row: Image) -> ImageNode: + return cls( + id=row.id, + row_id=row.id, + name=row.name, + humanized_name=row.humanized_name, + tag=row.tag, + project=row.project, + registry=row.registry, + architecture=row.architecture, + is_local=row.is_local, + digest=row.trimmed_digest, + labels=row.labels, + size_bytes=row.size_bytes, + resource_limits=row.resource_limits, + supported_accelerators=row.supported_accelerators, + aliases=row.aliases, + ) + + @classmethod + async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ImageNode: + graph_ctx: GraphQueryContext = info.context + + _, image_id = AsyncNode.resolve_global_id(info, id) + query = ( + sa.select(ImageRow) + .where(ImageRow.id == image_id) + .options(selectinload(ImageRow.aliases).options(load_only(ImageAliasRow.alias))) + ) + async with graph_ctx.db.begin_readonly_session() as db_session: + image_row = await db_session.scalar(query) + if image_row is None: + raise ValueError(f"Image not found (id: {image_id})") + return cls.from_row(image_row) + + +class ForgetImageById(graphene.Mutation): + """Added in 24.03.0.""" + + allowed_roles = ( + UserRole.SUPERADMIN, + UserRole.ADMIN, + UserRole.USER, + ) + + class Arguments: + image_id = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + image = graphene.Field(ImageNode, description="Added since 24.03.1.") + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + image_id: str, + ) -> ForgetImageById: + _, raw_image_id = AsyncNode.resolve_global_id(info, image_id) + if not raw_image_id: + raw_image_id = image_id + + try: + _image_id = UUID(raw_image_id) + except ValueError: + raise ObjectNotFound("image") + + log.info("forget image {0} by API request", image_id) + ctx: GraphQueryContext = info.context + client_role = ctx.user["role"] + + async with ctx.db.begin_session() as session: + image_row = await ImageRow.get(session, _image_id, load_aliases=True) + if not image_row: + raise ObjectNotFound("image") + if client_role != UserRole.SUPERADMIN: + customized_image_owner = (image_row.labels or {}).get( + "ai.backend.customized-image.owner" + ) + if ( + not customized_image_owner + or customized_image_owner != f"user:{ctx.user["uuid"]}" + ): + return ForgetImageById(ok=False, msg="Forbidden") + await session.delete(image_row) + return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(image_row)) + + +class ForgetImage(graphene.Mutation): + allowed_roles = ( + UserRole.SUPERADMIN, + UserRole.ADMIN, + UserRole.USER, + ) + + class Arguments: + reference = graphene.String(required=True) + architecture = graphene.String(default_value=DEFAULT_IMAGE_ARCH) + + ok = graphene.Boolean() + msg = graphene.String() + image = graphene.Field(ImageNode, description="Added since 24.03.1.") + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + reference: str, + architecture: str, + ) -> ForgetImage: + log.info("forget image {0} by API request", reference) + ctx: GraphQueryContext = info.context + client_role = ctx.user["role"] + + async with ctx.db.begin_session() as session: + image_row = await ImageRow.resolve( + session, + [ + ImageIdentifier(reference, architecture), + ImageAlias(reference), + ], + ) + if client_role != UserRole.SUPERADMIN: + customized_image_owner = (image_row.labels or {}).get( + "ai.backend.customized-image.owner" + ) + if ( + not customized_image_owner + or customized_image_owner != f"user:{ctx.user["uuid"]}" + ): + return ForgetImage(ok=False, msg="Forbidden") + await session.delete(image_row) + return ForgetImage(ok=True, msg="", image=ImageNode.from_row(image_row)) + + +class UntagImageFromRegistry(graphene.Mutation): + """Added in 24.03.1""" + + allowed_roles = ( + UserRole.SUPERADMIN, + UserRole.ADMIN, + UserRole.USER, + ) + + class Arguments: + image_id = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + image = graphene.Field(ImageNode, description="Added since 24.03.1.") + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + image_id: str, + ) -> UntagImageFromRegistry: + from ai.backend.manager.container_registry.harbor import HarborRegistry_v2 + + _, raw_image_id = AsyncNode.resolve_global_id(info, image_id) + if not raw_image_id: + raw_image_id = image_id + + try: + _image_id = UUID(raw_image_id) + except ValueError: + raise ObjectNotFound("image") + + log.info("remove image from registry {0} by API request", str(_image_id)) + ctx: GraphQueryContext = info.context + client_role = ctx.user["role"] + + async with ctx.db.begin_readonly_session() as session: + image_row = await ImageRow.get(session, _image_id, load_aliases=True) + if not image_row: + raise ImageNotFound + if client_role != UserRole.SUPERADMIN: + customized_image_owner = (image_row.labels or {}).get( + "ai.backend.customized-image.owner" + ) + if ( + not customized_image_owner + or customized_image_owner != f"user:{ctx.user["uuid"]}" + ): + return UntagImageFromRegistry(ok=False, msg="Forbidden") + + query = sa.select(ContainerRegistryRow).where( + ContainerRegistryRow.registry_name == image_row.image_ref.registry + ) + + registry_info = (await session.execute(query)).scalar() + + if registry_info.type != ContainerRegistryType.HARBOR2: + raise NotImplementedError("This feature is only supported for Harbor 2 registries") + + scanner = HarborRegistry_v2(ctx.db, image_row.image_ref.registry, registry_info) + await scanner.untag(image_row.image_ref) + + return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(image_row)) + + +class PreloadImage(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + references = graphene.List(graphene.String, required=True) + target_agents = graphene.List(graphene.String, required=True) + + ok = graphene.Boolean() + msg = graphene.String() + task_id = graphene.String() + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + references: Sequence[str], + target_agents: Sequence[str], + ) -> PreloadImage: + return PreloadImage(ok=False, msg="Not implemented.", task_id=None) + + +class UnloadImage(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + references = graphene.List(graphene.String, required=True) + target_agents = graphene.List(graphene.String, required=True) + + ok = graphene.Boolean() + msg = graphene.String() + task_id = graphene.String() + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + references: Sequence[str], + target_agents: Sequence[str], + ) -> UnloadImage: + return UnloadImage(ok=False, msg="Not implemented.", task_id=None) + + +class RescanImages(graphene.Mutation): + allowed_roles = (UserRole.ADMIN, UserRole.SUPERADMIN) + + class Arguments: + registry = graphene.String() + + ok = graphene.Boolean() + msg = graphene.String() + task_id = graphene.UUID() + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + registry: Optional[str] = None, + ) -> RescanImages: + log.info( + "rescanning docker registry {0} by API request", + f"({registry})" if registry else "(all)", + ) + ctx: GraphQueryContext = info.context + + async def _rescan_task(reporter: ProgressReporter) -> None: + await rescan_images(ctx.db, registry, reporter=reporter) + + task_id = await ctx.background_task_manager.start(_rescan_task) + return RescanImages(ok=True, msg="", task_id=task_id) + + +class AliasImage(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + alias = graphene.String(required=True) + target = graphene.String(required=True) + architecture = graphene.String(default_value=DEFAULT_IMAGE_ARCH) + + ok = graphene.Boolean() + msg = graphene.String() + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + alias: str, + target: str, + architecture: str, + ) -> AliasImage: + log.info("alias image {0} -> {1} by API request", alias, target) + ctx: GraphQueryContext = info.context + try: + async with ctx.db.begin_session() as session: + try: + image_row = await ImageRow.resolve( + session, [ImageIdentifier(target, architecture)] + ) + except UnknownImageReference: + raise ImageNotFound + else: + image_row.aliases.append(ImageAliasRow(alias=alias, image_id=image_row.id)) + except ValueError as e: + return AliasImage(ok=False, msg=str(e)) + return AliasImage(ok=True, msg="") + + +class DealiasImage(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + alias = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + alias: str, + ) -> DealiasImage: + log.info("dealias image {0} by API request", alias) + ctx: GraphQueryContext = info.context + try: + async with ctx.db.begin_session() as session: + existing_alias = await session.scalar( + sa.select(ImageAliasRow).where(ImageAliasRow.alias == alias), + ) + if existing_alias is None: + raise DealiasImage(ok=False, msg=str("No such alias")) + await session.delete(existing_alias) + except ValueError as e: + return DealiasImage(ok=False, msg=str(e)) + return DealiasImage(ok=True, msg="") + + +class ClearImages(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + registry = graphene.String() + + ok = graphene.Boolean() + msg = graphene.String() + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + registry: str, + ) -> ClearImages: + ctx: GraphQueryContext = info.context + try: + async with ctx.db.begin_session() as session: + result = await session.execute( + sa.select(ImageRow).where(ImageRow.registry == registry) + ) + image_ids = [x.id for x in result.scalars().all()] + + await session.execute( + sa.delete(ImageAliasRow).where(ImageAliasRow.image_id.in_(image_ids)) + ) + await session.execute(sa.delete(ImageRow).where(ImageRow.registry == registry)) + except ValueError as e: + return ClearImages(ok=False, msg=str(e)) + return ClearImages(ok=True, msg="") + + +class ModifyImageInput(graphene.InputObjectType): + name = graphene.String(required=False) + registry = graphene.String(required=False) + image = graphene.String(required=False) + tag = graphene.String(required=False) + architecture = graphene.String(required=False) + is_local = graphene.Boolean(required=False) + size_bytes = graphene.Int(required=False) + type = graphene.String(required=False) + + digest = graphene.String(required=False) + labels = graphene.List(lambda: KVPairInput, required=False) + supported_accelerators = graphene.List(graphene.String, required=False) + resource_limits = graphene.List(lambda: ResourceLimitInput, required=False) + + +class ModifyImage(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + target = graphene.String(required=True, default_value=None) + architecture = graphene.String(required=False, default_value=DEFAULT_IMAGE_ARCH) + props = ModifyImageInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @staticmethod + async def mutate( + root: Any, + info: graphene.ResolveInfo, + target: str, + architecture: str, + props: ModifyImageInput, + ) -> AliasImage: + ctx: GraphQueryContext = info.context + data: MutableMapping[str, Any] = {} + set_if_set(props, data, "name") + set_if_set(props, data, "registry") + set_if_set(props, data, "image") + set_if_set(props, data, "tag") + set_if_set(props, data, "architecture") + set_if_set(props, data, "is_local") + set_if_set(props, data, "size_bytes") + set_if_set(props, data, "type") + set_if_set(props, data, "digest", target_key="config_digest") + set_if_set( + props, + data, + "supported_accelerators", + clean_func=lambda v: ",".join(v), + target_key="accelerators", + ) + set_if_set(props, data, "labels", clean_func=lambda v: {pair.key: pair.value for pair in v}) + + if props.resource_limits is not Undefined: + resources_data = {} + for limit_option in props.resource_limits: + limit_data = {} + if limit_option.min is not Undefined and len(limit_option.min) > 0: + limit_data["min"] = limit_option.min + if limit_option.max is not Undefined and len(limit_option.max) > 0: + limit_data["max"] = limit_option.max + resources_data[limit_option.key] = limit_data + data["resources"] = resources_data + + try: + async with ctx.db.begin_session() as session: + try: + image_row = await ImageRow.resolve( + session, + [ + ImageIdentifier(target, architecture), + ImageAlias(target), + ], + ) + except UnknownImageReference: + return ModifyImage(ok=False, msg="Image not found") + for k, v in data.items(): + setattr(image_row, k, v) + except ValueError as e: + return ModifyImage(ok=False, msg=str(e)) + return ModifyImage(ok=True, msg="") diff --git a/src/ai/backend/manager/models/gql_models/kernel.py b/src/ai/backend/manager/models/gql_models/kernel.py index 8ac9026b01..f782af8058 100644 --- a/src/ai/backend/manager/models/gql_models/kernel.py +++ b/src/ai/backend/manager/models/gql_models/kernel.py @@ -17,9 +17,9 @@ from ai.backend.manager.models.base import batch_multiresult_in_session from ..gql_relay import AsyncNode, Connection -from ..image import ImageNode from ..kernel import KernelRow, KernelStatus from ..user import UserRole +from .image import ImageNode if TYPE_CHECKING: from ..gql import GraphQueryContext diff --git a/src/ai/backend/manager/models/image.py b/src/ai/backend/manager/models/image.py index 0afbd98416..fe8982e547 100644 --- a/src/ai/backend/manager/models/image.py +++ b/src/ai/backend/manager/models/image.py @@ -3,32 +3,25 @@ import enum import functools import logging -from collections.abc import Iterable, Mapping, MutableMapping, Sequence +from collections.abc import Iterable, Mapping from decimal import Decimal from typing import ( TYPE_CHECKING, Any, - AsyncIterator, List, NamedTuple, Optional, Tuple, cast, - overload, ) from uuid import UUID import aiotools -import graphene import sqlalchemy as sa import trafaret as t -from graphql import Undefined -from redis.asyncio import Redis -from redis.asyncio.client import Pipeline from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession -from sqlalchemy.orm import load_only, relationship, selectinload +from sqlalchemy.orm import relationship, selectinload -from ai.backend.common import redis_helper from ai.backend.common.docker import ImageRef from ai.backend.common.etcd import AsyncEtcd from ai.backend.common.exception import UnknownImageReference @@ -41,33 +34,23 @@ ResourceSlot, ) from ai.backend.logging import BraceStyleAdapter -from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType +from ai.backend.manager.models.container_registry import ContainerRegistryRow -from ..api.exceptions import ImageNotFound, ObjectNotFound +from ..api.exceptions import ImageNotFound from ..container_registry import get_container_registry_cls -from ..defs import DEFAULT_IMAGE_ARCH from .base import ( GUID, Base, - BigInt, ForeignKeyIDColumn, IDColumn, - KVPair, - KVPairInput, - ResourceLimit, - ResourceLimitInput, StructuredJSONColumn, - set_if_set, ) -from .gql_relay import AsyncNode -from .user import UserRole from .utils import ExtendedAsyncSAEngine if TYPE_CHECKING: from ai.backend.common.bgtask import ProgressReporter from ..config import SharedConfig - from .gql import GraphQueryContext log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -78,19 +61,8 @@ "ImageAliasRow", "ImageLoadFilter", "ImageRow", - "Image", - "ImageNode", "ImageIdentifier", - "PreloadImage", "PublicImageLoadFilter", - "RescanImages", - "ForgetImage", - "ForgetImageById", - "UntagImageFromRegistry", - "ModifyImage", - "AliasImage", - "DealiasImage", - "ClearImages", ) @@ -621,761 +593,3 @@ async def create( ImageRow.aliases = relationship("ImageAliasRow", back_populates="image") ImageAliasRow.image = relationship("ImageRow", back_populates="aliases") - - -class Image(graphene.ObjectType): - id = graphene.UUID() - name = graphene.String() - project = graphene.String(description="Added in 24.03.10.") - humanized_name = graphene.String() - tag = graphene.String() - registry = graphene.String() - architecture = graphene.String() - is_local = graphene.Boolean() - digest = graphene.String() - labels = graphene.List(KVPair) - aliases = graphene.List(graphene.String) - size_bytes = BigInt() - resource_limits = graphene.List(ResourceLimit) - supported_accelerators = graphene.List(graphene.String) - installed = graphene.Boolean() - installed_agents = graphene.List(graphene.String) - # legacy field - hash = graphene.String() - - # internal attributes - raw_labels: dict[str, Any] - - @classmethod - def populate_row( - cls, - ctx: GraphQueryContext, - row: ImageRow, - installed_agents: List[str], - ) -> Image: - is_superadmin = ctx.user["role"] == UserRole.SUPERADMIN - hide_agents = False if is_superadmin else ctx.local_config["manager"]["hide-agents"] - ret = cls( - id=row.id, - name=row.image, - project=row.project, - humanized_name=row.image, - tag=row.tag, - registry=row.registry, - architecture=row.architecture, - is_local=row.is_local, - digest=row.trimmed_digest or None, - labels=[KVPair(key=k, value=v) for k, v in row.labels.items()], - aliases=[alias_row.alias for alias_row in row.aliases], - size_bytes=row.size_bytes, - resource_limits=[ - ResourceLimit( - key=k, - min=v.get("min", Decimal(0)), - max=v.get("max", Decimal("Infinity")), - ) - for k, v in row.resources.items() - ], - supported_accelerators=(row.accelerators or "").split(","), - installed=len(installed_agents) > 0, - installed_agents=installed_agents if not hide_agents else None, - # legacy - hash=row.trimmed_digest or None, - ) - ret.raw_labels = row.labels - return ret - - @classmethod - async def from_row( - cls, - ctx: GraphQueryContext, - row: ImageRow, - ) -> Image: - # TODO: add architecture - _installed_agents = await redis_helper.execute( - ctx.redis_image, - lambda r: r.smembers(row.name), - ) - installed_agents: List[str] = [] - for agent_id in _installed_agents: - if isinstance(agent_id, bytes): - installed_agents.append(agent_id.decode()) - else: - installed_agents.append(agent_id) - return cls.populate_row(ctx, row, installed_agents) - - @classmethod - async def bulk_load( - cls, - ctx: GraphQueryContext, - rows: List[ImageRow], - ) -> AsyncIterator[Image]: - async def _pipe(r: Redis) -> Pipeline: - pipe = r.pipeline() - for row in rows: - await pipe.smembers(row.name) - return pipe - - results = await redis_helper.execute(ctx.redis_image, _pipe) - for idx, row in enumerate(rows): - installed_agents: List[str] = [] - _installed_agents = results[idx] - for agent_id in _installed_agents: - if isinstance(agent_id, bytes): - installed_agents.append(agent_id.decode()) - else: - installed_agents.append(agent_id) - yield cls.populate_row(ctx, row, installed_agents) - - @classmethod - async def batch_load_by_canonical( - cls, - graph_ctx: GraphQueryContext, - image_names: Sequence[str], - ) -> Sequence[Optional[Image]]: - query = ( - sa.select(ImageRow) - .where(ImageRow.name.in_(image_names)) - .options(selectinload(ImageRow.aliases)) - ) - async with graph_ctx.db.begin_readonly_session() as session: - result = await session.execute(query) - return [await Image.from_row(graph_ctx, row) for row in result.scalars().all()] - - @classmethod - async def batch_load_by_image_ref( - cls, - graph_ctx: GraphQueryContext, - image_refs: Sequence[ImageRef], - ) -> Sequence[Optional[Image]]: - image_names = [x.canonical for x in image_refs] - return await cls.batch_load_by_canonical(graph_ctx, image_names) - - @classmethod - async def load_item_by_id( - cls, - ctx: GraphQueryContext, - id: UUID, - ) -> Image: - async with ctx.db.begin_readonly_session() as session: - row = await ImageRow.get(session, id, load_aliases=True) - if not row: - raise ImageNotFound - - return await cls.from_row(ctx, row) - - @classmethod - async def load_item( - cls, - ctx: GraphQueryContext, - reference: str, - architecture: str, - ) -> Image: - try: - async with ctx.db.begin_readonly_session() as session: - image_row = await ImageRow.resolve( - session, - [ - ImageIdentifier(reference, architecture), - ImageAlias(reference), - ], - ) - except UnknownImageReference: - raise ImageNotFound - return await cls.from_row(ctx, image_row) - - @classmethod - async def load_all( - cls, - ctx: GraphQueryContext, - *, - types: set[ImageLoadFilter] = set(), - ) -> Sequence[Image]: - async with ctx.db.begin_readonly_session() as session: - rows = await ImageRow.list(session, load_aliases=True) - items: list[Image] = [ - item async for item in cls.bulk_load(ctx, rows) if item.matches_filter(ctx, types) - ] - - return items - - @staticmethod - async def filter_allowed( - ctx: GraphQueryContext, - items: Sequence[Image], - domain_name: str, - ) -> Sequence[Image]: - from .domain import domains - - async with ctx.db.begin() as conn: - query = ( - sa.select([domains.c.allowed_docker_registries]) - .select_from(domains) - .where(domains.c.name == domain_name) - ) - result = await conn.execute(query) - allowed_docker_registries = result.scalar() - - filtered_items: list[Image] = [ - item for item in items if item.registry in allowed_docker_registries - ] - - return filtered_items - - def matches_filter( - self, - ctx: GraphQueryContext, - load_filters: set[ImageLoadFilter], - ) -> bool: - """ - Determine if the image is filtered according to the `load_filters` parameter. - """ - user_role = ctx.user["role"] - - # If the image filtered by any of its labels, return False early. - # If the image is not filtered and is determiend to be valid by any of its labels, `is_valid = True`. - is_valid = ImageLoadFilter.GENERAL in load_filters - for label in self.labels: - match label.key: - case "ai.backend.features" if "operation" in label.value: - if ImageLoadFilter.OPERATIONAL in load_filters: - is_valid = True - else: - return False - case "ai.backend.customized-image.owner": - if ( - ImageLoadFilter.CUSTOMIZED not in load_filters - and ImageLoadFilter.CUSTOMIZED_GLOBAL not in load_filters - ): - return False - if ImageLoadFilter.CUSTOMIZED in load_filters: - if label.value == f"user:{ctx.user["uuid"]}": - is_valid = True - else: - return False - if ImageLoadFilter.CUSTOMIZED_GLOBAL in load_filters: - if user_role == UserRole.SUPERADMIN: - is_valid = True - else: - return False - return is_valid - - -class ImageNode(graphene.ObjectType): - class Meta: - interfaces = (AsyncNode,) - - row_id = graphene.UUID(description="Added in 24.03.4. The undecoded id value stored in DB.") - name = graphene.String() - project = graphene.String(description="Added in 24.03.10.") - humanized_name = graphene.String() - tag = graphene.String() - registry = graphene.String() - architecture = graphene.String() - is_local = graphene.Boolean() - digest = graphene.String() - labels = graphene.List(KVPair) - size_bytes = BigInt() - resource_limits = graphene.List(ResourceLimit) - supported_accelerators = graphene.List(graphene.String) - aliases = graphene.List( - graphene.String, description="Added in 24.03.4. The array of image aliases." - ) - - @overload - @classmethod - def from_row(cls, row: ImageRow) -> ImageNode: ... - - @overload - @classmethod - def from_row(cls, row: None) -> None: ... - - @classmethod - def from_row(cls, row: ImageRow | None) -> ImageNode | None: - if row is None: - return None - return cls( - id=row.id, - row_id=row.id, - name=row.image, - project=row.project, - humanized_name=row.image, - tag=row.tag, - registry=row.registry, - architecture=row.architecture, - is_local=row.is_local, - digest=row.trimmed_digest or None, - labels=[KVPair(key=k, value=v) for k, v in row.labels.items()], - size_bytes=row.size_bytes, - resource_limits=[ - ResourceLimit( - key=k, - min=v.get("min", Decimal(0)), - max=v.get("max", Decimal("Infinity")), - ) - for k, v in row.resources.items() - ], - supported_accelerators=(row.accelerators or "").split(","), - aliases=[alias_row.alias for alias_row in row.aliases], - ) - - @classmethod - def from_legacy_image(cls, row: Image) -> ImageNode: - return cls( - id=row.id, - row_id=row.id, - name=row.name, - humanized_name=row.humanized_name, - tag=row.tag, - project=row.project, - registry=row.registry, - architecture=row.architecture, - is_local=row.is_local, - digest=row.trimmed_digest, - labels=row.labels, - size_bytes=row.size_bytes, - resource_limits=row.resource_limits, - supported_accelerators=row.supported_accelerators, - aliases=row.aliases, - ) - - @classmethod - async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ImageNode: - graph_ctx: GraphQueryContext = info.context - - _, image_id = AsyncNode.resolve_global_id(info, id) - query = ( - sa.select(ImageRow) - .where(ImageRow.id == image_id) - .options(selectinload(ImageRow.aliases).options(load_only(ImageAliasRow.alias))) - ) - async with graph_ctx.db.begin_readonly_session() as db_session: - image_row = await db_session.scalar(query) - if image_row is None: - raise ValueError(f"Image not found (id: {image_id})") - return cls.from_row(image_row) - - -class PreloadImage(graphene.Mutation): - allowed_roles = (UserRole.SUPERADMIN,) - - class Arguments: - references = graphene.List(graphene.String, required=True) - target_agents = graphene.List(graphene.String, required=True) - - ok = graphene.Boolean() - msg = graphene.String() - task_id = graphene.String() - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - references: Sequence[str], - target_agents: Sequence[str], - ) -> PreloadImage: - return PreloadImage(ok=False, msg="Not implemented.", task_id=None) - - -class UnloadImage(graphene.Mutation): - allowed_roles = (UserRole.SUPERADMIN,) - - class Arguments: - references = graphene.List(graphene.String, required=True) - target_agents = graphene.List(graphene.String, required=True) - - ok = graphene.Boolean() - msg = graphene.String() - task_id = graphene.String() - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - references: Sequence[str], - target_agents: Sequence[str], - ) -> UnloadImage: - return UnloadImage(ok=False, msg="Not implemented.", task_id=None) - - -class RescanImages(graphene.Mutation): - allowed_roles = (UserRole.ADMIN, UserRole.SUPERADMIN) - - class Arguments: - registry = graphene.String() - - ok = graphene.Boolean() - msg = graphene.String() - task_id = graphene.UUID() - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - registry: Optional[str] = None, - ) -> RescanImages: - log.info( - "rescanning docker registry {0} by API request", - f"({registry})" if registry else "(all)", - ) - ctx: GraphQueryContext = info.context - - async def _rescan_task(reporter: ProgressReporter) -> None: - await rescan_images(ctx.db, registry, reporter=reporter) - - task_id = await ctx.background_task_manager.start(_rescan_task) - return RescanImages(ok=True, msg="", task_id=task_id) - - -class ForgetImageById(graphene.Mutation): - """Added in 24.03.0.""" - - allowed_roles = ( - UserRole.SUPERADMIN, - UserRole.ADMIN, - UserRole.USER, - ) - - class Arguments: - image_id = graphene.String(required=True) - - ok = graphene.Boolean() - msg = graphene.String() - image = graphene.Field(ImageNode, description="Added since 24.03.1.") - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - image_id: str, - ) -> ForgetImageById: - _, raw_image_id = AsyncNode.resolve_global_id(info, image_id) - if not raw_image_id: - raw_image_id = image_id - - try: - _image_id = UUID(raw_image_id) - except ValueError: - raise ObjectNotFound("image") - - log.info("forget image {0} by API request", image_id) - ctx: GraphQueryContext = info.context - client_role = ctx.user["role"] - - async with ctx.db.begin_session() as session: - image_row = await ImageRow.get(session, _image_id, load_aliases=True) - if not image_row: - raise ObjectNotFound("image") - if client_role != UserRole.SUPERADMIN: - customized_image_owner = (image_row.labels or {}).get( - "ai.backend.customized-image.owner" - ) - if ( - not customized_image_owner - or customized_image_owner != f"user:{ctx.user["uuid"]}" - ): - return ForgetImageById(ok=False, msg="Forbidden") - await session.delete(image_row) - return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(image_row)) - - -class ForgetImage(graphene.Mutation): - allowed_roles = ( - UserRole.SUPERADMIN, - UserRole.ADMIN, - UserRole.USER, - ) - - class Arguments: - reference = graphene.String(required=True) - architecture = graphene.String(default_value=DEFAULT_IMAGE_ARCH) - - ok = graphene.Boolean() - msg = graphene.String() - image = graphene.Field(ImageNode, description="Added since 24.03.1.") - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - reference: str, - architecture: str, - ) -> ForgetImage: - log.info("forget image {0} by API request", reference) - ctx: GraphQueryContext = info.context - client_role = ctx.user["role"] - - async with ctx.db.begin_session() as session: - image_row = await ImageRow.resolve( - session, - [ - ImageIdentifier(reference, architecture), - ImageAlias(reference), - ], - ) - if client_role != UserRole.SUPERADMIN: - customized_image_owner = (image_row.labels or {}).get( - "ai.backend.customized-image.owner" - ) - if ( - not customized_image_owner - or customized_image_owner != f"user:{ctx.user["uuid"]}" - ): - return ForgetImage(ok=False, msg="Forbidden") - await session.delete(image_row) - return ForgetImage(ok=True, msg="", image=ImageNode.from_row(image_row)) - - -class UntagImageFromRegistry(graphene.Mutation): - """Added in 24.03.1""" - - allowed_roles = ( - UserRole.SUPERADMIN, - UserRole.ADMIN, - UserRole.USER, - ) - - class Arguments: - image_id = graphene.String(required=True) - - ok = graphene.Boolean() - msg = graphene.String() - image = graphene.Field(ImageNode, description="Added since 24.03.1.") - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - image_id: str, - ) -> UntagImageFromRegistry: - from ai.backend.manager.container_registry.harbor import HarborRegistry_v2 - - _, raw_image_id = AsyncNode.resolve_global_id(info, image_id) - if not raw_image_id: - raw_image_id = image_id - - try: - _image_id = UUID(raw_image_id) - except ValueError: - raise ObjectNotFound("image") - - log.info("remove image from registry {0} by API request", str(_image_id)) - ctx: GraphQueryContext = info.context - client_role = ctx.user["role"] - - async with ctx.db.begin_readonly_session() as session: - image_row = await ImageRow.get(session, _image_id, load_aliases=True) - if not image_row: - raise ImageNotFound - if client_role != UserRole.SUPERADMIN: - customized_image_owner = (image_row.labels or {}).get( - "ai.backend.customized-image.owner" - ) - if ( - not customized_image_owner - or customized_image_owner != f"user:{ctx.user["uuid"]}" - ): - return UntagImageFromRegistry(ok=False, msg="Forbidden") - - query = sa.select(ContainerRegistryRow).where( - ContainerRegistryRow.registry_name == image_row.image_ref.registry - ) - - registry_info = (await session.execute(query)).scalar() - - if registry_info.type != ContainerRegistryType.HARBOR2: - raise NotImplementedError("This feature is only supported for Harbor 2 registries") - - scanner = HarborRegistry_v2(ctx.db, image_row.image_ref.registry, registry_info) - await scanner.untag(image_row.image_ref) - - return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(image_row)) - - -class AliasImage(graphene.Mutation): - allowed_roles = (UserRole.SUPERADMIN,) - - class Arguments: - alias = graphene.String(required=True) - target = graphene.String(required=True) - architecture = graphene.String(default_value=DEFAULT_IMAGE_ARCH) - - ok = graphene.Boolean() - msg = graphene.String() - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - alias: str, - target: str, - architecture: str, - ) -> AliasImage: - log.info("alias image {0} -> {1} by API request", alias, target) - ctx: GraphQueryContext = info.context - try: - async with ctx.db.begin_session() as session: - try: - image_row = await ImageRow.resolve( - session, [ImageIdentifier(target, architecture)] - ) - except UnknownImageReference: - raise ImageNotFound - else: - image_row.aliases.append(ImageAliasRow(alias=alias, image_id=image_row.id)) - except ValueError as e: - return AliasImage(ok=False, msg=str(e)) - return AliasImage(ok=True, msg="") - - -class DealiasImage(graphene.Mutation): - allowed_roles = (UserRole.SUPERADMIN,) - - class Arguments: - alias = graphene.String(required=True) - - ok = graphene.Boolean() - msg = graphene.String() - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - alias: str, - ) -> DealiasImage: - log.info("dealias image {0} by API request", alias) - ctx: GraphQueryContext = info.context - try: - async with ctx.db.begin_session() as session: - existing_alias = await session.scalar( - sa.select(ImageAliasRow).where(ImageAliasRow.alias == alias), - ) - if existing_alias is None: - raise DealiasImage(ok=False, msg=str("No such alias")) - await session.delete(existing_alias) - except ValueError as e: - return DealiasImage(ok=False, msg=str(e)) - return DealiasImage(ok=True, msg="") - - -class ClearImages(graphene.Mutation): - allowed_roles = (UserRole.SUPERADMIN,) - - class Arguments: - registry = graphene.String() - - ok = graphene.Boolean() - msg = graphene.String() - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - registry: str, - ) -> ClearImages: - ctx: GraphQueryContext = info.context - try: - async with ctx.db.begin_session() as session: - result = await session.execute( - sa.select(ImageRow).where(ImageRow.registry == registry) - ) - image_ids = [x.id for x in result.scalars().all()] - - await session.execute( - sa.delete(ImageAliasRow).where(ImageAliasRow.image_id.in_(image_ids)) - ) - await session.execute(sa.delete(ImageRow).where(ImageRow.registry == registry)) - except ValueError as e: - return ClearImages(ok=False, msg=str(e)) - return ClearImages(ok=True, msg="") - - -class ModifyImageInput(graphene.InputObjectType): - name = graphene.String(required=False) - registry = graphene.String(required=False) - image = graphene.String(required=False) - tag = graphene.String(required=False) - architecture = graphene.String(required=False) - is_local = graphene.Boolean(required=False) - size_bytes = graphene.Int(required=False) - type = graphene.String(required=False) - - digest = graphene.String(required=False) - labels = graphene.List(lambda: KVPairInput, required=False) - supported_accelerators = graphene.List(graphene.String, required=False) - resource_limits = graphene.List(lambda: ResourceLimitInput, required=False) - - -class ModifyImage(graphene.Mutation): - allowed_roles = (UserRole.SUPERADMIN,) - - class Arguments: - target = graphene.String(required=True, default_value=None) - architecture = graphene.String(required=False, default_value=DEFAULT_IMAGE_ARCH) - props = ModifyImageInput(required=True) - - ok = graphene.Boolean() - msg = graphene.String() - - @staticmethod - async def mutate( - root: Any, - info: graphene.ResolveInfo, - target: str, - architecture: str, - props: ModifyImageInput, - ) -> AliasImage: - ctx: GraphQueryContext = info.context - data: MutableMapping[str, Any] = {} - set_if_set(props, data, "name") - set_if_set(props, data, "registry") - set_if_set(props, data, "image") - set_if_set(props, data, "tag") - set_if_set(props, data, "architecture") - set_if_set(props, data, "is_local") - set_if_set(props, data, "size_bytes") - set_if_set(props, data, "type") - set_if_set(props, data, "digest", target_key="config_digest") - set_if_set( - props, - data, - "supported_accelerators", - clean_func=lambda v: ",".join(v), - target_key="accelerators", - ) - set_if_set(props, data, "labels", clean_func=lambda v: {pair.key: pair.value for pair in v}) - - if props.resource_limits is not Undefined: - resources_data = {} - for limit_option in props.resource_limits: - limit_data = {} - if limit_option.min is not Undefined and len(limit_option.min) > 0: - limit_data["min"] = limit_option.min - if limit_option.max is not Undefined and len(limit_option.max) > 0: - limit_data["max"] = limit_option.max - resources_data[limit_option.key] = limit_data - data["resources"] = resources_data - - try: - async with ctx.db.begin_session() as session: - try: - image_row = await ImageRow.resolve( - session, - [ - ImageIdentifier(target, architecture), - ImageAlias(target), - ], - ) - except UnknownImageReference: - return ModifyImage(ok=False, msg="Image not found") - for k, v in data.items(): - setattr(image_row, k, v) - except ValueError as e: - return ModifyImage(ok=False, msg=str(e)) - return ModifyImage(ok=True, msg="") - - -class ImageRefType(graphene.InputObjectType): - name = graphene.String(required=True) - registry = graphene.String() - architecture = graphene.String() diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index 011354ed88..f5639dd51d 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -74,8 +74,9 @@ batch_multiresult, batch_result, ) +from .gql_models.image import ImageNode from .group import groups -from .image import ImageNode, ImageRow +from .image import ImageRow from .minilang import JSONFieldItem from .minilang.ordering import ColumnMapType, QueryOrderParser from .minilang.queryfilter import FieldSpecType, QueryFilterParser, enum_field_getter