Skip to content

Commit

Permalink
Revert "[pull] main from kubeflow:main (#46)" due to squash
Browse files Browse the repository at this point in the history
This reverts commit 4e2a58c.

Missing the correct label, the automated merge did
not merge-commit but squash-commit.
Reverting.
  • Loading branch information
tarilabs committed Apr 18, 2024
1 parent 34bcfbc commit 7ee172e
Show file tree
Hide file tree
Showing 15 changed files with 123 additions and 243 deletions.
9 changes: 4 additions & 5 deletions api/openapi/model-registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1047,11 +1047,10 @@ components:
allOf:
- $ref: "#/components/schemas/BaseResourceCreate"
- $ref: "#/components/schemas/ModelVersionUpdate"
- type: object
properties:
registeredModelId:
description: ID of the `RegisteredModel` to which this version belongs.
type: string
properties:
registeredModelId:
description: ID of the `RegisteredModel` to which this version belongs.
type: string
ModelVersionUpdate:
description: Represents a ModelVersion belonging to a RegisteredModel.
allOf:
Expand Down
9 changes: 5 additions & 4 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Client for the model registry."""

from __future__ import annotations

from collections.abc import Sequence

from ml_metadata.proto import MetadataStoreClientConfig

from .exceptions import StoreException
Expand Down Expand Up @@ -129,7 +130,7 @@ def get_registered_model_by_params(

def get_registered_models(
self, options: ListOptions | None = None
) -> list[RegisteredModel]:
) -> Sequence[RegisteredModel]:
"""Fetch registered models.
Args:
Expand Down Expand Up @@ -193,7 +194,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None:

def get_model_versions(
self, registered_model_id: str, options: ListOptions | None = None
) -> list[ModelVersion]:
) -> Sequence[ModelVersion]:
"""Fetch model versions by registered model ID.
Args:
Expand Down Expand Up @@ -343,7 +344,7 @@ def get_model_artifacts(
self,
model_version_id: str | None = None,
options: ListOptions | None = None,
) -> list[ModelArtifact]:
) -> Sequence[ModelArtifact]:
"""Fetches model artifacts.
Args:
Expand Down
32 changes: 11 additions & 21 deletions clients/python/src/model_registry/store/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def put_context(self, context: Context) -> int:

def _filter_type(
self, type_name: str, protos: Sequence[ProtoType]
) -> list[ProtoType]:
) -> Sequence[ProtoType]:
return [proto for proto in protos if proto.type == type_name]

def get_context(
Expand Down Expand Up @@ -168,7 +168,9 @@ def get_context(

return None

def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context]:
def get_contexts(
self, ctx_type_name: str, options: ListOptions
) -> Sequence[Context]:
"""Get contexts from the store.
Args:
Expand All @@ -177,11 +179,6 @@ def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context
Returns:
Contexts.
Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
# TODO: should we make options optional?
# if options is not None:
Expand All @@ -198,11 +195,9 @@ def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context
# else:
# contexts = self._mlmd_store.get_contexts_by_type(ctx_type_name)

if not contexts and ctx_type_name not in [
t.name for t in self._mlmd_store.get_context_types()
]:
if not contexts:
msg = f"Context type {ctx_type_name} does not exist"
raise TypeNotFoundException(msg)
raise StoreException(msg)

return contexts

Expand Down Expand Up @@ -314,7 +309,9 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact:

return None

def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifact]:
def get_artifacts(
self, art_type_name: str, options: ListOptions
) -> Sequence[Artifact]:
"""Get artifacts from the store.
Args:
Expand All @@ -323,11 +320,6 @@ def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifa
Returns:
Artifacts.
Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
try:
artifacts = self._mlmd_store.get_artifacts(options)
Expand All @@ -339,10 +331,8 @@ def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifa
raise ServerException(msg) from e

artifacts = self._filter_type(art_type_name, artifacts)
if not artifacts and art_type_name not in [
t.name for t in self._mlmd_store.get_artifact_types()
]:
if not artifacts:
msg = f"Artifact type {art_type_name} does not exist"
raise TypeNotFoundException(msg)
raise StoreException(msg)

return artifacts
23 changes: 0 additions & 23 deletions clients/python/tests/store/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
TypeNotFoundException,
)
from model_registry.store import MLMDStore
from model_registry.types.options import MLMDListOptions


@pytest.fixture()
Expand Down Expand Up @@ -54,28 +53,6 @@ def test_get_undefined_context_type_id(plain_wrapper: MLMDStore):
plain_wrapper.get_type_id(Context, "undefined")


@pytest.mark.usefixtures("artifact")
def test_get_no_artifacts(plain_wrapper: MLMDStore):
arts = plain_wrapper.get_artifacts("test_artifact", MLMDListOptions())
assert arts == []


def test_get_undefined_artifacts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_artifacts("undefined", MLMDListOptions())


@pytest.mark.usefixtures("context")
def test_get_no_contexts(plain_wrapper: MLMDStore):
ctxs = plain_wrapper.get_contexts("test_context", MLMDListOptions())
assert ctxs == []


def test_get_undefined_contexts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_contexts("undefined", MLMDListOptions())


def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact):
artifact.properties["null"].int_value = 0

Expand Down
5 changes: 0 additions & 5 deletions internal/converter/generated/mlmd_openapi_converter.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 0 additions & 10 deletions internal/converter/generated/openapi_converter.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 0 additions & 18 deletions internal/converter/mlmd_converter_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,24 +570,6 @@ func TestMapNameFromOwned(t *testing.T) {
assertion.Nil(name)
}

func TestMapRegisteredModelIdFromOwned(t *testing.T) {
assertion := setup(t)

result, err := MapRegisteredModelIdFromOwned(of("prefix:name"))
assertion.Nil(err)
assertion.Equal("prefix", result)

_, err = MapRegisteredModelIdFromOwned(of("name"))
assertion.NotNil(err)

_, err = MapRegisteredModelIdFromOwned(of("prefix:name:postfix"))
assertion.NotNil(err)

result, err = MapRegisteredModelIdFromOwned(nil)
assertion.Nil(err)
assertion.Equal("", result)
}

func TestMapArtifactType(t *testing.T) {
assertion := setup(t)

Expand Down
1 change: 0 additions & 1 deletion internal/converter/mlmd_openapi_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ type MLMDToOpenAPIConverter interface {
ConvertRegisteredModel(source *proto.Context) (*openapi.RegisteredModel, error)

// goverter:map Name | MapNameFromOwned
// goverter:map Name RegisteredModelId | MapRegisteredModelIdFromOwned
// goverter:map Properties Description | MapDescription
// goverter:map Properties State | MapModelVersionState
// goverter:map Properties Author | MapPropertyAuthor
Expand Down
12 changes: 0 additions & 12 deletions internal/converter/mlmd_openapi_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,6 @@ func MapPropertyAuthor(properties map[string]*proto.Value) *string {
return MapStringProperty(properties, "author")
}

func MapRegisteredModelIdFromOwned(source *string) (string, error) {
if source == nil {
return "", nil
}

exploded := strings.Split(*source, ":")
if len(exploded) != 2 {
return "", fmt.Errorf("wrong owned format")
}
return exploded[0], nil
}

// ARTIFACT

func MapArtifactType(source *proto.Artifact) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/converter/openapi_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type OpenAPIConverter interface {
// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch
ConvertModelVersionCreate(source *openapi.ModelVersionCreate) (*openapi.ModelVersion, error)

// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Name RegisteredModelId
// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Name
ConvertModelVersionUpdate(source *openapi.ModelVersionUpdate) (*openapi.ModelVersion, error)

// goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch ArtifactType
Expand Down
Loading

0 comments on commit 7ee172e

Please sign in to comment.