From 65613f95ee8a3d7f1f06e7df177ec2003adc82fc Mon Sep 17 00:00:00 2001 From: Isabella Basso Date: Tue, 16 Apr 2024 09:52:52 -0300 Subject: [PATCH 1/2] Py: Fix misleading errors on missing list results (#65) * py: fix type annotations to return concrete types Signed-off-by: Isabella Basso do Amaral * py: fix missing type check on empty list results Signed-off-by: Isabella Basso do Amaral --------- Signed-off-by: Isabella Basso do Amaral --- clients/python/src/model_registry/core.py | 9 +++--- .../src/model_registry/store/wrapper.py | 32 ++++++++++++------- clients/python/tests/store/test_wrapper.py | 23 +++++++++++++ 3 files changed, 48 insertions(+), 16 deletions(-) diff --git a/clients/python/src/model_registry/core.py b/clients/python/src/model_registry/core.py index 270d89b8..9e57da09 100644 --- a/clients/python/src/model_registry/core.py +++ b/clients/python/src/model_registry/core.py @@ -1,7 +1,6 @@ """Client for the model registry.""" -from __future__ import annotations -from collections.abc import Sequence +from __future__ import annotations from ml_metadata.proto import MetadataStoreClientConfig @@ -130,7 +129,7 @@ def get_registered_model_by_params( def get_registered_models( self, options: ListOptions | None = None - ) -> Sequence[RegisteredModel]: + ) -> list[RegisteredModel]: """Fetch registered models. Args: @@ -194,7 +193,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None: def get_model_versions( self, registered_model_id: str, options: ListOptions | None = None - ) -> Sequence[ModelVersion]: + ) -> list[ModelVersion]: """Fetch model versions by registered model ID. Args: @@ -344,7 +343,7 @@ def get_model_artifacts( self, model_version_id: str | None = None, options: ListOptions | None = None, - ) -> Sequence[ModelArtifact]: + ) -> list[ModelArtifact]: """Fetches model artifacts. Args: diff --git a/clients/python/src/model_registry/store/wrapper.py b/clients/python/src/model_registry/store/wrapper.py index 57c56dd7..750d0642 100644 --- a/clients/python/src/model_registry/store/wrapper.py +++ b/clients/python/src/model_registry/store/wrapper.py @@ -124,7 +124,7 @@ def put_context(self, context: Context) -> int: def _filter_type( self, type_name: str, protos: Sequence[ProtoType] - ) -> Sequence[ProtoType]: + ) -> list[ProtoType]: return [proto for proto in protos if proto.type == type_name] def get_context( @@ -168,9 +168,7 @@ def get_context( return None - def get_contexts( - self, ctx_type_name: str, options: ListOptions - ) -> Sequence[Context]: + def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context]: """Get contexts from the store. Args: @@ -179,6 +177,11 @@ def get_contexts( Returns: Contexts. + + Raises: + TypeNotFoundException: If the type doesn't exist. + ServerException: If there was an error getting the type. + StoreException: Invalid arguments. """ # TODO: should we make options optional? # if options is not None: @@ -195,9 +198,11 @@ def get_contexts( # else: # contexts = self._mlmd_store.get_contexts_by_type(ctx_type_name) - if not contexts: + if not contexts and ctx_type_name not in [ + t.name for t in self._mlmd_store.get_context_types() + ]: msg = f"Context type {ctx_type_name} does not exist" - raise StoreException(msg) + raise TypeNotFoundException(msg) return contexts @@ -309,9 +314,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact: return None - def get_artifacts( - self, art_type_name: str, options: ListOptions - ) -> Sequence[Artifact]: + def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifact]: """Get artifacts from the store. Args: @@ -320,6 +323,11 @@ def get_artifacts( Returns: Artifacts. + + Raises: + TypeNotFoundException: If the type doesn't exist. + ServerException: If there was an error getting the type. + StoreException: Invalid arguments. """ try: artifacts = self._mlmd_store.get_artifacts(options) @@ -331,8 +339,10 @@ def get_artifacts( raise ServerException(msg) from e artifacts = self._filter_type(art_type_name, artifacts) - if not artifacts: + if not artifacts and art_type_name not in [ + t.name for t in self._mlmd_store.get_artifact_types() + ]: msg = f"Artifact type {art_type_name} does not exist" - raise StoreException(msg) + raise TypeNotFoundException(msg) return artifacts diff --git a/clients/python/tests/store/test_wrapper.py b/clients/python/tests/store/test_wrapper.py index c2d379de..c6f4dbe2 100644 --- a/clients/python/tests/store/test_wrapper.py +++ b/clients/python/tests/store/test_wrapper.py @@ -17,6 +17,7 @@ TypeNotFoundException, ) from model_registry.store import MLMDStore +from model_registry.types.options import MLMDListOptions @pytest.fixture() @@ -53,6 +54,28 @@ def test_get_undefined_context_type_id(plain_wrapper: MLMDStore): plain_wrapper.get_type_id(Context, "undefined") +@pytest.mark.usefixtures("artifact") +def test_get_no_artifacts(plain_wrapper: MLMDStore): + arts = plain_wrapper.get_artifacts("test_artifact", MLMDListOptions()) + assert arts == [] + + +def test_get_undefined_artifacts(plain_wrapper: MLMDStore): + with pytest.raises(TypeNotFoundException): + plain_wrapper.get_artifacts("undefined", MLMDListOptions()) + + +@pytest.mark.usefixtures("context") +def test_get_no_contexts(plain_wrapper: MLMDStore): + ctxs = plain_wrapper.get_contexts("test_context", MLMDListOptions()) + assert ctxs == [] + + +def test_get_undefined_contexts(plain_wrapper: MLMDStore): + with pytest.raises(TypeNotFoundException): + plain_wrapper.get_contexts("undefined", MLMDListOptions()) + + def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact): artifact.properties["null"].int_value = 0 From 734eb9e91c34c99a6dfd65de5ef39505f1bbb02d Mon Sep 17 00:00:00 2001 From: Matteo Mortari Date: Tue, 16 Apr 2024 20:12:52 +0200 Subject: [PATCH 2/2] fix: OpenAPI ModelVersion shall contain registeredModelId property (#61) need to adapt property definition in OpenAPI to accomodate openapi-codegen result; according to contract (as also visible in swagger) the ModelVersion is to contain property: registeredModelId Signed-off-by: Matteo Mortari --- api/openapi/model-registry.yaml | 9 +- .../generated/mlmd_openapi_converter.gen.go | 5 + .../generated/openapi_converter.gen.go | 10 ++ .../converter/mlmd_converter_util_test.go | 18 ++ internal/converter/mlmd_openapi_converter.go | 1 + .../converter/mlmd_openapi_converter_util.go | 12 ++ internal/converter/openapi_converter.go | 2 +- internal/server/openapi/type_asserts.go | 157 +++++++++--------- pkg/core/core_test.go | 2 + pkg/openapi/model_model_version.go | 30 +++- pkg/openapi/model_model_version_create.go | 55 +++--- test/robot/UserStory.robot | 1 + 12 files changed, 195 insertions(+), 107 deletions(-) diff --git a/api/openapi/model-registry.yaml b/api/openapi/model-registry.yaml index a7a22c71..f6503839 100644 --- a/api/openapi/model-registry.yaml +++ b/api/openapi/model-registry.yaml @@ -1047,10 +1047,11 @@ components: allOf: - $ref: "#/components/schemas/BaseResourceCreate" - $ref: "#/components/schemas/ModelVersionUpdate" - properties: - registeredModelId: - description: ID of the `RegisteredModel` to which this version belongs. - type: string + - type: object + properties: + registeredModelId: + description: ID of the `RegisteredModel` to which this version belongs. + type: string ModelVersionUpdate: description: Represents a ModelVersion belonging to a RegisteredModel. allOf: diff --git a/internal/converter/generated/mlmd_openapi_converter.gen.go b/internal/converter/generated/mlmd_openapi_converter.gen.go index 3152711d..00663e4c 100755 --- a/internal/converter/generated/mlmd_openapi_converter.gen.go +++ b/internal/converter/generated/mlmd_openapi_converter.gen.go @@ -136,6 +136,11 @@ func (c *MLMDToOpenAPIConverterImpl) ConvertModelVersion(source *proto.Context) openapiModelVersion.Name = converter.MapNameFromOwned((*source).Name) openapiModelVersion.State = converter.MapModelVersionState((*source).Properties) openapiModelVersion.Author = converter.MapPropertyAuthor((*source).Properties) + xstring2, err := converter.MapRegisteredModelIdFromOwned((*source).Name) + if err != nil { + return nil, fmt.Errorf("error setting field RegisteredModelId: %w", err) + } + openapiModelVersion.RegisteredModelId = xstring2 openapiModelVersion.Id = converter.Int64ToString((*source).Id) openapiModelVersion.CreateTimeSinceEpoch = converter.Int64ToString((*source).CreateTimeSinceEpoch) openapiModelVersion.LastUpdateTimeSinceEpoch = converter.Int64ToString((*source).LastUpdateTimeSinceEpoch) diff --git a/internal/converter/generated/openapi_converter.gen.go b/internal/converter/generated/openapi_converter.gen.go index f88ef4ac..067315c6 100755 --- a/internal/converter/generated/openapi_converter.gen.go +++ b/internal/converter/generated/openapi_converter.gen.go @@ -302,6 +302,7 @@ func (c *OpenAPIConverterImpl) ConvertModelVersionCreate(source *openapi.ModelVe pString4 = &xstring4 } openapiModelVersion.Author = pString4 + openapiModelVersion.RegisteredModelId = (*source).RegisteredModelId pOpenapiModelVersion = &openapiModelVersion } return pOpenapiModelVersion, nil @@ -636,6 +637,15 @@ func (c *OpenAPIConverterImpl) OverrideNotEditableForModelVersion(source convert pString2 = &xstring } openapiModelVersion.Name = pString2 + var pString3 *string + if source.Existing != nil { + pString3 = &source.Existing.RegisteredModelId + } + var xstring2 string + if pString3 != nil { + xstring2 = *pString3 + } + openapiModelVersion.RegisteredModelId = xstring2 return openapiModelVersion, nil } func (c *OpenAPIConverterImpl) OverrideNotEditableForRegisteredModel(source converter.OpenapiUpdateWrapper[openapi.RegisteredModel]) (openapi.RegisteredModel, error) { diff --git a/internal/converter/mlmd_converter_util_test.go b/internal/converter/mlmd_converter_util_test.go index 7fad4c33..fb40ea1e 100644 --- a/internal/converter/mlmd_converter_util_test.go +++ b/internal/converter/mlmd_converter_util_test.go @@ -570,6 +570,24 @@ func TestMapNameFromOwned(t *testing.T) { assertion.Nil(name) } +func TestMapRegisteredModelIdFromOwned(t *testing.T) { + assertion := setup(t) + + result, err := MapRegisteredModelIdFromOwned(of("prefix:name")) + assertion.Nil(err) + assertion.Equal("prefix", result) + + _, err = MapRegisteredModelIdFromOwned(of("name")) + assertion.NotNil(err) + + _, err = MapRegisteredModelIdFromOwned(of("prefix:name:postfix")) + assertion.NotNil(err) + + result, err = MapRegisteredModelIdFromOwned(nil) + assertion.Nil(err) + assertion.Equal("", result) +} + func TestMapArtifactType(t *testing.T) { assertion := setup(t) diff --git a/internal/converter/mlmd_openapi_converter.go b/internal/converter/mlmd_openapi_converter.go index 92bb4516..dcd55649 100644 --- a/internal/converter/mlmd_openapi_converter.go +++ b/internal/converter/mlmd_openapi_converter.go @@ -19,6 +19,7 @@ type MLMDToOpenAPIConverter interface { ConvertRegisteredModel(source *proto.Context) (*openapi.RegisteredModel, error) // goverter:map Name | MapNameFromOwned + // goverter:map Name RegisteredModelId | MapRegisteredModelIdFromOwned // goverter:map Properties Description | MapDescription // goverter:map Properties State | MapModelVersionState // goverter:map Properties Author | MapPropertyAuthor diff --git a/internal/converter/mlmd_openapi_converter_util.go b/internal/converter/mlmd_openapi_converter_util.go index aa326da7..99a9ba5e 100644 --- a/internal/converter/mlmd_openapi_converter_util.go +++ b/internal/converter/mlmd_openapi_converter_util.go @@ -108,6 +108,18 @@ func MapPropertyAuthor(properties map[string]*proto.Value) *string { return MapStringProperty(properties, "author") } +func MapRegisteredModelIdFromOwned(source *string) (string, error) { + if source == nil { + return "", nil + } + + exploded := strings.Split(*source, ":") + if len(exploded) != 2 { + return "", fmt.Errorf("wrong owned format") + } + return exploded[0], nil +} + // ARTIFACT func MapArtifactType(source *proto.Artifact) (string, error) { diff --git a/internal/converter/openapi_converter.go b/internal/converter/openapi_converter.go index ac44d7eb..e5204bc0 100644 --- a/internal/converter/openapi_converter.go +++ b/internal/converter/openapi_converter.go @@ -22,7 +22,7 @@ type OpenAPIConverter interface { // goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch ConvertModelVersionCreate(source *openapi.ModelVersionCreate) (*openapi.ModelVersion, error) - // goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Name + // goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Name RegisteredModelId ConvertModelVersionUpdate(source *openapi.ModelVersionUpdate) (*openapi.ModelVersion, error) // goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch ArtifactType diff --git a/internal/server/openapi/type_asserts.go b/internal/server/openapi/type_asserts.go index fa32e4c3..10a2f2e4 100644 --- a/internal/server/openapi/type_asserts.go +++ b/internal/server/openapi/type_asserts.go @@ -62,23 +62,23 @@ func AssertArtifactStateConstraints(obj model.ArtifactState) error { return nil } -// AssertBaseArtifactCreateRequired checks if the required fields are not zero-ed -func AssertBaseArtifactCreateRequired(obj model.BaseArtifactCreate) error { +// AssertBaseArtifactRequired checks if the required fields are not zero-ed +func AssertBaseArtifactRequired(obj model.BaseArtifact) error { return nil } -// AssertBaseArtifactCreateConstraints checks if the values respects the defined constraints -func AssertBaseArtifactCreateConstraints(obj model.BaseArtifactCreate) error { +// AssertBaseArtifactConstraints checks if the values respects the defined constraints +func AssertBaseArtifactConstraints(obj model.BaseArtifact) error { return nil } -// AssertBaseArtifactRequired checks if the required fields are not zero-ed -func AssertBaseArtifactRequired(obj model.BaseArtifact) error { +// AssertBaseArtifactCreateRequired checks if the required fields are not zero-ed +func AssertBaseArtifactCreateRequired(obj model.BaseArtifactCreate) error { return nil } -// AssertBaseArtifactConstraints checks if the values respects the defined constraints -func AssertBaseArtifactConstraints(obj model.BaseArtifact) error { +// AssertBaseArtifactCreateConstraints checks if the values respects the defined constraints +func AssertBaseArtifactCreateConstraints(obj model.BaseArtifactCreate) error { return nil } @@ -92,23 +92,23 @@ func AssertBaseArtifactUpdateConstraints(obj model.BaseArtifactUpdate) error { return nil } -// AssertBaseExecutionCreateRequired checks if the required fields are not zero-ed -func AssertBaseExecutionCreateRequired(obj model.BaseExecutionCreate) error { +// AssertBaseExecutionRequired checks if the required fields are not zero-ed +func AssertBaseExecutionRequired(obj model.BaseExecution) error { return nil } -// AssertBaseExecutionCreateConstraints checks if the values respects the defined constraints -func AssertBaseExecutionCreateConstraints(obj model.BaseExecutionCreate) error { +// AssertBaseExecutionConstraints checks if the values respects the defined constraints +func AssertBaseExecutionConstraints(obj model.BaseExecution) error { return nil } -// AssertBaseExecutionRequired checks if the required fields are not zero-ed -func AssertBaseExecutionRequired(obj model.BaseExecution) error { +// AssertBaseExecutionCreateRequired checks if the required fields are not zero-ed +func AssertBaseExecutionCreateRequired(obj model.BaseExecutionCreate) error { return nil } -// AssertBaseExecutionConstraints checks if the values respects the defined constraints -func AssertBaseExecutionConstraints(obj model.BaseExecution) error { +// AssertBaseExecutionCreateConstraints checks if the values respects the defined constraints +func AssertBaseExecutionCreateConstraints(obj model.BaseExecutionCreate) error { return nil } @@ -122,23 +122,23 @@ func AssertBaseExecutionUpdateConstraints(obj model.BaseExecutionUpdate) error { return nil } -// AssertBaseResourceCreateRequired checks if the required fields are not zero-ed -func AssertBaseResourceCreateRequired(obj model.BaseResourceCreate) error { +// AssertBaseResourceRequired checks if the required fields are not zero-ed +func AssertBaseResourceRequired(obj model.BaseResource) error { return nil } -// AssertBaseResourceCreateConstraints checks if the values respects the defined constraints -func AssertBaseResourceCreateConstraints(obj model.BaseResourceCreate) error { +// AssertBaseResourceConstraints checks if the values respects the defined constraints +func AssertBaseResourceConstraints(obj model.BaseResource) error { return nil } -// AssertBaseResourceRequired checks if the required fields are not zero-ed -func AssertBaseResourceRequired(obj model.BaseResource) error { +// AssertBaseResourceCreateRequired checks if the required fields are not zero-ed +func AssertBaseResourceCreateRequired(obj model.BaseResourceCreate) error { return nil } -// AssertBaseResourceConstraints checks if the values respects the defined constraints -func AssertBaseResourceConstraints(obj model.BaseResource) error { +// AssertBaseResourceCreateConstraints checks if the values respects the defined constraints +func AssertBaseResourceCreateConstraints(obj model.BaseResourceCreate) error { return nil } @@ -222,8 +222,8 @@ func AssertExecutionStateConstraints(obj model.ExecutionState) error { return nil } -// AssertInferenceServiceCreateRequired checks if the required fields are not zero-ed -func AssertInferenceServiceCreateRequired(obj model.InferenceServiceCreate) error { +// AssertInferenceServiceRequired checks if the required fields are not zero-ed +func AssertInferenceServiceRequired(obj model.InferenceService) error { elements := map[string]interface{}{ "registeredModelId": obj.RegisteredModelId, "servingEnvironmentId": obj.ServingEnvironmentId, @@ -237,13 +237,13 @@ func AssertInferenceServiceCreateRequired(obj model.InferenceServiceCreate) erro return nil } -// AssertInferenceServiceCreateConstraints checks if the values respects the defined constraints -func AssertInferenceServiceCreateConstraints(obj model.InferenceServiceCreate) error { +// AssertInferenceServiceConstraints checks if the values respects the defined constraints +func AssertInferenceServiceConstraints(obj model.InferenceService) error { return nil } -// AssertInferenceServiceRequired checks if the required fields are not zero-ed -func AssertInferenceServiceRequired(obj model.InferenceService) error { +// AssertInferenceServiceCreateRequired checks if the required fields are not zero-ed +func AssertInferenceServiceCreateRequired(obj model.InferenceServiceCreate) error { elements := map[string]interface{}{ "registeredModelId": obj.RegisteredModelId, "servingEnvironmentId": obj.ServingEnvironmentId, @@ -257,8 +257,8 @@ func AssertInferenceServiceRequired(obj model.InferenceService) error { return nil } -// AssertInferenceServiceConstraints checks if the values respects the defined constraints -func AssertInferenceServiceConstraints(obj model.InferenceService) error { +// AssertInferenceServiceCreateConstraints checks if the values respects the defined constraints +func AssertInferenceServiceCreateConstraints(obj model.InferenceServiceCreate) error { return nil } @@ -456,16 +456,6 @@ func AssertMetadataValueConstraints(obj model.MetadataValue) error { return nil } -// AssertModelArtifactCreateRequired checks if the required fields are not zero-ed -func AssertModelArtifactCreateRequired(obj model.ModelArtifactCreate) error { - return nil -} - -// AssertModelArtifactCreateConstraints checks if the values respects the defined constraints -func AssertModelArtifactCreateConstraints(obj model.ModelArtifactCreate) error { - return nil -} - // AssertModelArtifactRequired checks if the required fields are not zero-ed func AssertModelArtifactRequired(obj model.ModelArtifact) error { elements := map[string]interface{}{ @@ -485,6 +475,16 @@ func AssertModelArtifactConstraints(obj model.ModelArtifact) error { return nil } +// AssertModelArtifactCreateRequired checks if the required fields are not zero-ed +func AssertModelArtifactCreateRequired(obj model.ModelArtifactCreate) error { + return nil +} + +// AssertModelArtifactCreateConstraints checks if the values respects the defined constraints +func AssertModelArtifactCreateConstraints(obj model.ModelArtifactCreate) error { + return nil +} + // AssertModelArtifactListRequired checks if the required fields are not zero-ed func AssertModelArtifactListRequired(obj model.ModelArtifactList) error { elements := map[string]interface{}{ @@ -521,8 +521,8 @@ func AssertModelArtifactUpdateConstraints(obj model.ModelArtifactUpdate) error { return nil } -// AssertModelVersionCreateRequired checks if the required fields are not zero-ed -func AssertModelVersionCreateRequired(obj model.ModelVersionCreate) error { +// AssertModelVersionRequired checks if the required fields are not zero-ed +func AssertModelVersionRequired(obj model.ModelVersion) error { elements := map[string]interface{}{ "registeredModelId": obj.RegisteredModelId, } @@ -535,18 +535,27 @@ func AssertModelVersionCreateRequired(obj model.ModelVersionCreate) error { return nil } -// AssertModelVersionCreateConstraints checks if the values respects the defined constraints -func AssertModelVersionCreateConstraints(obj model.ModelVersionCreate) error { +// AssertModelVersionConstraints checks if the values respects the defined constraints +func AssertModelVersionConstraints(obj model.ModelVersion) error { return nil } -// AssertModelVersionRequired checks if the required fields are not zero-ed -func AssertModelVersionRequired(obj model.ModelVersion) error { +// AssertModelVersionCreateRequired checks if the required fields are not zero-ed +func AssertModelVersionCreateRequired(obj model.ModelVersionCreate) error { + elements := map[string]interface{}{ + "registeredModelId": obj.RegisteredModelId, + } + for name, el := range elements { + if isZero := IsZeroValue(el); isZero { + return &RequiredError{Field: name} + } + } + return nil } -// AssertModelVersionConstraints checks if the values respects the defined constraints -func AssertModelVersionConstraints(obj model.ModelVersion) error { +// AssertModelVersionCreateConstraints checks if the values respects the defined constraints +func AssertModelVersionCreateConstraints(obj model.ModelVersionCreate) error { return nil } @@ -606,23 +615,23 @@ func AssertOrderByFieldConstraints(obj model.OrderByField) error { return nil } -// AssertRegisteredModelCreateRequired checks if the required fields are not zero-ed -func AssertRegisteredModelCreateRequired(obj model.RegisteredModelCreate) error { +// AssertRegisteredModelRequired checks if the required fields are not zero-ed +func AssertRegisteredModelRequired(obj model.RegisteredModel) error { return nil } -// AssertRegisteredModelCreateConstraints checks if the values respects the defined constraints -func AssertRegisteredModelCreateConstraints(obj model.RegisteredModelCreate) error { +// AssertRegisteredModelConstraints checks if the values respects the defined constraints +func AssertRegisteredModelConstraints(obj model.RegisteredModel) error { return nil } -// AssertRegisteredModelRequired checks if the required fields are not zero-ed -func AssertRegisteredModelRequired(obj model.RegisteredModel) error { +// AssertRegisteredModelCreateRequired checks if the required fields are not zero-ed +func AssertRegisteredModelCreateRequired(obj model.RegisteredModelCreate) error { return nil } -// AssertRegisteredModelConstraints checks if the values respects the defined constraints -func AssertRegisteredModelConstraints(obj model.RegisteredModel) error { +// AssertRegisteredModelCreateConstraints checks if the values respects the defined constraints +func AssertRegisteredModelCreateConstraints(obj model.RegisteredModelCreate) error { return nil } @@ -672,8 +681,8 @@ func AssertRegisteredModelUpdateConstraints(obj model.RegisteredModelUpdate) err return nil } -// AssertServeModelCreateRequired checks if the required fields are not zero-ed -func AssertServeModelCreateRequired(obj model.ServeModelCreate) error { +// AssertServeModelRequired checks if the required fields are not zero-ed +func AssertServeModelRequired(obj model.ServeModel) error { elements := map[string]interface{}{ "modelVersionId": obj.ModelVersionId, } @@ -686,13 +695,13 @@ func AssertServeModelCreateRequired(obj model.ServeModelCreate) error { return nil } -// AssertServeModelCreateConstraints checks if the values respects the defined constraints -func AssertServeModelCreateConstraints(obj model.ServeModelCreate) error { +// AssertServeModelConstraints checks if the values respects the defined constraints +func AssertServeModelConstraints(obj model.ServeModel) error { return nil } -// AssertServeModelRequired checks if the required fields are not zero-ed -func AssertServeModelRequired(obj model.ServeModel) error { +// AssertServeModelCreateRequired checks if the required fields are not zero-ed +func AssertServeModelCreateRequired(obj model.ServeModelCreate) error { elements := map[string]interface{}{ "modelVersionId": obj.ModelVersionId, } @@ -705,8 +714,8 @@ func AssertServeModelRequired(obj model.ServeModel) error { return nil } -// AssertServeModelConstraints checks if the values respects the defined constraints -func AssertServeModelConstraints(obj model.ServeModel) error { +// AssertServeModelCreateConstraints checks if the values respects the defined constraints +func AssertServeModelCreateConstraints(obj model.ServeModelCreate) error { return nil } @@ -746,23 +755,23 @@ func AssertServeModelUpdateConstraints(obj model.ServeModelUpdate) error { return nil } -// AssertServingEnvironmentCreateRequired checks if the required fields are not zero-ed -func AssertServingEnvironmentCreateRequired(obj model.ServingEnvironmentCreate) error { +// AssertServingEnvironmentRequired checks if the required fields are not zero-ed +func AssertServingEnvironmentRequired(obj model.ServingEnvironment) error { return nil } -// AssertServingEnvironmentCreateConstraints checks if the values respects the defined constraints -func AssertServingEnvironmentCreateConstraints(obj model.ServingEnvironmentCreate) error { +// AssertServingEnvironmentConstraints checks if the values respects the defined constraints +func AssertServingEnvironmentConstraints(obj model.ServingEnvironment) error { return nil } -// AssertServingEnvironmentRequired checks if the required fields are not zero-ed -func AssertServingEnvironmentRequired(obj model.ServingEnvironment) error { +// AssertServingEnvironmentCreateRequired checks if the required fields are not zero-ed +func AssertServingEnvironmentCreateRequired(obj model.ServingEnvironmentCreate) error { return nil } -// AssertServingEnvironmentConstraints checks if the values respects the defined constraints -func AssertServingEnvironmentConstraints(obj model.ServingEnvironment) error { +// AssertServingEnvironmentCreateConstraints checks if the values respects the defined constraints +func AssertServingEnvironmentCreateConstraints(obj model.ServingEnvironmentCreate) error { return nil } diff --git a/pkg/core/core_test.go b/pkg/core/core_test.go index 7948f7ab..71a9770a 100644 --- a/pkg/core/core_test.go +++ b/pkg/core/core_test.go @@ -970,6 +970,7 @@ func (suite *CoreTestSuite) TestCreateModelVersion() { createdVersion, err := service.UpsertModelVersion(modelVersion, ®isteredModelId) suite.Nilf(err, "error creating new model version for %d", registeredModelId) + suite.Equal((*createdVersion).RegisteredModelId, registeredModelId, "RegisteredModelId should match the actual owner") suite.NotNilf(createdVersion.Id, "created model version should not have nil Id") @@ -1045,6 +1046,7 @@ func (suite *CoreTestSuite) TestUpdateModelVersion() { updatedVersion, err := service.UpsertModelVersion(createdVersion, ®isteredModelId) suite.Nilf(err, "error updating new model version for %s: %v", registeredModelId, err) + suite.Equal((*updatedVersion).RegisteredModelId, registeredModelId, "RegisteredModelId should match the actual owner") updateVersionId, _ := converter.StringToInt64(updatedVersion.Id) suite.Equal(*createdVersionId, *updateVersionId, "created and updated model version should have same id") diff --git a/pkg/openapi/model_model_version.go b/pkg/openapi/model_model_version.go index 236f84f9..c0b89f8c 100644 --- a/pkg/openapi/model_model_version.go +++ b/pkg/openapi/model_model_version.go @@ -30,6 +30,8 @@ type ModelVersion struct { State *ModelVersionState `json:"state,omitempty"` // Name of the author. Author *string `json:"author,omitempty"` + // ID of the `RegisteredModel` to which this version belongs. + RegisteredModelId string `json:"registeredModelId"` // Output only. The unique server generated id of the resource. Id *string `json:"id,omitempty"` // Output only. Create time of the resource in millisecond since epoch. @@ -42,10 +44,11 @@ type ModelVersion struct { // This constructor will assign default values to properties that have it defined, // and makes sure properties required by API are set, but the set of arguments // will change when the set of required properties is changed -func NewModelVersion() *ModelVersion { +func NewModelVersion(registeredModelId string) *ModelVersion { this := ModelVersion{} var state ModelVersionState = MODELVERSIONSTATE_LIVE this.State = &state + this.RegisteredModelId = registeredModelId return &this } @@ -251,6 +254,30 @@ func (o *ModelVersion) SetAuthor(v string) { o.Author = &v } +// GetRegisteredModelId returns the RegisteredModelId field value +func (o *ModelVersion) GetRegisteredModelId() string { + if o == nil { + var ret string + return ret + } + + return o.RegisteredModelId +} + +// GetRegisteredModelIdOk returns a tuple with the RegisteredModelId field value +// and a boolean to check if the value has been set. +func (o *ModelVersion) GetRegisteredModelIdOk() (*string, bool) { + if o == nil { + return nil, false + } + return &o.RegisteredModelId, true +} + +// SetRegisteredModelId sets field value +func (o *ModelVersion) SetRegisteredModelId(v string) { + o.RegisteredModelId = v +} + // GetId returns the Id field value if set, zero value otherwise. func (o *ModelVersion) GetId() string { if o == nil || IsNil(o.Id) { @@ -375,6 +402,7 @@ func (o ModelVersion) ToMap() (map[string]interface{}, error) { if !IsNil(o.Author) { toSerialize["author"] = o.Author } + toSerialize["registeredModelId"] = o.RegisteredModelId if !IsNil(o.Id) { toSerialize["id"] = o.Id } diff --git a/pkg/openapi/model_model_version_create.go b/pkg/openapi/model_model_version_create.go index e99675d3..e964055a 100644 --- a/pkg/openapi/model_model_version_create.go +++ b/pkg/openapi/model_model_version_create.go @@ -19,8 +19,6 @@ var _ MappedNullable = &ModelVersionCreate{} // ModelVersionCreate Represents a ModelVersion belonging to a RegisteredModel. type ModelVersionCreate struct { - // ID of the `RegisteredModel` to which this version belongs. - RegisteredModelId string `json:"registeredModelId"` // User provided custom properties which are not defined by its type. CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` // An optional description about the resource. @@ -32,6 +30,8 @@ type ModelVersionCreate struct { State *ModelVersionState `json:"state,omitempty"` // Name of the author. Author *string `json:"author,omitempty"` + // ID of the `RegisteredModel` to which this version belongs. + RegisteredModelId string `json:"registeredModelId"` } // NewModelVersionCreate instantiates a new ModelVersionCreate object @@ -42,6 +42,7 @@ func NewModelVersionCreate(registeredModelId string) *ModelVersionCreate { this := ModelVersionCreate{} var state ModelVersionState = MODELVERSIONSTATE_LIVE this.State = &state + this.RegisteredModelId = registeredModelId return &this } @@ -55,30 +56,6 @@ func NewModelVersionCreateWithDefaults() *ModelVersionCreate { return &this } -// GetRegisteredModelId returns the RegisteredModelId field value -func (o *ModelVersionCreate) GetRegisteredModelId() string { - if o == nil { - var ret string - return ret - } - - return o.RegisteredModelId -} - -// GetRegisteredModelIdOk returns a tuple with the RegisteredModelId field value -// and a boolean to check if the value has been set. -func (o *ModelVersionCreate) GetRegisteredModelIdOk() (*string, bool) { - if o == nil { - return nil, false - } - return &o.RegisteredModelId, true -} - -// SetRegisteredModelId sets field value -func (o *ModelVersionCreate) SetRegisteredModelId(v string) { - o.RegisteredModelId = v -} - // GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. func (o *ModelVersionCreate) GetCustomProperties() map[string]MetadataValue { if o == nil || IsNil(o.CustomProperties) { @@ -271,6 +248,30 @@ func (o *ModelVersionCreate) SetAuthor(v string) { o.Author = &v } +// GetRegisteredModelId returns the RegisteredModelId field value +func (o *ModelVersionCreate) GetRegisteredModelId() string { + if o == nil { + var ret string + return ret + } + + return o.RegisteredModelId +} + +// GetRegisteredModelIdOk returns a tuple with the RegisteredModelId field value +// and a boolean to check if the value has been set. +func (o *ModelVersionCreate) GetRegisteredModelIdOk() (*string, bool) { + if o == nil { + return nil, false + } + return &o.RegisteredModelId, true +} + +// SetRegisteredModelId sets field value +func (o *ModelVersionCreate) SetRegisteredModelId(v string) { + o.RegisteredModelId = v +} + func (o ModelVersionCreate) MarshalJSON() ([]byte, error) { toSerialize, err := o.ToMap() if err != nil { @@ -281,7 +282,6 @@ func (o ModelVersionCreate) MarshalJSON() ([]byte, error) { func (o ModelVersionCreate) ToMap() (map[string]interface{}, error) { toSerialize := map[string]interface{}{} - toSerialize["registeredModelId"] = o.RegisteredModelId if !IsNil(o.CustomProperties) { toSerialize["customProperties"] = o.CustomProperties } @@ -300,6 +300,7 @@ func (o ModelVersionCreate) ToMap() (map[string]interface{}, error) { if !IsNil(o.Author) { toSerialize["author"] = o.Author } + toSerialize["registeredModelId"] = o.RegisteredModelId return toSerialize, nil } diff --git a/test/robot/UserStory.robot b/test/robot/UserStory.robot index f6d43554..0acf173f 100644 --- a/test/robot/UserStory.robot +++ b/test/robot/UserStory.robot @@ -17,6 +17,7 @@ As a MLOps engineer I would like to store Model name And Should be equal ${r["name"]} ${name} ${r} Then I get ModelVersionByID id=${vId} And Should be equal ${r["name"]} v1 + And Should be equal ${r["registeredModelId"]} ${rId} ${r} Then I get ModelArtifactByID id=${aId} And Should be equal ${r["uri"]} s3://12345