Skip to content

Commit

Permalink
[pull] main from kubeflow:main (#46)
Browse files Browse the repository at this point in the history
* Py: Fix misleading errors on missing list results (#65)

* py: fix type annotations to return concrete types

Signed-off-by: Isabella Basso do Amaral <idoamara@redhat.com>

* py: fix missing type check on empty list results

Signed-off-by: Isabella Basso do Amaral <idoamara@redhat.com>

---------

Signed-off-by: Isabella Basso do Amaral <idoamara@redhat.com>

* fix: OpenAPI ModelVersion shall contain registeredModelId property (#61)

need to adapt property definition in OpenAPI
to accomodate openapi-codegen result;
according to contract (as also visible in swagger)
the ModelVersion is to contain property: registeredModelId

Signed-off-by: Matteo Mortari <matteo.mortari@gmail.com>

---------

Signed-off-by: Isabella Basso do Amaral <idoamara@redhat.com>
Signed-off-by: Matteo Mortari <matteo.mortari@gmail.com>
Co-authored-by: Isabella Basso <idoamara@redhat.com>
Co-authored-by: Matteo Mortari <matteo.mortari@gmail.com>
  • Loading branch information
3 people authored Apr 17, 2024
1 parent 6f72a00 commit 4e2a58c
Show file tree
Hide file tree
Showing 15 changed files with 243 additions and 123 deletions.
9 changes: 5 additions & 4 deletions api/openapi/model-registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1047,10 +1047,11 @@ components:
allOf:
- $ref: "#/components/schemas/BaseResourceCreate"
- $ref: "#/components/schemas/ModelVersionUpdate"
properties:
registeredModelId:
description: ID of the `RegisteredModel` to which this version belongs.
type: string
- type: object
properties:
registeredModelId:
description: ID of the `RegisteredModel` to which this version belongs.
type: string
ModelVersionUpdate:
description: Represents a ModelVersion belonging to a RegisteredModel.
allOf:
Expand Down
9 changes: 4 additions & 5 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Client for the model registry."""
from __future__ import annotations

from collections.abc import Sequence
from __future__ import annotations

from ml_metadata.proto import MetadataStoreClientConfig

Expand Down Expand Up @@ -130,7 +129,7 @@ def get_registered_model_by_params(

def get_registered_models(
self, options: ListOptions | None = None
) -> Sequence[RegisteredModel]:
) -> list[RegisteredModel]:
"""Fetch registered models.
Args:
Expand Down Expand Up @@ -194,7 +193,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None:

def get_model_versions(
self, registered_model_id: str, options: ListOptions | None = None
) -> Sequence[ModelVersion]:
) -> list[ModelVersion]:
"""Fetch model versions by registered model ID.
Args:
Expand Down Expand Up @@ -344,7 +343,7 @@ def get_model_artifacts(
self,
model_version_id: str | None = None,
options: ListOptions | None = None,
) -> Sequence[ModelArtifact]:
) -> list[ModelArtifact]:
"""Fetches model artifacts.
Args:
Expand Down
32 changes: 21 additions & 11 deletions clients/python/src/model_registry/store/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def put_context(self, context: Context) -> int:

def _filter_type(
self, type_name: str, protos: Sequence[ProtoType]
) -> Sequence[ProtoType]:
) -> list[ProtoType]:
return [proto for proto in protos if proto.type == type_name]

def get_context(
Expand Down Expand Up @@ -168,9 +168,7 @@ def get_context(

return None

def get_contexts(
self, ctx_type_name: str, options: ListOptions
) -> Sequence[Context]:
def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context]:
"""Get contexts from the store.
Args:
Expand All @@ -179,6 +177,11 @@ def get_contexts(
Returns:
Contexts.
Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
# TODO: should we make options optional?
# if options is not None:
Expand All @@ -195,9 +198,11 @@ def get_contexts(
# else:
# contexts = self._mlmd_store.get_contexts_by_type(ctx_type_name)

if not contexts:
if not contexts and ctx_type_name not in [
t.name for t in self._mlmd_store.get_context_types()
]:
msg = f"Context type {ctx_type_name} does not exist"
raise StoreException(msg)
raise TypeNotFoundException(msg)

return contexts

Expand Down Expand Up @@ -309,9 +314,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact:

return None

def get_artifacts(
self, art_type_name: str, options: ListOptions
) -> Sequence[Artifact]:
def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifact]:
"""Get artifacts from the store.
Args:
Expand All @@ -320,6 +323,11 @@ def get_artifacts(
Returns:
Artifacts.
Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
try:
artifacts = self._mlmd_store.get_artifacts(options)
Expand All @@ -331,8 +339,10 @@ def get_artifacts(
raise ServerException(msg) from e

artifacts = self._filter_type(art_type_name, artifacts)
if not artifacts:
if not artifacts and art_type_name not in [
t.name for t in self._mlmd_store.get_artifact_types()
]:
msg = f"Artifact type {art_type_name} does not exist"
raise StoreException(msg)
raise TypeNotFoundException(msg)

return artifacts
23 changes: 23 additions & 0 deletions clients/python/tests/store/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TypeNotFoundException,
)
from model_registry.store import MLMDStore
from model_registry.types.options import MLMDListOptions


@pytest.fixture()
Expand Down Expand Up @@ -53,6 +54,28 @@ def test_get_undefined_context_type_id(plain_wrapper: MLMDStore):
plain_wrapper.get_type_id(Context, "undefined")


@pytest.mark.usefixtures("artifact")
def test_get_no_artifacts(plain_wrapper: MLMDStore):
arts = plain_wrapper.get_artifacts("test_artifact", MLMDListOptions())
assert arts == []


def test_get_undefined_artifacts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_artifacts("undefined", MLMDListOptions())


@pytest.mark.usefixtures("context")
def test_get_no_contexts(plain_wrapper: MLMDStore):
ctxs = plain_wrapper.get_contexts("test_context", MLMDListOptions())
assert ctxs == []


def test_get_undefined_contexts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_contexts("undefined", MLMDListOptions())


def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact):
artifact.properties["null"].int_value = 0

Expand Down
5 changes: 5 additions & 0 deletions internal/converter/generated/mlmd_openapi_converter.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions internal/converter/generated/openapi_converter.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions internal/converter/mlmd_converter_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,24 @@ func TestMapNameFromOwned(t *testing.T) {
assertion.Nil(name)
}

func TestMapRegisteredModelIdFromOwned(t *testing.T) {
assertion := setup(t)

result, err := MapRegisteredModelIdFromOwned(of("prefix:name"))
assertion.Nil(err)
assertion.Equal("prefix", result)

_, err = MapRegisteredModelIdFromOwned(of("name"))
assertion.NotNil(err)

_, err = MapRegisteredModelIdFromOwned(of("prefix:name:postfix"))
assertion.NotNil(err)

result, err = MapRegisteredModelIdFromOwned(nil)
assertion.Nil(err)
assertion.Equal("", result)
}

func TestMapArtifactType(t *testing.T) {
assertion := setup(t)

Expand Down
1 change: 1 addition & 0 deletions internal/converter/mlmd_openapi_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type MLMDToOpenAPIConverter interface {
ConvertRegisteredModel(source *proto.Context) (*openapi.RegisteredModel, error)

// goverter:map Name | MapNameFromOwned
// goverter:map Name RegisteredModelId | MapRegisteredModelIdFromOwned
// goverter:map Properties Description | MapDescription
// goverter:map Properties State | MapModelVersionState
// goverter:map Properties Author | MapPropertyAuthor
Expand Down
12 changes: 12 additions & 0 deletions internal/converter/mlmd_openapi_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ func MapPropertyAuthor(properties map[string]*proto.Value) *string {
return MapStringProperty(properties, "author")
}

func MapRegisteredModelIdFromOwned(source *string) (string, error) {
if source == nil {
return "", nil
}

exploded := strings.Split(*source, ":")
if len(exploded) != 2 {
return "", fmt.Errorf("wrong owned format")
}
return exploded[0], nil
}

// ARTIFACT

func MapArtifactType(source *proto.Artifact) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/converter/openapi_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type OpenAPIConverter interface {
// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch
ConvertModelVersionCreate(source *openapi.ModelVersionCreate) (*openapi.ModelVersion, error)

// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Name
// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Name RegisteredModelId
ConvertModelVersionUpdate(source *openapi.ModelVersionUpdate) (*openapi.ModelVersion, error)

// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch ArtifactType
Expand Down
Loading

0 comments on commit 4e2a58c

Please sign in to comment.