diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index 24c27384..5d912fa2 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -73,13 +73,13 @@ curl -i localhost:4000/api/v1/model_registry ``` ``` # GET /v1/model_registry/{model_registry_id}/registered_models -curl -i localhost:4000/api/v1/model_registry/model_registry_1/registered_models +curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models ``` ``` #POST /v1/model_registry/{model_registry_id}/registered_models -curl -i -X POST "http://localhost:4000/api/v1/model_registry/model_registry/registered_models" \ +curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \ -H "Content-Type: application/json" \ - -d '{ + -d '{ "data": { "customProperties": { "my-label9": { "metadataType": "MetadataStringValue", @@ -91,9 +91,9 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model_registry/regi "name": "bella", "owner": "eder", "state": "LIVE" -}' +}}' ``` ``` # GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} -curl -i localhost:4000/api/v1/model_registry/model_registry/registered_models/1 +curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1 ``` diff --git a/clients/ui/bff/api/errors.go b/clients/ui/bff/api/errors.go index 6a28e0fb..b70b8aaf 100644 --- a/clients/ui/bff/api/errors.go +++ b/clients/ui/bff/api/errors.go @@ -18,6 +18,10 @@ type ErrorResponse struct { Message string `json:"message"` } +type ErrorEnvelope struct { + Error *integrations.HTTPError `json:"error"` +} + func (app *App) LogError(r *http.Request, err error) { var ( method = r.Method @@ -40,7 +44,7 @@ func (app *App) badRequestResponse(w http.ResponseWriter, r *http.Request, err e func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *integrations.HTTPError) { - env := Envelope{"error": error} + env := ErrorEnvelope{Error: error} err := app.WriteJSON(w, error.StatusCode, env, nil) diff --git a/clients/ui/bff/api/helpers.go b/clients/ui/bff/api/helpers.go index f35851af..a129e335 100644 --- a/clients/ui/bff/api/helpers.go +++ b/clients/ui/bff/api/helpers.go @@ -9,9 +9,12 @@ import ( "strings" ) -type Envelope map[string]interface{} +type Envelope[D any, M any] struct { + Data D `json:"data,omitempty"` + Metadata M `json:"metadata,omitempty"` +} -type TypedEnvelope[T any] map[string]T +type None *struct{} func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers http.Header) error { @@ -29,7 +32,11 @@ func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers h w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) - w.Write(js) + _, err = w.Write(js) + + if err != nil { + return err + } return nil } diff --git a/clients/ui/bff/api/model_registry_handler.go b/clients/ui/bff/api/model_registry_handler.go index 5a99b84f..8a85870c 100644 --- a/clients/ui/bff/api/model_registry_handler.go +++ b/clients/ui/bff/api/model_registry_handler.go @@ -2,9 +2,12 @@ package api import ( "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/ui/bff/data" "net/http" ) +type ModelRegistryListEnvelope Envelope[[]data.ModelRegistryModel, None] + func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { registries, err := app.models.ModelRegistry.FetchAllModelRegistries(app.kubernetesClient) @@ -13,8 +16,8 @@ func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps return } - modelRegistryRes := Envelope{ - "model_registry": registries, + modelRegistryRes := ModelRegistryListEnvelope{ + Data: registries, } err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil) diff --git a/clients/ui/bff/api/model_registry_handler_test.go b/clients/ui/bff/api/model_registry_handler_test.go index 8f05372e..0b4949d5 100644 --- a/clients/ui/bff/api/model_registry_handler_test.go +++ b/clients/ui/bff/api/model_registry_handler_test.go @@ -29,28 +29,20 @@ func TestModelRegistryHandler(t *testing.T) { defer rs.Body.Close() body, err := io.ReadAll(rs.Body) assert.NoError(t, err) - var modelRegistryRes Envelope - err = json.Unmarshal(body, &modelRegistryRes) + var actual ModelRegistryListEnvelope + err = json.Unmarshal(body, &actual) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rr.Code) - // Convert the unmarshalled data to the expected type - actualModelRegistry := make([]data.ModelRegistryModel, 0) - for _, v := range modelRegistryRes["model_registry"].([]interface{}) { - model := v.(map[string]interface{}) - actualModelRegistry = append(actualModelRegistry, data.ModelRegistryModel{Name: model["name"].(string), Description: model["description"].(string), DisplayName: model["displayName"].(string)}) - } - modelRegistryRes["model_registry"] = actualModelRegistry - - var expected = Envelope{ - "model_registry": []data.ModelRegistryModel{ + var expected = ModelRegistryListEnvelope{ + Data: []data.ModelRegistryModel{ {Name: "model-registry", Description: "Model registry description", DisplayName: "Model Registry"}, {Name: "model-registry-dora", Description: "Model registry dora description", DisplayName: "Model Registry Dora"}, {Name: "model-registry-bella", Description: "Model registry bella description", DisplayName: "Model Registry Bella"}, }, } - assert.Equal(t, expected, modelRegistryRes) + assert.Equal(t, expected, actual) } diff --git a/clients/ui/bff/api/registered_models_handler.go b/clients/ui/bff/api/registered_models_handler.go index 9e47d833..3e8cc71c 100644 --- a/clients/ui/bff/api/registered_models_handler.go +++ b/clients/ui/bff/api/registered_models_handler.go @@ -11,6 +11,9 @@ import ( "net/http" ) +type RegisteredModelEnvelope Envelope[*openapi.RegisteredModel, None] +type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None] + func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { //TODO (ederign) implement pagination client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) @@ -25,8 +28,8 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req return } - modelRegistryRes := Envelope{ - "registered_model_list": modelList, + modelRegistryRes := RegisteredModelListEnvelope{ + Data: modelList, } err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil) @@ -42,18 +45,20 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ return } - var model openapi.RegisteredModel - if err := json.NewDecoder(r.Body).Decode(&model); err != nil { + var envelope RegisteredModelEnvelope + if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil { app.serverErrorResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error())) return } - if err := validation.ValidateRegisteredModel(model); err != nil { + data := *envelope.Data + + if err := validation.ValidateRegisteredModel(data); err != nil { app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error())) return } - jsonData, err := json.Marshal(model) + jsonData, err := json.Marshal(data) if err != nil { app.serverErrorResponse(w, r, fmt.Errorf("error marshaling model to JSON: %w", err)) return @@ -75,8 +80,11 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ return } + response := RegisteredModelEnvelope{ + Data: createdModel, + } w.Header().Set("Location", fmt.Sprintf("%s/%s", RegisteredModelsPath, *createdModel.Id)) - err = app.WriteJSON(w, http.StatusCreated, createdModel, nil) + err = app.WriteJSON(w, http.StatusCreated, response, nil) if err != nil { app.serverErrorResponse(w, r, fmt.Errorf("error writing JSON")) return @@ -101,8 +109,8 @@ func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request return } - result := Envelope{ - "registered_model": model, + result := RegisteredModelEnvelope{ + Data: model, } err = app.WriteJSON(w, http.StatusOK, result, nil) diff --git a/clients/ui/bff/api/registered_models_handler_test.go b/clients/ui/bff/api/registered_models_handler_test.go index ec1eb039..0e6deba6 100644 --- a/clients/ui/bff/api/registered_models_handler_test.go +++ b/clients/ui/bff/api/registered_models_handler_test.go @@ -22,7 +22,7 @@ func TestGetRegisteredModelHandler(t *testing.T) { } req, err := http.NewRequest(http.MethodGet, - "/api/v1/model-registry/model-registry/registered_models/1", nil) + "/api/v1/model_registry/model-registry/registered_models/1", nil) assert.NoError(t, err) ctx := context.WithValue(req.Context(), httpClientKey, mockClient) @@ -37,19 +37,21 @@ func TestGetRegisteredModelHandler(t *testing.T) { body, err := io.ReadAll(rs.Body) assert.NoError(t, err) - var registeredModelRes TypedEnvelope[openapi.RegisteredModel] + var registeredModelRes RegisteredModelEnvelope err = json.Unmarshal(body, ®isteredModelRes) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rr.Code) - var expected = TypedEnvelope[openapi.RegisteredModel]{ - "registered_model": mocks.GetRegisteredModelMocks()[0], + mockModel := mocks.GetRegisteredModelMocks()[0] + + var expected = RegisteredModelEnvelope{ + Data: &mockModel, } //TODO assert the full structure, I couldn't get unmarshalling to work for the full customProperties values // this issue is in the test only - assert.Equal(t, expected["registered_model"].Name, registeredModelRes["registered_model"].Name) + assert.Equal(t, expected.Data.Name, registeredModelRes.Data.Name) } func TestGetAllRegisteredModelsHandler(t *testing.T) { @@ -61,7 +63,7 @@ func TestGetAllRegisteredModelsHandler(t *testing.T) { } req, err := http.NewRequest(http.MethodGet, - "/api/v1/model-registry/model-registry/registered_models", nil) + "/api/v1/model_registry/model-registry/registered_models", nil) assert.NoError(t, err) ctx := context.WithValue(req.Context(), httpClientKey, mockClient) @@ -76,20 +78,22 @@ func TestGetAllRegisteredModelsHandler(t *testing.T) { body, err := io.ReadAll(rs.Body) assert.NoError(t, err) - var registeredModelsListRes TypedEnvelope[openapi.RegisteredModelList] + var registeredModelsListRes RegisteredModelListEnvelope err = json.Unmarshal(body, ®isteredModelsListRes) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rr.Code) - var expected = TypedEnvelope[openapi.RegisteredModelList]{ - "registered_model_list": mocks.GetRegisteredModelListMock(), + modelList := mocks.GetRegisteredModelListMock() + + var expected = RegisteredModelListEnvelope{ + Data: &modelList, } - assert.Equal(t, expected["registered_model_list"].Size, registeredModelsListRes["registered_model_list"].Size) - assert.Equal(t, expected["registered_model_list"].PageSize, registeredModelsListRes["registered_model_list"].PageSize) - assert.Equal(t, expected["registered_model_list"].NextPageToken, registeredModelsListRes["registered_model_list"].NextPageToken) - assert.Equal(t, len(expected["registered_model_list"].Items), len(registeredModelsListRes["registered_model_list"].Items)) + assert.Equal(t, expected.Data.Size, registeredModelsListRes.Data.Size) + assert.Equal(t, expected.Data.PageSize, registeredModelsListRes.Data.PageSize) + assert.Equal(t, expected.Data.NextPageToken, registeredModelsListRes.Data.NextPageToken) + assert.Equal(t, len(expected.Data.Items), len(registeredModelsListRes.Data.Items)) } func TestCreateRegisteredModelHandler(t *testing.T) { @@ -100,14 +104,16 @@ func TestCreateRegisteredModelHandler(t *testing.T) { modelRegistryClient: mockMRClient, } - newModel := openapi.NewRegisteredModelCreate("Model One") - newModelJSON, err := newModel.MarshalJSON() + newModel := openapi.NewRegisteredModel("Model One") + newEnvelope := RegisteredModelEnvelope{Data: newModel} + + newModelJSON, err := json.Marshal(newEnvelope) assert.NoError(t, err) reqBody := bytes.NewReader(newModelJSON) req, err := http.NewRequest(http.MethodPost, - "/api/v1/model-registry/model-registry/registered_models", reqBody) + "/api/v1/model_registry/model-registry/registered_models", reqBody) assert.NoError(t, err) ctx := context.WithValue(req.Context(), httpClientKey, mockClient) @@ -122,14 +128,14 @@ func TestCreateRegisteredModelHandler(t *testing.T) { body, err := io.ReadAll(rs.Body) assert.NoError(t, err) - var registeredModelRes openapi.RegisteredModel - err = json.Unmarshal(body, ®isteredModelRes) + var actual RegisteredModelEnvelope + err = json.Unmarshal(body, &actual) assert.NoError(t, err) assert.Equal(t, http.StatusCreated, rr.Code) var expected = mocks.GetRegisteredModelMocks()[0] - assert.Equal(t, expected.Name, registeredModelRes.Name) + assert.Equal(t, expected.Name, actual.Data.Name) assert.NotEmpty(t, rs.Header.Get("location")) }