Skip to content

Commit

Permalink
feat: Automatic issuance and use of auth_token
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Jul 26, 2024
1 parent d85a358 commit b15f9fc
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 21 deletions.
7 changes: 6 additions & 1 deletion src/ai/backend/common/docker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import base64
import enum
import functools
import ipaddress
Expand Down Expand Up @@ -235,10 +236,14 @@ async def login(
basic_auth: Optional[aiohttp.BasicAuth]

if "public.ecr" in str(registry_url) or "dkr.ecr" in str(registry_url):
auth_token = base64.b64encode(
f"{credentials["username"]}:{credentials["password"]}".encode()
).decode()

return {
"auth": None,
"headers": {
"Authorization": f"Bearer {credentials['password']}",
"Authorization": f"Bearer {auth_token}",
},
}

Expand Down
13 changes: 11 additions & 2 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.docker import ImageRef
from ai.backend.manager.container_registry.aws_ecr import AWSElasticContainerRegistry_v2
from ai.backend.manager.models.group import GroupRow
from ai.backend.manager.models.image import rescan_images

Expand Down Expand Up @@ -1187,12 +1188,20 @@ async def _commit_and_upload(reporter: ProgressReporter) -> None:
raise BackendError(extra_msg="Operation cancelled")

if not new_image_ref.is_local:
if "public.ecr" in registry_hostname or "dkr.ecr" in str(registry_hostname):
credential = AWSElasticContainerRegistry_v2.get_credential(registry_conf)
else:
credential = registry_conf

username = credential.get("username")
password = credential.get("password")

# push image to registry from local agent
image_registry = ImageRegistry(
name=registry_hostname,
url=str(registry_conf[""]),
username=registry_conf.get("username"),
password=registry_conf.get("password"),
username=username,
password=password,
)
resp = await root_ctx.registry.push_image(
session.main_kernel.agent,
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/container_registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_container_registry_cls(registry_info: Mapping[str, Any]) -> Type[BaseCon

cr_cls = GitLabRegistry_v2
elif registry_type == "ecr" or registry_type == "ecr-public":
from .aws import AWSElasticContainerRegistry_v2
from .aws_ecr import AWSElasticContainerRegistry_v2

cr_cls = AWSElasticContainerRegistry_v2
elif registry_type == "local":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import logging
from typing import Any, AsyncIterator, Mapping

Expand All @@ -15,6 +16,28 @@


class AWSElasticContainerRegistry_v2(BaseContainerRegistry):
@staticmethod
def get_credential(registry_info: Mapping[str, Any]) -> dict[str, Any]:
access_key, secret_access_key, region, type_ = (
registry_info["access_key"],
registry_info["secret_access_key"],
registry_info["region"],
registry_info["type"],
)

ecr_client = boto3.client(
type_,
region_name=region,
aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key,
)

auth_token = ecr_client.get_authorization_token()["authorizationData"]["authorizationToken"]
decoded_auth_token = base64.b64decode(auth_token).decode("utf-8")
username, password = decoded_auth_token.split(":")

return {"username": username, "password": password}

def __init__(
self,
db: ExtendedAsyncSAEngine,
Expand All @@ -32,19 +55,7 @@ def __init__(
ssl_verify=ssl_verify,
)

access_key, secret_access_key, region, type_ = (
self.registry_info["access_key"],
self.registry_info["secret_access_key"],
self.registry_info["region"],
self.registry_info["type"],
)

self.ecr_client = boto3.client(
type_,
region_name=region,
aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key,
)
self.credentials = AWSElasticContainerRegistry_v2.get_credential(registry_info)

async def fetch_repositories(
self,
Expand Down Expand Up @@ -73,13 +84,13 @@ async def fetch_repositories(
response = client.describe_repositories()

for repo in response["repositories"]:
# TODO: Verify this logic
repo_id = (repo["repositoryUri"].split("/"))[1]
yield f"{repo_id}/{repo["repositoryName"]}"
# TODO: Fix this.
registry_alias = (repo["repositoryUri"].split("/"))[1]
yield f"{registry_alias}/{repo["repositoryName"]}"

next_token = response.get("nextToken")

if not next_token:
break
except Exception as e:
print(f"An error occurred: {e}")
log.error(f"Error occurred: {e}")

0 comments on commit b15f9fc

Please sign in to comment.