Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support AWS ECR Public Container Registry #2549

1 change: 1 addition & 0 deletions changes/2549.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support AWS ECR Public Container Registry
4 changes: 3 additions & 1 deletion src/ai/backend/common/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ async def login(
return {"auth": basic_auth, "headers": {}}
elif ping_status == 404:
raise RuntimeError(f"Unsupported docker registry: {registry_url}! (API v2 not implemented)")
elif ping_status == 401:
# Should check also 400 status since the AWS ECR Public server returns a 400 response
# when given invalid credential authorization.
elif ping_status in [400, 401]:
jopemachine marked this conversation as resolved.
Show resolved Hide resolved
params = {
"scope": scope,
"offline_token": "true",
Expand Down
4 changes: 4 additions & 0 deletions src/ai/backend/manager/container_registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def get_container_registry_cls(registry_info: Mapping[str, Any]) -> Type[BaseCon
from .gitlab import GitLabRegistry

cr_cls = GitLabRegistry
elif registry_type == "ecr" or registry_type == "ecr-public":
from .aws_ecr import AWSElasticContainerRegistry

cr_cls = AWSElasticContainerRegistry
elif registry_type == "local":
from .local import LocalRegistry

Expand Down
54 changes: 54 additions & 0 deletions src/ai/backend/manager/container_registry/aws_ecr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
from typing import AsyncIterator

import aiohttp
import boto3

from ai.backend.common.logging import BraceStyleAdapter

from .base import (
BaseContainerRegistry,
)

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]


class AWSElasticContainerRegistry(BaseContainerRegistry):
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
) -> AsyncIterator[str]:
access_key, secret_access_key, region, type_ = (
self.registry_info["access_key"],
self.registry_info["secret_access_key"],
self.registry_info["region"],
self.registry_info["type"],
)

client = boto3.client(
type_,
region_name=region,
aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key,
)

next_token = None
try:
while True:
if next_token:
response = client.describe_repositories(nextToken=next_token, maxResults=30)
else:
response = client.describe_repositories(maxResults=30)

for repo in response["repositories"]:
# repositoryUri format:
# public.ecr.aws/<registry_alias>/<repository>
registry_alias = (repo["repositoryUri"].split("/"))[1]
yield f"{registry_alias}/{repo["repositoryName"]}"

next_token = response.get("nextToken")

if not next_token:
break
except Exception as e:
log.error(f"Error occurred: {e}")
Loading