From b15f9fce0a061e79fbb4030bea6e2f1d1dfea238 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Thu, 25 Jul 2024 04:48:17 +0000 Subject: [PATCH] feat: Automatic issuance and use of `auth_token` --- src/ai/backend/common/docker.py | 7 ++- src/ai/backend/manager/api/session.py | 13 +++++- .../manager/container_registry/__init__.py | 2 +- .../container_registry/{aws.py => aws_ecr.py} | 45 ++++++++++++------- 4 files changed, 46 insertions(+), 21 deletions(-) rename src/ai/backend/manager/container_registry/{aws.py => aws_ecr.py} (69%) diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index cd4cac0969f..2c89c8d5e89 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import enum import functools import ipaddress @@ -235,10 +236,14 @@ async def login( basic_auth: Optional[aiohttp.BasicAuth] if "public.ecr" in str(registry_url) or "dkr.ecr" in str(registry_url): + auth_token = base64.b64encode( + f"{credentials["username"]}:{credentials["password"]}".encode() + ).decode() + return { "auth": None, "headers": { - "Authorization": f"Bearer {credentials['password']}", + "Authorization": f"Bearer {auth_token}", }, } diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 9446381372d..2e08d2e2bcd 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -50,6 +50,7 @@ from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.docker import ImageRef +from ai.backend.manager.container_registry.aws_ecr import AWSElasticContainerRegistry_v2 from ai.backend.manager.models.group import GroupRow from ai.backend.manager.models.image import rescan_images @@ -1187,12 +1188,20 @@ async def _commit_and_upload(reporter: ProgressReporter) -> None: raise BackendError(extra_msg="Operation cancelled") if not new_image_ref.is_local: + if "public.ecr" in registry_hostname or "dkr.ecr" in str(registry_hostname): + credential = AWSElasticContainerRegistry_v2.get_credential(registry_conf) + else: + credential = registry_conf + + username = credential.get("username") + password = credential.get("password") + # push image to registry from local agent image_registry = ImageRegistry( name=registry_hostname, url=str(registry_conf[""]), - username=registry_conf.get("username"), - password=registry_conf.get("password"), + username=username, + password=password, ) resp = await root_ctx.registry.push_image( session.main_kernel.agent, diff --git a/src/ai/backend/manager/container_registry/__init__.py b/src/ai/backend/manager/container_registry/__init__.py index 00f57849c7f..be251258b43 100644 --- a/src/ai/backend/manager/container_registry/__init__.py +++ b/src/ai/backend/manager/container_registry/__init__.py @@ -37,7 +37,7 @@ def get_container_registry_cls(registry_info: Mapping[str, Any]) -> Type[BaseCon cr_cls = GitLabRegistry_v2 elif registry_type == "ecr" or registry_type == "ecr-public": - from .aws import AWSElasticContainerRegistry_v2 + from .aws_ecr import AWSElasticContainerRegistry_v2 cr_cls = AWSElasticContainerRegistry_v2 elif registry_type == "local": diff --git a/src/ai/backend/manager/container_registry/aws.py b/src/ai/backend/manager/container_registry/aws_ecr.py similarity index 69% rename from src/ai/backend/manager/container_registry/aws.py rename to src/ai/backend/manager/container_registry/aws_ecr.py index 1f4c813b839..4871edf5f9b 100644 --- a/src/ai/backend/manager/container_registry/aws.py +++ b/src/ai/backend/manager/container_registry/aws_ecr.py @@ -1,3 +1,4 @@ +import base64 import logging from typing import Any, AsyncIterator, Mapping @@ -15,6 +16,28 @@ class AWSElasticContainerRegistry_v2(BaseContainerRegistry): + @staticmethod + def get_credential(registry_info: Mapping[str, Any]) -> dict[str, Any]: + access_key, secret_access_key, region, type_ = ( + registry_info["access_key"], + registry_info["secret_access_key"], + registry_info["region"], + registry_info["type"], + ) + + ecr_client = boto3.client( + type_, + region_name=region, + aws_access_key_id=access_key, + aws_secret_access_key=secret_access_key, + ) + + auth_token = ecr_client.get_authorization_token()["authorizationData"]["authorizationToken"] + decoded_auth_token = base64.b64decode(auth_token).decode("utf-8") + username, password = decoded_auth_token.split(":") + + return {"username": username, "password": password} + def __init__( self, db: ExtendedAsyncSAEngine, @@ -32,19 +55,7 @@ def __init__( ssl_verify=ssl_verify, ) - access_key, secret_access_key, region, type_ = ( - self.registry_info["access_key"], - self.registry_info["secret_access_key"], - self.registry_info["region"], - self.registry_info["type"], - ) - - self.ecr_client = boto3.client( - type_, - region_name=region, - aws_access_key_id=access_key, - aws_secret_access_key=secret_access_key, - ) + self.credentials = AWSElasticContainerRegistry_v2.get_credential(registry_info) async def fetch_repositories( self, @@ -73,13 +84,13 @@ async def fetch_repositories( response = client.describe_repositories() for repo in response["repositories"]: - # TODO: Verify this logic - repo_id = (repo["repositoryUri"].split("/"))[1] - yield f"{repo_id}/{repo["repositoryName"]}" + # TODO: Fix this. + registry_alias = (repo["repositoryUri"].split("/"))[1] + yield f"{registry_alias}/{repo["repositoryName"]}" next_token = response.get("nextToken") if not next_token: break except Exception as e: - print(f"An error occurred: {e}") + log.error(f"Error occurred: {e}")