diff --git a/changes/2549.feature.md b/changes/2549.feature.md new file mode 100644 index 0000000000..5bcdf5921f --- /dev/null +++ b/changes/2549.feature.md @@ -0,0 +1 @@ +Support AWS ECR Public Container Registry \ No newline at end of file diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index 58d9a7fb6e..549f3d2ffb 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -258,7 +258,9 @@ async def login( return {"auth": basic_auth, "headers": {}} elif ping_status == 404: raise RuntimeError(f"Unsupported docker registry: {registry_url}! (API v2 not implemented)") - elif ping_status == 401: + # Should check also 400 status since the AWS ECR Public server returns a 400 response + # when given invalid credential authorization. + elif ping_status in [400, 401]: params = { "scope": scope, "offline_token": "true", diff --git a/src/ai/backend/manager/container_registry/__init__.py b/src/ai/backend/manager/container_registry/__init__.py index d3f773fbe2..a9d5ca0c22 100644 --- a/src/ai/backend/manager/container_registry/__init__.py +++ b/src/ai/backend/manager/container_registry/__init__.py @@ -36,6 +36,10 @@ def get_container_registry_cls(registry_info: Mapping[str, Any]) -> Type[BaseCon from .gitlab import GitLabRegistry cr_cls = GitLabRegistry + elif registry_type == "ecr" or registry_type == "ecr-public": + from .aws_ecr import AWSElasticContainerRegistry + + cr_cls = AWSElasticContainerRegistry elif registry_type == "local": from .local import LocalRegistry diff --git a/src/ai/backend/manager/container_registry/aws_ecr.py b/src/ai/backend/manager/container_registry/aws_ecr.py new file mode 100644 index 0000000000..98466f03a4 --- /dev/null +++ b/src/ai/backend/manager/container_registry/aws_ecr.py @@ -0,0 +1,54 @@ +import logging +from typing import AsyncIterator + +import aiohttp +import boto3 + +from ai.backend.common.logging import BraceStyleAdapter + +from .base import ( + BaseContainerRegistry, +) + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] + + +class AWSElasticContainerRegistry(BaseContainerRegistry): + async def fetch_repositories( + self, + sess: aiohttp.ClientSession, + ) -> AsyncIterator[str]: + access_key, secret_access_key, region, type_ = ( + self.registry_info["access_key"], + self.registry_info["secret_access_key"], + self.registry_info["region"], + self.registry_info["type"], + ) + + client = boto3.client( + type_, + region_name=region, + aws_access_key_id=access_key, + aws_secret_access_key=secret_access_key, + ) + + next_token = None + try: + while True: + if next_token: + response = client.describe_repositories(nextToken=next_token, maxResults=30) + else: + response = client.describe_repositories(maxResults=30) + + for repo in response["repositories"]: + # repositoryUri format: + # public.ecr.aws// + registry_alias = (repo["repositoryUri"].split("/"))[1] + yield f"{registry_alias}/{repo["repositoryName"]}" + + next_token = response.get("nextToken") + + if not next_token: + break + except Exception as e: + log.error(f"Error occurred: {e}")