From ea09bab55e3b977a7108d3246f578b158c8a328c Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Thu, 27 Jun 2024 04:20:30 +0000 Subject: [PATCH] WIP --- .../manager/container_registry/base.py | 234 ++++++++++++++---- .../manager/container_registry/github.py | 123 +-------- 2 files changed, 189 insertions(+), 168 deletions(-) diff --git a/src/ai/backend/manager/container_registry/base.py b/src/ai/backend/manager/container_registry/base.py index 919ce1ac6ce..9dc8467e53f 100644 --- a/src/ai/backend/manager/container_registry/base.py +++ b/src/ai/backend/manager/container_registry/base.py @@ -217,7 +217,6 @@ async def _scan_tag( image: str, tag: str, ) -> None: - manifests = {} async with concurrency_sema.get(): rqst_args["headers"]["Accept"] = self.MEDIA_TYPE_DOCKER_MANIFEST_LIST async with sess.get( @@ -230,63 +229,192 @@ async def _scan_tag( content_type = resp.headers["Content-Type"] resp.raise_for_status() resp_json = await resp.json() - match content_type: - # TODO: Support `self.MEDIA_TYPE_DOCKER_MANIFEST` - case self.MEDIA_TYPE_DOCKER_MANIFEST_LIST: - manifest_list = resp_json["manifests"] - request_type = self.MEDIA_TYPE_DOCKER_MANIFEST - case self.MEDIA_TYPE_OCI_INDEX: - manifest_list = [ - item - for item in resp_json["manifests"] - if "annotations" not in item # skip attestation manifests - ] - request_type = self.MEDIA_TYPE_OCI_MANIFEST - case _: - log.warn("Unknown content type: {}", content_type) - raise RuntimeError( - "The registry does not support the standard way of " - "listing multiarch images." - ) - rqst_args["headers"]["Accept"] = request_type - for manifest in manifest_list: - platform_arg = ( - f"{manifest['platform']['os']}/{manifest['platform']['architecture']}" - ) - if variant := manifest["platform"].get("variant", None): - platform_arg += f"/{variant}" - architecture = manifest["platform"]["architecture"] - architecture = arch_name_aliases.get(architecture, architecture) - async with sess.get( - self.registry_url / f"v2/{image}/manifests/{manifest['digest']}", **rqst_args - ) as resp: - data = await resp.json() - config_digest = data["config"]["digest"] - size_bytes = sum(layer["size"] for layer in data["layers"]) + data["config"]["size"] - async with sess.get( - self.registry_url / f"v2/{image}/blobs/{config_digest}", **rqst_args - ) as resp: - resp.raise_for_status() - data = json.loads(await resp.read()) - labels = {} - # we should favor `config` instead of `container_config` since `config` can contain additional datas - # set when commiting image via `--change` flag - if _config_labels := data.get("config", {}).get("Labels"): - labels = _config_labels - elif _container_config_labels := data.get("container_config", {}).get("Labels"): - labels = _container_config_labels - - if not labels: - log.warning( - "Labels section not found on image {}:{}/{}", image, tag, architecture - ) - manifests[architecture] = { + async with aiotools.TaskGroup() as tg: + match content_type: + case self.MEDIA_TYPE_DOCKER_MANIFEST: + await self._process_docker_v2_image( + tg, sess, rqst_args, image, tag, resp_json + ) + case self.MEDIA_TYPE_DOCKER_MANIFEST_LIST: + await self._process_docker_v2_multiplatform_image( + tg, sess, rqst_args, image, tag, resp_json + ) + case self.MEDIA_TYPE_OCI_INDEX: + await self._process_oci_index( + tg, sess, rqst_args, image, tag, resp_json + ) + case _: + log.warn("Unknown content type: {}", content_type) + raise RuntimeError( + "The registry does not support the standard way of " + "listing multiarch images." + ) + + async def _process_oci_index( + self, + tg: aiotools.TaskGroup, + sess: aiohttp.ClientSession, + rqst_args: Mapping[str, Any], + image: str, + tag: str, + image_info: Mapping[str, Any], + ) -> None: + manifests = {} + manifest_list = [ + item + for item in image_info["manifests"] + if "annotations" not in item # skip attestation manifests + ] + rqst_args["headers"]["Accept"] = self.MEDIA_TYPE_OCI_MANIFEST + + for manifest in manifest_list: + platform_arg = f"{manifest['platform']['os']}/{manifest['platform']['architecture']}" + if variant := manifest["platform"].get("variant", None): + platform_arg += f"/{variant}" + architecture = manifest["platform"]["architecture"] + architecture = arch_name_aliases.get(architecture, architecture) + + async with sess.get( + self.registry_url / f"v2/{image}/manifests/{manifest['digest']}", **rqst_args + ) as resp: + data = await resp.json() + config_digest = data["config"]["digest"] + size_bytes = sum(layer["size"] for layer in data["layers"]) + data["config"]["size"] + + async with sess.get( + self.registry_url / f"v2/{image}/blobs/{config_digest}", **rqst_args + ) as resp: + resp.raise_for_status() + data = json.loads(await resp.read()) + labels = {} + # we should favor `config` instead of `container_config` since `config` can contain additional datas + # set when commiting image via `--change` flag + if _config_labels := data.get("config", {}).get("Labels"): + labels = _config_labels + elif _container_config_labels := data.get("container_config", {}).get("Labels"): + labels = _container_config_labels + + if not labels: + log.warning("Labels section not found on image {}:{}/{}", image, tag, architecture) + + manifests[architecture] = { + "size": size_bytes, + "labels": labels, + "digest": config_digest, + } + await self._read_manifest(image, tag, manifests) + + async def _process_docker_v2_multiplatform_image( + self, + tg: aiotools.TaskGroup, + sess: aiohttp.ClientSession, + rqst_args: Mapping[str, Any], + image: str, + tag: str, + image_info: Mapping[str, Any], + ) -> None: + manifests = {} + manifest_list = image_info["manifests"] + rqst_args["headers"]["Accept"] = self.MEDIA_TYPE_DOCKER_MANIFEST + + for manifest in manifest_list: + platform_arg = f"{manifest['platform']['os']}/{manifest['platform']['architecture']}" + if variant := manifest["platform"].get("variant", None): + platform_arg += f"/{variant}" + architecture = manifest["platform"]["architecture"] + architecture = arch_name_aliases.get(architecture, architecture) + + async with sess.get( + self.registry_url / f"v2/{image}/manifests/{manifest['digest']}", **rqst_args + ) as resp: + data = await resp.json() + config_digest = data["config"]["digest"] + size_bytes = sum(layer["size"] for layer in data["layers"]) + data["config"]["size"] + + async with sess.get( + self.registry_url / f"v2/{image}/blobs/{config_digest}", **rqst_args + ) as resp: + resp.raise_for_status() + data = json.loads(await resp.read()) + labels = {} + # we should favor `config` instead of `container_config` since `config` can contain additional datas + # set when commiting image via `--change` flag + if _config_labels := data.get("config", {}).get("Labels"): + labels = _config_labels + elif _container_config_labels := data.get("container_config", {}).get("Labels"): + labels = _container_config_labels + + if not labels: + log.warning("Labels section not found on image {}:{}/{}", image, tag, architecture) + + manifests[architecture] = { + "size": size_bytes, + "labels": labels, + "digest": config_digest, + } + await self._read_manifest(image, tag, manifests) + + async def _process_docker_v2_image( + self, + tg: aiotools.TaskGroup, + sess: aiohttp.ClientSession, + rqst_args: Mapping[str, Any], + image: str, + tag: str, + image_info: Mapping[str, Any], + ) -> None: + config_digest = image_info["config"]["digest"] + rqst_args["headers"]["Accept"] = self.MEDIA_TYPE_DOCKER_MANIFEST + + async with sess.get( + self.registry_url / f"v2/{image}/blobs/{config_digest}", + **rqst_args, + ) as resp: + resp.raise_for_status() + blob_data = json.loads(await resp.read()) + + manifest_os = blob_data["os"] + manifest_arch = blob_data["architecture"] + architecture = arch_name_aliases.get(manifest_arch, manifest_arch) + manifest_variant = blob_data.get("variant", None) + + platform_arg = f"{manifest_os}/{manifest_arch}" + if manifest_variant: + platform_arg += f"/{manifest_variant}" + + size_bytes = ( + sum(layer["size"] for layer in image_info["layers"]) + image_info["config"]["size"] + ) + + async with sess.get( + self.registry_url / f"v2/{image}/blobs/{config_digest}", **rqst_args + ) as resp: + resp.raise_for_status() + data = json.loads(await resp.read()) + labels = {} + + # we should favor `config` instead of `container_config` since `config` can contain additional datas + # set when commiting image via `--change` flag + if _config_labels := data.get("config", {}).get("Labels"): + labels = _config_labels + elif _container_config_labels := data.get("container_config", {}).get("Labels"): + labels = _container_config_labels + + if not labels: + log.warning("Labels section not found on image {}:{}/{}", image, tag, architecture) + + await self._read_manifest( + image, + tag, + { + architecture: { "size": size_bytes, "labels": labels, "digest": config_digest, } - await self._read_manifest(image, tag, manifests) + }, + ) async def _read_manifest( self, diff --git a/src/ai/backend/manager/container_registry/github.py b/src/ai/backend/manager/container_registry/github.py index 5518e258565..c6e02831219 100644 --- a/src/ai/backend/manager/container_registry/github.py +++ b/src/ai/backend/manager/container_registry/github.py @@ -1,20 +1,12 @@ -import asyncio -import json import logging -from typing import AsyncIterator, Optional, cast +from typing import AsyncIterator import aiohttp -import aiotools -import yarl -from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.logging import BraceStyleAdapter from .base import ( BaseContainerRegistry, - all_updates, - concurrency_sema, - progress_reporter, ) log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] @@ -25,13 +17,14 @@ async def fetch_repositories( self, sess: aiohttp.ClientSession, ) -> AsyncIterator[str]: - name, type_, access_token = ( - self.registry_info["name"], + name, access_token, type_ = ( + self.registry_info["username"], + self.registry_info["password"], self.registry_info["name_type"], - self.registry_info["token"], ) base_url = f"https://api.github.com/{type_}/{name}/packages" + headers = { "Authorization": f"Bearer {access_token}", "Accept": "application/vnd.github.v3+json", @@ -47,112 +40,12 @@ async def fetch_repositories( if response.status == 200: data = await response.json() for repo in data: - yield repo["name"] + yield f"{self.registry_info["username"]}/{repo["name"]}" if "next" in response.links: page += 1 else: break else: - print(f"Failed to fetch repositories: {response.status}") - break - - async def get_ghcr_token(self, image: str): - url = f"https://ghcr.io/token?scope=repository:{image}:pull" - auth = aiohttp.BasicAuth( - login=self.registry_info["name"], password=self.registry_info["token"] - ) - - async with aiohttp.ClientSession() as session: - async with session.get(url, auth=auth) as response: - if response.status == 200: - data = await response.json() - return data["token"] - else: - raise Exception("Failed to get token") - - async def rescan_single_registry( - self, - reporter: ProgressReporter | None = None, - ) -> None: - log.info("rescan_single_registry()") - all_updates_token = all_updates.set({}) - concurrency_sema.set(asyncio.Semaphore(self.max_concurrency_per_registry)) - progress_reporter.set(reporter) - try: - username = self.registry_info["name"] - if username is not None: - self.credentials["username"] = username - password = self.registry_info["token"] - if password is not None: - self.credentials["password"] = password - async with self.prepare_client_session() as (url, client_session): - self.registry_url = url - async with aiotools.TaskGroup() as tg: - async for image in self.fetch_repositories(client_session): - tg.create_task( - self._scan_image( - client_session, f"{self.registry_info["name"]}/{image}" - ) - ) - await self.commit_rescan_result() - finally: - all_updates.reset(all_updates_token) - - async def _scan_image( - self, - sess: aiohttp.ClientSession, - image: str, - ) -> None: - log.info("_scan_image()") - - ghcr_token = await self.get_ghcr_token( - image, - ) - - tags = [] - tag_list_url: Optional[yarl.URL] - tag_list_url = (self.registry_url / f"v2/{image}/tags/list").with_query( - {"n": "10"}, - ) - rqst_args = {"headers": {"Authorization": f"Bearer {ghcr_token}"}} - - while tag_list_url is not None: - async with sess.get(tag_list_url, allow_redirects=False, **rqst_args) as resp: - data = json.loads(await resp.read()) - - if "tags" in data: - # sometimes there are dangling image names in the hub. - tags.extend(data["tags"]) - tag_list_url = None - next_page_link = resp.links.get("next") - if next_page_link: - next_page_url = cast(yarl.URL, next_page_link["url"]) - tag_list_url = self.registry_url.with_path(next_page_url.path).with_query( - next_page_url.query + raise RuntimeError( + f"Failed to fetch repositories! {response.status} error occured." ) - - if (reporter := progress_reporter.get()) is not None: - reporter.total_progress += len(tags) - - async with aiotools.TaskGroup() as tg: - for tag in tags: - tg.create_task(self._scan_tag(sess, rqst_args, image, tag)) - - # async def _scan_tag(self, sess: aiohttp.ClientSession, image: str): - # url = f"https://ghcr.io/v2/{image}/tags/list" - # headers = {'Authorization': f'Bearer {token}'} - # async with aiohttp.ClientSession() as session: - # async with session.get(url, headers=headers) as response: - # if response.status == 200: - # data = await response.json() - # return data - # else: - # print('response.status', response.status) - # print("Failed to fetch tags") - # return {} - - # async def _scan_image(self): - # pass - - # async def _read_manifest(self): - # pass