From 4a2ab0949663973d47b845b3c634f2c9b569b7b0 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Wed, 24 Jul 2024 07:59:47 +0000 Subject: [PATCH 01/10] feat: Support AWS ECR Container Registry --- src/ai/backend/agent/docker/agent.py | 1 + src/ai/backend/common/docker.py | 9 ++ .../manager/container_registry/__init__.py | 4 + .../backend/manager/container_registry/aws.py | 85 +++++++++++++++++++ 4 files changed, 99 insertions(+) create mode 100644 src/ai/backend/manager/container_registry/aws.py diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index 0d4c64e7a9..37fec73c0d 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -1344,6 +1344,7 @@ async def push_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> } async with closing_async(Docker()) as docker: + # TODO: Fix the error where no error is displayed even when an authentication error occurs. await docker.images.push(image_ref.canonical, auth=auth_config) async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> None: diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index 58d9a7fb6e..4c7e7f0cb4 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -234,6 +234,15 @@ async def login( """ basic_auth: Optional[aiohttp.BasicAuth] + # TODO: Fix this. + if "public.ecr" in str(registry_url): + return { + "auth": None, + "headers": { + "Authorization": f"Bearer {credentials['password']}", + }, + } + if credentials.get("username") and credentials.get("password"): basic_auth = aiohttp.BasicAuth( credentials["username"], diff --git a/src/ai/backend/manager/container_registry/__init__.py b/src/ai/backend/manager/container_registry/__init__.py index d3f773fbe2..e1763dfa2f 100644 --- a/src/ai/backend/manager/container_registry/__init__.py +++ b/src/ai/backend/manager/container_registry/__init__.py @@ -36,6 +36,10 @@ def get_container_registry_cls(registry_info: Mapping[str, Any]) -> Type[BaseCon from .gitlab import GitLabRegistry cr_cls = GitLabRegistry + elif registry_type == "ecr" or registry_type == "ecr-public": + from .aws import AWSElasticContainerRegistry_v2 + + cr_cls = AWSElasticContainerRegistry_v2 elif registry_type == "local": from .local import LocalRegistry diff --git a/src/ai/backend/manager/container_registry/aws.py b/src/ai/backend/manager/container_registry/aws.py new file mode 100644 index 0000000000..1f4c813b83 --- /dev/null +++ b/src/ai/backend/manager/container_registry/aws.py @@ -0,0 +1,85 @@ +import logging +from typing import Any, AsyncIterator, Mapping + +import aiohttp +import boto3 + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine + +from .base import ( + BaseContainerRegistry, +) + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] + + +class AWSElasticContainerRegistry_v2(BaseContainerRegistry): + def __init__( + self, + db: ExtendedAsyncSAEngine, + registry_name: str, + registry_info: Mapping[str, Any], + *, + max_concurrency_per_registry: int = 4, + ssl_verify: bool = True, + ) -> None: + super().__init__( + db, + registry_name, + registry_info, + max_concurrency_per_registry=max_concurrency_per_registry, + ssl_verify=ssl_verify, + ) + + access_key, secret_access_key, region, type_ = ( + self.registry_info["access_key"], + self.registry_info["secret_access_key"], + self.registry_info["region"], + self.registry_info["type"], + ) + + self.ecr_client = boto3.client( + type_, + region_name=region, + aws_access_key_id=access_key, + aws_secret_access_key=secret_access_key, + ) + + async def fetch_repositories( + self, + sess: aiohttp.ClientSession, + ) -> AsyncIterator[str]: + access_key, secret_access_key, region, type_ = ( + self.registry_info["access_key"], + self.registry_info["secret_access_key"], + self.registry_info["region"], + self.registry_info["type"], + ) + + client = boto3.client( + type_, + region_name=region, + aws_access_key_id=access_key, + aws_secret_access_key=secret_access_key, + ) + + next_token = None + try: + while True: + if next_token: + response = client.describe_repositories(nextToken=next_token) + else: + response = client.describe_repositories() + + for repo in response["repositories"]: + # TODO: Verify this logic + repo_id = (repo["repositoryUri"].split("/"))[1] + yield f"{repo_id}/{repo["repositoryName"]}" + + next_token = response.get("nextToken") + + if not next_token: + break + except Exception as e: + print(f"An error occurred: {e}") From d1d203a2113e2811e4943b52d41428739ec7f863 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Wed, 24 Jul 2024 08:02:57 +0000 Subject: [PATCH 02/10] chore: Add fragment --- changes/2549.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/2549.feature.md diff --git a/changes/2549.feature.md b/changes/2549.feature.md new file mode 100644 index 0000000000..a0c41a63c2 --- /dev/null +++ b/changes/2549.feature.md @@ -0,0 +1 @@ +Support AWS ECR Container Registry \ No newline at end of file From dd2edc7da21abe4c41b4998d23c694651d8f6988 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Thu, 25 Jul 2024 02:02:17 +0000 Subject: [PATCH 03/10] fix: Support private repository login --- src/ai/backend/common/docker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index 4c7e7f0cb4..cd4cac0969 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -234,8 +234,7 @@ async def login( """ basic_auth: Optional[aiohttp.BasicAuth] - # TODO: Fix this. - if "public.ecr" in str(registry_url): + if "public.ecr" in str(registry_url) or "dkr.ecr" in str(registry_url): return { "auth": None, "headers": { From 4045e59773e58f18be820e8295128df6d7ad923f Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Thu, 25 Jul 2024 04:48:17 +0000 Subject: [PATCH 04/10] feat: Automatic issuance and use of `auth_token` --- src/ai/backend/common/docker.py | 7 ++- src/ai/backend/manager/api/session.py | 13 +++++- .../manager/container_registry/__init__.py | 2 +- .../container_registry/{aws.py => aws_ecr.py} | 45 ++++++++++++------- 4 files changed, 46 insertions(+), 21 deletions(-) rename src/ai/backend/manager/container_registry/{aws.py => aws_ecr.py} (69%) diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index cd4cac0969..2c89c8d5e8 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import enum import functools import ipaddress @@ -235,10 +236,14 @@ async def login( basic_auth: Optional[aiohttp.BasicAuth] if "public.ecr" in str(registry_url) or "dkr.ecr" in str(registry_url): + auth_token = base64.b64encode( + f"{credentials["username"]}:{credentials["password"]}".encode() + ).decode() + return { "auth": None, "headers": { - "Authorization": f"Bearer {credentials['password']}", + "Authorization": f"Bearer {auth_token}", }, } diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 032be7457d..d8a96c86a3 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -50,6 +50,7 @@ from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.docker import ImageRef +from ai.backend.manager.container_registry.aws_ecr import AWSElasticContainerRegistry_v2 from ai.backend.manager.models.group import GroupRow from ai.backend.manager.models.image import rescan_images @@ -1188,12 +1189,20 @@ async def _commit_and_upload(reporter: ProgressReporter) -> None: raise BackendError(extra_msg="Operation cancelled") if not new_image_ref.is_local: + if "public.ecr" in registry_hostname or "dkr.ecr" in str(registry_hostname): + credential = AWSElasticContainerRegistry_v2.get_credential(registry_conf) + else: + credential = registry_conf + + username = credential.get("username") + password = credential.get("password") + # push image to registry from local agent image_registry = ImageRegistry( name=registry_hostname, url=str(registry_conf[""]), - username=registry_conf.get("username"), - password=registry_conf.get("password"), + username=username, + password=password, ) resp = await root_ctx.registry.push_image( session.main_kernel.agent, diff --git a/src/ai/backend/manager/container_registry/__init__.py b/src/ai/backend/manager/container_registry/__init__.py index e1763dfa2f..d913905cea 100644 --- a/src/ai/backend/manager/container_registry/__init__.py +++ b/src/ai/backend/manager/container_registry/__init__.py @@ -37,7 +37,7 @@ def get_container_registry_cls(registry_info: Mapping[str, Any]) -> Type[BaseCon cr_cls = GitLabRegistry elif registry_type == "ecr" or registry_type == "ecr-public": - from .aws import AWSElasticContainerRegistry_v2 + from .aws_ecr import AWSElasticContainerRegistry_v2 cr_cls = AWSElasticContainerRegistry_v2 elif registry_type == "local": diff --git a/src/ai/backend/manager/container_registry/aws.py b/src/ai/backend/manager/container_registry/aws_ecr.py similarity index 69% rename from src/ai/backend/manager/container_registry/aws.py rename to src/ai/backend/manager/container_registry/aws_ecr.py index 1f4c813b83..4871edf5f9 100644 --- a/src/ai/backend/manager/container_registry/aws.py +++ b/src/ai/backend/manager/container_registry/aws_ecr.py @@ -1,3 +1,4 @@ +import base64 import logging from typing import Any, AsyncIterator, Mapping @@ -15,6 +16,28 @@ class AWSElasticContainerRegistry_v2(BaseContainerRegistry): + @staticmethod + def get_credential(registry_info: Mapping[str, Any]) -> dict[str, Any]: + access_key, secret_access_key, region, type_ = ( + registry_info["access_key"], + registry_info["secret_access_key"], + registry_info["region"], + registry_info["type"], + ) + + ecr_client = boto3.client( + type_, + region_name=region, + aws_access_key_id=access_key, + aws_secret_access_key=secret_access_key, + ) + + auth_token = ecr_client.get_authorization_token()["authorizationData"]["authorizationToken"] + decoded_auth_token = base64.b64decode(auth_token).decode("utf-8") + username, password = decoded_auth_token.split(":") + + return {"username": username, "password": password} + def __init__( self, db: ExtendedAsyncSAEngine, @@ -32,19 +55,7 @@ def __init__( ssl_verify=ssl_verify, ) - access_key, secret_access_key, region, type_ = ( - self.registry_info["access_key"], - self.registry_info["secret_access_key"], - self.registry_info["region"], - self.registry_info["type"], - ) - - self.ecr_client = boto3.client( - type_, - region_name=region, - aws_access_key_id=access_key, - aws_secret_access_key=secret_access_key, - ) + self.credentials = AWSElasticContainerRegistry_v2.get_credential(registry_info) async def fetch_repositories( self, @@ -73,13 +84,13 @@ async def fetch_repositories( response = client.describe_repositories() for repo in response["repositories"]: - # TODO: Verify this logic - repo_id = (repo["repositoryUri"].split("/"))[1] - yield f"{repo_id}/{repo["repositoryName"]}" + # TODO: Fix this. + registry_alias = (repo["repositoryUri"].split("/"))[1] + yield f"{registry_alias}/{repo["repositoryName"]}" next_token = response.get("nextToken") if not next_token: break except Exception as e: - print(f"An error occurred: {e}") + log.error(f"Error occurred: {e}") From 231a6c10e12950ebe487843cb80feb0146536848 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Thu, 25 Jul 2024 05:15:28 +0000 Subject: [PATCH 05/10] feat: Apply pagination to ECR fetch_repositories --- src/ai/backend/manager/container_registry/aws_ecr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ai/backend/manager/container_registry/aws_ecr.py b/src/ai/backend/manager/container_registry/aws_ecr.py index 4871edf5f9..0b9ae85923 100644 --- a/src/ai/backend/manager/container_registry/aws_ecr.py +++ b/src/ai/backend/manager/container_registry/aws_ecr.py @@ -79,9 +79,9 @@ async def fetch_repositories( try: while True: if next_token: - response = client.describe_repositories(nextToken=next_token) + response = client.describe_repositories(nextToken=next_token, maxResults=30) else: - response = client.describe_repositories() + response = client.describe_repositories(maxResults=30) for repo in response["repositories"]: # TODO: Fix this. From 6c067f77a0f4eda49109804d025a329ba1f57326 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Fri, 26 Jul 2024 03:05:30 +0000 Subject: [PATCH 06/10] fix: Wrong approach to auth --- src/ai/backend/common/docker.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index 2c89c8d5e8..284c80ee62 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -1,6 +1,5 @@ from __future__ import annotations -import base64 import enum import functools import ipaddress @@ -235,18 +234,6 @@ async def login( """ basic_auth: Optional[aiohttp.BasicAuth] - if "public.ecr" in str(registry_url) or "dkr.ecr" in str(registry_url): - auth_token = base64.b64encode( - f"{credentials["username"]}:{credentials["password"]}".encode() - ).decode() - - return { - "auth": None, - "headers": { - "Authorization": f"Bearer {auth_token}", - }, - } - if credentials.get("username") and credentials.get("password"): basic_auth = aiohttp.BasicAuth( credentials["username"], @@ -271,7 +258,7 @@ async def login( return {"auth": basic_auth, "headers": {}} elif ping_status == 404: raise RuntimeError(f"Unsupported docker registry: {registry_url}! (API v2 not implemented)") - elif ping_status == 401: + elif ping_status in [400, 401]: params = { "scope": scope, "offline_token": "true", From 02da19cd45b7b95d40f4fa88ab0ea1f6371f08b6 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Fri, 26 Jul 2024 03:12:38 +0000 Subject: [PATCH 07/10] chore: Rename fragment --- changes/2549.feature.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changes/2549.feature.md b/changes/2549.feature.md index a0c41a63c2..5bcdf5921f 100644 --- a/changes/2549.feature.md +++ b/changes/2549.feature.md @@ -1 +1 @@ -Support AWS ECR Container Registry \ No newline at end of file +Support AWS ECR Public Container Registry \ No newline at end of file From 6e332d03ed74b58c940ccd098ed5d1f02b175589 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Fri, 26 Jul 2024 03:17:26 +0000 Subject: [PATCH 08/10] chore: Remove useless comments --- src/ai/backend/agent/docker/agent.py | 1 - src/ai/backend/manager/container_registry/aws_ecr.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index 37fec73c0d..0d4c64e7a9 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -1344,7 +1344,6 @@ async def push_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> } async with closing_async(Docker()) as docker: - # TODO: Fix the error where no error is displayed even when an authentication error occurs. await docker.images.push(image_ref.canonical, auth=auth_config) async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> None: diff --git a/src/ai/backend/manager/container_registry/aws_ecr.py b/src/ai/backend/manager/container_registry/aws_ecr.py index 0b9ae85923..a1a4069e79 100644 --- a/src/ai/backend/manager/container_registry/aws_ecr.py +++ b/src/ai/backend/manager/container_registry/aws_ecr.py @@ -84,7 +84,6 @@ async def fetch_repositories( response = client.describe_repositories(maxResults=30) for repo in response["repositories"]: - # TODO: Fix this. registry_alias = (repo["repositoryUri"].split("/"))[1] yield f"{registry_alias}/{repo["repositoryName"]}" From 828d6eb0b1ac195393d3c013f459c64498c9b2d0 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Fri, 26 Jul 2024 04:45:29 +0000 Subject: [PATCH 09/10] revert: Remove automatic auth_token issurance logic --- src/ai/backend/manager/api/session.py | 13 +----- .../manager/container_registry/aws_ecr.py | 45 +------------------ 2 files changed, 3 insertions(+), 55 deletions(-) diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index d8a96c86a3..032be7457d 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -50,7 +50,6 @@ from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.docker import ImageRef -from ai.backend.manager.container_registry.aws_ecr import AWSElasticContainerRegistry_v2 from ai.backend.manager.models.group import GroupRow from ai.backend.manager.models.image import rescan_images @@ -1189,20 +1188,12 @@ async def _commit_and_upload(reporter: ProgressReporter) -> None: raise BackendError(extra_msg="Operation cancelled") if not new_image_ref.is_local: - if "public.ecr" in registry_hostname or "dkr.ecr" in str(registry_hostname): - credential = AWSElasticContainerRegistry_v2.get_credential(registry_conf) - else: - credential = registry_conf - - username = credential.get("username") - password = credential.get("password") - # push image to registry from local agent image_registry = ImageRegistry( name=registry_hostname, url=str(registry_conf[""]), - username=username, - password=password, + username=registry_conf.get("username"), + password=registry_conf.get("password"), ) resp = await root_ctx.registry.push_image( session.main_kernel.agent, diff --git a/src/ai/backend/manager/container_registry/aws_ecr.py b/src/ai/backend/manager/container_registry/aws_ecr.py index a1a4069e79..5a4a1b9a3e 100644 --- a/src/ai/backend/manager/container_registry/aws_ecr.py +++ b/src/ai/backend/manager/container_registry/aws_ecr.py @@ -1,12 +1,10 @@ -import base64 import logging -from typing import Any, AsyncIterator, Mapping +from typing import AsyncIterator import aiohttp import boto3 from ai.backend.common.logging import BraceStyleAdapter -from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from .base import ( BaseContainerRegistry, @@ -16,47 +14,6 @@ class AWSElasticContainerRegistry_v2(BaseContainerRegistry): - @staticmethod - def get_credential(registry_info: Mapping[str, Any]) -> dict[str, Any]: - access_key, secret_access_key, region, type_ = ( - registry_info["access_key"], - registry_info["secret_access_key"], - registry_info["region"], - registry_info["type"], - ) - - ecr_client = boto3.client( - type_, - region_name=region, - aws_access_key_id=access_key, - aws_secret_access_key=secret_access_key, - ) - - auth_token = ecr_client.get_authorization_token()["authorizationData"]["authorizationToken"] - decoded_auth_token = base64.b64decode(auth_token).decode("utf-8") - username, password = decoded_auth_token.split(":") - - return {"username": username, "password": password} - - def __init__( - self, - db: ExtendedAsyncSAEngine, - registry_name: str, - registry_info: Mapping[str, Any], - *, - max_concurrency_per_registry: int = 4, - ssl_verify: bool = True, - ) -> None: - super().__init__( - db, - registry_name, - registry_info, - max_concurrency_per_registry=max_concurrency_per_registry, - ssl_verify=ssl_verify, - ) - - self.credentials = AWSElasticContainerRegistry_v2.get_credential(registry_info) - async def fetch_repositories( self, sess: aiohttp.ClientSession, From e95013cd56bd744db51b926d8b88d2add2f5443b Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Mon, 5 Aug 2024 00:47:12 +0000 Subject: [PATCH 10/10] docs: Add comments --- src/ai/backend/common/docker.py | 2 ++ src/ai/backend/manager/container_registry/__init__.py | 4 ++-- src/ai/backend/manager/container_registry/aws_ecr.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index 284c80ee62..549f3d2ffb 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -258,6 +258,8 @@ async def login( return {"auth": basic_auth, "headers": {}} elif ping_status == 404: raise RuntimeError(f"Unsupported docker registry: {registry_url}! (API v2 not implemented)") + # Should check also 400 status since the AWS ECR Public server returns a 400 response + # when given invalid credential authorization. elif ping_status in [400, 401]: params = { "scope": scope, diff --git a/src/ai/backend/manager/container_registry/__init__.py b/src/ai/backend/manager/container_registry/__init__.py index d913905cea..a9d5ca0c22 100644 --- a/src/ai/backend/manager/container_registry/__init__.py +++ b/src/ai/backend/manager/container_registry/__init__.py @@ -37,9 +37,9 @@ def get_container_registry_cls(registry_info: Mapping[str, Any]) -> Type[BaseCon cr_cls = GitLabRegistry elif registry_type == "ecr" or registry_type == "ecr-public": - from .aws_ecr import AWSElasticContainerRegistry_v2 + from .aws_ecr import AWSElasticContainerRegistry - cr_cls = AWSElasticContainerRegistry_v2 + cr_cls = AWSElasticContainerRegistry elif registry_type == "local": from .local import LocalRegistry diff --git a/src/ai/backend/manager/container_registry/aws_ecr.py b/src/ai/backend/manager/container_registry/aws_ecr.py index 5a4a1b9a3e..98466f03a4 100644 --- a/src/ai/backend/manager/container_registry/aws_ecr.py +++ b/src/ai/backend/manager/container_registry/aws_ecr.py @@ -13,7 +13,7 @@ log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] -class AWSElasticContainerRegistry_v2(BaseContainerRegistry): +class AWSElasticContainerRegistry(BaseContainerRegistry): async def fetch_repositories( self, sess: aiohttp.ClientSession, @@ -41,6 +41,8 @@ async def fetch_repositories( response = client.describe_repositories(maxResults=30) for repo in response["repositories"]: + # repositoryUri format: + # public.ecr.aws// registry_alias = (repo["repositoryUri"].split("/"))[1] yield f"{registry_alias}/{repo["repositoryName"]}"