diff --git a/clients/python/README.md b/clients/python/README.md index 0c174fc0..c1deb672 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -146,18 +146,23 @@ for version in registry.get_model_versions("my-model"): ... ``` -To customize sorting order or query limits you can also use +You can also use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order ```py -latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending().limit(20) +latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending() for version in latest_updates: ... ``` -You can use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order. +By default, all queries will be `ascending`, but this method is also available for explicitness. -> Note that the `limit()` method only limits the query size, not the actual loop boundaries -- even if your limit is 1 -> you will still get all the models, with one query each. +> Note: You can also set the `page_size()` that you want the Pager to use when invoking the Model Registry backend. +> When using it as an iterator, it will automatically manage pages for you. + +#### Implementation notes + +The pager will manage pages for you in order to prevent infinite looping. +Currently, the Model Registry backend treats model lists as a circular buffer, and **will not end iteration** for you. ## Development diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index 0cf2a7a7..5484a795 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from collections.abc import Mapping from pathlib import Path from typing import Any, TypeVar, Union, get_args from warnings import warn @@ -138,7 +139,7 @@ def register_model( author: str | None = None, owner: str | None = None, description: str | None = None, - metadata: dict[str, SupportedTypes] | None = None, + metadata: Mapping[str, SupportedTypes] | None = None, ) -> RegisteredModel: """Register a model. diff --git a/clients/python/src/model_registry/types/base.py b/clients/python/src/model_registry/types/base.py index bf1b8dd9..df1166e8 100644 --- a/clients/python/src/model_registry/types/base.py +++ b/clients/python/src/model_registry/types/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any, Union, get_args from pydantic import BaseModel, ConfigDict @@ -35,7 +35,7 @@ class BaseResourceModel(BaseModel, ABC): external_id: str | None = None create_time_since_epoch: str | None = None last_update_time_since_epoch: str | None = None - custom_properties: dict[str, SupportedTypes] | None = None + custom_properties: Mapping[str, SupportedTypes] | None = None @abstractmethod def create(self, **kwargs) -> Any: diff --git a/clients/python/src/model_registry/types/pager.py b/clients/python/src/model_registry/types/pager.py index 11dae8ab..a5e2e5e9 100644 --- a/clients/python/src/model_registry/types/pager.py +++ b/clients/python/src/model_registry/types/pager.py @@ -73,12 +73,15 @@ def order_by_id(self) -> Pager[T]: self.options.order_by = OrderByField.ID return self.restart() - def limit(self, limit: int) -> Pager[T]: - """Limit the number of items to return. + def page_size(self, n: int) -> Pager[T]: + """Set the page size for each request. This resets the pager. """ - self.options.limit = limit + if n < 1: + msg = f"Page size must be at least 1, got {n}" + raise ValueError(msg) + self.options.limit = n return self.restart() def ascending(self) -> Pager[T]: diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 002126da..bb958cf5 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -72,10 +72,10 @@ def test_register_existing_version(client: ModelRegistry): "model_format_version": "test_version", "version": "1.0.0", } - client.register_model(**params) + client.register_model(**params, metadata=None) with pytest.raises(StoreError): - client.register_model(**params) + client.register_model(**params, metadata=None) @pytest.mark.e2e @@ -124,8 +124,10 @@ async def test_update_logical_model_with_labels(client: ModelRegistry): ) assert rm.id mv = client.get_model_version(name, version) + assert mv assert mv.id ma = client.get_model_artifact(name, version) + assert ma assert ma.id rm_labels = { @@ -149,9 +151,15 @@ async def test_update_logical_model_with_labels(client: ModelRegistry): ma.custom_properties = ma_labels client.update(ma) - assert client.get_registered_model(name).custom_properties == rm_labels - assert client.get_model_version(name, version).custom_properties == mv_labels - assert client.get_model_artifact(name, version).custom_properties == ma_labels + rm = client.get_registered_model(name) + assert rm + assert rm.custom_properties == rm_labels + mv = client.get_model_version(name, version) + assert mv + assert mv.custom_properties == mv_labels + ma = client.get_model_artifact(name, version) + assert ma + assert ma.custom_properties == ma_labels @pytest.mark.e2e @@ -232,7 +240,7 @@ def test_get_registered_models(client: ModelRegistry): version="1.0.0", ) - rm_iter = client.get_registered_models().limit(10) + rm_iter = client.get_registered_models().page_size(10) i = 0 prev_tok = None changes = 0 @@ -315,6 +323,17 @@ def test_get_registered_models_order_by(client: ModelRegistry): assert i == models + # or if descending is explicitly set + i = 0 + for rm, by_update in zip( + rms, + client.get_registered_models().order_by_update_time().descending(), + ): + assert rm.id == by_update.id + i += 1 + + assert i == models + @pytest.mark.e2e def test_get_registered_models_and_reset(client: ModelRegistry): @@ -330,7 +349,7 @@ def test_get_registered_models_and_reset(client: ModelRegistry): version="1.0.0", ) - rm_iter = client.get_registered_models().limit(model_count - 1) + rm_iter = client.get_registered_models().page_size(model_count - 1) models = [] for rm in islice(rm_iter, page): models.append(rm) @@ -355,7 +374,7 @@ def test_get_model_versions(client: ModelRegistry): version=v, ) - mv_iter = client.get_model_versions(name).limit(10) + mv_iter = client.get_model_versions(name).page_size(10) i = 0 prev_tok = None changes = 0 @@ -430,6 +449,18 @@ def test_get_model_versions_order_by(client: ModelRegistry): assert mv.id == by_update.id i += 1 + assert i == models + + i = 0 + for mv, by_update in zip( + mvs, + client.get_model_versions(name).order_by_update_time().descending(), + ): + assert mv.id == by_update.id + i += 1 + + assert i == models + @pytest.mark.e2e def test_get_model_versions_and_reset(client: ModelRegistry): @@ -447,7 +478,7 @@ def test_get_model_versions_and_reset(client: ModelRegistry): version=v, ) - mv_iter = client.get_model_versions(name).limit(model_count - 1) + mv_iter = client.get_model_versions(name).page_size(model_count - 1) models = [] for rm in islice(mv_iter, page): models.append(rm) diff --git a/clients/python/tests/test_core.py b/clients/python/tests/test_core.py index 75c52f36..8796d801 100644 --- a/clients/python/tests/test_core.py +++ b/clients/python/tests/test_core.py @@ -76,6 +76,7 @@ async def test_get_registered_model_by_external_id( client: ModelRegistryAPIClient, registered_model: RegisteredModel, ): + assert registered_model.external_id assert ( rm := await client.get_registered_model_by_params( external_id=registered_model.external_id @@ -99,7 +100,7 @@ async def test_page_through_registered_models(client: ModelRegistryAPIClient): models = 6 for i in range(models): await client.upsert_registered_model(RegisteredModel(name=f"rm{i}")) - pager = Pager(client.get_registered_models).limit(5) + pager = Pager(client.get_registered_models).page_size(5) total = 0 async for _ in pager: total += 1 @@ -205,7 +206,7 @@ async def test_page_through_model_versions( ) pager = Pager( lambda o: client.get_model_versions(str(registered_model.id), o) - ).limit(5) + ).page_size(5) total = 0 async for _ in pager: total += 1 @@ -227,7 +228,8 @@ async def test_insert_model_artifact( "service_account_name": "test service account", } ma = await client.upsert_model_artifact( - ModelArtifact(**props), str(model_version.id) + ModelArtifact(**props), # type: ignore + str(model_version.id), ) assert ma.id assert ma.name == "test model" @@ -340,7 +342,7 @@ async def test_page_through_model_version_artifacts( await client.create_model_version_artifact(art, str(model_version.id)) pager = Pager( lambda o: client.get_model_version_artifacts(str(model_version.id), o) - ).limit(5) + ).page_size(5) total = 0 async for _ in pager: total += 1 diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index 5a566e4a..5d912fa2 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -57,10 +57,10 @@ make docker-build | URL Pattern | Handler | Action | |------------------------------------------------------------------------------------|-------------------------|----------------------------------------------| | GET /v1/healthcheck | HealthcheckHandler | Show application information. | -| GET /v1/model-registry | ModelRegistryHandler | Get all model registries, | -| GET /v1/model-registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Gets a list of all RegisteredModel entities. | -| POST /v1/model-registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Create a RegisteredModel entity. | -| GET /v1/model-registry/{model_registry_id}/registered_models/{registered_model_id} | RegisteredModelHandler | Get a RegisteredModel entity by ID | +| GET /v1/model_registry | ModelRegistryHandler | Get all model registries, | +| GET /v1/model_registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Gets a list of all RegisteredModel entities. | +| POST /v1/model_registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Create a RegisteredModel entity. | +| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | RegisteredModelHandler | Get a RegisteredModel entity by ID | ### Sample local calls ``` @@ -68,18 +68,18 @@ make docker-build curl -i localhost:4000/api/v1/healthcheck ``` ``` -# GET /v1/model-registry -curl -i localhost:4000/api/v1/model-registry +# GET /v1/model_registry +curl -i localhost:4000/api/v1/model_registry ``` ``` -# GET /v1/model-registry/{model_registry_id}/registered_models -curl -i localhost:4000/api/v1/model-registry/model-registry/registered_models +# GET /v1/model_registry/{model_registry_id}/registered_models +curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models ``` ``` -#POST /v1/model-registry/{model_registry_id}/registered_models -curl -i -X POST "http://localhost:4000/api/v1/model-registry/model-registry/registered_models" \ +#POST /v1/model_registry/{model_registry_id}/registered_models +curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \ -H "Content-Type: application/json" \ - -d '{ + -d '{ "data": { "customProperties": { "my-label9": { "metadataType": "MetadataStringValue", @@ -91,9 +91,9 @@ curl -i -X POST "http://localhost:4000/api/v1/model-registry/model-registry/regi "name": "bella", "owner": "eder", "state": "LIVE" -}' +}}' ``` ``` -# GET /v1/model-registry/{model_registry_id}/registered_models/{registered_model_id} -curl -i localhost:4000/api/v1/model-registry/model-registry/registered_models/1 +# GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} +curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1 ``` diff --git a/clients/ui/bff/api/app.go b/clients/ui/bff/api/app.go index 41d1a52a..4040c2ef 100644 --- a/clients/ui/bff/api/app.go +++ b/clients/ui/bff/api/app.go @@ -2,13 +2,14 @@ package api import ( "fmt" + "log/slog" + "net/http" + "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/ui/bff/config" "github.com/kubeflow/model-registry/ui/bff/data" "github.com/kubeflow/model-registry/ui/bff/integrations" "github.com/kubeflow/model-registry/ui/bff/internals/mocks" - "log/slog" - "net/http" ) const ( @@ -17,7 +18,7 @@ const ( ModelRegistryId = "model_registry_id" RegisteredModelId = "registered_model_id" HealthCheckPath = PathPrefix + "/healthcheck" - ModelRegistry = PathPrefix + "/model-registry" + ModelRegistry = PathPrefix + "/model_registry" RegisteredModelsPath = ModelRegistry + "/:" + ModelRegistryId + "/registered_models" RegisteredModelPath = RegisteredModelsPath + "/:" + RegisteredModelId ) diff --git a/clients/ui/bff/api/errors.go b/clients/ui/bff/api/errors.go index 6a28e0fb..b70b8aaf 100644 --- a/clients/ui/bff/api/errors.go +++ b/clients/ui/bff/api/errors.go @@ -18,6 +18,10 @@ type ErrorResponse struct { Message string `json:"message"` } +type ErrorEnvelope struct { + Error *integrations.HTTPError `json:"error"` +} + func (app *App) LogError(r *http.Request, err error) { var ( method = r.Method @@ -40,7 +44,7 @@ func (app *App) badRequestResponse(w http.ResponseWriter, r *http.Request, err e func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *integrations.HTTPError) { - env := Envelope{"error": error} + env := ErrorEnvelope{Error: error} err := app.WriteJSON(w, error.StatusCode, env, nil) diff --git a/clients/ui/bff/api/helpers.go b/clients/ui/bff/api/helpers.go index f35851af..a129e335 100644 --- a/clients/ui/bff/api/helpers.go +++ b/clients/ui/bff/api/helpers.go @@ -9,9 +9,12 @@ import ( "strings" ) -type Envelope map[string]interface{} +type Envelope[D any, M any] struct { + Data D `json:"data,omitempty"` + Metadata M `json:"metadata,omitempty"` +} -type TypedEnvelope[T any] map[string]T +type None *struct{} func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers http.Header) error { @@ -29,7 +32,11 @@ func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers h w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) - w.Write(js) + _, err = w.Write(js) + + if err != nil { + return err + } return nil } diff --git a/clients/ui/bff/api/model_registry_handler.go b/clients/ui/bff/api/model_registry_handler.go index 5a99b84f..8a85870c 100644 --- a/clients/ui/bff/api/model_registry_handler.go +++ b/clients/ui/bff/api/model_registry_handler.go @@ -2,9 +2,12 @@ package api import ( "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/ui/bff/data" "net/http" ) +type ModelRegistryListEnvelope Envelope[[]data.ModelRegistryModel, None] + func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { registries, err := app.models.ModelRegistry.FetchAllModelRegistries(app.kubernetesClient) @@ -13,8 +16,8 @@ func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps return } - modelRegistryRes := Envelope{ - "model_registry": registries, + modelRegistryRes := ModelRegistryListEnvelope{ + Data: registries, } err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil) diff --git a/clients/ui/bff/api/model_registry_handler_test.go b/clients/ui/bff/api/model_registry_handler_test.go index 8f05372e..0b4949d5 100644 --- a/clients/ui/bff/api/model_registry_handler_test.go +++ b/clients/ui/bff/api/model_registry_handler_test.go @@ -29,28 +29,20 @@ func TestModelRegistryHandler(t *testing.T) { defer rs.Body.Close() body, err := io.ReadAll(rs.Body) assert.NoError(t, err) - var modelRegistryRes Envelope - err = json.Unmarshal(body, &modelRegistryRes) + var actual ModelRegistryListEnvelope + err = json.Unmarshal(body, &actual) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rr.Code) - // Convert the unmarshalled data to the expected type - actualModelRegistry := make([]data.ModelRegistryModel, 0) - for _, v := range modelRegistryRes["model_registry"].([]interface{}) { - model := v.(map[string]interface{}) - actualModelRegistry = append(actualModelRegistry, data.ModelRegistryModel{Name: model["name"].(string), Description: model["description"].(string), DisplayName: model["displayName"].(string)}) - } - modelRegistryRes["model_registry"] = actualModelRegistry - - var expected = Envelope{ - "model_registry": []data.ModelRegistryModel{ + var expected = ModelRegistryListEnvelope{ + Data: []data.ModelRegistryModel{ {Name: "model-registry", Description: "Model registry description", DisplayName: "Model Registry"}, {Name: "model-registry-dora", Description: "Model registry dora description", DisplayName: "Model Registry Dora"}, {Name: "model-registry-bella", Description: "Model registry bella description", DisplayName: "Model Registry Bella"}, }, } - assert.Equal(t, expected, modelRegistryRes) + assert.Equal(t, expected, actual) } diff --git a/clients/ui/bff/api/registered_models_handler.go b/clients/ui/bff/api/registered_models_handler.go index 9e47d833..3e8cc71c 100644 --- a/clients/ui/bff/api/registered_models_handler.go +++ b/clients/ui/bff/api/registered_models_handler.go @@ -11,6 +11,9 @@ import ( "net/http" ) +type RegisteredModelEnvelope Envelope[*openapi.RegisteredModel, None] +type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None] + func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { //TODO (ederign) implement pagination client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) @@ -25,8 +28,8 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req return } - modelRegistryRes := Envelope{ - "registered_model_list": modelList, + modelRegistryRes := RegisteredModelListEnvelope{ + Data: modelList, } err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil) @@ -42,18 +45,20 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ return } - var model openapi.RegisteredModel - if err := json.NewDecoder(r.Body).Decode(&model); err != nil { + var envelope RegisteredModelEnvelope + if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil { app.serverErrorResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error())) return } - if err := validation.ValidateRegisteredModel(model); err != nil { + data := *envelope.Data + + if err := validation.ValidateRegisteredModel(data); err != nil { app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error())) return } - jsonData, err := json.Marshal(model) + jsonData, err := json.Marshal(data) if err != nil { app.serverErrorResponse(w, r, fmt.Errorf("error marshaling model to JSON: %w", err)) return @@ -75,8 +80,11 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ return } + response := RegisteredModelEnvelope{ + Data: createdModel, + } w.Header().Set("Location", fmt.Sprintf("%s/%s", RegisteredModelsPath, *createdModel.Id)) - err = app.WriteJSON(w, http.StatusCreated, createdModel, nil) + err = app.WriteJSON(w, http.StatusCreated, response, nil) if err != nil { app.serverErrorResponse(w, r, fmt.Errorf("error writing JSON")) return @@ -101,8 +109,8 @@ func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request return } - result := Envelope{ - "registered_model": model, + result := RegisteredModelEnvelope{ + Data: model, } err = app.WriteJSON(w, http.StatusOK, result, nil) diff --git a/clients/ui/bff/api/registered_models_handler_test.go b/clients/ui/bff/api/registered_models_handler_test.go index ec1eb039..0e6deba6 100644 --- a/clients/ui/bff/api/registered_models_handler_test.go +++ b/clients/ui/bff/api/registered_models_handler_test.go @@ -22,7 +22,7 @@ func TestGetRegisteredModelHandler(t *testing.T) { } req, err := http.NewRequest(http.MethodGet, - "/api/v1/model-registry/model-registry/registered_models/1", nil) + "/api/v1/model_registry/model-registry/registered_models/1", nil) assert.NoError(t, err) ctx := context.WithValue(req.Context(), httpClientKey, mockClient) @@ -37,19 +37,21 @@ func TestGetRegisteredModelHandler(t *testing.T) { body, err := io.ReadAll(rs.Body) assert.NoError(t, err) - var registeredModelRes TypedEnvelope[openapi.RegisteredModel] + var registeredModelRes RegisteredModelEnvelope err = json.Unmarshal(body, ®isteredModelRes) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rr.Code) - var expected = TypedEnvelope[openapi.RegisteredModel]{ - "registered_model": mocks.GetRegisteredModelMocks()[0], + mockModel := mocks.GetRegisteredModelMocks()[0] + + var expected = RegisteredModelEnvelope{ + Data: &mockModel, } //TODO assert the full structure, I couldn't get unmarshalling to work for the full customProperties values // this issue is in the test only - assert.Equal(t, expected["registered_model"].Name, registeredModelRes["registered_model"].Name) + assert.Equal(t, expected.Data.Name, registeredModelRes.Data.Name) } func TestGetAllRegisteredModelsHandler(t *testing.T) { @@ -61,7 +63,7 @@ func TestGetAllRegisteredModelsHandler(t *testing.T) { } req, err := http.NewRequest(http.MethodGet, - "/api/v1/model-registry/model-registry/registered_models", nil) + "/api/v1/model_registry/model-registry/registered_models", nil) assert.NoError(t, err) ctx := context.WithValue(req.Context(), httpClientKey, mockClient) @@ -76,20 +78,22 @@ func TestGetAllRegisteredModelsHandler(t *testing.T) { body, err := io.ReadAll(rs.Body) assert.NoError(t, err) - var registeredModelsListRes TypedEnvelope[openapi.RegisteredModelList] + var registeredModelsListRes RegisteredModelListEnvelope err = json.Unmarshal(body, ®isteredModelsListRes) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rr.Code) - var expected = TypedEnvelope[openapi.RegisteredModelList]{ - "registered_model_list": mocks.GetRegisteredModelListMock(), + modelList := mocks.GetRegisteredModelListMock() + + var expected = RegisteredModelListEnvelope{ + Data: &modelList, } - assert.Equal(t, expected["registered_model_list"].Size, registeredModelsListRes["registered_model_list"].Size) - assert.Equal(t, expected["registered_model_list"].PageSize, registeredModelsListRes["registered_model_list"].PageSize) - assert.Equal(t, expected["registered_model_list"].NextPageToken, registeredModelsListRes["registered_model_list"].NextPageToken) - assert.Equal(t, len(expected["registered_model_list"].Items), len(registeredModelsListRes["registered_model_list"].Items)) + assert.Equal(t, expected.Data.Size, registeredModelsListRes.Data.Size) + assert.Equal(t, expected.Data.PageSize, registeredModelsListRes.Data.PageSize) + assert.Equal(t, expected.Data.NextPageToken, registeredModelsListRes.Data.NextPageToken) + assert.Equal(t, len(expected.Data.Items), len(registeredModelsListRes.Data.Items)) } func TestCreateRegisteredModelHandler(t *testing.T) { @@ -100,14 +104,16 @@ func TestCreateRegisteredModelHandler(t *testing.T) { modelRegistryClient: mockMRClient, } - newModel := openapi.NewRegisteredModelCreate("Model One") - newModelJSON, err := newModel.MarshalJSON() + newModel := openapi.NewRegisteredModel("Model One") + newEnvelope := RegisteredModelEnvelope{Data: newModel} + + newModelJSON, err := json.Marshal(newEnvelope) assert.NoError(t, err) reqBody := bytes.NewReader(newModelJSON) req, err := http.NewRequest(http.MethodPost, - "/api/v1/model-registry/model-registry/registered_models", reqBody) + "/api/v1/model_registry/model-registry/registered_models", reqBody) assert.NoError(t, err) ctx := context.WithValue(req.Context(), httpClientKey, mockClient) @@ -122,14 +128,14 @@ func TestCreateRegisteredModelHandler(t *testing.T) { body, err := io.ReadAll(rs.Body) assert.NoError(t, err) - var registeredModelRes openapi.RegisteredModel - err = json.Unmarshal(body, ®isteredModelRes) + var actual RegisteredModelEnvelope + err = json.Unmarshal(body, &actual) assert.NoError(t, err) assert.Equal(t, http.StatusCreated, rr.Code) var expected = mocks.GetRegisteredModelMocks()[0] - assert.Equal(t, expected.Name, registeredModelRes.Name) + assert.Equal(t, expected.Name, actual.Data.Name) assert.NotEmpty(t, rs.Header.Get("location")) } diff --git a/clients/ui/frontend/config/webpack.dev.js b/clients/ui/frontend/config/webpack.dev.js index f7a6c806..b2198a83 100644 --- a/clients/ui/frontend/config/webpack.dev.js +++ b/clients/ui/frontend/config/webpack.dev.js @@ -18,7 +18,6 @@ module.exports = merge(common('development'), { host: HOST, port: PORT, historyApiFallback: true, - open: true, static: { directory: path.resolve(relativeDir, 'dist'), }, diff --git a/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts b/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts new file mode 100644 index 00000000..8f2bb628 --- /dev/null +++ b/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts @@ -0,0 +1,34 @@ +import { ModelArtifact, ModelArtifactState } from '~/app/types'; + +type MockModelArtifact = { + id?: string; + name?: string; + uri?: string; + state?: ModelArtifactState; + author?: string; +}; + +export const mockModelArtifact = ({ + id = '1', + name = 'test', + uri = 'test', + state = ModelArtifactState.LIVE, + author = 'Author 1', +}: MockModelArtifact): ModelArtifact => ({ + id, + name, + externalID: '1234132asdfasdf', + description: '', + createTimeSinceEpoch: '1710404288975', + lastUpdateTimeSinceEpoch: '1710404288975', + customProperties: {}, + uri, + state, + author, + modelFormatName: 'test', + storageKey: 'test', + storagePath: 'test', + modelFormatVersion: 'test', + serviceAccountName: 'test', + artifactType: 'test', +}); diff --git a/clients/ui/frontend/src/__mocks__/mockModelRegistry.ts b/clients/ui/frontend/src/__mocks__/mockModelRegistry.ts new file mode 100644 index 00000000..56fed2e3 --- /dev/null +++ b/clients/ui/frontend/src/__mocks__/mockModelRegistry.ts @@ -0,0 +1,17 @@ +import { ModelRegistry } from '~/app/types'; + +type MockModelRegistry = { + name?: string; + description?: string; + displayName?: string; +}; + +export const mockModelRegistry = ({ + name = 'modelregistry-sample', + description = 'New model registry', + displayName = 'Model Registry Sample', +}: MockModelRegistry): ModelRegistry => ({ + name, + description, + displayName, +}); diff --git a/clients/ui/frontend/src/__mocks__/mockModelRegistryResponse.ts b/clients/ui/frontend/src/__mocks__/mockModelRegistryResponse.ts new file mode 100644 index 00000000..79b1e197 --- /dev/null +++ b/clients/ui/frontend/src/__mocks__/mockModelRegistryResponse.ts @@ -0,0 +1,8 @@ +/* eslint-disable camelcase */ +import { ModelRegistryResponse } from '~/app/types'; + +export const mockModelRegistryResponse = ({ + model_registry = [], +}: Partial): ModelRegistryResponse => ({ + model_registry, +}); diff --git a/clients/ui/frontend/src/__mocks__/mockModelVersion.ts b/clients/ui/frontend/src/__mocks__/mockModelVersion.ts new file mode 100644 index 00000000..80a6f310 --- /dev/null +++ b/clients/ui/frontend/src/__mocks__/mockModelVersion.ts @@ -0,0 +1,36 @@ +import { ModelVersion, ModelState } from '~/app/types'; +import { createModelRegistryLabelsObject } from './utils'; + +type MockModelVersionType = { + author?: string; + id?: string; + registeredModelId?: string; + name?: string; + labels?: string[]; + state?: ModelState; + description?: string; + createTimeSinceEpoch?: string; + lastUpdateTimeSinceEpoch?: string; +}; + +export const mockModelVersion = ({ + author = 'Test author', + registeredModelId = '1', + name = 'new model version', + labels = [], + id = '1', + state = ModelState.LIVE, + description = 'Description of model version', + createTimeSinceEpoch = '1712234877179', + lastUpdateTimeSinceEpoch = '1712234877179', +}: MockModelVersionType): ModelVersion => ({ + author, + createTimeSinceEpoch, + customProperties: createModelRegistryLabelsObject(labels), + id, + lastUpdateTimeSinceEpoch, + name, + state, + registeredModelId, + description, +}); diff --git a/clients/ui/frontend/src/__mocks__/mockModelVersionList.ts b/clients/ui/frontend/src/__mocks__/mockModelVersionList.ts new file mode 100644 index 00000000..16a83379 --- /dev/null +++ b/clients/ui/frontend/src/__mocks__/mockModelVersionList.ts @@ -0,0 +1,11 @@ +/* eslint-disable camelcase */ +import { ModelVersionList } from '~/app/types'; + +export const mockModelVersionList = ({ + items = [], +}: Partial): ModelVersionList => ({ + items, + nextPageToken: '', + pageSize: 0, + size: items.length, +}); diff --git a/clients/ui/frontend/src/__mocks__/mockRegisteredModelsList.ts b/clients/ui/frontend/src/__mocks__/mockRegisteredModelsList.ts new file mode 100644 index 00000000..ddb525af --- /dev/null +++ b/clients/ui/frontend/src/__mocks__/mockRegisteredModelsList.ts @@ -0,0 +1,53 @@ +import { RegisteredModelList } from '~/app/types'; +import { mockRegisteredModel } from './mockRegisteredModel'; + +export const mockRegisteredModelList = ({ + size = 5, + items = [ + mockRegisteredModel({ name: 'test-1' }), + mockRegisteredModel({ name: 'test-2' }), + mockRegisteredModel({ + name: 'Fraud detection model', + description: + 'A machine learning model trained to detect fraudulent transactions in financial data', + labels: [ + 'Financial data', + 'Fraud detection', + 'Test label', + 'Machine learning', + 'Next data to be overflow', + ], + }), + mockRegisteredModel({ + name: 'Credit Scoring', + labels: [ + 'Credit Score Predictor', + 'Creditworthiness scoring system', + 'Default Risk Analyzer', + 'Portfolio Management', + 'Risk Assessment', + ], + }), + mockRegisteredModel({ + name: 'Label modal', + description: + 'A machine learning model trained to detect fraudulent transactions in financial data', + labels: [ + 'Testing label', + 'Financial data', + 'Fraud detection', + 'Long label data to be truncated abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc', + 'Machine learning', + 'Next data to be overflow', + 'Label x', + 'Label y', + 'Label z', + ], + }), + ], +}: Partial): RegisteredModelList => ({ + items, + nextPageToken: '', + pageSize: 0, + size, +}); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/appChrome.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/appChrome.ts new file mode 100644 index 00000000..8d30c9a4 --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/appChrome.ts @@ -0,0 +1,45 @@ +class AppChrome { + visit() { + cy.visit('/'); + this.wait(); + } + + private wait() { + cy.get('#dashboard-page-main'); + cy.testA11y(); + } + + // TODO: implement when authorization is enabled + // shouldBeUnauthorized() { + // cy.findByTestId('unauthorized-error'); + // return this; + // } + + findNavToggle() { + return cy.get('#page-nav-toggle'); + } + + findSideBar() { + return cy.get('#page-sidebar'); + } + + findNavSection(name: string) { + return this.findSideBar().findByRole('button', { name }); + } + + findNavItem(name: string, section?: string) { + if (section) { + this.findNavSection(section) + // do not fail if the section is not found + .should('have.length.at.least', 0) + .then(($el) => { + if ($el.attr('aria-expanded') === 'false') { + cy.wrap($el).click(); + } + }); + } + return this.findSideBar().findByRole('link', { name }); + } +} + +export const appChrome = new AppChrome(); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/home.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/home.ts deleted file mode 100644 index 4299e844..00000000 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/home.ts +++ /dev/null @@ -1,11 +0,0 @@ -class Home { - visit() { - cy.visit(`/`); - } - - findTitle() { - cy.get(`h1`).should(`have.text`, `Model registry`); - } -} - -export const home = new Home(); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts new file mode 100644 index 00000000..9133b2ce --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts @@ -0,0 +1,192 @@ +import { appChrome } from '~/__tests__/cypress/cypress/pages/appChrome'; +// import { TableRow } from './components/table'; +// import { Modal } from './components/Modal'; + +// TODO: Uncomment when the modal is implemented +// class LabelModal extends Modal { +// constructor() { +// super('Labels'); +// } + +// findModalSearchInput() { +// return cy.findByTestId('label-modal-search'); +// } + +// findCloseModal() { +// return cy.findByTestId('close-modal'); +// } + +// shouldContainsModalLabels(labels: string[]) { +// cy.findByTestId('modal-label-group').within(() => labels.map((label) => cy.contains(label))); +// return this; +// } +// } + +// TODO: Uncomment when the table is implemented +// class ModelRegistryTableRow extends TableRow { +// findName() { +// return this.find().findByTestId('model-name'); +// } + +// findDescription() { +// return this.find().findByTestId('description'); +// } + +// findOwner() { +// return this.find().findByTestId('registered-model-owner'); +// } + +// findLabelPopoverText() { +// return this.find().findByTestId('popover-label-text'); +// } + +// findLabelModalText() { +// return this.find().findByTestId('modal-label-text'); +// } + +// shouldContainsPopoverLabels(labels: string[]) { +// cy.findByTestId('popover-label-group').within(() => labels.map((label) => cy.contains(label))); +// return this; +// } + +// findModelVersionName() { +// return this.find().findByTestId('model-version-name'); +// } +// } + +class ModelRegistry { + landingPage() { + cy.visit('/'); + this.waitLanding(); + } + + visit() { + cy.visit(`/modelRegistry`); + this.wait(); + } + + navigate() { + appChrome.findNavItem('Model Registry').click(); + this.wait(); + } + + private wait() { + cy.findByTestId('app-page-title').should('exist'); + cy.findByTestId('app-page-title').contains('Model Registry'); + cy.testA11y(); + } + + private waitLanding() { + cy.findByTestId('home-page').should('be.visible'); + } + + shouldBeEmpty() { + cy.findByTestId('empty-state-title').should('exist'); + return this; + } + + findModelRegistryEmptyState() { + return cy.findByTestId('empty-model-registries-state'); + } + + shouldregisteredModelsEmpty() { + cy.findByTestId('empty-registered-models').should('exist'); + } + + shouldmodelVersionsEmpty() { + cy.findByTestId('empty-model-versions').should('exist'); + } + + shouldModelRegistrySelectorExist() { + cy.findByTestId('model-registry-selector-dropdown').should('exist'); + } + + shouldtableToolbarExist() { + cy.findByTestId('registered-models-table-toolbar').should('exist'); + } + + tabEnabled() { + appChrome.findNavItem('Model Registry').should('exist'); + return this; + } + + tabDisabled() { + appChrome.findNavItem('Model Registry').should('not.exist'); + return this; + } + + findTable() { + return cy.findByTestId('registered-model-table'); + } + + findModelVersionsTable() { + return cy.findByTestId('model-versions-table'); + } + + findTableRows() { + return this.findTable().find('tbody tr'); + } + + findModelVersionsTableRows() { + return this.findModelVersionsTable().find('tbody tr'); + } + + // TODO: Uncomment when the table row is implemented + // getRow(name: string) { + // return new ModelRegistryTableRow(() => + // this.findTable().find(`[data-label="Model name"]`).contains(name).parents('tr'), + // ); + // } + + // getModelVersionRow(name: string) { + // return new ModelRegistryTableRow(() => + // this.findModelVersionsTable() + // .find(`[data-label="Version name"]`) + // .contains(name) + // .parents('tr'), + // ); + // } + + findRegisteredModelTableHeaderButton(name: string) { + return this.findTable().find('thead').findByRole('button', { name }); + } + + findModelRegistry() { + return cy.findByTestId('model-registry-selector-dropdown'); + } + + findModelVersionsTableHeaderButton(name: string) { + return this.findModelVersionsTable().find('thead').findByRole('button', { name }); + } + + findTableSearch() { + return cy.findByTestId('registered-model-table-search'); + } + + findModelVersionsTableSearch() { + return cy.findByTestId('model-versions-table-search'); + } + + findModelBreadcrumbItem() { + return cy.findByTestId('breadcrumb-model'); + } + + findModelVersionsTableKebab() { + return cy.findByTestId('model-versions-table-kebab-action'); + } + + findModelVersionsHeaderAction() { + return cy.findByTestId('model-version-action-toggle'); + } + + findModelVersionsTableFilter() { + return cy.findByTestId('model-versions-table-filter'); + } + + findRegisterModelButton() { + return cy.findByRole('button', { name: 'Register model' }); + } +} + +export const modelRegistry = new ModelRegistry(); +// export const labelModal = new LabelModal(); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registerModelPage.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registerModelPage.ts new file mode 100644 index 00000000..7b55feae --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registerModelPage.ts @@ -0,0 +1,61 @@ +export enum FormFieldSelector { + MODEL_NAME = '#model-name', + MODEL_DESCRIPTION = '#model-description', + VERSION_NAME = '#version-name', + VERSION_DESCRIPTION = '#version-description', + SOURCE_MODEL_FORMAT = '#source-model-format', + SOURCE_MODEL_FORMAT_VERSION = '#source-model-format-version', + LOCATION_TYPE_OBJECT_STORAGE = '#location-type-object-storage', + LOCATION_ENDPOINT = '#location-endpoint', + LOCATION_BUCKET = '#location-bucket', + LOCATION_REGION = '#location-region', + LOCATION_PATH = '#location-path', + LOCATION_TYPE_URI = '#location-type-uri', + LOCATION_URI = '#location-uri', +} + +class RegisterModelPage { + visit() { + const preferredModelRegistry = 'modelregistry-sample'; + cy.visit(`/modelRegistry/${preferredModelRegistry}/registerModel`); + this.wait(); + } + + private wait() { + const preferredModelRegistry = 'modelregistry-sample'; + cy.findByTestId('app-page-title').should('exist'); + cy.findByTestId('app-page-title').contains('Register model'); + cy.findByText(`Model registry - ${preferredModelRegistry}`).should('exist'); + cy.testA11y(); + } + + findFormField(selector: FormFieldSelector) { + return cy.get(selector); + } + + findObjectStorageAutofillButton() { + return cy.findByTestId('object-storage-autofill-button'); + } + + findConnectionAutofillModal() { + return cy.findByTestId('connection-autofill-modal'); + } + + findProjectSelector() { + return this.findConnectionAutofillModal().findByTestId('project-selector-dropdown'); + } + + findConnectionSelector() { + return this.findConnectionAutofillModal().findByTestId('select-data-connection'); + } + + findAutofillButton() { + return cy.findByTestId('autofill-modal-button'); + } + + findSubmitButton() { + return cy.findByTestId('create-button'); + } +} + +export const registerModelPage = new RegisterModelPage(); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts new file mode 100644 index 00000000..e96daa3c --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts @@ -0,0 +1,138 @@ +import type { GenericStaticResponse, RouteHandlerController } from 'cypress/types/net-stubbing'; +import type { + ModelArtifact, + ModelArtifactList, + ModelRegistryResponse, + ModelVersion, + ModelVersionList, + RegisteredModel, + RegisteredModelList, +} from '~/app/types'; + +type SuccessErrorResponse = { + success: boolean; + error?: string; +}; + +type ApiResponse = + | V + | GenericStaticResponse + | RouteHandlerController; + +type Replacement = Record; +type Query = Record; + +type Options = { path?: Replacement; query?: Query; times?: number } | null; + +/* eslint-disable @typescript-eslint/no-namespace */ +declare global { + namespace Cypress { + interface Chainable { + interceptApi: (( + type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models', + options: { path: { modelRegistryName: string; apiVersion: string } }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'POST /api/:apiVersion/model_registry/:modelRegistryName/registered_models', + options: { path: { modelRegistryName: string; apiVersion: string } }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId/versions', + options: { + path: { modelRegistryName: string; apiVersion: string; registeredModelId: number }; + }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'POST /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId/versions', + options: { + path: { modelRegistryName: string; apiVersion: string; registeredModelId: number }; + }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId', + options: { + path: { modelRegistryName: string; apiVersion: string; registeredModelId: number }; + }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'PATCH /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId', + options: { + path: { modelRegistryName: string; apiVersion: string; registeredModelId: number }; + }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId', + options: { + path: { modelRegistryName: string; apiVersion: string; modelVersionId: number }; + }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId/artifacts', + options: { + path: { modelRegistryName: string; apiVersion: string; modelVersionId: number }; + }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'POST /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId/artifacts', + options: { + path: { modelRegistryName: string; apiVersion: string; modelVersionId: number }; + }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'PATCH /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId', + options: { + path: { modelRegistryName: string; apiVersion: string; modelVersionId: number }; + }, + response: ApiResponse, + ) => Cypress.Chainable) & + (( + type: 'GET /api/:apiVersion/model_registry', + options: { path: { apiVersion: string } }, + response: ApiResponse, + ) => Cypress.Chainable); + } + } +} + +Cypress.Commands.add( + 'interceptApi', + (type: string, ...args: [Options | null, ApiResponse] | [ApiResponse]) => { + if (!type) { + throw new Error('Invalid type parameter.'); + } + const options = args.length === 2 ? args[0] : null; + const response = (args.length === 2 ? args[1] : args[0]) ?? ''; + + const pathParts = type.match(/:[a-z][a-zA-Z0-9-_]+/g); + const [method, staticPathname] = type.split(' '); + let pathname = staticPathname; + if (pathParts?.length) { + if (!options || !options.path) { + throw new Error(`${type}: missing path replacements`); + } + const { path: pathReplacements } = options; + pathParts.forEach((p) => { + // remove the starting colun from the regex match + const part = p.substring(1); + const replacement = pathReplacements[part]; + if (!replacement) { + throw new Error(`${type} missing path replacement: ${part}`); + } + pathname = pathname.replace(new RegExp(`:${part}\\b`), replacement); + }); + } + return cy.intercept( + { method, pathname, query: options?.query, ...(options?.times && { times: options.times }) }, + response, + ); + }, +); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/application.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/application.ts new file mode 100644 index 00000000..98f42c4a --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/application.ts @@ -0,0 +1,271 @@ +import type { MatcherOptions } from '@testing-library/cypress'; +import type { Matcher, MatcherOptions as DTLMatcherOptions } from '@testing-library/dom'; +// import type { UserAuthConfig } from '~/__tests__/cypress/cypress/types'; +// import { TEST_USER } from '~/__tests__/cypress/cypress/utils/e2eUsers'; + +/* eslint-disable @typescript-eslint/no-namespace */ +declare global { + namespace Cypress { + interface Chainable { + // TODO: Uncomment when authorization is enabled + // /** + // * Visits the URL and performs a login if necessary. + // * Uses credentials supplied by environment variables if not provided. + // * + // * @param url the URL to visit + // * @param credentials login credentials + // */ + // visitWithLogin: (url: string, user?: UserAuthConfig) => Cypress.Chainable; + + /** + * Find a patternfly kebab toggle button. + * + * @param isDropdownToggle - True to indicate that it is a dropdown toggle instead of table kebab actions + */ + findKebab: (isDropdownToggle?: boolean) => Cypress.Chainable; + + /** + * Finds a patternfly kebab toggle button, opens the menu, and finds the action. + * + * @param name the name of the action in the kebeb menu + * @param isDropdownToggle - True to indicate that it is a dropdown toggle instead of table kebab actions + */ + findKebabAction: ( + name: string | RegExp, + isDropdownToggle?: boolean, + ) => Cypress.Chainable; + + /** + * Finds a patternfly dropdown item by first opening the dropdown if not already opened. + * + * @param name the name of the item + */ + findDropdownItem: (name: string | RegExp) => Cypress.Chainable; + + /** + * Finds a patternfly dropdown item by data-testid, first opening the dropdown if not already opened. + * + * @param testId the name of the item + */ + findDropdownItemByTestId: (testId: string) => Cypress.Chainable; + /** + * Finds a patternfly select option by first opening the select menu if not already opened. + * + * @param name the name of the option + */ + findSelectOption: (name: string | RegExp) => Cypress.Chainable; + /** + * Finds a patternfly select option by first opening the select menu if not already opened. + * + * @param testId the name of the option + */ + findSelectOptionByTestId: (testId: string) => Cypress.Chainable; + + /** + * Shortcut to first clear the previous value and then type text into DOM element. + * + * @see https://on.cypress.io/type + */ + fill: ( + text: string, + options?: Partial | undefined, + ) => Cypress.Chainable; + + /** + * Returns a PF Switch label for clickable actions. + * + * @param dataId - the data test id you provided to the PF Switch + */ + pfSwitch: (dataId: string) => Cypress.Chainable; + + /** + * Returns a PF Switch input behind the checkbox to compare .should('be.checked') like ops + * + * @param dataId + */ + pfSwitchValue: (dataId: string) => Cypress.Chainable; + + /** + * The bottom two functions, findByTestId and findAllByTestId have the disabled rule + * method-signature-style because they are overwrites. + * Thus, we cannot change it to use the property signature for functions. + * https://typescript-eslint.io/rules/method-signature-style/ + */ + + /** + * Overwrite `findByTestId` to support an array of Matchers. + * When an array of Matches is supplied, parses the data-testid attribute value as a + * whitespace-separated list of words allowing the query to mimic the CSS selector `[data-testid~=value]`. + * + * data-testid="card my-id" + * + * cy.findByTestId(['card', 'my-id']); + * cy.findByTestId('card my-id'); + */ + // eslint-disable-next-line @typescript-eslint/method-signature-style + findByTestId(id: Matcher | Matcher[], options?: MatcherOptions): Chainable; + + /** + * Overwrite `findAllByTestId` to support an array of Matchers. + * When an array of Matches is supplied, parses the data-testid attribute value as a + * whitespace-separated list of words allowing the query to mimic the CSS selector `[data-testid~=value]`. + * + * data-testid="card my-id" + * + * cy.findAllByTestId(['card']); + * cy.findAllByTestId('card my-id'); + */ + // eslint-disable-next-line @typescript-eslint/method-signature-style + findAllByTestId(id: Matcher | Matcher[], options?: MatcherOptions): Chainable; + } + } +} + +// TODO: Uncomment when authorization is enabled +// Cypress.Commands.add('visitWithLogin', (url, user = TEST_USER) => { +// if (Cypress.env('MOCK')) { +// cy.visit(url); +// } else { +// cy.intercept('GET', url, { log: false }).as('visitWithLogin'); + +// cy.visit(url, { failOnStatusCode: false }); + +// cy.wait('@visitWithLogin', { log: false }).then((interception) => { +// if (interception.response?.statusCode === 403) { +// cy.log('Do login'); +// // do login +// cy.get('form[action="/oauth/start"]').submit(); +// cy.findAllByRole('link', user.AUTH_TYPE ? { name: user.AUTH_TYPE } : {}) +// .last() +// .click(); +// cy.get('input[name=username]').type(user.USERNAME); +// cy.get('input[name=password]').type(user.PASSWORD); +// cy.get('form').submit(); +// } else if (interception.response?.statusCode !== 200) { +// throw new Error( +// `Failed to visit '${url}'. Status code: ${ +// interception.response?.statusCode || 'unknown' +// }`, +// ); +// } +// }); +// } +// }); + +Cypress.Commands.add('findKebab', { prevSubject: 'element' }, (subject, isDropdownToggle) => { + Cypress.log({ displayName: 'findKebab' }); + return cy + .wrap(subject) + .findByRole('button', { name: isDropdownToggle ? 'Actions' : 'Kebab toggle' }); +}); + +Cypress.Commands.add( + 'findKebabAction', + { prevSubject: 'element' }, + (subject, name, isDropdownToggle) => { + Cypress.log({ displayName: 'findKebab', message: name }); + return cy + .wrap(subject) + .findKebab(isDropdownToggle) + .then(($el) => { + if ($el.attr('aria-expanded') === 'false') { + cy.wrap($el).click(); + } + return cy.wrap($el.parent()).findByRole('menuitem', { name }); + }); + }, +); + +Cypress.Commands.add('findDropdownItem', { prevSubject: 'element' }, (subject, name) => { + Cypress.log({ displayName: 'findDropdownItem', message: name }); + return cy.wrap(subject).then(($el) => { + if ($el.attr('aria-expanded') === 'false') { + cy.wrap($el).click(); + } + return cy.wrap($el).parent().findByRole('menuitem', { name }); + }); +}); + +Cypress.Commands.add('findDropdownItemByTestId', { prevSubject: 'element' }, (subject, testId) => { + Cypress.log({ displayName: 'findDropdownItemByTestId', message: testId }); + return cy.wrap(subject).then(($el) => { + if ($el.attr('aria-expanded') === 'false') { + cy.wrap($el).click(); + } + return cy.wrap($el).parent().findByTestId(testId); + }); +}); + +Cypress.Commands.add('findSelectOption', { prevSubject: 'element' }, (subject, name) => { + Cypress.log({ displayName: 'findSelectOption', message: name }); + return cy.wrap(subject).then(($el) => { + if ($el.attr('aria-expanded') === 'false') { + cy.wrap($el).click(); + } + //cy.get('[role=listbox]') TODO fix cases where there are multiple listboxes + return cy.findByRole('option', { name }); + }); +}); + +Cypress.Commands.add('findSelectOptionByTestId', { prevSubject: 'element' }, (subject, testId) => { + Cypress.log({ displayName: 'findSelectOptionByTestId', message: testId }); + return cy.wrap(subject).then(($el) => { + if ($el.attr('aria-expanded') === 'false') { + cy.wrap($el).click(); + } + return cy.wrap($el).parent().findByTestId(testId); + }); +}); + +Cypress.Commands.add('fill', { prevSubject: 'optional' }, (subject, text, options) => { + cy.wrap(subject).clear(); + return cy.wrap(subject).type(text, options); +}); + +Cypress.Commands.add('pfSwitch', { prevSubject: 'optional' }, (subject, dataId) => { + Cypress.log({ displayName: 'pfSwitch', message: dataId }); + return cy.wrap(subject).findByTestId(dataId).parent(); +}); + +Cypress.Commands.add('pfSwitchValue', { prevSubject: 'optional' }, (subject, dataId) => { + Cypress.log({ displayName: 'pfSwitchValue', message: dataId }); + return cy.wrap(subject).pfSwitch(dataId).find('[type=checkbox]'); +}); + +Cypress.Commands.overwriteQuery('findByTestId', function findByTestId(...args) { + return enhancedFindByTestId(this, ...args); +}); +Cypress.Commands.overwriteQuery('findAllByTestId', function findAllByTestId(...args) { + return enhancedFindByTestId(this, ...args); +}); + +const enhancedFindByTestId = ( + command: Cypress.Command, + originalFn: Cypress.QueryFn<'findAllByTestId' | 'findByTestId'>, + matcher: Matcher | Matcher[], + options?: MatcherOptions, +) => { + if (Array.isArray(matcher)) { + return originalFn.call( + command, + (content, node) => { + const values = content.trim().split(/\s+/); + return matcher.every((m) => + values.some((v) => { + if (typeof m === 'string' || typeof m === 'number') { + return options && (options as DTLMatcherOptions).exact + ? v.toLowerCase().includes(matcher.toString().toLowerCase()) + : v === String(m); + } + if (typeof m === 'function') { + return m(v, node); + } + return m.test(v); + }), + ); + }, + options, + ); + } + return originalFn.call(command, matcher, options); +}; diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/index.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/index.ts index 4464e031..79ed0c3d 100644 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/index.ts +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/index.ts @@ -1,2 +1,4 @@ import '@testing-library/cypress/add-commands'; import './axe'; +import './application'; +import './api'; diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/application.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/application.cy.ts deleted file mode 100644 index ae7d99d3..00000000 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/application.cy.ts +++ /dev/null @@ -1,13 +0,0 @@ -import { pageNotfound } from '~/__tests__/cypress/cypress/pages/pageNoteFound'; -import { home } from '~/__tests__/cypress/cypress/pages/home'; - -describe('Application', () => { - it('Page not found should render', () => { - pageNotfound.visit(); - }); - - it('Home page should have primary button', () => { - home.visit(); - home.findTitle(); - }); -}); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry.cy.ts new file mode 100644 index 00000000..ac6c0729 --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry.cy.ts @@ -0,0 +1,226 @@ +/* eslint-disable camelcase */ +import { mockModelRegistry } from '~/__mocks__/mockModelRegistry'; +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import { mockModelVersionList } from '~/__mocks__/mockModelVersionList'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; +import { mockRegisteredModelList } from '~/__mocks__/mockRegisteredModelsList'; +import { modelRegistry } from '~/__tests__/cypress/cypress/pages/modelRegistry'; +import { mockModelRegistryResponse } from '~/__mocks__/mockModelRegistryResponse'; +import type { ModelRegistry, ModelVersion, RegisteredModel } from '~/app/types'; + +const MODEL_REGISTRY_API_VERSION = 'v1'; + +type HandlersProps = { + modelRegistries?: ModelRegistry[]; + registeredModels?: RegisteredModel[]; + modelVersions?: ModelVersion[]; +}; + +const initIntercepts = ({ + modelRegistries = [ + mockModelRegistry({ + name: 'modelregistry-sample', + description: 'New model registry', + displayName: 'Model Registry Sample', + }), + mockModelRegistry({ + name: 'modelregistry-sample-2', + description: 'New model registry 2', + displayName: 'Model Registry Sample 2', + }), + ], + registeredModels = [ + mockRegisteredModel({ + name: 'Fraud detection model', + description: + 'A machine learning model trained to detect fraudulent transactions in financial data', + labels: [ + 'Financial data', + 'Fraud detection', + 'Test label', + 'Machine learning', + 'Next data to be overflow', + ], + }), + mockRegisteredModel({ + name: 'Label modal', + description: + 'A machine learning model trained to detect fraudulent transactions in financial data', + labels: [ + 'Testing label', + 'Financial data', + 'Fraud detection', + 'Long label data to be truncated abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc', + 'Machine learning', + 'Next data to be overflow', + 'Label x', + 'Label y', + 'Label z', + ], + }), + ], + modelVersions = [ + mockModelVersion({ author: 'Author 1' }), + mockModelVersion({ name: 'model version' }), + ], +}: HandlersProps) => { + cy.interceptApi( + `GET /api/:apiVersion/model_registry`, + { + path: { apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockModelRegistryResponse({ model_registry: modelRegistries }), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models`, + { + path: { modelRegistryName: 'modelregistry-sample', apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockRegisteredModelList({ items: registeredModels }), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId/versions`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockModelVersionList({ items: modelVersions }), + ); +}; + +describe('Model Registry core', () => { + it('Model Registry Enabled in the cluster', () => { + initIntercepts({}); + + modelRegistry.visit(); + modelRegistry.navigate(); + + modelRegistry.tabEnabled(); + }); + + // it('Renders empty state with no model registries', () => { + // initIntercepts({ + // disableModelRegistryFeature: false, + // modelRegistries: [], + // }); + + // modelRegistry.visit(); + // modelRegistry.navigate(); + // modelRegistry.findModelRegistryEmptyState().should('exist'); + // }); + + it('No registered models in the selected Model Registry', () => { + initIntercepts({ + registeredModels: [], + }); + + modelRegistry.visit(); + modelRegistry.navigate(); + modelRegistry.shouldModelRegistrySelectorExist(); + // modelRegistry.shouldregisteredModelsEmpty(); + }); + + // TODO: Enable when registered model table is enabled + // describe('Registered model table', () => { + // beforeEach(() => { + // initIntercepts({ disableModelRegistryFeature: false }); + // modelRegistry.visit(); + // }); + + // it('Renders row contents', () => { + // const registeredModelRow = modelRegistry.getRow('Fraud detection model'); + // registeredModelRow.findName().contains('Fraud detection model'); + // registeredModelRow + // .findDescription() + // .contains( + // 'A machine learning model trained to detect fraudulent transactions in financial data', + // ); + // registeredModelRow.findOwner().contains('Author 1'); + + // // Label popover + // registeredModelRow.findLabelPopoverText().contains('2 more'); + // registeredModelRow.findLabelPopoverText().click(); + // registeredModelRow.shouldContainsPopoverLabels([ + // 'Machine learning', + // 'Next data to be overflow', + // ]); + // }); + + // it('Renders labels in modal', () => { + // const registeredModelRow2 = modelRegistry.getRow('Label modal'); + // registeredModelRow2.findLabelModalText().contains('6 more'); + // registeredModelRow2.findLabelModalText().click(); + // labelModal.shouldContainsModalLabels([ + // 'Testing label', + // 'Financial', + // 'Financial data', + // 'Fraud detection', + // 'Machine learning', + // 'Next data to be overflow', + // 'Label x', + // 'Label y', + // 'Label z', + // ]); + // labelModal.findModalSearchInput().type('Financial'); + // labelModal.shouldContainsModalLabels(['Financial', 'Financial data']); + // labelModal.findCloseModal().click(); + // }); + + // it('Sort by Model name', () => { + // modelRegistry.findRegisteredModelTableHeaderButton('Model name').click(); + // modelRegistry.findRegisteredModelTableHeaderButton('Model name').should(be.sortAscending); + // modelRegistry.findRegisteredModelTableHeaderButton('Model name').click(); + // modelRegistry.findRegisteredModelTableHeaderButton('Model name').should(be.sortDescending); + // }); + + // it('Sort by Last modified', () => { + // modelRegistry.findRegisteredModelTableHeaderButton('Last modified').should(be.sortAscending); + // modelRegistry.findRegisteredModelTableHeaderButton('Last modified').click(); + // modelRegistry.findRegisteredModelTableHeaderButton('Last modified').should(be.sortDescending); + // }); + + // it('Filter by keyword', () => { + // modelRegistry.findTableSearch().type('Fraud detection model'); + // modelRegistry.findTableRows().should('have.length', 1); + // modelRegistry.findTableRows().contains('Fraud detection model'); + // }); + // }); +}); + +// TODO: Enable when model registration is there +// describe('Register Model button', () => { +// it('Navigates to register page from empty state', () => { +// initIntercepts({ disableModelRegistryFeature: false, registeredModels: [] }); +// modelRegistry.visit(); +// modelRegistry.findRegisterModelButton().click(); +// cy.findByTestId('app-page-title').should('exist'); +// cy.findByTestId('app-page-title').contains('Register model'); +// cy.findByText('Model registry - modelregistry-sample').should('exist'); +// }); + +// it('Navigates to register page from table toolbar', () => { +// initIntercepts({ disableModelRegistryFeature: false }); +// modelRegistry.visit(); +// modelRegistry.findRegisterModelButton().click(); +// cy.findByTestId('app-page-title').should('exist'); +// cy.findByTestId('app-page-title').contains('Register model'); +// cy.findByText('Model registry - modelregistry-sample').should('exist'); +// }); + +// it('should be accessible for non-admin users', () => { +// asProjectEditUser(); +// initIntercepts({ +// disableModelRegistryFeature: false, +// allowed: false, +// }); + +// modelRegistry.visit(); +// modelRegistry.navigate(); +// modelRegistry.shouldModelRegistrySelectorExist(); +// }); +// }); diff --git a/clients/ui/frontend/src/app/App.tsx b/clients/ui/frontend/src/app/App.tsx index 2a8a7a81..97e644f6 100644 --- a/clients/ui/frontend/src/app/App.tsx +++ b/clients/ui/frontend/src/app/App.tsx @@ -23,6 +23,7 @@ import NavSidebar from './NavSidebar'; import AppRoutes from './AppRoutes'; import { AppContext } from './AppContext'; import { useSettings } from './useSettings'; +import { ModelRegistrySelectorContextProvider } from './context/ModelRegistrySelectorContext'; const App: React.FC = () => { const { @@ -108,7 +109,9 @@ const App: React.FC = () => { isManagedSidebar sidebar={} > - + + + ); diff --git a/clients/ui/frontend/src/app/AppRoutes.tsx b/clients/ui/frontend/src/app/AppRoutes.tsx index 10051e6d..c2383da6 100644 --- a/clients/ui/frontend/src/app/AppRoutes.tsx +++ b/clients/ui/frontend/src/app/AppRoutes.tsx @@ -1,5 +1,5 @@ import * as React from 'react'; -import { Route, Routes } from 'react-router-dom'; +import { Navigate, Route, Routes } from 'react-router-dom'; import { NotFound } from './pages/notFound/NotFound'; import ModelRegistrySettingsRoutes from './pages/settings/ModelRegistrySettingsRoutes'; import ModelRegistryRoutes from './pages/modelRegistry/ModelRegistryRoutes'; @@ -42,7 +42,7 @@ export const useAdminSettings = (): NavDataItem[] => { export const useNavData = (): NavDataItem[] => [ { label: 'Model Registry', - path: '/', + path: '/modelRegistry', }, ...useAdminSettings(), ]; @@ -52,12 +52,13 @@ const AppRoutes: React.FC = () => { return ( - } /> + } /> + } /> } /> { // TODO: Remove the linter skip when we implement authentication // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition - isAdmin && } /> + isAdmin && } /> } ); diff --git a/clients/ui/frontend/src/app/api/__tests__/errorUtils.spec.ts b/clients/ui/frontend/src/app/api/__tests__/errorUtils.spec.ts index 3c225152..fdf473f1 100644 --- a/clients/ui/frontend/src/app/api/__tests__/errorUtils.spec.ts +++ b/clients/ui/frontend/src/app/api/__tests__/errorUtils.spec.ts @@ -1,5 +1,5 @@ import { NotReadyError } from '~/utilities/useFetchState'; -import { APIError } from '~/types'; +import { APIError } from '~/app/api/types'; import { handleRestFailures } from '~/app/api/errorUtils'; import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; @@ -12,8 +12,10 @@ describe('handleRestFailures', () => { it('should handle and throw model registry errors', async () => { const statusMock: APIError = { - code: '', - message: 'error', + error: { + code: '', + message: 'error', + }, }; await expect(handleRestFailures(Promise.resolve(statusMock))).rejects.toThrow('error'); diff --git a/clients/ui/frontend/src/app/api/__tests__/service.spec.ts b/clients/ui/frontend/src/app/api/__tests__/service.spec.ts index 1e2a36e2..6bfe6e6a 100644 --- a/clients/ui/frontend/src/app/api/__tests__/service.spec.ts +++ b/clients/ui/frontend/src/app/api/__tests__/service.spec.ts @@ -45,18 +45,21 @@ const K8sAPIOptionsMock = {}; describe('createRegisteredModel', () => { it('should call restCREATE and handleRestFailures to create registered model', () => { expect( - createRegisteredModel('hostPath', 'model-registry-1')(K8sAPIOptionsMock, { - description: 'test', - externalID: '1', - name: 'test new registered model', - state: ModelState.LIVE, - customProperties: {}, - }), + createRegisteredModel(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + { + description: 'test', + externalID: '1', + name: 'test new registered model', + state: ModelState.LIVE, + customProperties: {}, + }, + ), ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models`, { description: 'test', externalID: '1', @@ -75,20 +78,23 @@ describe('createRegisteredModel', () => { describe('createModelVersion', () => { it('should call restCREATE and handleRestFailures to create model version', () => { expect( - createModelVersion('hostPath', 'model-registry-1')(K8sAPIOptionsMock, { - description: 'test', - externalID: '1', - author: 'test author', - registeredModelId: '1', - name: 'test new model version', - state: ModelState.LIVE, - customProperties: {}, - }), + createModelVersion(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + { + description: 'test', + externalID: '1', + author: 'test author', + registeredModelId: '1', + name: 'test new model version', + state: ModelState.LIVE, + customProperties: {}, + }, + ), ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions`, { description: 'test', externalID: '1', @@ -109,7 +115,9 @@ describe('createModelVersion', () => { describe('createModelVersionForRegisteredModel', () => { it('should call restCREATE and handleRestFailures to create model version for a model', () => { expect( - createModelVersionForRegisteredModel('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '1', { + createModelVersionForRegisteredModel( + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + )(K8sAPIOptionsMock, '1', { description: 'test', externalID: '1', author: 'test author', @@ -121,8 +129,8 @@ describe('createModelVersionForRegisteredModel', () => { ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models/1/versions`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models/1/versions`, { description: 'test', externalID: '1', @@ -143,25 +151,28 @@ describe('createModelVersionForRegisteredModel', () => { describe('createModelArtifact', () => { it('should call restCREATE and handleRestFailures to create model artifact', () => { expect( - createModelArtifact('hostPath', 'model-registry-1')(K8sAPIOptionsMock, { - description: 'test', - externalID: 'test', - uri: 'test-uri', - state: ModelArtifactState.LIVE, - name: 'test-name', - modelFormatName: 'test-modelformatname', - storageKey: 'teststoragekey', - storagePath: 'teststoragePath', - modelFormatVersion: 'testmodelFormatVersion', - serviceAccountName: 'testserviceAccountname', - customProperties: {}, - artifactType: 'model-artifact', - }), + createModelArtifact(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + { + description: 'test', + externalID: 'test', + uri: 'test-uri', + state: ModelArtifactState.LIVE, + name: 'test-name', + modelFormatName: 'test-modelformatname', + storageKey: 'teststoragekey', + storagePath: 'teststoragePath', + modelFormatVersion: 'testmodelFormatVersion', + serviceAccountName: 'testserviceAccountname', + customProperties: {}, + artifactType: 'model-artifact', + }, + ), ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_artifacts`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_artifacts`, { description: 'test', externalID: 'test', @@ -187,7 +198,9 @@ describe('createModelArtifact', () => { describe('createModelArtifactForModelVersion', () => { it('should call restCREATE and handleRestFailures to create model artifact for version', () => { expect( - createModelArtifactForModelVersion('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '2', { + createModelArtifactForModelVersion( + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + )(K8sAPIOptionsMock, '2', { description: 'test', externalID: 'test', uri: 'test-uri', @@ -204,8 +217,8 @@ describe('createModelArtifactForModelVersion', () => { ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions/2/artifacts`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions/2/artifacts`, { description: 'test', externalID: 'test', @@ -230,13 +243,16 @@ describe('createModelArtifactForModelVersion', () => { describe('getRegisteredModel', () => { it('should call restGET and handleRestFailures to fetch registered model', () => { - expect(getRegisteredModel('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '1')).toBe( - mockResultPromise, - ); + expect( + getRegisteredModel(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models/1`, {}, K8sAPIOptionsMock, ); @@ -247,13 +263,16 @@ describe('getRegisteredModel', () => { describe('getModelVersion', () => { it('should call restGET and handleRestFailures to fetch model version', () => { - expect(getModelVersion('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '1')).toBe( - mockResultPromise, - ); + expect( + getModelVersion(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions/1`, {}, K8sAPIOptionsMock, ); @@ -264,13 +283,16 @@ describe('getModelVersion', () => { describe('getModelArtifact', () => { it('should call restGET and handleRestFailures to fetch model version', () => { - expect(getModelArtifact('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '1')).toBe( - mockResultPromise, - ); + expect( + getModelArtifact(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_artifacts/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_artifacts/1`, {}, K8sAPIOptionsMock, ); @@ -281,11 +303,13 @@ describe('getModelArtifact', () => { describe('getListRegisteredModels', () => { it('should call restGET and handleRestFailures to list registered models', () => { - expect(getListRegisteredModels('hostPath', 'model-registry-1')({})).toBe(mockResultPromise); + expect( + getListRegisteredModels(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)({}), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models`, {}, K8sAPIOptionsMock, ); @@ -296,11 +320,13 @@ describe('getListRegisteredModels', () => { describe('getListModelArtifacts', () => { it('should call restGET and handleRestFailures to list models artifacts', () => { - expect(getListModelArtifacts('hostPath', 'model-registry-1')({})).toBe(mockResultPromise); + expect( + getListModelArtifacts(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)({}), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_artifacts`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_artifacts`, {}, K8sAPIOptionsMock, ); @@ -311,11 +337,13 @@ describe('getListModelArtifacts', () => { describe('getListModelVersions', () => { it('should call restGET and handleRestFailures to list models versions', () => { - expect(getListModelVersions('hostPath', 'model-registry-1')({})).toBe(mockResultPromise); + expect( + getListModelVersions(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)({}), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions`, {}, K8sAPIOptionsMock, ); @@ -326,13 +354,16 @@ describe('getListModelVersions', () => { describe('getModelVersionsByRegisteredModel', () => { it('should call restGET and handleRestFailures to list models versions by registered model', () => { - expect(getModelVersionsByRegisteredModel('hostPath', 'model-registry-1')({}, '1')).toBe( - mockResultPromise, - ); + expect( + getModelVersionsByRegisteredModel(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + {}, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models/1/versions`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models/1/versions`, {}, K8sAPIOptionsMock, ); @@ -343,13 +374,16 @@ describe('getModelVersionsByRegisteredModel', () => { describe('getModelArtifactsByModelVersion', () => { it('should call restGET and handleRestFailures to list models artifacts by model version', () => { - expect(getModelArtifactsByModelVersion('hostPath', 'model-registry-1')({}, '1')).toBe( - mockResultPromise, - ); + expect( + getModelArtifactsByModelVersion(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + {}, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions/1/artifacts`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions/1/artifacts`, {}, K8sAPIOptionsMock, ); @@ -361,7 +395,7 @@ describe('getModelArtifactsByModelVersion', () => { describe('patchRegisteredModel', () => { it('should call restPATCH and handleRestFailures to update registered model', () => { expect( - patchRegisteredModel('hostPath', 'model-registry-1')( + patchRegisteredModel(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( K8sAPIOptionsMock, { description: 'new test' }, '1', @@ -369,8 +403,8 @@ describe('patchRegisteredModel', () => { ).toBe(mockResultPromise); expect(restPATCHMock).toHaveBeenCalledTimes(1); expect(restPATCHMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models/1`, { description: 'new test' }, K8sAPIOptionsMock, ); @@ -382,7 +416,7 @@ describe('patchRegisteredModel', () => { describe('patchModelVersion', () => { it('should call restPATCH and handleRestFailures to update model version', () => { expect( - patchModelVersion('hostPath', 'model-registry-1')( + patchModelVersion(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( K8sAPIOptionsMock, { description: 'new test' }, '1', @@ -390,8 +424,8 @@ describe('patchModelVersion', () => { ).toBe(mockResultPromise); expect(restPATCHMock).toHaveBeenCalledTimes(1); expect(restPATCHMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions/1`, { description: 'new test' }, K8sAPIOptionsMock, ); @@ -403,7 +437,7 @@ describe('patchModelVersion', () => { describe('patchModelArtifact', () => { it('should call restPATCH and handleRestFailures to update model artifact', () => { expect( - patchModelArtifact('hostPath', 'model-registry-1')( + patchModelArtifact(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( K8sAPIOptionsMock, { description: 'new test' }, '1', @@ -411,8 +445,8 @@ describe('patchModelArtifact', () => { ).toBe(mockResultPromise); expect(restPATCHMock).toHaveBeenCalledTimes(1); expect(restPATCHMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_artifacts/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_artifacts/1`, { description: 'new test' }, K8sAPIOptionsMock, ); diff --git a/clients/ui/frontend/src/app/api/apiUtils.ts b/clients/ui/frontend/src/app/api/apiUtils.ts index d4adff6c..0af63847 100644 --- a/clients/ui/frontend/src/app/api/apiUtils.ts +++ b/clients/ui/frontend/src/app/api/apiUtils.ts @@ -1,5 +1,6 @@ -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; import { EitherOrNone } from '~/typeHelpers'; +import { ModelRegistryResponse } from '~/app/types'; export const mergeRequestInit = ( opts: APIOptions = {}, @@ -161,3 +162,12 @@ export const restDELETE = ( queryParams, parseJSON: options?.parseJSON, }); + +export const isModelRegistryResponse = (response: unknown): response is ModelRegistryResponse => { + if (typeof response === 'object' && response !== null) { + // eslint-disable-next-line @typescript-eslint/consistent-type-assertions + const modelRegistryResponse = response as { model_registry?: unknown }; + return Array.isArray(modelRegistryResponse.model_registry); + } + return false; +}; diff --git a/clients/ui/frontend/src/app/api/errorUtils.ts b/clients/ui/frontend/src/app/api/errorUtils.ts index 4cb92823..35fbdc2d 100644 --- a/clients/ui/frontend/src/app/api/errorUtils.ts +++ b/clients/ui/frontend/src/app/api/errorUtils.ts @@ -1,8 +1,7 @@ -import { APIError } from '~/types'; +import { APIError } from '~/app/api/types'; import { isCommonStateError } from '~/utilities/useFetchState'; -const isError = (e: unknown): e is APIError => - typeof e === 'object' && e !== null && ['code', 'message'].every((key) => key in e); +const isError = (e: unknown): e is APIError => typeof e === 'object' && e !== null && 'error' in e; export const handleRestFailures = (promise: Promise): Promise => promise @@ -14,7 +13,7 @@ export const handleRestFailures = (promise: Promise): Promise => }) .catch((e) => { if (isError(e)) { - throw new Error(e.message); + throw new Error(e.error.message); } if (isCommonStateError(e)) { // Common state errors are handled by useFetchState at storage level, let them deal with it diff --git a/clients/ui/frontend/src/app/api/k8s.ts b/clients/ui/frontend/src/app/api/k8s.ts index e17e55db..07f70e98 100644 --- a/clients/ui/frontend/src/app/api/k8s.ts +++ b/clients/ui/frontend/src/app/api/k8s.ts @@ -1,10 +1,17 @@ -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; import { handleRestFailures } from '~/app/api/errorUtils'; -import { restGET } from '~/app/api/apiUtils'; +import { isModelRegistryResponse, restGET } from '~/app/api/apiUtils'; import { ModelRegistry } from '~/app/types'; import { BFF_API_VERSION } from '~/app/const'; -export const getModelRegistries = +export const getListModelRegistries = (hostPath: string) => - (opts: APIOptions): Promise => - handleRestFailures(restGET(hostPath, `/api/${BFF_API_VERSION}/model_registry`, {}, opts)); + (opts: APIOptions): Promise => + handleRestFailures(restGET(hostPath, `/api/${BFF_API_VERSION}/model_registry`, {}, opts)).then( + (response) => { + if (isModelRegistryResponse(response)) { + return response.model_registry; + } + throw new Error('Invalid response format'); + }, + ); diff --git a/clients/ui/frontend/src/app/api/service.ts b/clients/ui/frontend/src/app/api/service.ts index 42f8dbb5..696c46ba 100644 --- a/clients/ui/frontend/src/app/api/service.ts +++ b/clients/ui/frontend/src/app/api/service.ts @@ -10,218 +10,107 @@ import { RegisteredModel, } from '~/app/types'; import { restCREATE, restGET, restPATCH } from '~/app/api/apiUtils'; -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; import { handleRestFailures } from '~/app/api/errorUtils'; -import { BFF_API_VERSION } from '~/app/const'; export const createRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, data: CreateRegisteredModelData): Promise => - handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models`, - data, - {}, - opts, - ), - ); + handleRestFailures(restCREATE(hostPath, `/registered_models`, data, {}, opts)); export const createModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, data: CreateModelVersionData): Promise => - handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions`, - data, - {}, - opts, - ), - ); + handleRestFailures(restCREATE(hostPath, `/model_versions`, data, {}, opts)); + export const createModelVersionForRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => ( opts: APIOptions, registeredModelId: string, data: CreateModelVersionData, ): Promise => handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models/${registeredModelId}/versions`, - data, - {}, - opts, - ), + restCREATE(hostPath, `/registered_models/${registeredModelId}/versions`, data, {}, opts), ); export const createModelArtifact = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, data: CreateModelArtifactData): Promise => - handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_artifacts`, - data, - {}, - opts, - ), - ); + handleRestFailures(restCREATE(hostPath, `/model_artifacts`, data, {}, opts)); export const createModelArtifactForModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => ( opts: APIOptions, modelVersionId: string, data: CreateModelArtifactData, ): Promise => handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions/${modelVersionId}/artifacts`, - data, - {}, - opts, - ), + restCREATE(hostPath, `/model_versions/${modelVersionId}/artifacts`, data, {}, opts), ); export const getRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, registeredModelId: string): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models/${registeredModelId}`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/registered_models/${registeredModelId}`, {}, opts)); export const getModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, modelversionId: string): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions/${modelversionId}`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_versions/${modelversionId}`, {}, opts)); export const getModelArtifact = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, modelArtifactId: string): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_artifacts/${modelArtifactId}`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_artifacts/${modelArtifactId}`, {}, opts)); export const getListModelArtifacts = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_artifacts`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_artifacts`, {}, opts)); export const getListModelVersions = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_versions`, {}, opts)); export const getListRegisteredModels = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/registered_models`, {}, opts)); export const getModelVersionsByRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, registeredmodelId: string): Promise => handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models/${registeredmodelId}/versions`, - {}, - opts, - ), + restGET(hostPath, `/registered_models/${registeredmodelId}/versions`, {}, opts), ); export const getModelArtifactsByModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, modelVersionId: string): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions/${modelVersionId}/artifacts`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_versions/${modelVersionId}/artifacts`, {}, opts)); export const patchRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => ( opts: APIOptions, data: Partial, registeredModelId: string, ): Promise => - handleRestFailures( - restPATCH( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models/${registeredModelId}`, - data, - opts, - ), - ); + handleRestFailures(restPATCH(hostPath, `/registered_models/${registeredModelId}`, data, opts)); export const patchModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, data: Partial, modelversionId: string): Promise => - handleRestFailures( - restPATCH( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions/${modelversionId}`, - data, - opts, - ), - ); + handleRestFailures(restPATCH(hostPath, `/model_versions/${modelversionId}`, data, opts)); export const patchModelArtifact = - (hostPath: string, mrName: string) => + (hostPath: string) => ( opts: APIOptions, data: Partial, modelartifactId: string, ): Promise => - handleRestFailures( - restPATCH( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_artifacts/${modelartifactId}`, - data, - opts, - ), - ); + handleRestFailures(restPATCH(hostPath, `/model_artifacts/${modelartifactId}`, data, opts)); diff --git a/clients/ui/frontend/src/app/api/types.ts b/clients/ui/frontend/src/app/api/types.ts new file mode 100644 index 00000000..e7335512 --- /dev/null +++ b/clients/ui/frontend/src/app/api/types.ts @@ -0,0 +1,19 @@ +export type APIOptions = { + dryRun?: boolean; + signal?: AbortSignal; + parseJSON?: boolean; +}; + +export type APIError = { + error: { + code: string; + message: string; + }; +}; + +export type APIState = { + /** If API will successfully call */ + apiAvailable: boolean; + /** The available API functions */ + api: T; +}; diff --git a/clients/ui/frontend/src/app/api/useAPIState.ts b/clients/ui/frontend/src/app/api/useAPIState.ts new file mode 100644 index 00000000..4783e8cb --- /dev/null +++ b/clients/ui/frontend/src/app/api/useAPIState.ts @@ -0,0 +1,32 @@ +import * as React from 'react'; +import { APIState } from '~/app/api/types'; + +const useAPIState = ( + hostPath: string | null, + createAPI: (path: string) => T, +): [apiState: APIState, refreshAPIState: () => void] => { + const [internalAPIToggleState, setInternalAPIToggleState] = React.useState(false); + + const refreshAPIState = React.useCallback(() => { + setInternalAPIToggleState((v) => !v); + }, []); + + const apiState = React.useMemo>(() => { + let path = hostPath; + if (!path) { + // TODO: we need to figure out maybe a stopgap or something + path = ''; + } + const api = createAPI(path); + + return { + apiAvailable: !!path, + api, + }; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [createAPI, hostPath, internalAPIToggleState]); + + return [apiState, refreshAPIState]; +}; + +export default useAPIState; diff --git a/clients/ui/frontend/src/app/components/EmptyStateErrorMessage.tsx b/clients/ui/frontend/src/app/components/EmptyStateErrorMessage.tsx new file mode 100644 index 00000000..af24ffaa --- /dev/null +++ b/clients/ui/frontend/src/app/components/EmptyStateErrorMessage.tsx @@ -0,0 +1,34 @@ +import * as React from 'react'; +import { + EmptyState, + EmptyStateBody, + Stack, + StackItem, + EmptyStateFooter, +} from '@patternfly/react-core'; +import { PathMissingIcon } from '@patternfly/react-icons'; + +type EmptyStateErrorMessageProps = { + children?: React.ReactNode; + title: string; + bodyText: string; +}; + +const EmptyStateErrorMessage: React.FC = ({ + title, + bodyText, + children, +}) => ( + + + + + {bodyText} + + {children && {children}} + + + +); + +export default EmptyStateErrorMessage; diff --git a/clients/ui/frontend/src/app/context/ModelRegistryContext.tsx b/clients/ui/frontend/src/app/context/ModelRegistryContext.tsx new file mode 100644 index 00000000..6c107e27 --- /dev/null +++ b/clients/ui/frontend/src/app/context/ModelRegistryContext.tsx @@ -0,0 +1,44 @@ +import * as React from 'react'; +import { BFF_API_VERSION } from '~/app/const'; +import useModelRegistryAPIState, { ModelRegistryAPIState } from './useModelRegistryAPIState'; + +export type ModelRegistryContextType = { + apiState: ModelRegistryAPIState; + refreshAPIState: () => void; +}; + +type ModelRegistryContextProviderProps = { + children: React.ReactNode; + modelRegistryName: string; +}; + +export const ModelRegistryContext = React.createContext({ + // eslint-disable-next-line @typescript-eslint/consistent-type-assertions + apiState: { apiAvailable: false, api: null as unknown as ModelRegistryAPIState['api'] }, + refreshAPIState: () => undefined, +}); + +export const ModelRegistryContextProvider: React.FC = ({ + children, + modelRegistryName, +}) => { + const hostPath = modelRegistryName + ? `/api/${BFF_API_VERSION}/model_registry/${modelRegistryName}` + : null; + + const [apiState, refreshAPIState] = useModelRegistryAPIState(hostPath); + + return ( + ({ + apiState, + refreshAPIState, + }), + [apiState, refreshAPIState], + )} + > + {children} + + ); +}; diff --git a/clients/ui/frontend/src/app/context/ModelRegistrySelectorContext.tsx b/clients/ui/frontend/src/app/context/ModelRegistrySelectorContext.tsx new file mode 100644 index 00000000..273c900a --- /dev/null +++ b/clients/ui/frontend/src/app/context/ModelRegistrySelectorContext.tsx @@ -0,0 +1,58 @@ +import * as React from 'react'; +import { ModelRegistry } from '~/app/types'; +import useModelRegistries from '~/app/hooks/useModelRegistries'; + +export type ModelRegistrySelectorContextType = { + modelRegistriesLoaded: boolean; + modelRegistriesLoadError?: Error; + modelRegistries: ModelRegistry[]; + preferredModelRegistry: ModelRegistry | undefined; + updatePreferredModelRegistry: (modelRegistry: ModelRegistry | undefined) => void; +}; + +type ModelRegistrySelectorContextProviderProps = { + children: React.ReactNode; +}; + +export const ModelRegistrySelectorContext = React.createContext({ + modelRegistriesLoaded: false, + modelRegistriesLoadError: undefined, + modelRegistries: [], + preferredModelRegistry: undefined, + updatePreferredModelRegistry: () => undefined, +}); + +export const ModelRegistrySelectorContextProvider: React.FC< + ModelRegistrySelectorContextProviderProps +> = ({ children, ...props }) => ( + + {children} + +); + +const EnabledModelRegistrySelectorContextProvider: React.FC< + ModelRegistrySelectorContextProviderProps +> = ({ children }) => { + const [modelRegistries, isLoaded, error] = useModelRegistries(); + const [preferredModelRegistry, setPreferredModelRegistry] = + React.useState(undefined); + + const firstModelRegistry = modelRegistries.length > 0 ? modelRegistries[0] : null; + + const contextValue = React.useMemo( + () => ({ + modelRegistriesLoaded: isLoaded, + modelRegistriesLoadError: error, + modelRegistries, + preferredModelRegistry: preferredModelRegistry ?? firstModelRegistry ?? undefined, + updatePreferredModelRegistry: setPreferredModelRegistry, + }), + [isLoaded, error, modelRegistries, preferredModelRegistry, firstModelRegistry], + ); + + return ( + + {children} + + ); +}; diff --git a/clients/ui/frontend/src/app/context/useModelRegistryAPIState.tsx b/clients/ui/frontend/src/app/context/useModelRegistryAPIState.tsx new file mode 100644 index 00000000..9b1465ba --- /dev/null +++ b/clients/ui/frontend/src/app/context/useModelRegistryAPIState.tsx @@ -0,0 +1,54 @@ +import React from 'react'; +import { APIState } from '~/app/api/types'; +import { ModelRegistryAPIs } from '~/app/types'; +import { + createModelArtifact, + createModelArtifactForModelVersion, + createModelVersion, + createModelVersionForRegisteredModel, + createRegisteredModel, + getListModelArtifacts, + getListModelVersions, + getListRegisteredModels, + getModelArtifact, + getModelArtifactsByModelVersion, + getModelVersion, + getModelVersionsByRegisteredModel, + getRegisteredModel, + patchModelArtifact, + patchModelVersion, + patchRegisteredModel, +} from '~/app/api/service'; +import useAPIState from '~/app/api/useAPIState'; + +export type ModelRegistryAPIState = APIState; + +const useModelRegistryAPIState = ( + hostPath: string | null, +): [apiState: ModelRegistryAPIState, refreshAPIState: () => void] => { + const createAPI = React.useCallback( + (path: string) => ({ + createRegisteredModel: createRegisteredModel(path), + createModelVersion: createModelVersion(path), + createModelVersionForRegisteredModel: createModelVersionForRegisteredModel(path), + createModelArtifact: createModelArtifact(path), + createModelArtifactForModelVersion: createModelArtifactForModelVersion(path), + getRegisteredModel: getRegisteredModel(path), + getModelVersion: getModelVersion(path), + getModelArtifact: getModelArtifact(path), + listModelArtifacts: getListModelArtifacts(path), + listModelVersions: getListModelVersions(path), + listRegisteredModels: getListRegisteredModels(path), + getModelVersionsByRegisteredModel: getModelVersionsByRegisteredModel(path), + getModelArtifactsByModelVersion: getModelArtifactsByModelVersion(path), + patchRegisteredModel: patchRegisteredModel(path), + patchModelVersion: patchModelVersion(path), + patchModelArtifact: patchModelArtifact(path), + }), + [], + ); + + return useAPIState(hostPath, createAPI); +}; + +export default useModelRegistryAPIState; diff --git a/clients/ui/frontend/src/app/hooks/__tests__/useModelArtifactsByVersionId.spec.ts b/clients/ui/frontend/src/app/hooks/__tests__/useModelArtifactsByVersionId.spec.ts new file mode 100644 index 00000000..aefe8267 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/__tests__/useModelArtifactsByVersionId.spec.ts @@ -0,0 +1,89 @@ +import { waitFor } from '@testing-library/react'; +import useModelArtifactsByVersionId from '~/app/hooks/useModelArtifactsByVersionId'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; +import { ModelRegistryAPIs } from '~/app/types'; +import { mockModelArtifact } from '~/__mocks__/mockModelArtifact'; +import { testHook } from '~/__tests__/unit/testUtils/hooks'; + +global.fetch = jest.fn(); +// Mock the useModelRegistryAPI hook +jest.mock('~/app/hooks/useModelRegistryAPI', () => ({ + useModelRegistryAPI: jest.fn(), +})); + +const mockUseModelRegistryAPI = jest.mocked(useModelRegistryAPI); + +const mockModelRegistryAPIs: ModelRegistryAPIs = { + createRegisteredModel: jest.fn(), + createModelVersionForRegisteredModel: jest.fn(), + createModelArtifactForModelVersion: jest.fn(), + getRegisteredModel: jest.fn(), + getModelVersion: jest.fn(), + listRegisteredModels: jest.fn(), + getModelVersionsByRegisteredModel: jest.fn(), + getModelArtifactsByModelVersion: jest.fn(), + patchRegisteredModel: jest.fn(), + patchModelVersion: jest.fn(), +}; + +describe('useModelArtifactsByVersionId', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should return NotReadyError if API is not available', async () => { + mockUseModelRegistryAPI.mockReturnValue({ + api: mockModelRegistryAPIs, + apiAvailable: false, + refreshAllAPI: jest.fn(), + }); + + const { result } = testHook(useModelArtifactsByVersionId)('version-id'); + + await waitFor(() => { + const [, , error] = result.current; + expect(error?.message).toBe('API not yet available'); + expect(error).toBeInstanceOf(Error); + }); + }); + + it('should return NotReadyError if modelVersionId is not provided', async () => { + mockUseModelRegistryAPI.mockReturnValue({ + api: mockModelRegistryAPIs, + apiAvailable: true, + refreshAllAPI: jest.fn(), + }); + + const { result } = testHook(useModelArtifactsByVersionId)(); + + await waitFor(() => { + const [, , error] = result.current; + expect(error?.message).toBe('No model registeredModel id'); + expect(error).toBeInstanceOf(Error); + }); + }); + + it('should fetch model artifacts if API is available and modelVersionId is provided', async () => { + const mockedResponse = { + items: [mockModelArtifact({ id: 'artifact-1' })], + size: 1, + pageSize: 1, + }; + + mockUseModelRegistryAPI.mockReturnValue({ + api: { + ...mockModelRegistryAPIs, + getModelArtifactsByModelVersion: jest.fn().mockResolvedValue(mockedResponse), + }, + apiAvailable: true, + refreshAllAPI: jest.fn(), + }); + + const { result } = testHook(useModelArtifactsByVersionId)('version-id'); + + await waitFor(() => { + const [data] = result.current; + expect(data).toEqual(mockedResponse); + }); + }); +}); diff --git a/clients/ui/frontend/src/app/hooks/useModelArtifactsByVersionId.ts b/clients/ui/frontend/src/app/hooks/useModelArtifactsByVersionId.ts new file mode 100644 index 00000000..5fb90a17 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelArtifactsByVersionId.ts @@ -0,0 +1,27 @@ +import * as React from 'react'; +import useFetchState, { FetchState, FetchStateCallbackPromise } from '~/utilities/useFetchState'; +import { ModelArtifactList } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useModelArtifactsByVersionId = (modelVersionId?: string): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + const callback = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new Error('API not yet available')); + } + if (!modelVersionId) { + return Promise.reject(new Error('No model registeredModel id')); + } + return api.getModelArtifactsByModelVersion(opts, modelVersionId); + }, + [api, apiAvailable, modelVersionId], + ); + return useFetchState( + callback, + { items: [], size: 0, pageSize: 0, nextPageToken: '' }, + { initialPromisePurity: true }, + ); +}; + +export default useModelArtifactsByVersionId; diff --git a/clients/ui/frontend/src/app/hooks/useModelRegistries.ts b/clients/ui/frontend/src/app/hooks/useModelRegistries.ts new file mode 100644 index 00000000..705256a8 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelRegistries.ts @@ -0,0 +1,15 @@ +import * as React from 'react'; +import useFetchState, { FetchState, FetchStateCallbackPromise } from '~/utilities/useFetchState'; +import { ModelRegistry } from '~/app/types'; +import { getListModelRegistries } from '~/app/api/k8s'; + +const useModelRegistries = (): FetchState => { + const listModelRegistries = React.useMemo(() => getListModelRegistries(''), []); + const callback = React.useCallback>( + (opts) => listModelRegistries(opts), + [listModelRegistries], + ); + return useFetchState(callback, [], { initialPromisePurity: true }); +}; + +export default useModelRegistries; diff --git a/clients/ui/frontend/src/app/hooks/useModelRegistryAPI.ts b/clients/ui/frontend/src/app/hooks/useModelRegistryAPI.ts new file mode 100644 index 00000000..5a211568 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelRegistryAPI.ts @@ -0,0 +1,16 @@ +import * as React from 'react'; +import { ModelRegistryAPIState } from '~/app/context/useModelRegistryAPIState'; +import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; + +type UseModelRegistryAPI = ModelRegistryAPIState & { + refreshAllAPI: () => void; +}; + +export const useModelRegistryAPI = (): UseModelRegistryAPI => { + const { apiState, refreshAPIState: refreshAllAPI } = React.useContext(ModelRegistryContext); + + return { + refreshAllAPI, + ...apiState, + }; +}; diff --git a/clients/ui/frontend/src/app/hooks/useModelVersionById.ts b/clients/ui/frontend/src/app/hooks/useModelVersionById.ts new file mode 100644 index 00000000..19b7ecd9 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelVersionById.ts @@ -0,0 +1,26 @@ +import * as React from 'react'; +import useFetchState, { FetchState, FetchStateCallbackPromise } from '~/utilities/useFetchState'; +import { ModelVersion } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useModelVersionById = (modelVersionId?: string): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + + const call = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new Error('API not yet available')); + } + if (!modelVersionId) { + return Promise.reject(new Error('No model version id')); + } + + return api.getModelVersion(opts, modelVersionId); + }, + [api, apiAvailable, modelVersionId], + ); + + return useFetchState(call, null); +}; + +export default useModelVersionById; diff --git a/clients/ui/frontend/src/app/hooks/useModelVersionsByRegisteredModel.ts b/clients/ui/frontend/src/app/hooks/useModelVersionsByRegisteredModel.ts new file mode 100644 index 00000000..c8f82f9e --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelVersionsByRegisteredModel.ts @@ -0,0 +1,32 @@ +import * as React from 'react'; +import useFetchState, { FetchState, FetchStateCallbackPromise } from '~/utilities/useFetchState'; +import { ModelVersionList } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useModelVersionsByRegisteredModel = ( + registeredModelId?: string, +): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + + const call = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new Error('API not yet available')); + } + if (!registeredModelId) { + return Promise.reject(new Error('No model registeredModel id')); + } + + return api.getModelVersionsByRegisteredModel(opts, registeredModelId); + }, + [api, apiAvailable, registeredModelId], + ); + + return useFetchState( + call, + { items: [], size: 0, pageSize: 0, nextPageToken: '' }, + { initialPromisePurity: true }, + ); +}; + +export default useModelVersionsByRegisteredModel; diff --git a/clients/ui/frontend/src/app/hooks/useRegisteredModelById.ts b/clients/ui/frontend/src/app/hooks/useRegisteredModelById.ts new file mode 100644 index 00000000..c2d45bc8 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useRegisteredModelById.ts @@ -0,0 +1,26 @@ +import * as React from 'react'; +import useFetchState, { FetchState, FetchStateCallbackPromise } from '~/utilities/useFetchState'; +import { RegisteredModel } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useRegisteredModelById = (registeredModel?: string): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + + const call = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new Error('API not yet available')); + } + if (!registeredModel) { + return Promise.reject(new Error('No registered model id')); + } + + return api.getRegisteredModel(opts, registeredModel); + }, + [api, apiAvailable, registeredModel], + ); + + return useFetchState(call, null); +}; + +export default useRegisteredModelById; diff --git a/clients/ui/frontend/src/app/hooks/useRegisteredModels.ts b/clients/ui/frontend/src/app/hooks/useRegisteredModels.ts new file mode 100644 index 00000000..6553c7ae --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useRegisteredModels.ts @@ -0,0 +1,24 @@ +import * as React from 'react'; +import useFetchState, { FetchState, FetchStateCallbackPromise } from '~/utilities/useFetchState'; +import { RegisteredModelList } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useRegisteredModels = (): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + const callback = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new Error('API not yet available')); + } + return api.listRegisteredModels(opts); + }, + [api, apiAvailable], + ); + return useFetchState( + callback, + { items: [], size: 0, pageSize: 0, nextPageToken: '' }, + { initialPromisePurity: true }, + ); +}; + +export default useRegisteredModels; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryCoreLoader.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryCoreLoader.tsx new file mode 100644 index 00000000..0c4ec86f --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryCoreLoader.tsx @@ -0,0 +1,137 @@ +import * as React from 'react'; +import { Navigate, Outlet, useParams } from 'react-router-dom'; +import { Bullseye, Alert, Popover, List, ListItem, Button } from '@patternfly/react-core'; +import { OutlinedQuestionCircleIcon } from '@patternfly/react-icons'; +import ApplicationsPage from '~/app/components/ApplicationsPage'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import { ProjectObjectType, typedEmptyImage } from '~/app/components/design/utils'; +import { ModelRegistryContextProvider } from '~/app/context/ModelRegistryContext'; +import TitleWithIcon from '~/app/components/design/TitleWithIcon'; +import EmptyModelRegistryState from './screens/components/EmptyModelRegistryState'; +import InvalidModelRegistry from './screens/InvalidModelRegistry'; +import ModelRegistrySelectorNavigator from './screens/ModelRegistrySelectorNavigator'; +import { modelRegistryUrl } from './screens/routeUtils'; + +type ApplicationPageProps = React.ComponentProps; + +type ModelRegistryCoreLoaderProps = { + getInvalidRedirectPath: (modelRegistry: string) => string; +}; + +type ApplicationPageRenderState = Pick< + ApplicationPageProps, + 'emptyStatePage' | 'empty' | 'headerContent' +>; + +const ModelRegistryCoreLoader: React.FC = ({ + getInvalidRedirectPath, +}) => { + const { modelRegistry } = useParams<{ modelRegistry: string }>(); + + const { + modelRegistriesLoaded, + modelRegistriesLoadError, + modelRegistries, + preferredModelRegistry, + updatePreferredModelRegistry, + } = React.useContext(ModelRegistrySelectorContext); + + const modelRegistryFromRoute = modelRegistries.find((mr) => mr.name === modelRegistry); + + React.useEffect(() => { + if (modelRegistryFromRoute && preferredModelRegistry?.name !== modelRegistryFromRoute.name) { + updatePreferredModelRegistry(modelRegistryFromRoute); + } + }, [modelRegistryFromRoute, updatePreferredModelRegistry, preferredModelRegistry?.name]); + + if (modelRegistriesLoadError) { + return ( + + + {modelRegistriesLoadError.message} + + + ); + } + if (!modelRegistriesLoaded) { + return Loading model registries...; + } + + let renderStateProps: ApplicationPageRenderState & { children?: React.ReactNode }; + if (modelRegistries.length === 0) { + renderStateProps = { + empty: true, + emptyStatePage: ( + ( + + )} + customAction={ + + + The person who gave you your username, or who helped you to log in for the first + time + + Someone in your IT department or help desk + A project manager or developer + + } + > + + + } + /> + ), + headerContent: null, + }; + } else if (modelRegistry) { + const foundModelRegistry = modelRegistries.find((mr) => mr.name === modelRegistry); + if (foundModelRegistry) { + // Render the content + return ( + + + + ); + } + + // They ended up on a non-valid project path + renderStateProps = { + empty: true, + emptyStatePage: , + }; + } else { + // Redirect the namespace suffix into the URL + const redirectModelRegistry = preferredModelRegistry ?? modelRegistries[0]; + return ; + } + + return ( + + } + description="Select a model registry to view and manage your registered models. Model registries provide a structured and organized way to store, share, version, deploy, and track models." + headerContent={ + modelRegistryUrl(modelRegistryName)} + /> + } + {...renderStateProps} + loaded + provideChildrenPadding + /> + ); +}; + +export default ModelRegistryCoreLoader; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx index 1d3e4c0c..40050b5a 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx @@ -1,10 +1,21 @@ import * as React from 'react'; import { Route, Routes } from 'react-router-dom'; -import ModelRegistry from './ModelRegistry'; +import ModelRegistry from './screens/ModelRegistry'; +import ModelRegistryCoreLoader from './ModelRegistryCoreLoader'; +import { modelRegistryUrl } from './screens/routeUtils'; const ModelRegistryRoutes: React.FC = () => ( - } /> + modelRegistryUrl(modelRegistry)} + /> + } + > + } /> + ); diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/InvalidModelRegistry.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/InvalidModelRegistry.tsx new file mode 100644 index 00000000..c1559a72 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/InvalidModelRegistry.tsx @@ -0,0 +1,25 @@ +import * as React from 'react'; +import EmptyStateErrorMessage from '~/app/components/EmptyStateErrorMessage'; +import { modelRegistryUrl } from './routeUtils'; +import ModelRegistrySelectorNavigator from './ModelRegistrySelectorNavigator'; + +type InvalidModelRegistryProps = { + title?: string; + modelRegistry?: string; +}; + +const InvalidModelRegistry: React.FC = ({ title, modelRegistry }) => ( + + modelRegistryUrl(modelRegistryName)} + primary + /> + +); + +export default InvalidModelRegistry; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistry.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistry.tsx similarity index 63% rename from clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistry.tsx rename to clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistry.tsx index a9edd66c..d37dda7a 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistry.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistry.tsx @@ -2,6 +2,9 @@ import React from 'react'; import ApplicationsPage from '~/app/components/ApplicationsPage'; import TitleWithIcon from '~/app/components/design/TitleWithIcon'; import { ProjectObjectType } from '~/app/components/design/utils'; +import useRegisteredModels from '~/app/hooks/useRegisteredModels'; +import ModelRegistrySelectorNavigator from './ModelRegistrySelectorNavigator'; +import { modelRegistryUrl } from './routeUtils'; type ModelRegistryProps = Omit< React.ComponentProps, @@ -15,20 +18,27 @@ type ModelRegistryProps = Omit< >; const ModelRegistry: React.FC = ({ ...pageProps }) => { - const [loaded, loadError] = [true, undefined]; // TODO: change with real usage + const [, loaded, loadError] = useRegisteredModels(); return ( + } description="Select a model registry to view and manage your registered models. Model registries provide a structured and organized way to store, share, version, deploy, and track models." + headerContent={ + modelRegistryUrl(modelRegistryName)} + /> + } loadError={loadError} loaded={loaded} provideChildrenPadding removeChildrenTopPadding - /> + > + TODO: Add table of registered models; + ); }; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistrySelector.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistrySelector.tsx new file mode 100644 index 00000000..757eeed4 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistrySelector.tsx @@ -0,0 +1,196 @@ +import * as React from 'react'; +import { + Bullseye, + Button, + DescriptionList, + DescriptionListDescription, + DescriptionListGroup, + DescriptionListTerm, + Divider, + Flex, + FlexItem, + Icon, + MenuToggle, + Popover, + Select, + SelectGroup, + SelectList, + SelectOption, + Tooltip, +} from '@patternfly/react-core'; +import truncateStyles from '@patternfly/react-styles/css/components/Truncate/truncate'; +import { InfoCircleIcon, BlueprintIcon } from '@patternfly/react-icons'; +import { useBrowserStorage } from '~/components/browserStorage'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import { ModelRegistry } from '~/app/types'; + +const MODEL_REGISTRY_FAVORITE_STORAGE_KEY = 'kubeflow.dashboard.model.registry.favorite'; + +type ModelRegistrySelectorProps = { + modelRegistry: string; + onSelection: (modelRegistry: string) => void; + primary?: boolean; +}; + +const ModelRegistrySelector: React.FC = ({ + modelRegistry, + onSelection, + primary, +}) => { + const { modelRegistries, updatePreferredModelRegistry } = React.useContext( + ModelRegistrySelectorContext, + ); + + const selection = modelRegistries.find((mr) => mr.name === modelRegistry); + const [isOpen, setIsOpen] = React.useState(false); + const [favorites, setFavorites] = useBrowserStorage( + MODEL_REGISTRY_FAVORITE_STORAGE_KEY, + [], + ); + + const selectionDisplayName = selection ? selection.displayName : modelRegistry; + + const toggleLabel = modelRegistries.length === 0 ? 'No model registries' : selectionDisplayName; + + const getMRSelectDescription = (mr: ModelRegistry) => { + const desc = mr.description || mr.name; + if (!desc) { + return; + } + const tooltipContent = ( + + + {`${mr.displayName} description`} + {desc} + + + ); + return ( + + + {desc} + + + ); + }; + + const options = [ + + + {modelRegistries.map((mr) => ( + + {mr.displayName} + + ))} + + , + ]; + + const createFavorites = (favIds: string[]) => { + const favorite: JSX.Element[] = []; + + options.forEach((item) => { + if (item.type === SelectList) { + item.props.children.filter( + (child: JSX.Element) => favIds.includes(child.props.value) && favorite.push(child), + ); + } else if (item.type === SelectGroup) { + item.props.children.props.children.filter( + (child: JSX.Element) => favIds.includes(child.props.value) && favorite.push(child), + ); + } else if (favIds.includes(item.props.value)) { + favorite.push(item); + } + }); + + return favorite; + }; + + const selector = ( + + ); + + if (primary) { + return selector; + } + + return ( + + + + + + + Model registry + + {selector} + {selection && selection.description && ( + + + + + + )} + + + ); +}; + +export default ModelRegistrySelector; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistrySelectorNavigator.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistrySelectorNavigator.tsx new file mode 100644 index 00000000..7c606fd3 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistrySelectorNavigator.tsx @@ -0,0 +1,27 @@ +import * as React from 'react'; +import { useNavigate, useParams } from 'react-router-dom'; +import ModelRegistrySelector from './ModelRegistrySelector'; + +type ModelRegistrySelectorNavigatorProps = { + getRedirectPath: (namespace: string) => string; +} & Omit, 'onSelection' | 'modelRegistry'>; + +const ModelRegistrySelectorNavigator: React.FC = ({ + getRedirectPath, + ...modelRegistrySelectorProps +}) => { + const navigate = useNavigate(); + const { modelRegistry } = useParams<{ modelRegistry: string }>(); + + return ( + { + navigate(getRedirectPath(modelRegistryName)); + }} + modelRegistry={modelRegistry ?? ''} + /> + ); +}; + +export default ModelRegistrySelectorNavigator; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/components/EmptyModelRegistryState.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/components/EmptyModelRegistryState.tsx new file mode 100644 index 00000000..e15ce55e --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/components/EmptyModelRegistryState.tsx @@ -0,0 +1,73 @@ +import React from 'react'; +import { + Button, + ButtonVariant, + EmptyState, + EmptyStateActions, + EmptyStateBody, + EmptyStateFooter, + EmptyStateVariant, +} from '@patternfly/react-core'; +import { PlusCircleIcon } from '@patternfly/react-icons'; + +type EmptyModelRegistryStateType = { + testid?: string; + title: string; + description: string; + primaryActionText?: string; + primaryActionOnClick?: () => void; + secondaryActionText?: string; + secondaryActionOnClick?: () => void; + headerIcon?: React.ComponentType; + customAction?: React.ReactNode; +}; + +const EmptyModelRegistryState: React.FC = ({ + testid, + title, + description, + primaryActionText, + secondaryActionText, + primaryActionOnClick, + secondaryActionOnClick, + headerIcon, + customAction, +}) => ( + + {description} + + {primaryActionText && ( + + + + )} + + {secondaryActionText && ( + + + + )} + + {customAction && {customAction}} + + +); + +export default EmptyModelRegistryState; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/routeUtils.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/routeUtils.ts new file mode 100644 index 00000000..e7ec95ef --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/routeUtils.ts @@ -0,0 +1,51 @@ +export const modelRegistryUrl = (preferredModelRegistry?: string): string => + `/modelRegistry/${preferredModelRegistry}`; + +export const registeredModelsUrl = (preferredModelRegistry?: string): string => + `${modelRegistryUrl(preferredModelRegistry)}/registeredModels`; + +export const registeredModelUrl = (rmId?: string, preferredModelRegistry?: string): string => + `${registeredModelsUrl(preferredModelRegistry)}/${rmId}`; + +export const registeredModelArchiveUrl = (preferredModelRegistry?: string): string => + `${registeredModelsUrl(preferredModelRegistry)}/archive`; + +export const registeredModelArchiveDetailsUrl = ( + rmId?: string, + preferredModelRegistry?: string, +): string => `${registeredModelArchiveUrl(preferredModelRegistry)}/${rmId}`; + +export const modelVersionListUrl = (rmId?: string, preferredModelRegistry?: string): string => + `${registeredModelUrl(rmId, preferredModelRegistry)}/versions`; + +export const modelVersionUrl = ( + mvId: string, + rmId?: string, + preferredModelRegistry?: string, +): string => `${modelVersionListUrl(rmId, preferredModelRegistry)}/${mvId}`; + +export const modelVersionArchiveUrl = (rmId?: string, preferredModelRegistry?: string): string => + `${modelVersionListUrl(rmId, preferredModelRegistry)}/archive`; + +export const modelVersionArchiveDetailsUrl = ( + mvId: string, + rmId?: string, + preferredModelRegistry?: string, +): string => `${modelVersionArchiveUrl(rmId, preferredModelRegistry)}/${mvId}`; + +export const registerModelUrl = (preferredModelRegistry?: string): string => + `${modelRegistryUrl(preferredModelRegistry)}/registerModel`; + +export const registerVersionUrl = (preferredModelRegistry?: string): string => + `${modelRegistryUrl(preferredModelRegistry)}/registerVersion`; + +export const registerVersionForModelUrl = ( + rmId?: string, + preferredModelRegistry?: string, +): string => `${registeredModelUrl(rmId, preferredModelRegistry)}/registerVersion`; + +export const modelVersionDeploymentsUrl = ( + mvId: string, + rmId?: string, + preferredModelRegistry?: string, +): string => `${modelVersionUrl(mvId, rmId, preferredModelRegistry)}/deployments`; diff --git a/clients/ui/frontend/src/app/types.ts b/clients/ui/frontend/src/app/types.ts index 17fcc589..da39d133 100644 --- a/clients/ui/frontend/src/app/types.ts +++ b/clients/ui/frontend/src/app/types.ts @@ -1,4 +1,4 @@ -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; export enum ModelState { LIVE = 'LIVE', @@ -21,6 +21,11 @@ export type ModelRegistry = { description: string; }; +// TODO: Change in the backend AND frontend to "items" instead of "model-registries" +export type ModelRegistryResponse = { + model_registry: ModelRegistry[]; +}; + export enum ModelRegistryMetadataType { INT = 'MetadataIntValue', DOUBLE = 'MetadataDoubleValue', diff --git a/clients/ui/frontend/src/types.ts b/clients/ui/frontend/src/types.ts index 0be5cb1a..34f4c36f 100644 --- a/clients/ui/frontend/src/types.ts +++ b/clients/ui/frontend/src/types.ts @@ -19,14 +19,3 @@ export type CommonConfig = { export type FeatureFlag = { modelRegistry: boolean; }; - -export type APIOptions = { - dryRun?: boolean; - signal?: AbortSignal; - parseJSON?: boolean; -}; - -export type APIError = { - code: string; - message: string; -}; diff --git a/clients/ui/frontend/src/utilities/useFetchState.ts b/clients/ui/frontend/src/utilities/useFetchState.ts index 64b2e3eb..aa688d34 100644 --- a/clients/ui/frontend/src/utilities/useFetchState.ts +++ b/clients/ui/frontend/src/utilities/useFetchState.ts @@ -1,5 +1,5 @@ import * as React from 'react'; -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; /** * Allows "I'm not ready" rejections if you lack a lazy provided prop