Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from kubeflow:main #46

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions api/openapi/model-registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1047,10 +1047,11 @@ components:
allOf:
- $ref: "#/components/schemas/BaseResourceCreate"
- $ref: "#/components/schemas/ModelVersionUpdate"
properties:
registeredModelId:
description: ID of the `RegisteredModel` to which this version belongs.
type: string
- type: object
properties:
registeredModelId:
description: ID of the `RegisteredModel` to which this version belongs.
type: string
ModelVersionUpdate:
description: Represents a ModelVersion belonging to a RegisteredModel.
allOf:
Expand Down
9 changes: 4 additions & 5 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Client for the model registry."""
from __future__ import annotations

from collections.abc import Sequence
from __future__ import annotations

from ml_metadata.proto import MetadataStoreClientConfig

Expand Down Expand Up @@ -130,7 +129,7 @@ def get_registered_model_by_params(

def get_registered_models(
self, options: ListOptions | None = None
) -> Sequence[RegisteredModel]:
) -> list[RegisteredModel]:
"""Fetch registered models.

Args:
Expand Down Expand Up @@ -194,7 +193,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None:

def get_model_versions(
self, registered_model_id: str, options: ListOptions | None = None
) -> Sequence[ModelVersion]:
) -> list[ModelVersion]:
"""Fetch model versions by registered model ID.

Args:
Expand Down Expand Up @@ -344,7 +343,7 @@ def get_model_artifacts(
self,
model_version_id: str | None = None,
options: ListOptions | None = None,
) -> Sequence[ModelArtifact]:
) -> list[ModelArtifact]:
"""Fetches model artifacts.

Args:
Expand Down
32 changes: 21 additions & 11 deletions clients/python/src/model_registry/store/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def put_context(self, context: Context) -> int:

def _filter_type(
self, type_name: str, protos: Sequence[ProtoType]
) -> Sequence[ProtoType]:
) -> list[ProtoType]:
return [proto for proto in protos if proto.type == type_name]

def get_context(
Expand Down Expand Up @@ -168,9 +168,7 @@ def get_context(

return None

def get_contexts(
self, ctx_type_name: str, options: ListOptions
) -> Sequence[Context]:
def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context]:
"""Get contexts from the store.

Args:
Expand All @@ -179,6 +177,11 @@ def get_contexts(

Returns:
Contexts.

Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
# TODO: should we make options optional?
# if options is not None:
Expand All @@ -195,9 +198,11 @@ def get_contexts(
# else:
# contexts = self._mlmd_store.get_contexts_by_type(ctx_type_name)

if not contexts:
if not contexts and ctx_type_name not in [
t.name for t in self._mlmd_store.get_context_types()
]:
msg = f"Context type {ctx_type_name} does not exist"
raise StoreException(msg)
raise TypeNotFoundException(msg)

return contexts

Expand Down Expand Up @@ -309,9 +314,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact:

return None

def get_artifacts(
self, art_type_name: str, options: ListOptions
) -> Sequence[Artifact]:
def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifact]:
"""Get artifacts from the store.

Args:
Expand All @@ -320,6 +323,11 @@ def get_artifacts(

Returns:
Artifacts.

Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
try:
artifacts = self._mlmd_store.get_artifacts(options)
Expand All @@ -331,8 +339,10 @@ def get_artifacts(
raise ServerException(msg) from e

artifacts = self._filter_type(art_type_name, artifacts)
if not artifacts:
if not artifacts and art_type_name not in [
t.name for t in self._mlmd_store.get_artifact_types()
]:
msg = f"Artifact type {art_type_name} does not exist"
raise StoreException(msg)
raise TypeNotFoundException(msg)

return artifacts
23 changes: 23 additions & 0 deletions clients/python/tests/store/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TypeNotFoundException,
)
from model_registry.store import MLMDStore
from model_registry.types.options import MLMDListOptions


@pytest.fixture()
Expand Down Expand Up @@ -53,6 +54,28 @@ def test_get_undefined_context_type_id(plain_wrapper: MLMDStore):
plain_wrapper.get_type_id(Context, "undefined")


@pytest.mark.usefixtures("artifact")
def test_get_no_artifacts(plain_wrapper: MLMDStore):
arts = plain_wrapper.get_artifacts("test_artifact", MLMDListOptions())
assert arts == []


def test_get_undefined_artifacts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_artifacts("undefined", MLMDListOptions())


@pytest.mark.usefixtures("context")
def test_get_no_contexts(plain_wrapper: MLMDStore):
ctxs = plain_wrapper.get_contexts("test_context", MLMDListOptions())
assert ctxs == []


def test_get_undefined_contexts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_contexts("undefined", MLMDListOptions())


def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact):
artifact.properties["null"].int_value = 0

Expand Down
5 changes: 5 additions & 0 deletions internal/converter/generated/mlmd_openapi_converter.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions internal/converter/generated/openapi_converter.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions internal/converter/mlmd_converter_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,24 @@ func TestMapNameFromOwned(t *testing.T) {
assertion.Nil(name)
}

func TestMapRegisteredModelIdFromOwned(t *testing.T) {
assertion := setup(t)

result, err := MapRegisteredModelIdFromOwned(of("prefix:name"))
assertion.Nil(err)
assertion.Equal("prefix", result)

_, err = MapRegisteredModelIdFromOwned(of("name"))
assertion.NotNil(err)

_, err = MapRegisteredModelIdFromOwned(of("prefix:name:postfix"))
assertion.NotNil(err)

result, err = MapRegisteredModelIdFromOwned(nil)
assertion.Nil(err)
assertion.Equal("", result)
}

func TestMapArtifactType(t *testing.T) {
assertion := setup(t)

Expand Down
1 change: 1 addition & 0 deletions internal/converter/mlmd_openapi_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type MLMDToOpenAPIConverter interface {
ConvertRegisteredModel(source *proto.Context) (*openapi.RegisteredModel, error)

// goverter:map Name | MapNameFromOwned
// goverter:map Name RegisteredModelId | MapRegisteredModelIdFromOwned
// goverter:map Properties Description | MapDescription
// goverter:map Properties State | MapModelVersionState
// goverter:map Properties Author | MapPropertyAuthor
Expand Down
12 changes: 12 additions & 0 deletions internal/converter/mlmd_openapi_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ func MapPropertyAuthor(properties map[string]*proto.Value) *string {
return MapStringProperty(properties, "author")
}

func MapRegisteredModelIdFromOwned(source *string) (string, error) {
if source == nil {
return "", nil
}

exploded := strings.Split(*source, ":")
if len(exploded) != 2 {
return "", fmt.Errorf("wrong owned format")
}
return exploded[0], nil
}

// ARTIFACT

func MapArtifactType(source *proto.Artifact) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/converter/openapi_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type OpenAPIConverter interface {
// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch
ConvertModelVersionCreate(source *openapi.ModelVersionCreate) (*openapi.ModelVersion, error)

// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Name
// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Name RegisteredModelId
ConvertModelVersionUpdate(source *openapi.ModelVersionUpdate) (*openapi.ModelVersion, error)

// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch ArtifactType
Expand Down
Loading
Loading