diff --git a/.github/workflows/build-and-push-image.yml b/.github/workflows/build-and-push-image.yml index f0303083..ec18b55d 100644 --- a/.github/workflows/build-and-push-image.yml +++ b/.github/workflows/build-and-push-image.yml @@ -14,10 +14,10 @@ on: - '.github/dependabot.yml' - 'docs/**' env: - QUAY_ORG: opendatahub - QUAY_IMG_REPO: model-registry - QUAY_USERNAME: ${{ secrets.QUAY_USERNAME }} - QUAY_PASSWORD: ${{ secrets.QUAY_PASSWORD }} + IMG_ORG: opendatahub + IMG_REPO: model-registry + DOCKER_USER: ${{ secrets.QUAY_USERNAME }} + DOCKER_PWD: ${{ secrets.QUAY_PASSWORD }} PUSH_IMAGE: true jobs: build-image: @@ -50,8 +50,8 @@ jobs: if: env.BUILD_CONTEXT == 'main' shell: bash env: - IMG: quay.io/${{ env.QUAY_ORG }}/${{ env.QUAY_IMG_REPO }} - BUILD_IMAGE: false + IMG: quay.io/${{ env.IMG_ORG }}/${{ env.IMG_REPO }} + BUILD_IMAGE: false # image is already built in "Build and Push Image" step run: | docker tag ${{ env.IMG }}:$VERSION ${{ env.IMG }}:latest # BUILD_IMAGE=false skip the build, just push the tag made above @@ -60,8 +60,8 @@ jobs: if: env.BUILD_CONTEXT == 'main' shell: bash env: - IMG: quay.io/${{ env.QUAY_ORG }}/${{ env.QUAY_IMG_REPO }} - BUILD_IMAGE: false + IMG: quay.io/${{ env.IMG_ORG }}/${{ env.IMG_REPO }} + BUILD_IMAGE: false # image is already built in "Build and Push Image" step run: | docker tag ${{ env.IMG }}:$VERSION ${{ env.IMG }}:main # BUILD_IMAGE=false skip the build, just push the tag made above diff --git a/.github/workflows/build-image-pr.yml b/.github/workflows/build-image-pr.yml index 55f62e2f..c79e8bb5 100644 --- a/.github/workflows/build-image-pr.yml +++ b/.github/workflows/build-image-pr.yml @@ -11,7 +11,8 @@ on: - 'docs/**' - 'clients/python/**' env: - QUAY_IMG_REPO: model-registry + IMG_ORG: opendatahub + IMG_REPO: model-registry PUSH_IMAGE: false BRANCH: ${{ github.base_ref }} jobs: @@ -35,12 +36,12 @@ jobs: uses: helm/kind-action@v1.9.0 - name: Load Local Registry Test Image env: - IMG: "quay.io/opendatahub/model-registry:${{ steps.tags.outputs.tag }}" + IMG: "quay.io/${{ env.IMG_ORG }}/${{ env.IMG_REPO }}:${{ steps.tags.outputs.tag }}" run: | kind load docker-image -n chart-testing ${IMG} - name: Deploy Operator With Test Image env: - IMG: "quay.io/opendatahub/model-registry:${{ steps.tags.outputs.tag }}" + IMG: "quay.io/${{ env.IMG_ORG }}/${{ env.IMG_REPO }}:${{ steps.tags.outputs.tag }}" run: | echo "Deploying operator from model-registry-operator branch ${BRANCH}" kubectl apply -k "https://github.com/opendatahub-io/model-registry-operator.git/config/default?ref=${BRANCH}" diff --git a/Makefile b/Makefile index e248fcdc..6006d890 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ MLMD_VERSION ?= 1.14.0 DOCKER ?= docker # default Dockerfile DOCKERFILE ?= Dockerfile -# container registry +# container registry, default to empty (dockerhub) if not explicitly set IMG_REGISTRY ?= quay.io # container image organization IMG_ORG ?= opendatahub @@ -21,7 +21,11 @@ IMG_VERSION ?= main # container image repository IMG_REPO ?= model-registry # container image -IMG ?= ${IMG_REGISTRY}/$(IMG_ORG)/$(IMG_REPO) +ifdef IMG_REGISTRY + IMG := ${IMG_REGISTRY}/${IMG_ORG}/${IMG_REPO} +else + IMG := ${IMG_ORG}/${IMG_REPO} +endif model-registry: build @@ -190,7 +194,12 @@ proxy: build # login to docker .PHONY: docker/login docker/login: - $(DOCKER) login -u "${DOCKER_USER}" -p "${DOCKER_PWD}" "${IMG_REGISTRY}" + ifdef IMG_REGISTRY + $(DOCKER) login -u "${DOCKER_USER}" -p "${DOCKER_PWD}" "${IMG_REGISTRY}" + else + $(DOCKER) login -u "${DOCKER_USER}" -p "${DOCKER_PWD}" + endif + # build docker image .PHONY: image/build diff --git a/clients/python/README.md b/clients/python/README.md index e542e8e2..6af08cf1 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -14,12 +14,12 @@ registry = ModelRegistry(server_address="server-address", port=9090, author="aut model = registry.register_model( "my-model", # model name - "s3://path/to/model", # model URI + "https://storage-place.my-company.com", # model URI version="2.0.0", description="lorem ipsum", model_format_name="onnx", model_format_version="1", - storage_key="aws-connection-path", + storage_key="my-data-connection", storage_path="path/to/model", metadata={ # can be one of the following types @@ -37,10 +37,33 @@ version = registry.get_model_version("my-model", "v2.0") experiment = registry.get_model_artifact("my-model", "v2.0") ``` -### Default values for metadata +### Importing from S3 -If not supplied, `metadata` values defaults to a predefined set of conventional values. -Reference the technical documentation in the pydoc of the client. +When registering models stored on S3-compatible object storage, you should use `utils.s3_uri_from` to build an +unambiguous URI for your artifact. + +```py +from model_registry import ModelRegistry, utils + +registry = ModelRegistry(server_address="server-address", port=9090, author="author") + +model = registry.register_model( + "my-model", # model name + uri=utils.s3_uri_from("path/to/model", "my-bucket"), + version="2.0.0", + description="lorem ipsum", + model_format_name="onnx", + model_format_version="1", + storage_key="my-data-connection", + metadata={ + # can be one of the following types + "int_key": 1, + "bool_key": False, + "float_key": 3.14, + "str_key": "str_value", + } +) +``` ### Importing from Hugging Face Hub diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index 826f061f..5a3074e9 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -1,7 +1,7 @@ """Standard client for the model registry.""" + from __future__ import annotations -import os from typing import get_args from warnings import warn @@ -75,16 +75,22 @@ def register_model( model_format_name: str, model_format_version: str, version: str, - author: str | None = None, - description: str | None = None, storage_key: str | None = None, storage_path: str | None = None, service_account_name: str | None = None, + author: str | None = None, + description: str | None = None, metadata: dict[str, ScalarType] | None = None, ) -> RegisteredModel: """Register a model. - Either `storage_key` and `storage_path`, or `service_account_name` must be provided. + This registers a model in the model registry. The model is not downloaded, and has to be stored prior to + registration. + + Most models can be registered using their URI, along with optional connection-specific parameters, `storage_key` + and `storage_path` or, simply a `service_account_name`. + URI builder utilities are recommended when referring to specialized storage; for example `utils.s3_uri_from` + helper when using S3 object storage data connections. Args: name: Name of the model. @@ -110,7 +116,7 @@ def register_model( version, author or self._author, description=description, - metadata=metadata or self.default_metadata(), + metadata=metadata or {}, ) self._register_model_artifact( mv, @@ -124,19 +130,6 @@ def register_model( return rm - def default_metadata(self) -> dict[str, ScalarType]: - """Default metadata valorisations. - - When not explicitly supplied by the end users, these valorisations will be used - by default. - - Returns: - default metadata valorisations. - """ - return { - key: os.environ[key] for key in ["AWS_S3_ENDPOINT", "AWS_S3_BUCKET", "AWS_DEFAULT_REGION"] if key in os.environ - } - def register_hf_model( self, repo: str, @@ -202,7 +195,6 @@ def register_hf_model( model_author = author source_uri = hf_hub_url(repo, path, revision=git_ref) metadata = { - **self.default_metadata(), "repo": repo, "source_uri": source_uri, "model_origin": "huggingface_hub", diff --git a/clients/python/src/model_registry/_utils.py b/clients/python/src/model_registry/_utils.py new file mode 100644 index 00000000..b2a32cb8 --- /dev/null +++ b/clients/python/src/model_registry/_utils.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import functools +import inspect +from collections.abc import Sequence +from typing import Any, Callable, TypeVar + +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) + + +# copied from https://github.com/Rapptz/RoboDanny +def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: + size = len(seq) + if size == 0: + return "" + + if size == 1: + return seq[0] + + if size == 2: + return f"{seq[0]} {final} {seq[1]}" + + return delim.join(seq[:-1]) + f" {final} {seq[-1]}" + + +def quote(string: str) -> str: + """Add single quotation marks around the given string. Does *not* do any escaping.""" + return f"'{string}'" + + +# copied from https://github.com/openai/openai-python +def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: # noqa: C901 + """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. + + Useful for enforcing runtime validation of overloaded functions. + + Example usage: + ```py + @overload + def foo(*, a: str) -> str: + ... + + + @overload + def foo(*, b: bool) -> str: + ... + + + # This enforces the same constraints that a static type checker would + # i.e. that either a or b must be passed to the function + @required_args(["a"], ["b"]) + def foo(*, a: str | None = None, b: bool | None = None) -> str: + ... + ``` + """ + + def inner(func: CallableT) -> CallableT: # noqa: C901 + params = inspect.signature(func).parameters + positional = [ + name + for name, param in params.items() + if param.kind + in { + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + } + ] + + @functools.wraps(func) + def wrapper(*args: object, **kwargs: object) -> object: + given_params: set[str] = set() + for i, _ in enumerate(args): + try: + given_params.add(positional[i]) + except IndexError: + msg = f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given" + raise TypeError(msg) from None + + for key in kwargs: + given_params.add(key) + + for variant in variants: + matches = all(param in given_params for param in variant) + if matches: + break + else: # no break + if len(variants) > 1: + variations = human_join( + [ + "(" + + human_join([quote(arg) for arg in variant], final="and") + + ")" + for variant in variants + ] + ) + msg = f"Missing required arguments; Expected either {variations} arguments to be given" + else: + # TODO: this error message is not deterministic + missing = list(set(variants[0]) - given_params) + if len(missing) > 1: + msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" + else: + msg = f"Missing required argument: {quote(missing[0])}" + raise TypeError(msg) + return func(*args, **kwargs) + + return wrapper # type: ignore + + return inner diff --git a/clients/python/src/model_registry/exceptions.py b/clients/python/src/model_registry/exceptions.py index bac601d2..7a52ddac 100644 --- a/clients/python/src/model_registry/exceptions.py +++ b/clients/python/src/model_registry/exceptions.py @@ -5,6 +5,10 @@ class StoreException(Exception): """Storage related error.""" +class MissingMetadata(Exception): + """Not enough metadata to complete operation.""" + + class UnsupportedTypeException(StoreException): """Raised when an unsupported type is encountered.""" diff --git a/clients/python/src/model_registry/utils.py b/clients/python/src/model_registry/utils.py new file mode 100644 index 00000000..e60dcf5d --- /dev/null +++ b/clients/python/src/model_registry/utils.py @@ -0,0 +1,92 @@ +"""Utilities for the model registry.""" + +from __future__ import annotations + +import os + +from typing_extensions import overload + +from ._utils import required_args +from .exceptions import MissingMetadata + + +@overload +def s3_uri_from( + path: str, +) -> str: ... + + +@overload +def s3_uri_from( + path: str, + bucket: str, +) -> str: ... + + +@overload +def s3_uri_from( + path: str, + bucket: str, + *, + endpoint: str, + region: str, +) -> str: ... + + +@required_args( + (), + ( # pre-configured env + "bucket", + ), + ( # custom env or non-default bucket + "bucket", + "endpoint", + "region", + ), +) +def s3_uri_from( + path: str, + bucket: str | None = None, + *, + endpoint: str | None = None, + region: str | None = None, +) -> str: + """Build an S3 URI. + + This helper function builds an S3 URI from a path and a bucket name, assuming you have a configured environment + with a default bucket, endpoint, and region set. + If you don't, you must provide all three optional arguments. + That is also the case for custom environments, where the default bucket is not set, or if you want to use a + different bucket. + + Args: + path: Storage path. + bucket: Name of the S3 bucket. Defaults to AWS_S3_BUCKET. + endpoint: Endpoint of the S3 bucket. Defaults to AWS_S3_ENDPOINT. + region: Region of the S3 bucket. Defaults to AWS_DEFAULT_REGION. + + Returns: + S3 URI. + """ + default_bucket = os.environ.get("AWS_S3_BUCKET") + if not bucket: + if not default_bucket: + msg = "Custom environment requires all arguments" + raise MissingMetadata(msg) + bucket = default_bucket + elif (not default_bucket or default_bucket != bucket) and not endpoint: + msg = ( + "bucket_endpoint and bucket_region must be provided for non-default bucket" + ) + raise MissingMetadata(msg) + + endpoint = endpoint or os.getenv("AWS_S3_ENDPOINT") + region = region or os.getenv("AWS_DEFAULT_REGION") + + if not (endpoint and region): + msg = "Missing environment variables: bucket_endpoint and bucket_region are required" + raise MissingMetadata(msg) + + # https://alexwlchan.net/2020/s3-keys-are-not-file-paths/ nor do they resolve to valid URls + # FIXME: is this safe? + return f"s3://{bucket}/{path}?endpoint={endpoint}&defaultRegion={region}" diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 9e33b07b..e498b892 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -1,7 +1,7 @@ import os import pytest -from model_registry import ModelRegistry +from model_registry import ModelRegistry, utils from model_registry.core import ModelRegistryAPIClient from model_registry.exceptions import StoreException @@ -31,6 +31,27 @@ def test_register_new(mr_client: ModelRegistry): assert mr_api.get_model_artifact_by_params(mv.id) is not None +def test_register_new_using_s3_uri_builder(mr_client: ModelRegistry): + name = "test_model" + version = "1.0.0" + uri = utils.s3_uri_from( + "storage/path", "my-bucket", endpoint="my-endpoint", region="my-region" + ) + rm = mr_client.register_model( + name, + uri, + model_format_name="test_format", + model_format_version="test_version", + version=version, + ) + assert rm.id is not None + + mr_api = mr_client._api + assert (mv := mr_api.get_model_version_by_params(rm.id, version)) is not None + assert (ma := mr_api.get_model_artifact_by_params(mv.id)) is not None + assert ma.uri == uri + + def test_register_existing_version(mr_client: ModelRegistry): params = { "name": "test_model", @@ -56,7 +77,7 @@ def test_get(mr_client: ModelRegistry): model_format_name="test_format", model_format_version="test_version", version=version, - metadata=metadata + metadata=metadata, ) assert (_rm := mr_client.get_registered_model(name)) @@ -73,28 +94,6 @@ def test_get(mr_client: ModelRegistry): assert ma.id == _ma.id -def test_default_md(mr_client: ModelRegistry): - name = "test_model" - version = "1.0.0" - env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"} - for k, v in env_values.items(): - os.environ[k] = v - - assert mr_client.register_model( - name, - "s3", - model_format_name="test_format", - model_format_version="test_version", - version=version, - # ensure leave empty metadata - ) - assert (mv := mr_client.get_model_version(name, version)) - assert mv.metadata == env_values - - for k in env_values: - os.environ.pop(k) - - def test_hf_import(mr_client: ModelRegistry): pytest.importorskip("huggingface_hub") name = "openai-community/gpt2" @@ -113,19 +112,25 @@ def test_hf_import(mr_client: ModelRegistry): assert mv.author == author assert mv.metadata["model_author"] == author assert mv.metadata["model_origin"] == "huggingface_hub" - assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + assert ( + mv.metadata["source_uri"] + == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + ) assert mv.metadata["repo"] == name assert mr_client.get_model_artifact(name, version) def test_hf_import_default_env(mr_client: ModelRegistry): - """Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata - """ + """Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata""" pytest.importorskip("huggingface_hub") name = "openai-community/gpt2" version = "1.2.3" author = "test author" - env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"} + env_values = { + "AWS_S3_ENDPOINT": "value1", + "AWS_S3_BUCKET": "value2", + "AWS_DEFAULT_REGION": "value3", + } for k, v in env_values.items(): os.environ[k] = v @@ -140,7 +145,10 @@ def test_hf_import_default_env(mr_client: ModelRegistry): assert (mv := mr_client.get_model_version(name, version)) assert mv.metadata["model_author"] == author assert mv.metadata["model_origin"] == "huggingface_hub" - assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + assert ( + mv.metadata["source_uri"] + == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + ) assert mv.metadata["repo"] == name assert mr_client.get_model_artifact(name, version) diff --git a/clients/python/tests/test_utils.py b/clients/python/tests/test_utils.py new file mode 100644 index 00000000..a29f04f3 --- /dev/null +++ b/clients/python/tests/test_utils.py @@ -0,0 +1,72 @@ +import os + +import pytest +from model_registry.exceptions import MissingMetadata +from model_registry.utils import s3_uri_from + + +def test_s3_uri_builder(): + s3_uri = s3_uri_from( + "test-path", + "test-bucket", + endpoint="test-endpoint", + region="test-region", + ) + assert ( + s3_uri + == "s3://test-bucket/test-path?endpoint=test-endpoint&defaultRegion=test-region" + ) + + +def test_s3_uri_builder_without_env(): + os.environ.pop("AWS_S3_BUCKET", None) + os.environ.pop("AWS_S3_ENDPOINT", None) + os.environ.pop("AWS_DEFAULT_REGION", None) + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + ) + assert "custom environment" in str(e.value).lower() + + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + "test-bucket", + ) + assert "non-default bucket" in str(e.value).lower() + + +def test_s3_uri_builder_with_only_default_bucket_env(): + os.environ["AWS_S3_BUCKET"] = "test-bucket" + os.environ.pop("AWS_S3_ENDPOINT", None) + os.environ.pop("AWS_DEFAULT_REGION", None) + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + ) + assert "missing environment variables" in str(e.value).lower() + + +def test_s3_uri_builder_with_other_default_variables(): + os.environ.pop("AWS_S3_BUCKET", None) + os.environ["AWS_S3_ENDPOINT"] = "test-endpoint" + os.environ["AWS_DEFAULT_REGION"] = "test-region" + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + ) + assert "custom environment" in str(e.value).lower() + + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + "test-bucket", + ) + assert "non-default bucket" in str(e.value).lower() + + +def test_s3_uri_builder_with_complete_env(): + os.environ["AWS_S3_BUCKET"] = "test-bucket" + os.environ["AWS_S3_ENDPOINT"] = "test-endpoint" + os.environ["AWS_DEFAULT_REGION"] = "test-region" + assert s3_uri_from("test-path") == s3_uri_from("test-path", "test-bucket") diff --git a/go.mod b/go.mod index de27e15d..62910805 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/cpuguy83/dockercfg v0.3.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/docker/distribution v2.8.2+incompatible // indirect - github.com/docker/docker v24.0.7+incompatible // indirect + github.com/docker/docker v24.0.9+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect diff --git a/go.sum b/go.sum index b5da6005..c7bed18f 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v24.0.7+incompatible h1:Wo6l37AuwP3JaMnZa226lzVXGA3F9Ig1seQen0cKYlM= -github.com/docker/docker v24.0.7+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v24.0.9+incompatible h1:HPGzNmwfLZWdxHqK9/II92pyi1EpYKsAqcl4G0Of9v0= +github.com/docker/docker v24.0.9+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= diff --git a/internal/apiutils/api_utils.go b/internal/apiutils/api_utils.go index 6c5a10da..d601097c 100644 --- a/internal/apiutils/api_utils.go +++ b/internal/apiutils/api_utils.go @@ -51,6 +51,13 @@ func Of[E any](e E) *E { return &e } +func StrPtr(notEmpty string) *string { + if notEmpty == "" { + return nil + } + return ¬Empty +} + func BuildListOption(pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (api.ListOptions, error) { var pageSizeInt32 *int32 if pageSize != "" { diff --git a/internal/server/openapi/api_model_registry_service_service.go b/internal/server/openapi/api_model_registry_service_service.go index 18ce3323..e5d5b98f 100644 --- a/internal/server/openapi/api_model_registry_service_service.go +++ b/internal/server/openapi/api_model_registry_service_service.go @@ -138,7 +138,7 @@ func (s *ModelRegistryServiceAPIService) CreateRegisteredModel(ctx context.Conte // CreateRegisteredModelVersion - Create a ModelVersion in RegisteredModel func (s *ModelRegistryServiceAPIService) CreateRegisteredModelVersion(ctx context.Context, registeredmodelId string, modelVersion model.ModelVersion) (ImplResponse, error) { - result, err := s.coreApi.UpsertModelVersion(&modelVersion, ®isteredmodelId) + result, err := s.coreApi.UpsertModelVersion(&modelVersion, apiutils.StrPtr(registeredmodelId)) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } @@ -166,7 +166,7 @@ func (s *ModelRegistryServiceAPIService) CreateServingEnvironment(ctx context.Co // FindInferenceService - Get an InferenceServices that matches search parameters. func (s *ModelRegistryServiceAPIService) FindInferenceService(ctx context.Context, name string, externalId string, parentResourceId string) (ImplResponse, error) { - result, err := s.coreApi.GetInferenceServiceByParams(&name, &parentResourceId, &externalId) + result, err := s.coreApi.GetInferenceServiceByParams(apiutils.StrPtr(name), apiutils.StrPtr(parentResourceId), apiutils.StrPtr(externalId)) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } @@ -178,7 +178,7 @@ func (s *ModelRegistryServiceAPIService) FindInferenceService(ctx context.Contex // FindModelArtifact - Get a ModelArtifact that matches search parameters. func (s *ModelRegistryServiceAPIService) FindModelArtifact(ctx context.Context, name string, externalId string, parentResourceId string) (ImplResponse, error) { - result, err := s.coreApi.GetModelArtifactByParams(&name, &parentResourceId, &externalId) + result, err := s.coreApi.GetModelArtifactByParams(apiutils.StrPtr(name), apiutils.StrPtr(parentResourceId), apiutils.StrPtr(externalId)) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } @@ -190,7 +190,7 @@ func (s *ModelRegistryServiceAPIService) FindModelArtifact(ctx context.Context, // FindModelVersion - Get a ModelVersion that matches search parameters. func (s *ModelRegistryServiceAPIService) FindModelVersion(ctx context.Context, name string, externalId string, registeredModelId string) (ImplResponse, error) { - result, err := s.coreApi.GetModelVersionByParams(&name, ®isteredModelId, &externalId) + result, err := s.coreApi.GetModelVersionByParams(apiutils.StrPtr(name), apiutils.StrPtr(registeredModelId), apiutils.StrPtr(externalId)) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } @@ -202,7 +202,7 @@ func (s *ModelRegistryServiceAPIService) FindModelVersion(ctx context.Context, n // FindRegisteredModel - Get a RegisteredModel that matches search parameters. func (s *ModelRegistryServiceAPIService) FindRegisteredModel(ctx context.Context, name string, externalID string) (ImplResponse, error) { - result, err := s.coreApi.GetRegisteredModelByParams(&name, &externalID) + result, err := s.coreApi.GetRegisteredModelByParams(apiutils.StrPtr(name), apiutils.StrPtr(externalID)) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } @@ -214,7 +214,7 @@ func (s *ModelRegistryServiceAPIService) FindRegisteredModel(ctx context.Context // FindServingEnvironment - Find ServingEnvironment func (s *ModelRegistryServiceAPIService) FindServingEnvironment(ctx context.Context, name string, externalID string) (ImplResponse, error) { - result, err := s.coreApi.GetServingEnvironmentByParams(&name, &externalID) + result, err := s.coreApi.GetServingEnvironmentByParams(apiutils.StrPtr(name), apiutils.StrPtr(externalID)) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } @@ -229,7 +229,7 @@ func (s *ModelRegistryServiceAPIService) GetEnvironmentInferenceServices(ctx con if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } - result, err := s.coreApi.GetInferenceServices(listOpts, &servingenvironmentId, nil) + result, err := s.coreApi.GetInferenceServices(listOpts, apiutils.StrPtr(servingenvironmentId), nil) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } @@ -266,7 +266,7 @@ func (s *ModelRegistryServiceAPIService) GetInferenceServiceServes(ctx context.C if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } - result, err := s.coreApi.GetServeModels(listOpts, &inferenceserviceId) + result, err := s.coreApi.GetServeModels(listOpts, apiutils.StrPtr(inferenceserviceId)) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } @@ -347,7 +347,7 @@ func (s *ModelRegistryServiceAPIService) GetModelVersionArtifacts(ctx context.Co if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } - result, err := s.coreApi.GetArtifacts(listOpts, &modelversionId) + result, err := s.coreApi.GetArtifacts(listOpts, apiutils.StrPtr(modelversionId)) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } @@ -391,7 +391,7 @@ func (s *ModelRegistryServiceAPIService) GetRegisteredModelVersions(ctx context. if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } - result, err := s.coreApi.GetModelVersions(listOpts, ®isteredmodelId) + result, err := s.coreApi.GetModelVersions(listOpts, apiutils.StrPtr(registeredmodelId)) if err != nil { return Response(500, model.Error{Message: err.Error()}), nil } diff --git a/pkg/core/core.go b/pkg/core/core.go index dc0ec9fe..afb0eaf8 100644 --- a/pkg/core/core.go +++ b/pkg/core/core.go @@ -245,6 +245,7 @@ func (serv *ModelRegistryService) GetRegisteredModelByParams(name *string, exter } else { return nil, fmt.Errorf("invalid parameters call, supply either name or externalId") } + glog.Info("filterQuery ", filterQuery) getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ TypeName: &serv.nameConfig.RegisteredModelTypeName, @@ -779,6 +780,7 @@ func (serv *ModelRegistryService) GetModelArtifactByParams(artifactName *string, } else { return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and modelVersionId), or externalId") } + glog.Info("filterQuery ", filterQuery) artifactsResponse, err := serv.mlmdClient.GetArtifactsByType(context.Background(), &proto.GetArtifactsByTypeRequest{ TypeName: &serv.nameConfig.ModelArtifactTypeName, diff --git a/scripts/build_deploy.sh b/scripts/build_deploy.sh index df610c37..9ede030f 100755 --- a/scripts/build_deploy.sh +++ b/scripts/build_deploy.sh @@ -2,12 +2,12 @@ set -e -# quay.io credentials -QUAY_REGISTRY=quay.io -QUAY_ORG="${QUAY_ORG:-opendatahub}" -QUAY_IMG_REPO="${QUAY_IMG_REPO:-model-registry}" -QUAY_USERNAME="${QUAY_USERNAME}" -QUAY_PASSWORD="${QUAY_PASSWORD}" +# see Makefile for the IMG_ variables semantic +IMG_REGISTRY="quay.io" +IMG_ORG="${IMG_ORG:-opendatahub}" +IMG_REPO="${IMG_REPO:-model-registry}" +DOCKER_USER="${DOCKER_USER}" +DOCKER_PWD="${DOCKER_PWD}" # image version HASH="$(git rev-parse --short=7 HEAD)" @@ -27,15 +27,15 @@ SKIP_IF_EXISTING="${SKIP_IF_EXISTING:-false}" # assure docker exists docker -v foo >/dev/null 2>&1 || { echo >&2 "::error:: Docker is required. Aborting."; exit 1; } -# skip if image already existing -if [[ "${SKIP_IF_EXISTING,,}" == "true" ]]; then - TAGS=$(curl --request GET "https://$QUAY_REGISTRY/api/v1/repository/${QUAY_ORG}/${QUAY_IMG_REPO}/tag/?specificTag=${VERSION}") +# if quay.io, can opt to skip if image already existing +if [[ "${SKIP_IF_EXISTING,,}" == "true" && "${IMG_REGISTRY,,}" == "quay.io" ]]; then + TAGS=$(curl --request GET "https://$IMG_REGISTRY/api/v1/repository/${IMG_ORG}/${IMG_REPO}/tag/?specificTag=${VERSION}") LATEST_TAG_HAS_END_TS=$(echo $TAGS | jq .tags - | jq 'sort_by(.start_ts) | reverse' | jq '.[0].end_ts') NOT_EMPTY=$(echo ${TAGS} | jq .tags - | jq any) # Image only exists if there is a tag that does not have "end_ts" (i.e. it is still present). if [[ "$NOT_EMPTY" == "true" && $LATEST_TAG_HAS_END_TS == "null" ]]; then - echo "::error:: The image ${QUAY_ORG}/${QUAY_IMG_REPO}:${VERSION} already exists" + echo "::error:: The image ${IMG_ORG}/${IMG_REPO}:${VERSION} already exists" exit 1 else echo "Image does not exist...proceeding with build & push." @@ -46,9 +46,9 @@ fi if [[ "${BUILD_IMAGE,,}" == "true" ]]; then echo "Building container image.." make \ - IMG_REGISTRY="${QUAY_REGISTRY}" \ - IMG_ORG="${QUAY_ORG}" \ - IMG_REPO="${QUAY_IMG_REPO}" \ + IMG_REGISTRY="${IMG_REGISTRY}" \ + IMG_ORG="${IMG_ORG}" \ + IMG_REPO="${IMG_REPO}" \ IMG_VERSION="${VERSION}" \ image/build else @@ -59,12 +59,12 @@ fi if [[ "${PUSH_IMAGE,,}" == "true" ]]; then echo "Pushing container image.." make \ - IMG_REGISTRY="${QUAY_REGISTRY}" \ - IMG_ORG="${QUAY_ORG}" \ - IMG_REPO="${QUAY_IMG_REPO}" \ + IMG_REGISTRY="${IMG_REGISTRY}" \ + IMG_ORG="${IMG_ORG}" \ + IMG_REPO="${IMG_REPO}" \ IMG_VERSION="${VERSION}" \ - DOCKER_USER="${QUAY_USERNAME}"\ - DOCKER_PWD="${QUAY_PASSWORD}" \ + DOCKER_USER="${DOCKER_USER}"\ + DOCKER_PWD="${DOCKER_PWD}" \ docker/login \ image/push else diff --git a/test/robot/MRkeywords.resource b/test/robot/MRkeywords.resource index 2bd25b44..46bc0f05 100644 --- a/test/robot/MRkeywords.resource +++ b/test/robot/MRkeywords.resource @@ -121,6 +121,18 @@ I get RegisteredModelByID END RETURN ${result} +I findRegisteredModel by name + [Arguments] ${name} + IF $MODE == "REST" + ${resp}= GET url=http://${MR_HOST}:${MR_PORT}/api/model_registry/v1alpha3/registered_model?name=${name} expected_status=200 + ${result} Set Variable ${resp.json()} + Log to console ${resp.json()} + ELSE + Log to console ${MODE} + Fail Not Implemented + END + RETURN ${result} + I get ModelVersionByID [Arguments] ${id} @@ -135,6 +147,19 @@ I get ModelVersionByID RETURN ${result} +I findModelVersion by name and parentResourceId + [Arguments] ${name} ${parentResourceId} + IF $MODE == "REST" + ${resp}= GET url=http://${MR_HOST}:${MR_PORT}/api/model_registry/v1alpha3/model_version?name=${name}&parentResourceId=${parentResourceId} expected_status=200 + ${result} Set Variable ${resp.json()} + Log to console ${resp.json()} + ELSE + Log to console ${MODE} + Fail Not Implemented + END + RETURN ${result} + + I get ModelArtifactByID [Arguments] ${id} IF $MODE == "REST" @@ -148,6 +173,19 @@ I get ModelArtifactByID RETURN ${result} +I findModelArtifact by name and parentResourceId + [Arguments] ${name} ${parentResourceId} + IF $MODE == "REST" + ${resp}= GET url=http://${MR_HOST}:${MR_PORT}/api/model_registry/v1alpha3/model_artifact?name=${name}&parentResourceId=${parentResourceId} expected_status=200 + ${result} Set Variable ${resp.json()} + Log to console ${resp.json()} + ELSE + Log to console ${MODE} + Fail Not Implemented + END + RETURN ${result} + + I get ArtifactsByModelVersionID [Arguments] ${id} IF $MODE == "REST" diff --git a/test/robot/UserStory.robot b/test/robot/UserStory.robot index 4fa234c5..6fe598d1 100644 --- a/test/robot/UserStory.robot +++ b/test/robot/UserStory.robot @@ -29,10 +29,16 @@ As a MLOps engineer I would like to store a description of the model ${aId} And I create a child ModelArtifact modelversionId=${vId} payload=&{model_artifact} ${r} Then I get RegisteredModelByID id=${rId} And Should be equal ${r["description"]} Lorem ipsum dolor sit amet + ${r} Then I findRegisteredModel by name name=${name} + And Should be equal ${r["description"]} Lorem ipsum dolor sit amet ${r} Then I get ModelVersionByID id=${vId} And Should be equal ${r["description"]} consectetur adipiscing elit + ${r} Then I findModelVersion by name and parentResourceId name=v1.2.3 parentResourceId=${rId} + And Should be equal ${r["description"]} consectetur adipiscing elit ${r} Then I get ModelArtifactByID id=${aId} And Should be equal ${r["description"]} sed do eiusmod tempor incididunt + ${r} Then I findModelArtifact by name and parentResourceId name=ModelArtifactName parentResourceId=${vId} + And Should be equal ${r["description"]} sed do eiusmod tempor incididunt As a MLOps engineer I would like to store a longer documentation for the model Set To Dictionary ${registered_model} description=Lorem ipsum dolor sit amet name=${name}