From 7958ab6a8a2c2cfff00d5e5aa84abb00f6226a57 Mon Sep 17 00:00:00 2001 From: Alessio Pragliola <83355398+Al-Pragliola@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:00:33 +0200 Subject: [PATCH 01/13] feat(docs): add kind - ingress instructions to deploy model registry (#410) * feat(docs): added documentation about using mr on kind with ingress Signed-off-by: Alessio Pragliola * feat(docs): added link to the ingress guide in CONTRIBUTING.md Signed-off-by: Alessio Pragliola * chore(docs): apply suggestions from code review Co-authored-by: Matteo Mortari Signed-off-by: Alessio Pragliola <83355398+Al-Pragliola@users.noreply.github.com> * feat(docs): added port forwarding guide Signed-off-by: Alessio Pragliola Co-authored-by: Matteo Mortari * chore(docs): apply suggestions from code review Co-authored-by: Matteo Mortari Signed-off-by: Alessio Pragliola <83355398+Al-Pragliola@users.noreply.github.com> --------- Signed-off-by: Alessio Pragliola Signed-off-by: Alessio Pragliola <83355398+Al-Pragliola@users.noreply.github.com> Co-authored-by: Matteo Mortari --- CONTRIBUTING.md | 28 ++++++++++ docs/mr_kind_deploy_ingress.md | 93 ++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 docs/mr_kind_deploy_ingress.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 02ef869c..7fd47cc2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -61,6 +61,34 @@ to your PATH from your bashrc like: and now you can substitute `gmake` every time the make command is mentioned in guides (or perform the path management per the caveat). +## Local kubernetes deployment of Model Registry + +To test the Model Registry locally without mocking the k8s calls, the Model Registry backend can be deployed using Kind. + +### Prerequisites + +The following tools need to be installed in your local environment: + +- [Podman](https://podman.io/) (Docker should also work) +- [kubectl](https://kubernetes.io/docs/tasks/tools/#kubectl) +- [kind](https://kind.sigs.k8s.io/docs/user/quick-start/#installation) + +Choose the networking setup that fits your needs, either port-forwarding or Ingress. + +### Port-forwarding guide + +Create a Kind cluster with the following command: + +```sh +kind create cluster +``` + +and then follow the steps from the [Installation guide](https://www.kubeflow.org/docs/components/model-registry/installation/#standalone-installation) on the Kubeflow website, to set up the port-forwarding and deploy the Model Registry on the cluster. + +### Ingress guide + +Follow the [Ingress guide](docs/mr_kind_deploy_ingress.md) to set up the Ingress controller and deploy the Model Registry on the cluster. + ## Docker engine Several options of docker engines are available for Mac. diff --git a/docs/mr_kind_deploy_ingress.md b/docs/mr_kind_deploy_ingress.md new file mode 100644 index 00000000..28fcdaf3 --- /dev/null +++ b/docs/mr_kind_deploy_ingress.md @@ -0,0 +1,93 @@ +# Model Registry - Kind Ingress Guide + +## Create a Kind cluster ready for the ingress controller + +1. Create a file named `kind-config.yaml` with the following content: + +```yaml +kind: Cluster +apiVersion: kind.x-k8s.io/v1alpha4 +nodes: +- role: control-plane + kubeadmConfigPatches: + - | + kind: InitConfiguration + nodeRegistration: + kubeletExtraArgs: + node-labels: "ingress-ready=true" + extraPortMappings: + - containerPort: 3080 + hostPort: 3080 + protocol: TCP + - containerPort: 30443 + hostPort: 30443 + protocol: TCP +``` + +> 📖 **NOTE** +> +> ContainerPorts 3080 and 30443 are customisable, you can change them to any other port number, but make sure to update the port number in the following kubectl patch commands. + +2. Run the following command `kind create cluster --config=kind-config.yaml` + +## Install the ingress controller (nginx) on the cluster + +1. Install the controller by using `kubectl apply -f https://raw.githubusercontent.com/kubernetes/ingress-nginx/main/deploy/static/provider/kind/deploy.yaml` + + +2. Patch the ports inside the controller's deployment, by running the following commands: + +```shell +kubectl patch deployment -n ingress-nginx ingress-nginx-controller --type='json' -p='[{"op": "replace", "path": "/spec/template/spec/containers/0/ports/0/hostPort", "value": 3080}]' + +kubectl patch deployment -n ingress-nginx ingress-nginx-controller --type='json' -p='[{"op": "replace", "path": "/spec/template/spec/containers/0/ports/1/hostPort", "value": 30443}]' +``` + +## Install model registry on the cluster + +`kubectl create namespace kubeflow` + +`kubectl apply -k "https://github.com/kubeflow/model-registry/manifests/kustomize/overlays/db"` + +`kubectl wait --for=condition=available -n kubeflow deployment/model-registry-deployment --timeout=1m` + +## Apply the ingress + +1. Create a file named `mr-ingress.yaml` with the following content: + +```yaml +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: model-registry +spec: + rules: + - host: "model-registry.io" # choose a name of your liking + http: + paths: + - pathType: Prefix + path: "/" + backend: + service: + name: model-registry-service + port: + number: 8080 +``` + +2. Run the following command `kubectl apply -f mr-ingress.yaml -n kubeflow` + +3. Add the following line to the file `/etc/hosts`: + +`127.0.0.1 model-registry.io` + +## Test the ingress + +Run `curl http://model-registry.io:3080/api/model_registry/v1alpha3/registered_models`, you should see and output similar to this: + +```json +{"items":[],"nextPageToken":"","pageSize":0,"size":0} +``` + +## Teardown + +kind delete cluster From 99251f47129b4d71f70ebc53efa9e84291b9d868 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:50:33 +0000 Subject: [PATCH 02/13] build(deps): bump @patternfly/react-core from 6.0.0-alpha.102 to 6.0.0-prerelease.14 in /clients/ui/frontend (#416) Bumps [@patternfly/react-core](https://github.com/patternfly/patternfly-react) from 6.0.0-alpha.102 to 6.0.0-prerelease.14. - [Release notes](https://github.com/patternfly/patternfly-react/releases) - [Commits](https://github.com/patternfly/patternfly-react/compare/@patternfly/react-core@6.0.0-alpha.102...@patternfly/react-core@6.0.0-prerelease.14) --- updated-dependencies: - dependency-name: "@patternfly/react-core" dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- clients/ui/frontend/package-lock.json | 52 ++++++++++++++++----------- clients/ui/frontend/package.json | 2 +- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/clients/ui/frontend/package-lock.json b/clients/ui/frontend/package-lock.json index 7dbae7b0..25a2827d 100644 --- a/clients/ui/frontend/package-lock.json +++ b/clients/ui/frontend/package-lock.json @@ -9,7 +9,7 @@ "version": "0.0.1", "license": "Apache-2.0", "dependencies": { - "@patternfly/react-core": "6.0.0-alpha.102", + "@patternfly/react-core": "6.0.0-prerelease.14", "@patternfly/react-icons": "6.0.0-alpha.37", "@patternfly/react-styles": "6.0.0-alpha.35", "@patternfly/react-table": "6.0.0-alpha.101", @@ -3582,23 +3582,36 @@ } }, "node_modules/@patternfly/react-core": { - "version": "6.0.0-alpha.102", - "resolved": "https://registry.npmjs.org/@patternfly/react-core/-/react-core-6.0.0-alpha.102.tgz", - "integrity": "sha512-NjnLhgYwJ3LuA3/DDwzM10X9dlZyR9ICAKOXI2FlxRM0kTAnAK0kLx7MuQjZ7wboIjxpRaA3lG8c9zlHDrCWdQ==", - "license": "MIT", - "dependencies": { - "@patternfly/react-icons": "^6.0.0-alpha.36", - "@patternfly/react-styles": "^6.0.0-alpha.35", - "@patternfly/react-tokens": "^6.0.0-alpha.35", - "focus-trap": "7.5.4", + "version": "6.0.0-prerelease.14", + "resolved": "https://registry.npmjs.org/@patternfly/react-core/-/react-core-6.0.0-prerelease.14.tgz", + "integrity": "sha512-FkKx9p76tLXc8kVOeYs16hNOjuwn7AxCI6OPB9lEEZOcIGTPixmvNLLy/AKeCyi1rQeZvh74XZJZTIDJLhVB1w==", + "dependencies": { + "@patternfly/react-icons": "^6.0.0-prerelease.4", + "@patternfly/react-styles": "^6.0.0-prerelease.3", + "@patternfly/react-tokens": "^6.0.0-prerelease.4", + "focus-trap": "7.6.0", "react-dropzone": "^14.2.3", - "tslib": "^2.6.2" + "tslib": "^2.7.0" }, "peerDependencies": { "react": "^17 || ^18", "react-dom": "^17 || ^18" } }, + "node_modules/@patternfly/react-core/node_modules/@patternfly/react-icons": { + "version": "6.0.0-prerelease.4", + "resolved": "https://registry.npmjs.org/@patternfly/react-icons/-/react-icons-6.0.0-prerelease.4.tgz", + "integrity": "sha512-KHo0v4XG4vS5wSZ76EUOrXDM636/ikXe6lNYqbAL/KRfqhfvXHEESZnK+0p1tpoBwwEUivAmJNSdIjppBPhACg==", + "peerDependencies": { + "react": "^17 || ^18", + "react-dom": "^17 || ^18" + } + }, + "node_modules/@patternfly/react-core/node_modules/@patternfly/react-styles": { + "version": "6.0.0-prerelease.3", + "resolved": "https://registry.npmjs.org/@patternfly/react-styles/-/react-styles-6.0.0-prerelease.3.tgz", + "integrity": "sha512-VyAODCKA/PkyGMVT0A2G2TVVx1H1QKBrmXBwY11Ba3ggvuLZ2zWu+vU9LyM/HhmefOwy+5/P8bmRtLM+37D/CA==" + }, "node_modules/@patternfly/react-icons": { "version": "6.0.0-alpha.37", "resolved": "https://registry.npmjs.org/@patternfly/react-icons/-/react-icons-6.0.0-alpha.37.tgz", @@ -3634,10 +3647,9 @@ } }, "node_modules/@patternfly/react-tokens": { - "version": "6.0.0-prerelease.1", - "resolved": "https://registry.npmjs.org/@patternfly/react-tokens/-/react-tokens-6.0.0-prerelease.1.tgz", - "integrity": "sha512-drKu/J78V0De+SOhFmLKwMxKzfCj0rvC3SR8+8MoqNPSmXxholN3s4+TA4aEx1CNvuIZj0IMTFOeakS1RMdeDA==", - "license": "MIT" + "version": "6.0.0-prerelease.4", + "resolved": "https://registry.npmjs.org/@patternfly/react-tokens/-/react-tokens-6.0.0-prerelease.4.tgz", + "integrity": "sha512-T1/C6nj78zvk4zLUp3VvNI3hChPR2vEy0BasIG3AYakWoLJsdOY6qS3PlLulawNZvlK7KH0X7VRfT1Zk073R6A==" }, "node_modules/@pkgjs/parseargs": { "version": "0.11.0", @@ -10294,10 +10306,9 @@ "optional": true }, "node_modules/focus-trap": { - "version": "7.5.4", - "resolved": "https://registry.npmjs.org/focus-trap/-/focus-trap-7.5.4.tgz", - "integrity": "sha512-N7kHdlgsO/v+iD/dMoJKtsSqs5Dz/dXZVebRgJw23LDk+jMi/974zyiOYDziY2JPp8xivq9BmUGwIJMiuSBi7w==", - "license": "MIT", + "version": "7.6.0", + "resolved": "https://registry.npmjs.org/focus-trap/-/focus-trap-7.6.0.tgz", + "integrity": "sha512-1td0l3pMkWJLFipobUcGaf+5DTY4PLDDrcqoSaKP8ediO/CoWCCYk/fT/Y2A4e6TNB+Sh6clRJCjOPPnKoNHnQ==", "dependencies": { "tabbable": "^6.2.0" } @@ -20709,8 +20720,7 @@ "node_modules/tabbable": { "version": "6.2.0", "resolved": "https://registry.npmjs.org/tabbable/-/tabbable-6.2.0.tgz", - "integrity": "sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew==", - "license": "MIT" + "integrity": "sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew==" }, "node_modules/tapable": { "version": "2.2.1", diff --git a/clients/ui/frontend/package.json b/clients/ui/frontend/package.json index ce958388..6f14a588 100644 --- a/clients/ui/frontend/package.json +++ b/clients/ui/frontend/package.json @@ -90,7 +90,7 @@ "webpack-merge": "^6.0.1" }, "dependencies": { - "@patternfly/react-core": "6.0.0-alpha.102", + "@patternfly/react-core": "6.0.0-prerelease.14", "@patternfly/react-icons": "6.0.0-alpha.37", "@patternfly/react-styles": "6.0.0-alpha.35", "@patternfly/react-table": "6.0.0-alpha.101", From a263591f11f4a8b3dfb02448e54e84a07472a098 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:52:33 +0000 Subject: [PATCH 03/13] build(deps): bump dompurify and @types/dompurify in /clients/ui/frontend (#418) Bumps [dompurify](https://github.com/cure53/DOMPurify) and [@types/dompurify](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/dompurify). These dependencies needed to be updated together. Updates `dompurify` from 2.5.6 to 3.1.6 - [Release notes](https://github.com/cure53/DOMPurify/releases) - [Commits](https://github.com/cure53/DOMPurify/compare/2.5.6...3.1.6) Updates `@types/dompurify` from 2.4.0 to 3.0.5 - [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases) - [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/dompurify) --- updated-dependencies: - dependency-name: dompurify dependency-type: direct:production update-type: version-update:semver-major - dependency-name: "@types/dompurify" dependency-type: direct:development update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- clients/ui/frontend/package-lock.json | 18 ++++++++---------- clients/ui/frontend/package.json | 4 ++-- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/clients/ui/frontend/package-lock.json b/clients/ui/frontend/package-lock.json index 25a2827d..e36704e1 100644 --- a/clients/ui/frontend/package-lock.json +++ b/clients/ui/frontend/package-lock.json @@ -14,7 +14,7 @@ "@patternfly/react-styles": "6.0.0-alpha.35", "@patternfly/react-table": "6.0.0-alpha.101", "classnames": "^2.2.6", - "dompurify": "^2.2.6", + "dompurify": "^3.1.6", "lodash-es": "^4.17.15", "npm-run-all": "^4.1.5", "react": "^18", @@ -33,7 +33,7 @@ "@testing-library/react": "^16.0.0", "@testing-library/user-event": "14.5.2", "@types/classnames": "^2.3.1", - "@types/dompurify": "^2.2.6", + "@types/dompurify": "^3.0.5", "@types/jest": "^29.5.12", "@types/lodash-es": "^4.17.8", "@types/react-router-dom": "^5.3.3", @@ -4148,11 +4148,10 @@ } }, "node_modules/@types/dompurify": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/@types/dompurify/-/dompurify-2.4.0.tgz", - "integrity": "sha512-IDBwO5IZhrKvHFUl+clZxgf3hn2b/lU6H1KaBShPkQyGJUQ0xwebezIPSuiyGwfz1UzJWQl4M7BDxtHtCCPlTg==", + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/@types/dompurify/-/dompurify-3.0.5.tgz", + "integrity": "sha512-1Wg0g3BtQF7sSb27fJQAKck1HECM6zV1EB66j8JH9i3LCjYabJa0FSdiSgsD5K/RbrsR0SiraKacLB+T8ZVYAg==", "dev": true, - "license": "MIT", "dependencies": { "@types/trusted-types": "*" } @@ -8400,10 +8399,9 @@ } }, "node_modules/dompurify": { - "version": "2.5.6", - "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-2.5.6.tgz", - "integrity": "sha512-zUTaUBO8pY4+iJMPE1B9XlO2tXVYIcEA4SNGtvDELzTSCQO7RzH+j7S180BmhmJId78lqGU2z19vgVx2Sxs/PQ==", - "license": "(MPL-2.0 OR Apache-2.0)" + "version": "3.1.6", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.1.6.tgz", + "integrity": "sha512-cTOAhc36AalkjtBpfG6O8JimdTMWNXjiePT2xQH/ppBGi/4uIpmj8eKyIkMJErXWARyINV/sB38yf8JCLF5pbQ==" }, "node_modules/domutils": { "version": "2.8.0", diff --git a/clients/ui/frontend/package.json b/clients/ui/frontend/package.json index 6f14a588..eb33f8e3 100644 --- a/clients/ui/frontend/package.json +++ b/clients/ui/frontend/package.json @@ -43,7 +43,7 @@ "@types/jest": "^29.5.12", "@types/react-router-dom": "^5.3.3", "@types/classnames": "^2.3.1", - "@types/dompurify": "^2.2.6", + "@types/dompurify": "^3.0.5", "@types/showdown": "^2.0.3", "@types/lodash-es": "^4.17.8", "chai-subset": "^1.6.0", @@ -99,7 +99,7 @@ "react": "^18", "react-dom": "^18", "sass": "^1.78.0", - "dompurify": "^2.2.6", + "dompurify": "^3.1.6", "showdown": "^2.1.0", "classnames": "^2.2.6" }, From 85417f2d7a6332f6734ee8453f0af55270244fce Mon Sep 17 00:00:00 2001 From: Eder Ignatowicz Date: Tue, 24 Sep 2024 08:14:33 -0400 Subject: [PATCH 04/13] Small cleanup on cypress gitignore (#408) Signed-off-by: Eder Ignatowicz --- clients/ui/frontend/.gitignore | 1 - clients/ui/frontend/src/__tests__/cypress/.gitignore | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/ui/frontend/.gitignore b/clients/ui/frontend/.gitignore index e236f1ea..4731abfe 100644 --- a/clients/ui/frontend/.gitignore +++ b/clients/ui/frontend/.gitignore @@ -1,5 +1,4 @@ **/node_modules -src/__tests__/cypress/cypress/downloads/ dist yarn-error.log yarn.lock diff --git a/clients/ui/frontend/src/__tests__/cypress/.gitignore b/clients/ui/frontend/src/__tests__/cypress/.gitignore index c1f05361..4539e50c 100644 --- a/clients/ui/frontend/src/__tests__/cypress/.gitignore +++ b/clients/ui/frontend/src/__tests__/cypress/.gitignore @@ -1,2 +1,3 @@ coverage results +cypress/downloads/ \ No newline at end of file From 3afc987ff12b0d49edb69d707c83eb262a9504e1 Mon Sep 17 00:00:00 2001 From: Matteo Mortari Date: Tue, 24 Sep 2024 15:09:33 +0200 Subject: [PATCH 05/13] py: show Kubeflow alpha banner like website (#422) Signed-off-by: Matteo Mortari --- clients/python/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/clients/python/README.md b/clients/python/README.md index f5e0cb88..05419dcf 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -7,6 +7,12 @@ This library provides a high level interface for interacting with a model registry server. +> **Alpha** +> +> This Kubeflow component has **alpha** status with limited support. +> See the [Kubeflow versioning policies](https://www.kubeflow.org/docs/started/support/#application-status). +> The Kubeflow team is interested in your [feedback](https://github.com/kubeflow/model-registry) about the usability of the feature. + ## Installation In your Python environment, you can install the latest version of the Model Registry Python client with: From e0598fd4d0d0a3d5671234ff7baee2a2460acd54 Mon Sep 17 00:00:00 2001 From: Griffin Sullivan <48397354+Griffin-Sullivan@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:06:34 -0400 Subject: [PATCH 06/13] Add views for Model Versions and Model Details (#409) Signed-off-by: Griffin-Sullivan --- clients/ui/frontend/package-lock.json | 4 +- clients/ui/frontend/package.json | 9 +- .../cypress/cypress/support/commands/api.ts | 2 +- .../cypress/tests/mocked/modelVersions.cy.ts | 222 +++++++++++ .../__tests__/cypress/cypress/utils/url.ts | 12 + .../modelRegistry/ModelRegistryRoutes.tsx | 14 + .../ModelPropertiesDescriptionListGroup.tsx | 129 +++++++ .../screens/ModelPropertiesTableRow.tsx | 187 ++++++++++ .../ModelVersions/ModelDetailsView.tsx | 106 ++++++ .../ModelVersions/ModelVersionListView.tsx | 176 +++++++++ .../screens/ModelVersions/ModelVersions.tsx | 64 ++++ .../ModelVersionsHeaderActions.tsx | 86 +++++ .../ModelVersions/ModelVersionsTable.tsx | 34 ++ .../ModelVersionsTableColumns.ts | 40 ++ .../ModelVersions/ModelVersionsTableRow.tsx | 128 +++++++ .../ModelVersions/ModelVersionsTabs.tsx | 63 ++++ .../screens/ModelVersions/const.ts | 9 + .../RegisteredModelTableRow.tsx | 3 +- .../screens/__tests__/utils.spec.ts | 348 ++++++++++++++++++ .../components/ArchiveModelVersionModal.tsx | 90 +++++ .../components/RestoreModelVersionModal.tsx | 62 ++++ .../DashboardDescriptionListGroup.scss | 4 + .../DashboardDescriptionListGroup.tsx | 120 ++++++ .../EditableLabelsDescriptionListGroup.tsx | 218 +++++++++++ .../EditableTextDescriptionListGroup.tsx | 78 ++++ clients/ui/frontend/tsconfig.json | 1 + 26 files changed, 2200 insertions(+), 9 deletions(-) create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersions.cy.ts create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/utils/url.ts create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesTableRow.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsView.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersions.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsHeaderActions.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTable.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableColumns.ts create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTabs.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/const.ts create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/__tests__/utils.spec.ts create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/components/ArchiveModelVersionModal.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/components/RestoreModelVersionModal.tsx create mode 100644 clients/ui/frontend/src/components/DashboardDescriptionListGroup.scss create mode 100644 clients/ui/frontend/src/components/DashboardDescriptionListGroup.tsx create mode 100644 clients/ui/frontend/src/components/EditableLabelsDescriptionListGroup.tsx create mode 100644 clients/ui/frontend/src/components/EditableTextDescriptionListGroup.tsx diff --git a/clients/ui/frontend/package-lock.json b/clients/ui/frontend/package-lock.json index e36704e1..3975f834 100644 --- a/clients/ui/frontend/package-lock.json +++ b/clients/ui/frontend/package-lock.json @@ -19,6 +19,7 @@ "npm-run-all": "^4.1.5", "react": "^18", "react-dom": "^18", + "react-router": "^6.26.2", "sass": "^1.78.0", "showdown": "^2.1.0" }, @@ -3686,7 +3687,6 @@ "version": "1.19.2", "resolved": "https://registry.npmjs.org/@remix-run/router/-/router-1.19.2.tgz", "integrity": "sha512-baiMx18+IMuD1yyvOGaHM9QrVUPGGG0jC+z+IPHnRJWUAUvaKuWKyE8gjDj2rzv3sz9zOGoRSPgeBVHRhZnBlA==", - "dev": true, "license": "MIT", "engines": { "node": ">=14.0.0" @@ -18631,8 +18631,6 @@ "version": "6.26.2", "resolved": "https://registry.npmjs.org/react-router/-/react-router-6.26.2.tgz", "integrity": "sha512-tvN1iuT03kHgOFnLPfLJ8V95eijteveqdOSk+srqfePtQvqCExB8eHOYnlilbOcyJyKnYkr1vJvf7YqotAJu1A==", - "dev": true, - "license": "MIT", "dependencies": { "@remix-run/router": "1.19.2" }, diff --git a/clients/ui/frontend/package.json b/clients/ui/frontend/package.json index eb33f8e3..8748f1b6 100644 --- a/clients/ui/frontend/package.json +++ b/clients/ui/frontend/package.json @@ -24,7 +24,7 @@ "test:fix": "eslint --ext .js,.ts,.jsx,.tsx ./src --fix", "test:lint": "eslint --max-warnings 0 --ext .js,.ts,.jsx,.tsx ./src", "cypress:open": "cypress open --project src/__tests__/cypress", - "cypress:open:mock": "CY_MOCK=1 CY_WS_PORT=9002 npm run cypress:open -- ", + "cypress:open:mock": "CY_MOCK=1 npm run cypress:open -- ", "cypress:run": "cypress run -b chrome --project src/__tests__/cypress", "cypress:run:mock": "CY_MOCK=1 npm run cypress:run -- ", "cypress:server:build": "POLL_INTERVAL=9999999 FAST_POLL_INTERVAL=9999999 npm run build", @@ -40,12 +40,12 @@ "@testing-library/jest-dom": "^6.5.0", "@testing-library/react": "^16.0.0", "@testing-library/user-event": "14.5.2", - "@types/jest": "^29.5.12", - "@types/react-router-dom": "^5.3.3", "@types/classnames": "^2.3.1", "@types/dompurify": "^3.0.5", - "@types/showdown": "^2.0.3", + "@types/jest": "^29.5.12", "@types/lodash-es": "^4.17.8", + "@types/react-router-dom": "^5.3.3", + "@types/showdown": "^2.0.3", "chai-subset": "^1.6.0", "copy-webpack-plugin": "^12.0.2", "core-js": "^3.37.1", @@ -98,6 +98,7 @@ "npm-run-all": "^4.1.5", "react": "^18", "react-dom": "^18", + "react-router": "^6.26.2", "sass": "^1.78.0", "dompurify": "^3.1.6", "showdown": "^2.1.0", diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts index 730064da..edf7c007 100644 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts @@ -58,7 +58,7 @@ declare global { options: { path: { modelRegistryName: string; apiVersion: string; registeredModelId: number }; }, - response: ApiResponse, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'PATCH /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId', diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersions.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersions.cy.ts new file mode 100644 index 00000000..31a7ab37 --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersions.cy.ts @@ -0,0 +1,222 @@ +/* eslint-disable camelcase */ +import { mockModelVersionList } from '~/__mocks__/mockModelVersionList'; +import { mockRegisteredModelList } from '~/__mocks__/mockRegisteredModelsList'; +import { labelModal, modelRegistry } from '~/__tests__/cypress/cypress/pages/modelRegistry'; +import { be } from '~/__tests__/cypress/cypress/utils/should'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; +import type { ModelRegistry, ModelVersion } from '~/app/types'; +import { verifyRelativeURL } from '~/__tests__/cypress/cypress/utils/url'; +import { mockModelRegistry } from '~/__mocks__/mockModelRegistry'; +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import { mockBFFResponse } from '~/__mocks__/utils'; + +const MODEL_REGISTRY_API_VERSION = 'v1'; + +type HandlersProps = { + registeredModelsSize?: number; + modelVersions?: ModelVersion[]; + modelRegistries?: ModelRegistry[]; +}; + +const initIntercepts = ({ + registeredModelsSize = 4, + modelRegistries = [ + mockModelRegistry({ + name: 'modelregistry-sample', + description: 'New model registry', + displayName: 'Model Registry Sample', + }), + mockModelRegistry({ + name: 'modelregistry-sample-2', + description: 'New model registry 2', + displayName: 'Model Registry Sample 2', + }), + ], + modelVersions = [ + mockModelVersion({ + author: 'Author 1', + id: '1', + labels: [ + 'Financial data', + 'Fraud detection', + 'Test label', + 'Machine learning', + 'Next data to be overflow', + 'Test label x', + 'Test label y', + 'Test label z', + ], + }), + mockModelVersion({ id: '2', name: 'model version' }), + ], +}: HandlersProps) => { + cy.interceptApi( + `GET /api/:apiVersion/model_registry`, + { + path: { apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockBFFResponse(modelRegistries), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models`, + { + path: { modelRegistryName: 'modelregistry-sample', apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockBFFResponse(mockRegisteredModelList({ size: registeredModelsSize })), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId/versions`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockBFFResponse(mockModelVersionList({ items: modelVersions })), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockBFFResponse(mockRegisteredModel({})), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 1, + }, + }, + mockModelVersion({ id: '1', name: 'model version' }), + ); +}; + +describe('Model Versions', () => { + it('No model versions in the selected registered model', () => { + initIntercepts({ + modelVersions: [], + }); + + modelRegistry.visit(); + const registeredModelRow = modelRegistry.getRow('Fraud detection model'); + registeredModelRow.findName().contains('Fraud detection model').click(); + verifyRelativeURL(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + modelRegistry.shouldmodelVersionsEmpty(); + }); + + it('Model versions table browser back button should lead to Registered models table', () => { + initIntercepts({ + modelVersions: [], + }); + + modelRegistry.visit(); + const registeredModelRow = modelRegistry.getRow('Fraud detection model'); + registeredModelRow.findName().contains('Fraud detection model').click(); + verifyRelativeURL(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + cy.go('back'); + verifyRelativeURL(`/modelRegistry/modelregistry-sample`); + registeredModelRow.findName().contains('Fraud detection model').should('exist'); + }); + + it('Model versions table', () => { + // TODO: Uncomment when we fix finding dropdown items + + initIntercepts({ + modelRegistries: [ + // mockModelRegistry({ name: 'modelRegistry-1', displayName: 'modelRegistry-1' }), + mockModelRegistry({}), + ], + }); + + modelRegistry.visit(); + //modelRegistry.findModelRegistry().findSelectOption('Model Registry Sample').click(); + //cy.reload(); + const registeredModelRow = modelRegistry.getRow('Fraud detection model'); + registeredModelRow.findName().contains('Fraud detection model').click(); + verifyRelativeURL(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + modelRegistry.findModelBreadcrumbItem().contains('test'); + //modelRegistry.findModelVersionsTableKebab().findDropdownItem('View archived versions'); + //modelRegistry.findModelVersionsHeaderAction().findDropdownItem('Archive model'); + modelRegistry.findModelVersionsTable().should('be.visible'); + modelRegistry.findModelVersionsTableRows().should('have.length', 2); + + // Label modal + const modelVersionRow = modelRegistry.getModelVersionRow('new model version'); + + modelVersionRow.findLabelModalText().contains('5 more'); + modelVersionRow.findLabelModalText().click(); + labelModal.shouldContainsModalLabels([ + 'Financial', + 'Financial data', + 'Fraud detection', + 'Test label', + 'Machine learning', + 'Next data to be overflow', + 'Test label x', + 'Test label y', + 'Test label y', + ]); + labelModal.findModalSearchInput().type('Financial'); + labelModal.shouldContainsModalLabels(['Financial', 'Financial data']); + labelModal.findCloseModal().click(); + + // sort by model version name + modelRegistry.findModelVersionsTableHeaderButton('Version name').click(); + modelRegistry.findModelVersionsTableHeaderButton('Version name').should(be.sortAscending); + modelRegistry.findModelVersionsTableHeaderButton('Version name').click(); + modelRegistry.findModelVersionsTableHeaderButton('Version name').should(be.sortDescending); + + // sort by Last modified + modelRegistry.findModelVersionsTableHeaderButton('Last modified').click(); + modelRegistry.findModelVersionsTableHeaderButton('Last modified').should(be.sortAscending); + modelRegistry.findModelVersionsTableHeaderButton('Last modified').click(); + modelRegistry.findModelVersionsTableHeaderButton('Last modified').should(be.sortDescending); + + // sort by model version author + modelRegistry.findModelVersionsTableHeaderButton('Author').click(); + modelRegistry.findModelVersionsTableHeaderButton('Author').should(be.sortAscending); + modelRegistry.findModelVersionsTableHeaderButton('Author').click(); + modelRegistry.findModelVersionsTableHeaderButton('Author').should(be.sortDescending); + + // filtering by keyword + modelRegistry.findModelVersionsTableSearch().type('new model version'); + modelRegistry.findModelVersionsTableRows().should('have.length', 1); + modelRegistry.findModelVersionsTableRows().contains('new model version'); + modelRegistry.findModelVersionsTableSearch().focused().clear(); + + // filtering by model version author + modelRegistry.findModelVersionsTableFilter().findSelectOption('Author').click(); + modelRegistry.findModelVersionsTableSearch().type('Test author'); + modelRegistry.findModelVersionsTableRows().should('have.length', 1); + modelRegistry.findModelVersionsTableRows().contains('Test author'); + }); + + it('Model version details back button should lead to versions table', () => { + initIntercepts({}); + + modelRegistry.visit(); + const registeredModelRow = modelRegistry.getRow('Fraud detection model'); + registeredModelRow.findName().contains('Fraud detection model').click(); + verifyRelativeURL(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + // TODO: Uncomment when we have model version details + // const modelVersionRow = modelRegistry.getModelVersionRow('model version'); + // modelVersionRow.findModelVersionName().contains('model version').click(); + // verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions/1/details'); + // cy.findByTestId('app-page-title').should('have.text', 'model version'); + // cy.findByTestId('breadcrumb-version-name').should('have.text', 'model version'); + // cy.go('back'); + // verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions'); + }); +}); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/utils/url.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/utils/url.ts new file mode 100644 index 00000000..b976a070 --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/utils/url.ts @@ -0,0 +1,12 @@ +/** + * Verify the relative route to the cypress host + * e.g. If page is running on `https://localhost:9001/pipelines` + * calling verifyRelativeURL('/pipelines') will check whether the full URL matches the URL above + */ +export const verifyRelativeURL = (relativeURL: string): Cypress.Chainable => { + return cy + .location() + .then((location) => + cy.url().should('eq', `${location.protocol}//${location.host}${relativeURL}`), + ); +}; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx index 6e318440..c7b78d10 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx @@ -4,6 +4,8 @@ import ModelRegistry from './screens/ModelRegistry'; import ModelRegistryCoreLoader from './ModelRegistryCoreLoader'; import { modelRegistryUrl } from './screens/routeUtils'; import RegisteredModelsArchive from './screens/RegisteredModelsArchive/RegisteredModelsArchive'; +import { ModelVersionsTab } from './screens/ModelVersions/const'; +import ModelVersions from './screens/ModelVersions/ModelVersions'; const ModelRegistryRoutes: React.FC = () => ( @@ -16,6 +18,18 @@ const ModelRegistryRoutes: React.FC = () => ( } > } /> + + } /> + } + /> + } + /> + } /> + } /> } /> diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup.tsx new file mode 100644 index 00000000..aeb98464 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup.tsx @@ -0,0 +1,129 @@ +import * as React from 'react'; +import { Button } from '@patternfly/react-core'; +import { Table, Tbody, Th, Thead, Tr } from '@patternfly/react-table'; +import { PlusCircleIcon } from '@patternfly/react-icons'; +import text from '@patternfly/react-styles/css/utilities/Text/text'; +import spacing from '@patternfly/react-styles/css/utilities/Spacing/spacing'; +import DashboardDescriptionListGroup from '~/components/DashboardDescriptionListGroup'; +import { getProperties, mergeUpdatedProperty } from '~/app/pages/modelRegistry/screens/utils'; +import { ModelRegistryCustomProperties } from '~/app/types'; +import ModelPropertiesTableRow from '~/app/pages/modelRegistry/screens/ModelPropertiesTableRow'; + +type ModelPropertiesDescriptionListGroupProps = { + customProperties: ModelRegistryCustomProperties; + saveEditedCustomProperties: (properties: ModelRegistryCustomProperties) => Promise; +}; + +const ModelPropertiesDescriptionListGroup: React.FC = ({ + customProperties = {}, + saveEditedCustomProperties, +}) => { + const [editingPropertyKeys, setEditingPropertyKeys] = React.useState([]); + const setIsEditingKey = (key: string, isEditing: boolean) => + setEditingPropertyKeys([ + ...editingPropertyKeys.filter((k) => k !== key), + ...(isEditing ? [key] : []), + ]); + const [isAdding, setIsAdding] = React.useState(false); + const isEditingSomeRow = isAdding || editingPropertyKeys.length > 0; + + const [isSavingEdits, setIsSavingEdits] = React.useState(false); + + // We only show string properties with a defined value (no labels or other property types) + const filteredProperties = getProperties(customProperties); + + const [isShowingMoreProperties, setIsShowingMoreProperties] = React.useState(false); + const keys = Object.keys(filteredProperties); + const needExpandControl = keys.length > 5; + const shownKeys = isShowingMoreProperties ? keys : keys.slice(0, 5); + const numHiddenKeys = keys.length - shownKeys.length; + + // Includes keys reserved by non-string properties and labels + const allExistingKeys = Object.keys(customProperties); + + const requiredAsterisk = ( + + ); + + return ( + } + iconPosition="start" + isDisabled={isAdding || isSavingEdits} + onClick={() => setIsAdding(true)} + > + Add property + + } + isEmpty={!isAdding && keys.length === 0} + contentWhenEmpty="No properties" + > + + + + + + + + + {shownKeys.map((key) => ( + setIsEditingKey(key, isEditing)} + isSavingEdits={isSavingEdits} + setIsSavingEdits={setIsSavingEdits} + saveEditedProperty={(oldKey, newPair) => + saveEditedCustomProperties( + mergeUpdatedProperty({ customProperties, op: 'update', oldKey, newPair }), + ) + } + deleteProperty={(oldKey) => + saveEditedCustomProperties( + mergeUpdatedProperty({ customProperties, op: 'delete', oldKey }), + ) + } + /> + ))} + {isAdding && ( + + saveEditedCustomProperties( + mergeUpdatedProperty({ customProperties, op: 'create', newPair }), + ) + } + /> + )} + +
Key {isEditingSomeRow && requiredAsterisk}Value {isEditingSomeRow && requiredAsterisk} +
+ {needExpandControl && ( + + )} +
+ ); +}; + +export default ModelPropertiesDescriptionListGroup; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesTableRow.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesTableRow.tsx new file mode 100644 index 00000000..c3d75a2c --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesTableRow.tsx @@ -0,0 +1,187 @@ +import * as React from 'react'; +import { ActionsColumn, Td, Tr } from '@patternfly/react-table'; +import { + ActionList, + ActionListItem, + Button, + ExpandableSection, + FormHelperText, + HelperText, + HelperTextItem, + TextInput, +} from '@patternfly/react-core'; +import { CheckIcon, TimesIcon } from '@patternfly/react-icons'; +import { KeyValuePair } from '~/types'; +import { EitherNotBoth } from '~/typeHelpers'; + +type ModelPropertiesTableRowProps = { + allExistingKeys: string[]; + setIsEditing: (isEditing: boolean) => void; + isSavingEdits: boolean; + setIsSavingEdits: (isSaving: boolean) => void; + saveEditedProperty: (oldKey: string, newPair: KeyValuePair) => Promise; +} & EitherNotBoth< + { isAddRow: true }, + { + isEditing: boolean; + keyValuePair: KeyValuePair; + deleteProperty: (key: string) => Promise; + } +>; + +const ModelPropertiesTableRow: React.FC = ({ + isAddRow, + isEditing = isAddRow, + keyValuePair = { key: '', value: '' }, + deleteProperty = () => Promise.resolve(), + allExistingKeys, + setIsEditing, + isSavingEdits, + setIsSavingEdits, + saveEditedProperty, +}) => { + const { key, value } = keyValuePair; + const [unsavedKey, setUnsavedKey] = React.useState(key); + const [unsavedValue, setUnsavedValue] = React.useState(value); + + const [isValueExpanded, setIsValueExpanded] = React.useState(false); + + let keyValidationError: string | null = null; + if (unsavedKey !== key && allExistingKeys.includes(unsavedKey)) { + keyValidationError = 'Key must not match an existing property key or label'; + } else if (unsavedKey.length > 63) { + keyValidationError = "Key text can't exceed 63 characters"; + } + + const clearUnsavedInputs = () => { + setUnsavedKey(key); + setUnsavedValue(value); + }; + + const onEditClick = () => { + clearUnsavedInputs(); + setIsEditing(true); + }; + + const onDeleteClick = async () => { + setIsSavingEdits(true); + try { + await deleteProperty(key); + } finally { + setIsSavingEdits(false); + } + }; + + const onSaveEditsClick = async () => { + setIsSavingEdits(true); + try { + await saveEditedProperty(key, { key: unsavedKey, value: unsavedValue }); + } finally { + setIsSavingEdits(false); + } + setIsEditing(false); + }; + + const onDiscardEditsClick = () => { + clearUnsavedInputs(); + setIsEditing(false); + }; + + return ( + + + {isEditing ? ( + <> + setUnsavedKey(str)} + validated={keyValidationError ? 'error' : 'default'} + /> + {keyValidationError && ( + + + {keyValidationError} + + + )} + + ) : ( + key + )} + + + {isEditing ? ( + setUnsavedValue(str)} + /> + ) : ( + setIsValueExpanded(isExpanded)} + isExpanded={isValueExpanded} + > + {value} + + )} + + + {isEditing ? ( + + + + + + + + + ) : ( + + )} + + + ); +}; + +export default ModelPropertiesTableRow; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsView.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsView.tsx new file mode 100644 index 00000000..ec145465 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsView.tsx @@ -0,0 +1,106 @@ +import * as React from 'react'; +import { ClipboardCopy, DescriptionList, Flex, FlexItem, Content } from '@patternfly/react-core'; +import { RegisteredModel } from '~/app/types'; +import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; +import EditableTextDescriptionListGroup from '~/components/EditableTextDescriptionListGroup'; +import EditableLabelsDescriptionListGroup from '~/components/EditableLabelsDescriptionListGroup'; +import { getLabels, mergeUpdatedLabels } from '~/app/pages/modelRegistry/screens/utils'; +import ModelPropertiesDescriptionListGroup from '~/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup'; +import DashboardDescriptionListGroup from '~/components/DashboardDescriptionListGroup'; +import ModelTimestamp from '~/app/pages/modelRegistry/screens/components/ModelTimestamp'; + +type ModelDetailsViewProps = { + registeredModel: RegisteredModel; + refresh: () => void; +}; + +const ModelDetailsView: React.FC = ({ registeredModel: rm, refresh }) => { + const { apiState } = React.useContext(ModelRegistryContext); + return ( + + + + + apiState.api + .patchRegisteredModel( + {}, + { + description: value, + }, + rm.id, + ) + .then(refresh) + } + /> + + apiState.api + .patchRegisteredModel( + {}, + { + customProperties: mergeUpdatedLabels(rm.customProperties, editedLabels), + }, + rm.id, + ) + .then(refresh) + } + /> + + apiState.api + .patchRegisteredModel( + {}, + { + customProperties: editedProperties, + }, + rm.id, + ) + .then(refresh) + } + /> + + + + + + + {rm.id} + + + + + {rm.owner || '-'} + + + + + + + + + + + + ); +}; + +export default ModelDetailsView; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx new file mode 100644 index 00000000..0f9db46c --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx @@ -0,0 +1,176 @@ +import * as React from 'react'; +import { + Button, + Dropdown, + DropdownItem, + DropdownList, + MenuToggle, + MenuToggleElement, + SearchInput, + ToolbarContent, + ToolbarFilter, + ToolbarGroup, + ToolbarItem, + ToolbarToggleGroup, +} from '@patternfly/react-core'; +import { EllipsisVIcon, FilterIcon } from '@patternfly/react-icons'; +import { useNavigate } from 'react-router'; +import { ModelVersion, RegisteredModel } from '~/app/types'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import { SearchType } from '~/app/components/DashboardSearchField'; +import { + filterModelVersions, + sortModelVersionsByCreateTime, +} from '~/app/pages/modelRegistry/screens/utils'; +import EmptyModelRegistryState from '~/app/pages/modelRegistry/screens/components/EmptyModelRegistryState'; +import { ProjectObjectType, typedEmptyImage } from '~/app/components/design/utils'; +import { + modelVersionArchiveUrl, + registerVersionForModelUrl, +} from '~/app/pages/modelRegistry/screens/routeUtils'; +import { asEnumMember } from '~/app/utils'; +import ModelVersionsTable from '~/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTable'; +import SimpleSelect from '~/app/components/SimpleSelect'; + +type ModelVersionListViewProps = { + modelVersions: ModelVersion[]; + registeredModel?: RegisteredModel; + refresh: () => void; +}; + +const ModelVersionListView: React.FC = ({ + modelVersions: unfilteredModelVersions, + registeredModel: rm, + refresh, +}) => { + const navigate = useNavigate(); + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + + const [searchType, setSearchType] = React.useState(SearchType.KEYWORD); + const [search, setSearch] = React.useState(''); + + const searchTypes = [SearchType.KEYWORD, SearchType.AUTHOR]; + + const [isArchivedModelVersionKebabOpen, setIsArchivedModelVersionKebabOpen] = + React.useState(false); + + const filteredModelVersions = filterModelVersions(unfilteredModelVersions, search, searchType); + + if (unfilteredModelVersions.length === 0) { + return ( + ( + missing version + )} + description={`${rm?.name} has no registered versions. Register a version to this model.`} + primaryActionText="Register new version" + secondaryActionText="View archived versions" + primaryActionOnClick={() => { + navigate(registerVersionForModelUrl(rm?.id, preferredModelRegistry?.name)); + }} + secondaryActionOnClick={() => { + navigate(modelVersionArchiveUrl(rm?.id, preferredModelRegistry?.name)); + }} + /> + ); + } + + return ( + setSearch('')} + modelVersions={sortModelVersionsByCreateTime(filteredModelVersions)} + toolbarContent={ + + } breakpoint="xl"> + + setSearch('')} + deleteLabelGroup={() => setSearch('')} + categoryName={searchType} + > + ({ + key, + label: key, + }))} + value={searchType} + onChange={(newSearchType) => { + const enumMember = asEnumMember(newSearchType, SearchType); + if (enumMember !== null) { + setSearchType(enumMember); + } + }} + icon={} + /> + + + { + setSearch(searchValue); + }} + onClear={() => setSearch('')} + style={{ minWidth: '200px' }} + data-testid="model-versions-table-search" + /> + + + + + + + + setIsArchivedModelVersionKebabOpen(false)} + onOpenChange={(isOpen: boolean) => setIsArchivedModelVersionKebabOpen(isOpen)} + toggle={(tr: React.Ref) => ( + + setIsArchivedModelVersionKebabOpen(!isArchivedModelVersionKebabOpen) + } + isExpanded={isArchivedModelVersionKebabOpen} + aria-label="View archived versions" + > + + + )} + shouldFocusToggleOnSelect + > + + + navigate(modelVersionArchiveUrl(rm?.id, preferredModelRegistry?.name)) + } + > + View archived versions + + + + + + } + /> + ); +}; + +export default ModelVersionListView; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersions.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersions.tsx new file mode 100644 index 00000000..9e471aca --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersions.tsx @@ -0,0 +1,64 @@ +import React from 'react'; +import { useParams } from 'react-router'; +import { Breadcrumb, BreadcrumbItem, Truncate } from '@patternfly/react-core'; +import { Link } from 'react-router-dom'; +import { ModelVersionsTab } from '~/app/pages/modelRegistry/screens/ModelVersions/const'; +import ApplicationsPage from '~/app/components/ApplicationsPage'; +import useModelVersionsByRegisteredModel from '~/app/hooks/useModelVersionsByRegisteredModel'; +import useRegisteredModelById from '~/app/hooks/useRegisteredModelById'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import { filterLiveVersions } from '~/app/pages/modelRegistry/screens/utils'; +import ModelVersionsHeaderActions from '~/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsHeaderActions'; +import ModelVersionsTabs from './ModelVersionsTabs'; + +type ModelVersionsProps = { + tab: ModelVersionsTab; +} & Omit< + React.ComponentProps, + 'breadcrumb' | 'title' | 'description' | 'loadError' | 'loaded' | 'provideChildrenPadding' +>; + +const ModelVersions: React.FC = ({ tab, ...pageProps }) => { + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + const { registeredModelId: rmId } = useParams(); + const [modelVersions, mvLoaded, mvLoadError, mvRefresh] = useModelVersionsByRegisteredModel(rmId); + const [rm, rmLoaded, rmLoadError, rmRefresh] = useRegisteredModelById(rmId); + const loadError = mvLoadError || rmLoadError; + const loaded = mvLoaded && rmLoaded; + + return ( + + ( + Model registry - {preferredModelRegistry?.name} + )} + /> + + {rm?.name || 'Loading...'} + + + } + title={rm?.name} + headerAction={rm && } + description={} + loadError={loadError} + loaded={loaded} + provideChildrenPadding + > + {rm !== null && ( + + )} + + ); +}; + +export default ModelVersions; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsHeaderActions.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsHeaderActions.tsx new file mode 100644 index 00000000..1a383430 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsHeaderActions.tsx @@ -0,0 +1,86 @@ +import * as React from 'react'; +import { + Dropdown, + DropdownList, + MenuToggle, + DropdownItem, + Flex, + FlexItem, +} from '@patternfly/react-core'; +import { useNavigate } from 'react-router'; +import { ModelState, RegisteredModel } from '~/app/types'; +import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import { ArchiveRegisteredModelModal } from '~/app/pages/modelRegistry/screens/components/ArchiveRegisteredModelModal'; +import { registeredModelsUrl } from '~/app/pages/modelRegistry/screens/routeUtils'; + +interface ModelVersionsHeaderActionsProps { + rm: RegisteredModel; +} + +const ModelVersionsHeaderActions: React.FC = ({ rm }) => { + const { apiState } = React.useContext(ModelRegistryContext); + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + + const navigate = useNavigate(); + const [isOpen, setOpen] = React.useState(false); + const tooltipRef = React.useRef(null); + const [isArchiveModalOpen, setIsArchiveModalOpen] = React.useState(false); + + return ( + <> + + + setOpen(false)} + onOpenChange={(open) => setOpen(open)} + popperProps={{ position: 'end' }} + toggle={(toggleRef) => ( + setOpen(!isOpen)} + isExpanded={isOpen} + aria-label="Model version action toggle" + data-testid="model-version-action-toggle" + > + Actions + + )} + > + + setIsArchiveModalOpen(true)} + ref={tooltipRef} + > + Archive model + + + + + + setIsArchiveModalOpen(false)} + onSubmit={() => + apiState.api + .patchRegisteredModel( + {}, + { + state: ModelState.ARCHIVED, + }, + rm.id, + ) + .then(() => navigate(registeredModelsUrl(preferredModelRegistry?.name))) + } + isOpen={isArchiveModalOpen} + registeredModelName={rm.name} + /> + + ); +}; + +export default ModelVersionsHeaderActions; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTable.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTable.tsx new file mode 100644 index 00000000..999dc2aa --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTable.tsx @@ -0,0 +1,34 @@ +import * as React from 'react'; +import { Table } from '~/app/components/table'; +import { ModelVersion } from '~/app/types'; +import { mvColumns } from '~/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableColumns'; +import DashboardEmptyTableView from '~/app/components/DashboardEmptyTableView'; +import ModelVersionsTableRow from '~/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow'; + +type ModelVersionsTableProps = { + clearFilters: () => void; + modelVersions: ModelVersion[]; + refresh: () => void; +} & Partial, 'toolbarContent'>>; + +const ModelVersionsTable: React.FC = ({ + clearFilters, + modelVersions, + toolbarContent, + refresh, +}) => ( + } + rowRenderer={(mv) => ( + + )} + /> +); + +export default ModelVersionsTable; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableColumns.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableColumns.ts new file mode 100644 index 00000000..f98ea912 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableColumns.ts @@ -0,0 +1,40 @@ +import { SortableData } from '~/app/components/table'; +import { ModelVersion } from '~/app/types'; + +export const mvColumns: SortableData[] = [ + { + field: 'version name', + label: 'Version name', + sortable: (a, b) => a.name.localeCompare(b.name), + width: 40, + }, + { + field: 'last_modified', + label: 'Last modified', + sortable: (a: ModelVersion, b: ModelVersion): number => { + const first = parseInt(a.lastUpdateTimeSinceEpoch); + const second = parseInt(b.lastUpdateTimeSinceEpoch); + return new Date(second).getTime() - new Date(first).getTime(); + }, + }, + { + field: 'author', + label: 'Author', + sortable: (a: ModelVersion, b: ModelVersion): number => { + const first = a.author || ''; + const second = b.author || ''; + return first.localeCompare(second); + }, + }, + { + field: 'labels', + label: 'Labels', + sortable: false, + width: 35, + }, + { + field: 'kebab', + label: '', + sortable: false, + }, +]; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx new file mode 100644 index 00000000..55a4d85b --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx @@ -0,0 +1,128 @@ +import * as React from 'react'; +import { ActionsColumn, Td, Tr } from '@patternfly/react-table'; +import { Content, ContentVariants, Truncate, FlexItem } from '@patternfly/react-core'; +import { Link, useNavigate } from 'react-router-dom'; +import { ModelState, ModelVersion } from '~/app/types'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; +import { + modelVersionArchiveDetailsUrl, + modelVersionUrl, +} from '~/app/pages/modelRegistry/screens/routeUtils'; +import ModelTimestamp from '~/app/pages/modelRegistry/screens/components/ModelTimestamp'; +import ModelLabels from '~/app/pages/modelRegistry/screens/components/ModelLabels'; +import { ArchiveModelVersionModal } from '~/app/pages/modelRegistry/screens/components/ArchiveModelVersionModal'; +import { RestoreModelVersionModal } from '~/app/pages/modelRegistry/screens/components/RestoreModelVersionModal'; + +type ModelVersionsTableRowProps = { + modelVersion: ModelVersion; + isArchiveRow?: boolean; + refresh: () => void; +}; + +const ModelVersionsTableRow: React.FC = ({ + modelVersion: mv, + isArchiveRow, + refresh, +}) => { + const navigate = useNavigate(); + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + const [isArchiveModalOpen, setIsArchiveModalOpen] = React.useState(false); + const [isRestoreModalOpen, setIsRestoreModalOpen] = React.useState(false); + const { apiState } = React.useContext(ModelRegistryContext); + + const actions = isArchiveRow + ? [ + { + title: 'Restore version', + onClick: () => setIsRestoreModalOpen(true), + }, + ] + : [ + { + title: 'Deploy', + onClick: () => setIsDeployModalOpen(true), + }, + { + title: 'Archive model version', + onClick: () => setIsArchiveModalOpen(true), + }, + ]; + + return ( + + + + + + + + ); +}; + +export default ModelVersionsTableRow; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTabs.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTabs.tsx new file mode 100644 index 00000000..7460e663 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTabs.tsx @@ -0,0 +1,63 @@ +import * as React from 'react'; +import { useNavigate } from 'react-router-dom'; +import { PageSection, Tab, Tabs, TabTitleText } from '@patternfly/react-core'; +import ModelDetailsView from '~/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsView'; +import { ModelVersion, RegisteredModel } from '~/app/types'; +import { + ModelVersionsTab, + ModelVersionsTabTitle, +} from '~/app/pages/modelRegistry/screens/ModelVersions/const'; +import ModelVersionListView from '~/app/pages/modelRegistry/screens/ModelVersions/ModelVersionListView'; + +type ModelVersionsTabProps = { + tab: ModelVersionsTab; + registeredModel: RegisteredModel; + modelVersions: ModelVersion[]; + refresh: () => void; + mvRefresh: () => void; +}; + +const ModelVersionsTabs: React.FC = ({ + tab, + registeredModel: rm, + modelVersions, + refresh, + mvRefresh, +}) => { + const navigate = useNavigate(); + return ( + navigate(`../${eventKey}`, { relative: 'path' })} + > + {ModelVersionsTabTitle.VERSIONS}} + aria-label="Model versions tab" + data-testid="model-versions-tab" + > + + + + + {ModelVersionsTabTitle.DETAILS}} + aria-label="Model Details tab" + data-testid="model-details-tab" + > + + + + + + ); +}; +export default ModelVersionsTabs; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/const.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/const.ts new file mode 100644 index 00000000..42133ed9 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/const.ts @@ -0,0 +1,9 @@ +export enum ModelVersionsTab { + VERSIONS = 'versions', + DETAILS = 'details', +} + +export enum ModelVersionsTabTitle { + VERSIONS = 'Versions', + DETAILS = 'Details', +} diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModels/RegisteredModelTableRow.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModels/RegisteredModelTableRow.tsx index 881e0033..07e0c722 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModels/RegisteredModelTableRow.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModels/RegisteredModelTableRow.tsx @@ -13,6 +13,7 @@ import { registeredModelArchiveDetailsUrl, registeredModelUrl, } from '~/app/pages/modelRegistry/screens/routeUtils'; +import { ModelVersionsTab } from '~/app/pages/modelRegistry/screens/ModelVersions/const'; type RegisteredModelTableRowProps = { registeredModel: RegisteredModel; @@ -36,7 +37,7 @@ const RegisteredModelTableRow: React.FC = ({ { title: 'View details', // eslint-disable-next-line @typescript-eslint/no-empty-function - onClick: () => {}, // TODO: @Griffin-Sullivan uncomment this once model versions is active ---> navigate(`${rmUrl}/${ModelVersionsTab.DETAILS}`), + onClick: () => navigate(`${rmUrl}/${ModelVersionsTab.DETAILS}`), }, isArchiveRow ? { diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/__tests__/utils.spec.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/__tests__/utils.spec.ts new file mode 100644 index 00000000..d4a102c0 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/__tests__/utils.spec.ts @@ -0,0 +1,348 @@ +/* eslint-disable camelcase */ +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; +import { + ModelRegistryCustomProperties, + ModelRegistryMetadataType, + ModelRegistryStringCustomProperties, + ModelState, + ModelVersion, + RegisteredModel, +} from '~/app/types'; +import { + filterModelVersions, + filterRegisteredModels, + getLabels, + getProperties, + mergeUpdatedLabels, + mergeUpdatedProperty, + sortModelVersionsByCreateTime, +} from '~/app/pages/modelRegistry/screens/utils'; +import { SearchType } from '~/app/components/DashboardSearchField'; + +describe('getLabels', () => { + it('should return an empty array when customProperties is empty', () => { + const customProperties: ModelRegistryCustomProperties = {}; + const result = getLabels(customProperties); + expect(result).toEqual([]); + }); + + it('should return an array of keys with empty string values in customProperties', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { metadataType: ModelRegistryMetadataType.STRING, string_value: '' }, + label2: { metadataType: ModelRegistryMetadataType.STRING, string_value: 'non-empty' }, + label3: { metadataType: ModelRegistryMetadataType.STRING, string_value: '' }, + }; + const result = getLabels(customProperties); + expect(result).toEqual(['label1', 'label3']); + }); + + it('should return an empty array when all values in customProperties are non-empty strings', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { metadataType: ModelRegistryMetadataType.STRING, string_value: 'non-empty' }, + label2: { metadataType: ModelRegistryMetadataType.STRING, string_value: 'another-non-empty' }, + }; + const result = getLabels(customProperties); + expect(result).toEqual([]); + }); +}); + +describe('mergeUpdatedLabels', () => { + it('should return an empty object when customProperties and updatedLabels are empty', () => { + const customProperties: ModelRegistryCustomProperties = {}; + const result = mergeUpdatedLabels(customProperties, []); + expect(result).toEqual({}); + }); + + it('should return an unmodified object if updatedLabels match existing labels', () => { + const customProperties: ModelRegistryCustomProperties = { + someUnrelatedProp: { string_value: 'foo', metadataType: ModelRegistryMetadataType.STRING }, + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedLabels(customProperties, ['label1']); + expect(result).toEqual(customProperties); + }); + + it('should return an object with labels added', () => { + const customProperties: ModelRegistryCustomProperties = {}; + const result = mergeUpdatedLabels(customProperties, ['label1', 'label2']); + expect(result).toEqual({ + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label2: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); + + it('should return an object with labels removed', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label2: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label3: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label4: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedLabels(customProperties, ['label2', 'label4']); + expect(result).toEqual({ + label2: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label4: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); + + it('should return an object with labels both added and removed', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label2: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label3: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedLabels(customProperties, ['label1', 'label3', 'label4']); + expect(result).toEqual({ + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label3: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label4: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); + + it('should not affect non-label properties on the object', () => { + const customProperties: ModelRegistryCustomProperties = { + someUnrelatedStrProp: { string_value: 'foo', metadataType: ModelRegistryMetadataType.STRING }, + someUnrelatedIntProp: { int_value: '3', metadataType: ModelRegistryMetadataType.INT }, + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label2: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedLabels(customProperties, ['label2', 'label3']); + expect(result).toEqual({ + someUnrelatedStrProp: { string_value: 'foo', metadataType: ModelRegistryMetadataType.STRING }, + someUnrelatedIntProp: { int_value: '3', metadataType: ModelRegistryMetadataType.INT }, + label2: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + label3: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); +}); + +describe('getProperties', () => { + it('should return an empty object when customProperties is empty', () => { + const customProperties: ModelRegistryCustomProperties = {}; + const result = getProperties(customProperties); + expect(result).toEqual({}); + }); + + it('should return a filtered object including only string properties with a non-empty value', () => { + const customProperties: ModelRegistryCustomProperties = { + property1: { metadataType: ModelRegistryMetadataType.STRING, string_value: 'non-empty' }, + property2: { + metadataType: ModelRegistryMetadataType.STRING, + string_value: 'another-non-empty', + }, + label1: { metadataType: ModelRegistryMetadataType.STRING, string_value: '' }, + label2: { metadataType: ModelRegistryMetadataType.STRING, string_value: '' }, + int1: { metadataType: ModelRegistryMetadataType.INT, int_value: '1' }, + int2: { metadataType: ModelRegistryMetadataType.INT, int_value: '2' }, + }; + const result = getProperties(customProperties); + expect(result).toEqual({ + property1: { metadataType: ModelRegistryMetadataType.STRING, string_value: 'non-empty' }, + property2: { + metadataType: ModelRegistryMetadataType.STRING, + string_value: 'another-non-empty', + }, + } satisfies ModelRegistryStringCustomProperties); + }); + + it('should return an empty object when all values in customProperties are empty strings or non-string values', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { metadataType: ModelRegistryMetadataType.STRING, string_value: '' }, + label2: { metadataType: ModelRegistryMetadataType.STRING, string_value: '' }, + int1: { metadataType: ModelRegistryMetadataType.INT, int_value: '1' }, + int2: { metadataType: ModelRegistryMetadataType.INT, int_value: '2' }, + }; + const result = getProperties(customProperties); + expect(result).toEqual({}); + }); +}); + +describe('mergeUpdatedProperty', () => { + it('should handle the create operation', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedProperty({ + customProperties, + op: 'create', + newPair: { key: 'prop2', value: 'val2' }, + }); + expect(result).toEqual({ + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + prop2: { string_value: 'val2', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); + + it('should handle the update operation without a key change', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedProperty({ + customProperties, + op: 'update', + oldKey: 'prop1', + newPair: { key: 'prop1', value: 'updatedVal1' }, + }); + expect(result).toEqual({ + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'updatedVal1', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); + + it('should handle the update operation with a key change', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedProperty({ + customProperties, + op: 'update', + oldKey: 'prop1', + newPair: { key: 'prop2', value: 'val2' }, + }); + expect(result).toEqual({ + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop2: { string_value: 'val2', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); + + it('should perform a create if using the update operation with an invalid oldKey', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedProperty({ + customProperties, + op: 'update', + oldKey: 'prop2', + newPair: { key: 'prop3', value: 'val3' }, + }); + expect(result).toEqual({ + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + prop3: { string_value: 'val3', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); + + it('should handle the delete operation', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + prop2: { string_value: 'val2', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedProperty({ + customProperties, + op: 'delete', + oldKey: 'prop2', + }); + expect(result).toEqual({ + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); + + it('should do nothing if using the delete operation with an invalid oldKey', () => { + const customProperties: ModelRegistryCustomProperties = { + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + }; + const result = mergeUpdatedProperty({ + customProperties, + op: 'delete', + oldKey: 'prop2', + }); + expect(result).toEqual({ + label1: { string_value: '', metadataType: ModelRegistryMetadataType.STRING }, + prop1: { string_value: 'val1', metadataType: ModelRegistryMetadataType.STRING }, + } satisfies ModelRegistryCustomProperties); + }); +}); + +describe('filterModelVersions', () => { + const modelVersions: ModelVersion[] = [ + mockModelVersion({ name: 'Test 1', state: ModelState.ARCHIVED }), + mockModelVersion({ + name: 'Test 2', + description: 'Description2', + }), + mockModelVersion({ name: 'Test 3', author: 'Author3', state: ModelState.ARCHIVED }), + mockModelVersion({ name: 'Test 4', state: ModelState.ARCHIVED }), + mockModelVersion({ name: 'Test 5' }), + ]; + + test('filters by name', () => { + const filtered = filterModelVersions(modelVersions, 'Test 1', SearchType.KEYWORD); + expect(filtered).toEqual([modelVersions[0]]); + }); + + test('filters by description', () => { + const filtered = filterModelVersions(modelVersions, 'Description2', SearchType.KEYWORD); + expect(filtered).toEqual([modelVersions[1]]); + }); + + test('filters by author', () => { + const filtered = filterModelVersions(modelVersions, 'Author3', SearchType.AUTHOR); + expect(filtered).toEqual([modelVersions[2]]); + }); + + test('does not filter when search is empty', () => { + const filtered = filterModelVersions(modelVersions, '', SearchType.KEYWORD); + expect(filtered).toEqual(modelVersions); + }); +}); + +describe('filterRegisteredModels', () => { + const registeredModels: RegisteredModel[] = [ + mockRegisteredModel({ name: 'Test 1', state: ModelState.ARCHIVED }), + mockRegisteredModel({ + name: 'Test 2', + description: 'Description2', + }), + mockRegisteredModel({ name: 'Test 3', state: ModelState.ARCHIVED }), + mockRegisteredModel({ name: 'Test 4', state: ModelState.ARCHIVED }), + mockRegisteredModel({ name: 'Test 5' }), + ]; + + test('filters by name', () => { + const filtered = filterRegisteredModels(registeredModels, 'Test 1', SearchType.KEYWORD); + expect(filtered).toEqual([registeredModels[0]]); + }); + + test('filters by description', () => { + const filtered = filterRegisteredModels(registeredModels, 'Description2', SearchType.KEYWORD); + expect(filtered).toEqual([registeredModels[1]]); + }); + + test('does not filter when search is empty', () => { + const filtered = filterRegisteredModels(registeredModels, '', SearchType.KEYWORD); + expect(filtered).toEqual(registeredModels); + }); +}); + +describe('sortModelVersionsByCreateTime', () => { + it('should return list of sorted modelVersions by create time', () => { + const modelVersions: ModelVersion[] = [ + mockModelVersion({ + name: 'model version 1', + author: 'Author 1', + id: '1', + createTimeSinceEpoch: '1725018764650', + lastUpdateTimeSinceEpoch: '1725030215299', + }), + mockModelVersion({ + name: 'model version 1', + author: 'Author 1', + id: '1', + createTimeSinceEpoch: '1725028468207', + lastUpdateTimeSinceEpoch: '1725030142332', + }), + ]; + + const result = sortModelVersionsByCreateTime(modelVersions); + expect(result).toEqual([modelVersions[1], modelVersions[0]]); + }); +}); diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/components/ArchiveModelVersionModal.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/components/ArchiveModelVersionModal.tsx new file mode 100644 index 00000000..3260639a --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/components/ArchiveModelVersionModal.tsx @@ -0,0 +1,90 @@ +import * as React from 'react'; +import { Flex, FlexItem, Stack, StackItem, TextInput } from '@patternfly/react-core'; +import { Modal } from '@patternfly/react-core/deprecated'; +import DashboardModalFooter from '~/app/components/DashboardModalFooter'; + +interface ArchiveModelVersionModalProps { + onCancel: () => void; + onSubmit: () => void; + isOpen: boolean; + modelVersionName: string; +} + +export const ArchiveModelVersionModal: React.FC = ({ + onCancel, + onSubmit, + isOpen, + modelVersionName, +}) => { + const [isSubmitting, setIsSubmitting] = React.useState(false); + const [error, setError] = React.useState(); + const [confirmInputValue, setConfirmInputValue] = React.useState(''); + const isDisabled = confirmInputValue.trim() !== modelVersionName || isSubmitting; + + const onClose = React.useCallback(() => { + setConfirmInputValue(''); + onCancel(); + }, [onCancel]); + + const onConfirm = React.useCallback(async () => { + setIsSubmitting(true); + + try { + await onSubmit(); + onClose(); + } catch (e) { + if (e instanceof Error) { + setError(e); + } + } finally { + setIsSubmitting(false); + } + }, [onSubmit, onClose]); + + return ( + + } + data-testid="archive-model-version-modal" + > + + + {modelVersionName} will be archived and unavailable for use unless it is restored. + + + + + Type {modelVersionName} to confirm archiving: + + setConfirmInputValue(newValue)} + onKeyDown={(event) => { + if (event.key === 'Enter' && !isDisabled) { + onConfirm(); + } + }} + /> + + + + + ); +}; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/components/RestoreModelVersionModal.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/components/RestoreModelVersionModal.tsx new file mode 100644 index 00000000..b9c57d85 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/components/RestoreModelVersionModal.tsx @@ -0,0 +1,62 @@ +import * as React from 'react'; +import { Modal } from '@patternfly/react-core/deprecated'; +import DashboardModalFooter from '~/app/components/DashboardModalFooter'; + +interface RestoreModelVersionModalProps { + onCancel: () => void; + onSubmit: () => void; + isOpen: boolean; + modelVersionName: string; +} + +export const RestoreModelVersionModal: React.FC = ({ + onCancel, + onSubmit, + isOpen, + modelVersionName, +}) => { + const [isSubmitting, setIsSubmitting] = React.useState(false); + const [error, setError] = React.useState(); + + const onClose = React.useCallback(() => { + onCancel(); + }, [onCancel]); + + const onConfirm = React.useCallback(async () => { + setIsSubmitting(true); + + try { + await onSubmit(); + onClose(); + } catch (e) { + if (e instanceof Error) { + setError(e); + } + } finally { + setIsSubmitting(false); + } + }, [onSubmit, onClose]); + + return ( + + } + data-testid="restore-model-version-modal" + > + {modelVersionName} will be restored and returned to the versions list. + + ); +}; diff --git a/clients/ui/frontend/src/components/DashboardDescriptionListGroup.scss b/clients/ui/frontend/src/components/DashboardDescriptionListGroup.scss new file mode 100644 index 00000000..aed757d5 --- /dev/null +++ b/clients/ui/frontend/src/components/DashboardDescriptionListGroup.scss @@ -0,0 +1,4 @@ +.kubeflow-custom-description-list-term-with-action > span { + /* Workaround for missing functionality in PF DescriptionList, see https://github.com/patternfly/patternfly/issues/6583 */ + width: 100%; +} \ No newline at end of file diff --git a/clients/ui/frontend/src/components/DashboardDescriptionListGroup.tsx b/clients/ui/frontend/src/components/DashboardDescriptionListGroup.tsx new file mode 100644 index 00000000..4216c86b --- /dev/null +++ b/clients/ui/frontend/src/components/DashboardDescriptionListGroup.tsx @@ -0,0 +1,120 @@ +import * as React from 'react'; +import { + ActionList, + ActionListItem, + Button, + DescriptionListDescription, + DescriptionListGroup, + DescriptionListTerm, + Flex, + FlexItem, + Split, + SplitItem, +} from '@patternfly/react-core'; +import text from '@patternfly/react-styles/css/utilities/Text/text'; +import { CheckIcon, PencilAltIcon, TimesIcon } from '@patternfly/react-icons'; + +import '~/components/DashboardDescriptionListGroup.scss'; + +type EditableProps = { + isEditing: boolean; + contentWhenEditing: React.ReactNode; + isSavingEdits?: boolean; + onEditClick: () => void; + onSaveEditsClick: () => void; + onDiscardEditsClick: () => void; +}; + +export type DashboardDescriptionListGroupProps = { + title: React.ReactNode; + tooltip?: React.ReactNode; + action?: React.ReactNode; + isEmpty?: boolean; + contentWhenEmpty?: React.ReactNode; + children: React.ReactNode; +} & (({ isEditable: true } & EditableProps) | ({ isEditable?: false } & Partial)); + +const DashboardDescriptionListGroup: React.FC = (props) => { + const { + title, + tooltip, + action, + isEmpty, + contentWhenEmpty, + isEditable = false, + isEditing, + contentWhenEditing, + isSavingEdits = false, + onEditClick, + onSaveEditsClick, + onDiscardEditsClick, + children, + } = props; + return ( + + {action || isEditable ? ( + + + {title} + + {action || + (isEditing ? ( + + + + + + + + + ) : ( + + ))} + + + + ) : ( + + + {title} + {tooltip} + + + )} + + {isEditing ? contentWhenEditing : isEmpty ? contentWhenEmpty : children} + + + ); +}; + +export default DashboardDescriptionListGroup; diff --git a/clients/ui/frontend/src/components/EditableLabelsDescriptionListGroup.tsx b/clients/ui/frontend/src/components/EditableLabelsDescriptionListGroup.tsx new file mode 100644 index 00000000..42cc9091 --- /dev/null +++ b/clients/ui/frontend/src/components/EditableLabelsDescriptionListGroup.tsx @@ -0,0 +1,218 @@ +import * as React from 'react'; +import { + Button, + Form, + FormGroup, + FormHelperText, + HelperText, + HelperTextItem, + Label, + LabelGroup, + TextInput, +} from '@patternfly/react-core'; +import { Modal } from '@patternfly/react-core/deprecated'; +import { ExclamationCircleIcon } from '@patternfly/react-icons'; +import DashboardDescriptionListGroup, { + DashboardDescriptionListGroupProps, +} from '~/components/DashboardDescriptionListGroup'; + +type EditableTextDescriptionListGroupProps = Partial< + Pick +> & { + labels: string[]; + saveEditedLabels: (labels: string[]) => Promise; + allExistingKeys?: string[]; +}; + +const EditableLabelsDescriptionListGroup: React.FC = ({ + title = 'Labels', + contentWhenEmpty = 'No labels', + labels, + saveEditedLabels, + allExistingKeys = labels, +}) => { + const [isEditing, setIsEditing] = React.useState(false); + const [unsavedLabels, setUnsavedLabels] = React.useState(labels); + const [isSavingEdits, setIsSavingEdits] = React.useState(false); + + const editUnsavedLabel = (newText: string, index: number) => { + if (isSavingEdits) { + return; + } + const copy = [...unsavedLabels]; + copy[index] = newText; + setUnsavedLabels(copy); + }; + const removeUnsavedLabel = (text: string) => { + if (isSavingEdits) { + return; + } + setUnsavedLabels(unsavedLabels.filter((label) => label !== text)); + }; + const addUnsavedLabel = (text: string) => { + if (isSavingEdits) { + return; + } + setUnsavedLabels([...unsavedLabels, text]); + }; + + // Don't allow a label that matches a non-label property key or another label (as they stand before saving) + // Note that this means if you remove a label and add it back before saving, that is valid + const reservedKeys = [ + ...allExistingKeys.filter((key) => !labels.includes(key)), + ...unsavedLabels, + ]; + + const [isAddLabelModalOpen, setIsAddLabelModalOpen] = React.useState(false); + const [addLabelInputValue, setAddLabelInputValue] = React.useState(''); + const addLabelInputRef = React.useRef(null); + let addLabelValidationError: string | null = null; + if (reservedKeys.includes(addLabelInputValue)) { + addLabelValidationError = 'Label must not match an existing label or property key'; + } else if (addLabelInputValue.length > 63) { + addLabelValidationError = "Label text can't exceed 63 characters"; + } + + const toggleAddLabelModal = () => { + setAddLabelInputValue(''); + setIsAddLabelModalOpen(!isAddLabelModalOpen); + }; + React.useEffect(() => { + if (isAddLabelModalOpen && addLabelInputRef.current) { + addLabelInputRef.current.focus(); + } + }, [isAddLabelModalOpen]); + + const addLabelModalSubmitDisabled = !addLabelInputValue || !!addLabelValidationError; + const submitAddLabelModal = (event?: React.FormEvent) => { + event?.preventDefault(); + if (!addLabelModalSubmitDisabled) { + addUnsavedLabel(addLabelInputValue); + toggleAddLabelModal(); + } + }; + + return ( + <> + + Add label + + ) + } + > + {unsavedLabels.map((label, index) => ( + + ))} + + } + onEditClick={() => { + setUnsavedLabels(labels); + setIsEditing(true); + }} + onSaveEditsClick={async () => { + setIsSavingEdits(true); + try { + await saveEditedLabels(unsavedLabels); + } finally { + setIsSavingEdits(false); + } + setIsEditing(false); + }} + onDiscardEditsClick={() => { + setUnsavedLabels(labels); + setIsEditing(false); + }} + > + + {labels.map((label) => ( + + ))} + + + + Save + , + , + ]} + > +
+ + , value: string) => + setAddLabelInputValue(value) + } + ref={addLabelInputRef} + isRequired + validated={addLabelValidationError ? 'error' : 'default'} + /> + {addLabelValidationError && ( + + + } variant="error"> + {addLabelValidationError} + + + + )} + + +
+ + ); +}; + +export default EditableLabelsDescriptionListGroup; diff --git a/clients/ui/frontend/src/components/EditableTextDescriptionListGroup.tsx b/clients/ui/frontend/src/components/EditableTextDescriptionListGroup.tsx new file mode 100644 index 00000000..1ad501ea --- /dev/null +++ b/clients/ui/frontend/src/components/EditableTextDescriptionListGroup.tsx @@ -0,0 +1,78 @@ +import * as React from 'react'; +import { ExpandableSection, TextArea } from '@patternfly/react-core'; +import DashboardDescriptionListGroup, { + DashboardDescriptionListGroupProps, +} from '~/components/DashboardDescriptionListGroup'; + +type EditableTextDescriptionListGroupProps = Pick< + DashboardDescriptionListGroupProps, + 'title' | 'contentWhenEmpty' +> & { + value: string; + saveEditedValue: (value: string) => Promise; + testid?: string; +}; + +const EditableTextDescriptionListGroup: React.FC = ({ + title, + contentWhenEmpty, + value, + saveEditedValue, + testid, +}) => { + const [isEditing, setIsEditing] = React.useState(false); + const [unsavedValue, setUnsavedValue] = React.useState(value); + const [isSavingEdits, setIsSavingEdits] = React.useState(false); + const [isTextExpanded, setIsTextExpanded] = React.useState(false); + return ( + setUnsavedValue(v)} + isDisabled={isSavingEdits} + rows={24} + resizeOrientation="vertical" + /> + } + onEditClick={() => { + setUnsavedValue(value); + setIsEditing(true); + }} + onSaveEditsClick={async () => { + setIsSavingEdits(true); + try { + await saveEditedValue(unsavedValue); + } finally { + setIsSavingEdits(false); + } + setIsEditing(false); + }} + onDiscardEditsClick={() => { + setUnsavedValue(value); + setIsEditing(false); + }} + > + setIsTextExpanded(isExpanded)} + isExpanded={isTextExpanded} + > + {value} + + + ); +}; + +export default EditableTextDescriptionListGroup; diff --git a/clients/ui/frontend/tsconfig.json b/clients/ui/frontend/tsconfig.json index 1ea9486c..674889be 100644 --- a/clients/ui/frontend/tsconfig.json +++ b/clients/ui/frontend/tsconfig.json @@ -6,6 +6,7 @@ "module": "esnext", "target": "es5", "lib": [ + "ESNext.Array", "es6", "dom" ], From 89fdaff532b50f4266d654a7f88c77efe2f28851 Mon Sep 17 00:00:00 2001 From: Lucas Fernandez Date: Wed, 25 Sep 2024 12:52:34 +0200 Subject: [PATCH 07/13] Add Model Registry Settings view (#423) Signed-off-by: lucferbux --- clients/ui/bff/internal/api/helpers.go | 2 +- .../ui/bff/internal/data/model_registry.go | 2 +- .../cypress/pages/modelRegistrySettings.ts | 116 ++++++++++++++++++ .../{ => modelRegistry}/modelRegistry.cy.ts | 0 .../modelRegistrySettings.cy.ts | 49 ++++++++ clients/ui/frontend/src/app/AppRoutes.tsx | 6 +- .../pages/settings/ModelRegistriesTable.tsx | 22 ++++ .../settings/ModelRegistriesTableRow.tsx | 22 ++++ .../pages/settings/ModelRegistrySettings.tsx | 17 ++- .../src/app/pages/settings/columns.ts | 23 ++++ 10 files changed, 251 insertions(+), 8 deletions(-) create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistrySettings.ts rename clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/{ => modelRegistry}/modelRegistry.cy.ts (100%) create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistrySettings/modelRegistrySettings.cy.ts create mode 100644 clients/ui/frontend/src/app/pages/settings/ModelRegistriesTable.tsx create mode 100644 clients/ui/frontend/src/app/pages/settings/ModelRegistriesTableRow.tsx create mode 100644 clients/ui/frontend/src/app/pages/settings/columns.ts diff --git a/clients/ui/bff/internal/api/helpers.go b/clients/ui/bff/internal/api/helpers.go index 6cb6329b..4f0fb58b 100644 --- a/clients/ui/bff/internal/api/helpers.go +++ b/clients/ui/bff/internal/api/helpers.go @@ -10,7 +10,7 @@ import ( ) type Envelope[D any, M any] struct { - Data D `json:"data,omitempty"` + Data D `json:"data"` Metadata M `json:"metadata,omitempty"` } diff --git a/clients/ui/bff/internal/data/model_registry.go b/clients/ui/bff/internal/data/model_registry.go index e3102883..e27c4790 100644 --- a/clients/ui/bff/internal/data/model_registry.go +++ b/clients/ui/bff/internal/data/model_registry.go @@ -18,7 +18,7 @@ func (m ModelRegistryModel) FetchAllModelRegistries(client k8s.KubernetesClientI return nil, fmt.Errorf("error fetching model registries: %w", err) } - var registries []ModelRegistryModel + var registries []ModelRegistryModel = []ModelRegistryModel{} for _, item := range resources { registry := ModelRegistryModel{ Name: item.Name, diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistrySettings.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistrySettings.ts new file mode 100644 index 00000000..1bac8692 --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistrySettings.ts @@ -0,0 +1,116 @@ +import { appChrome } from './appChrome'; + +export enum FormFieldSelector { + NAME = '#mr-name', + RESOURCENAME = '#resource-mr-name', + HOST = '#mr-host', + PORT = '#mr-port', + USERNAME = '#mr-username', + PASSWORD = '#mr-password', + DATABASE = '#mr-database', +} + +export enum FormErrorTestId { + HOST = 'mr-host-error', + PORT = 'mr-port-error', + USERNAME = 'mr-username-error', + PASSWORD = 'mr-password-error', + DATABASE = 'mr-database-error', +} + +export enum DatabaseDetailsTestId { + HOST = 'mr-db-host', + PORT = 'mr-db-port', + USERNAME = 'mr-db-username', + PASSWORD = 'mr-db-password', + DATABASE = 'mr-db-database', +} + +class ModelRegistrySettings { + visit(wait = true) { + cy.visit('/modelRegistrySettings'); + if (wait) { + this.wait(); + } + } + + navigate() { + this.findNavItem().click(); + this.wait(); + } + + private wait() { + this.findHeading(); + cy.testA11y(); + } + + private findHeading() { + cy.findByTestId('app-page-title').should('exist'); + cy.findByTestId('app-page-title').contains('Model Registry Settings'); + } + + findNavItem() { + return appChrome.findNavItem('Model registry settings', 'Settings'); + } + + findEmptyState() { + return cy.findByTestId('mr-settings-empty-state'); + } + + // findCreateButton() { + // return cy.findByText('Create model registry'); + // } + + findFormField(selector: FormFieldSelector) { + return cy.get(selector); + } + + clearFormFields() { + Object.values(FormFieldSelector).forEach((selector) => { + this.findFormField(selector).clear(); + this.findFormField(selector).blur(); + }); + } + + findFormError(testId: FormErrorTestId) { + return cy.findByTestId(testId); + } + + shouldHaveAllErrors() { + Object.values(FormErrorTestId).forEach((testId) => this.findFormError(testId).should('exist')); + } + + shouldHaveNoErrors() { + Object.values(FormErrorTestId).forEach((testId) => + this.findFormError(testId).should('not.exist'), + ); + } + + findSubmitButton() { + return cy.findByTestId('modal-submit-button'); + } + + findTable() { + return cy.findByTestId('model-registries-table'); + } + + findModelRegistryRow(registryName: string) { + return this.findTable().findByText(registryName).closest('tr'); + } + + findDatabaseDetail(testId: DatabaseDetailsTestId) { + return cy.findByTestId(testId); + } + + findDatabasePasswordHiddenButton() { + return this.findDatabaseDetail(DatabaseDetailsTestId.PASSWORD).findByTestId( + 'password-hidden-button', + ); + } + + findConfirmDeleteNameInput() { + return cy.findByTestId('confirm-delete-input'); + } +} + +export const modelRegistrySettings = new ModelRegistrySettings(); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts similarity index 100% rename from clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry.cy.ts rename to clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistrySettings/modelRegistrySettings.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistrySettings/modelRegistrySettings.cy.ts new file mode 100644 index 00000000..22ed09af --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistrySettings/modelRegistrySettings.cy.ts @@ -0,0 +1,49 @@ +import { mockModelRegistry } from '~/__mocks__/mockModelRegistry'; +import type { ModelRegistry } from '~/app/types'; +import { mockBFFResponse } from '~/__mocks__/mockBFFResponse'; +import { modelRegistrySettings } from '~/__tests__/cypress/cypress/pages/modelRegistrySettings'; + +type HandlersProps = { + modelRegistries?: ModelRegistry[]; +}; + +const MODEL_REGISTRY_API_VERSION = 'v1'; + +const initIntercepts = ({ + modelRegistries = [ + mockModelRegistry({ + name: 'modelregistry-sample', + description: 'New model registry', + displayName: 'Model Registry Sample', + }), + mockModelRegistry({ + name: 'modelregistry-sample-2', + description: 'New model registry 2', + displayName: 'Model Registry Sample 2', + }), + ], +}: HandlersProps) => { + cy.interceptApi( + `GET /api/:apiVersion/model_registry`, + { + path: { apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockBFFResponse(modelRegistries), + ); +}; + +it('Shows empty state when there are no registries', () => { + initIntercepts({ modelRegistries: [] }); + modelRegistrySettings.visit(true); + modelRegistrySettings.findEmptyState().should('exist'); +}); + +describe('ModelRegistriesTable', () => { + it('Shows table when there are registries', () => { + initIntercepts({}); + modelRegistrySettings.visit(true); + modelRegistrySettings.findEmptyState().should('not.exist'); + modelRegistrySettings.findTable().should('exist'); + modelRegistrySettings.findModelRegistryRow('Model Registry Sample').should('exist'); + }); +}); diff --git a/clients/ui/frontend/src/app/AppRoutes.tsx b/clients/ui/frontend/src/app/AppRoutes.tsx index c2383da6..2852cc23 100644 --- a/clients/ui/frontend/src/app/AppRoutes.tsx +++ b/clients/ui/frontend/src/app/AppRoutes.tsx @@ -34,7 +34,7 @@ export const useAdminSettings = (): NavDataItem[] => { return [ { label: 'Settings', - children: [{ label: 'Model Registry', path: '/settings' }], + children: [{ label: 'Model Registry', path: '/modelRegistrySettings' }], }, ]; }; @@ -58,7 +58,9 @@ const AppRoutes: React.FC = () => { { // TODO: Remove the linter skip when we implement authentication // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition - isAdmin && } /> + isAdmin && ( + } /> + ) } ); diff --git a/clients/ui/frontend/src/app/pages/settings/ModelRegistriesTable.tsx b/clients/ui/frontend/src/app/pages/settings/ModelRegistriesTable.tsx new file mode 100644 index 00000000..04aeb608 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/settings/ModelRegistriesTable.tsx @@ -0,0 +1,22 @@ +import React from 'react'; +import { ModelRegistry } from '~/app/types'; +import { Table } from '~/app/components/table'; +import { modelRegistryColumns } from './columns'; +import ModelRegistriesTableRow from './ModelRegistriesTableRow'; + +type ModelRegistriesTableProps = { + modelRegistries: ModelRegistry[]; +}; + +const ModelRegistriesTable: React.FC = ({ modelRegistries }) => ( + // TODO: Add toolbar once we manage permissions +
+
+ + + + + +
+ {mv.description && ( + + + + )} +
+ + {mv.author} + + + + setIsArchiveModalOpen(false)} + onSubmit={() => + apiState.api + .patchModelVersion( + {}, + { + state: ModelState.ARCHIVED, + }, + mv.id, + ) + .then(refresh) + } + isOpen={isArchiveModalOpen} + modelVersionName={mv.name} + /> + setIsRestoreModalOpen(false)} + onSubmit={() => + apiState.api + .patchModelVersion( + {}, + { + state: ModelState.LIVE, + }, + mv.id, + ) + .then(() => + navigate( + modelVersionUrl(mv.id, mv.registeredModelId, preferredModelRegistry?.name), + ), + ) + } + isOpen={isRestoreModalOpen} + modelVersionName={mv.name} + /> +
} + variant="compact" + /> +); + +export default ModelRegistriesTable; diff --git a/clients/ui/frontend/src/app/pages/settings/ModelRegistriesTableRow.tsx b/clients/ui/frontend/src/app/pages/settings/ModelRegistriesTableRow.tsx new file mode 100644 index 00000000..2d599612 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/settings/ModelRegistriesTableRow.tsx @@ -0,0 +1,22 @@ +import React from 'react'; +import { Td, Tr } from '@patternfly/react-table'; +import { ModelRegistry } from '~/app/types'; + +type ModelRegistriesTableRowProps = { + modelRegistry: ModelRegistry; +}; + +const ModelRegistriesTableRow: React.FC = ({ modelRegistry: mr }) => ( + <> + + + + +); + +// TODO: Get rest of columns once we manage permissions + +export default ModelRegistriesTableRow; diff --git a/clients/ui/frontend/src/app/pages/settings/ModelRegistrySettings.tsx b/clients/ui/frontend/src/app/pages/settings/ModelRegistrySettings.tsx index f48f0bdf..9b3067c5 100644 --- a/clients/ui/frontend/src/app/pages/settings/ModelRegistrySettings.tsx +++ b/clients/ui/frontend/src/app/pages/settings/ModelRegistrySettings.tsx @@ -2,14 +2,23 @@ import React from 'react'; import { EmptyState, EmptyStateBody, EmptyStateVariant } from '@patternfly/react-core'; import { PlusCircleIcon } from '@patternfly/react-icons'; import ApplicationsPage from '~/app/components/ApplicationsPage'; +import useModelRegistries from '~/app/hooks/useModelRegistries'; +import TitleWithIcon from '~/app/components/design/TitleWithIcon'; +import { ProjectObjectType } from '~/app/components/design/utils'; +import ModelRegistriesTable from './ModelRegistriesTable'; const ModelRegistrySettings: React.FC = () => { - const [modelRegistries, loaded, loadError] = [[], true, undefined]; // TODO: change to real values + const [modelRegistries, loaded, loadError] = useModelRegistries(); return ( <> + } + description="List all the model registries deployed in your environment." loaded={loaded} loadError={loadError} errorMessage="Unable to load model registries." @@ -27,7 +36,7 @@ const ModelRegistrySettings: React.FC = () => { } provideChildrenPadding > - TODO: Add model registry settings + ); diff --git a/clients/ui/frontend/src/app/pages/settings/columns.ts b/clients/ui/frontend/src/app/pages/settings/columns.ts new file mode 100644 index 00000000..eb50403d --- /dev/null +++ b/clients/ui/frontend/src/app/pages/settings/columns.ts @@ -0,0 +1,23 @@ +import { SortableData } from '~/app/components/table'; +import { ModelRegistry } from '~/app/types'; + +export const modelRegistryColumns: SortableData[] = [ + { + field: 'model regisry name', + label: 'Model registry name', + sortable: (a, b) => a.name.localeCompare(b.name), + width: 30, + }, + // TODO: Add once we manage permissions + // { + // field: 'status', + // label: 'Status', + // sortable: false, + // }, + // { + // field: 'manage permissions', + // label: '', + // sortable: false, + // }, + // kebabTableColumn(), +]; From 3110d1bdc1d5ad086554866694f2ad13fb910773 Mon Sep 17 00:00:00 2001 From: Alessio Pragliola <83355398+Al-Pragliola@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:44:34 +0200 Subject: [PATCH 08/13] fix(csi): prevent race condition in ci tests step (#426) * fix(csi): prevent race condition in ci tests step Signed-off-by: Alessio Pragliola * chore(csi): reword repeat_cmd_until fail msg Signed-off-by: Alessio Pragliola * refactor(csi): move wait_for_port in test_utils Signed-off-by: Alessio Pragliola --------- Signed-off-by: Alessio Pragliola --- csi/test/e2e_test.sh | 12 +++--------- csi/test/test_utils.sh | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 csi/test/test_utils.sh diff --git a/csi/test/e2e_test.sh b/csi/test/e2e_test.sh index bcdc47c5..87ec06a4 100755 --- a/csi/test/e2e_test.sh +++ b/csi/test/e2e_test.sh @@ -6,16 +6,10 @@ set -o xtrace # This test assumes there is a Kubernetes environment up and running. # It could be either a remote one or a local one (e.g., using KinD or minikube). -# Function to check if the port is ready -wait_for_port() { - local port=$1 - while ! nc -z localhost $port; do - sleep 0.1 - done -} - DIR="$(dirname "$0")" +source ./${DIR}/test_utils.sh + KUBECTL=${KUBECTL:-"kubectl"} # You can provide a local version of the model registry storage initializer @@ -145,7 +139,7 @@ spec: EOF # wait for pod predictor to be initialized -sleep 2 +repeat_cmd_until "kubectl get pod -n $KSERVE_TEST_NAMESPACE --selector='component=predictor' | wc -l" "-gt 0" 60 predictor=$(kubectl get pod -n $KSERVE_TEST_NAMESPACE --selector="component=predictor" --output jsonpath='{.items[0].metadata.name}') kubectl wait --for=condition=Ready pod/$predictor -n $KSERVE_TEST_NAMESPACE --timeout=5m diff --git a/csi/test/test_utils.sh b/csi/test/test_utils.sh new file mode 100644 index 00000000..e0a4e4c0 --- /dev/null +++ b/csi/test/test_utils.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +set -e +set -o xtrace + +# Function to check if the port is ready +wait_for_port() { + local port=$1 + while ! nc -z localhost $port; do + sleep 0.1 + done +} + +repeat_cmd_until() { + local cmd=$1 + local condition=$2 + local max_wait_secs=$3 + local interval_secs=2 + local start_time=$(date +%s) + local output + + while true; do + + current_time=$(date +%s) + if (( (current_time - start_time) > max_wait_secs )); then + echo "Waited for expression "$1" to satisfy condition "$2" for $max_wait_secs seconds without luck. Returning with error." + return 1 + fi + + output=$(eval $cmd) + + if [ $output $condition ]; then + break + else + sleep $interval_secs + fi + done +} From 0d77878a5338320ab1710b3f44a51f768c0e66c0 Mon Sep 17 00:00:00 2001 From: Isabella Basso Date: Wed, 25 Sep 2024 12:04:34 -0300 Subject: [PATCH 09/13] OAS: add generic artifacts routes (#406) * OAS: add generic artifacts routes Signed-off-by: Isabella do Amaral * OAS: make discriminator explicit as model prop Signed-off-by: Isabella do Amaral * update existing artifact with custom helper Signed-off-by: Isabella do Amaral --------- Signed-off-by: Isabella do Amaral --- api/openapi/model-registry.yaml | 209 ++- clients/python/src/.openapi-generator/FILES | 4 + .../src/model_registry/types/artifacts.py | 2 + clients/python/src/mr_openapi/README.md | 20 +- clients/python/src/mr_openapi/__init__.py | 4 + .../api/model_registry_service_api.py | 1504 +++++++++++++++-- .../python/src/mr_openapi/models/__init__.py | 4 + .../src/mr_openapi/models/artifact_create.py | 174 ++ .../src/mr_openapi/models/artifact_update.py | 174 ++ .../src/mr_openapi/models/doc_artifact.py | 2 + .../mr_openapi/models/doc_artifact_create.py | 128 ++ .../mr_openapi/models/doc_artifact_update.py | 114 ++ .../src/mr_openapi/models/model_artifact.py | 2 + .../models/model_artifact_create.py | 3 + .../models/model_artifact_update.py | 3 +- .../generated/openapi_converter.gen.go | 119 ++ internal/converter/openapi_converter.go | 20 + internal/converter/openapi_converter_test.go | 5 +- internal/converter/openapi_converter_util.go | 3 +- internal/converter/openapi_reconciler_util.go | 23 + internal/server/openapi/api.go | 5 + .../openapi/api_model_registry_service.go | 126 ++ .../api_model_registry_service_service.go | 76 + internal/server/openapi/type_asserts.go | 76 + pkg/api/api.go | 2 + pkg/core/artifact.go | 45 + pkg/core/artifact_test.go | 53 + pkg/openapi/.openapi-generator/FILES | 4 + pkg/openapi/api_model_registry_service.go | 1039 ++++++++++-- pkg/openapi/model_artifact_create.go | 163 ++ pkg/openapi/model_artifact_update.go | 163 ++ pkg/openapi/model_doc_artifact_create.go | 341 ++++ pkg/openapi/model_doc_artifact_update.go | 304 ++++ pkg/openapi/model_model_artifact_create.go | 30 +- pkg/openapi/model_model_artifact_update.go | 32 +- pkg/openapi/model_registered_model.go | 2 +- pkg/openapi/model_registered_model_create.go | 2 +- 37 files changed, 4719 insertions(+), 261 deletions(-) create mode 100644 clients/python/src/mr_openapi/models/artifact_create.py create mode 100644 clients/python/src/mr_openapi/models/artifact_update.py create mode 100644 clients/python/src/mr_openapi/models/doc_artifact_create.py create mode 100644 clients/python/src/mr_openapi/models/doc_artifact_update.py create mode 100644 internal/converter/openapi_reconciler_util.go create mode 100644 pkg/openapi/model_artifact_create.go create mode 100644 pkg/openapi/model_artifact_update.go create mode 100644 pkg/openapi/model_doc_artifact_create.go create mode 100644 pkg/openapi/model_doc_artifact_update.go diff --git a/api/openapi/model-registry.yaml b/api/openapi/model-registry.yaml index 392cf116..4a56dacf 100644 --- a/api/openapi/model-registry.yaml +++ b/api/openapi/model-registry.yaml @@ -10,6 +10,129 @@ servers: - url: "https://localhost:8080" - url: "http://localhost:8080" paths: + /api/model_registry/v1alpha3/artifact: + summary: Path used to search for an artifact. + description: >- + The REST endpoint/path used to search for an `Artifact` entity. This path contains a `GET` operation to perform the find task. + get: + tags: + - ModelRegistryService + responses: + "200": + $ref: "#/components/responses/ArtifactResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: findArtifact + summary: Get an Artifact that matches search parameters. + description: Gets the details of a single instance of an `Artifact` that matches search parameters. + parameters: + - $ref: "#/components/parameters/name" + - $ref: "#/components/parameters/externalId" + - $ref: "#/components/parameters/parentResourceId" + /api/model_registry/v1alpha3/artifacts: + summary: Path used to manage the list of artifacts. + description: >- + The REST endpoint/path used to list and create zero or more `Artifact` entities. This path contains a `GET` and `POST` operation to perform the list and create tasks, respectively. + get: + tags: + - ModelRegistryService + parameters: + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/orderBy" + - $ref: "#/components/parameters/sortOrder" + - $ref: "#/components/parameters/nextPageToken" + responses: + "200": + $ref: "#/components/responses/ArtifactListResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: getArtifacts + summary: List All Artifacts + description: Gets a list of all `Artifact` entities. + post: + requestBody: + description: A new `Artifact` to be created. + content: + application/json: + schema: + $ref: "#/components/schemas/ArtifactCreate" + required: true + tags: + - ModelRegistryService + responses: + "201": + $ref: "#/components/responses/ArtifactResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: createArtifact + summary: Create an Artifact + description: Creates a new instance of an `Artifact`. + /api/model_registry/v1alpha3/artifacts/{id}: + summary: Path used to manage a single Artifact. + description: >- + The REST endpoint/path used to get and update single instances of an `Artifact`. This path contains `GET` and `PATCH` operations used to perform the get and update tasks, respectively. + get: + tags: + - ModelRegistryService + responses: + "200": + $ref: "#/components/responses/ArtifactResponse" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: getArtifact + summary: Get an Artifact + description: Gets the details of a single instance of an `Artifact`. + patch: + requestBody: + description: Updated `Artifact` information. + content: + application/json: + schema: + $ref: "#/components/schemas/ArtifactUpdate" + required: true + tags: + - ModelRegistryService + responses: + "200": + $ref: "#/components/responses/ArtifactResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "404": + $ref: "#/components/responses/NotFound" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: updateArtifact + summary: Update an Artifact + description: Updates an existing `Artifact`. + parameters: + - name: id + description: A unique identifier for an `Artifact`. + schema: + type: string + in: path + required: true /api/model_registry/v1alpha3/model_artifact: summary: Path used to search for a modelartifact. description: >- @@ -991,15 +1114,36 @@ components: default: "doc-artifact" allOf: - $ref: "#/components/schemas/BaseArtifact" + - $ref: "#/components/schemas/DocArtifactCreate" + DocArtifactCreate: + description: A document artifact to be created. + type: object + required: + - artifactType + properties: + artifactType: + type: string + default: "doc-artifact" + allOf: + - $ref: "#/components/schemas/BaseArtifactCreate" + - $ref: "#/components/schemas/DocArtifactUpdate" + DocArtifactUpdate: + description: A document artifact to be updated. + required: + - artifactType + properties: + artifactType: + type: string + default: "doc-artifact" + allOf: + - $ref: "#/components/schemas/BaseArtifactUpdate" RegisteredModel: description: A registered model in model registry. A registered model has ModelVersion children. allOf: - $ref: "#/components/schemas/BaseResource" - - type: object - $ref: "#/components/schemas/RegisteredModelCreate" ModelVersionList: description: List of ModelVersion entities. - type: object allOf: - type: object properties: @@ -1011,7 +1155,6 @@ components: - $ref: "#/components/schemas/BaseResourceList" ModelArtifactList: description: List of ModelArtifact entities. - type: object allOf: - type: object properties: @@ -1026,9 +1169,9 @@ components: required: - name allOf: - - type: object - $ref: "#/components/schemas/BaseResourceCreate" - $ref: "#/components/schemas/RegisteredModelUpdate" + - type: object properties: name: description: |- @@ -1098,12 +1241,10 @@ components: BaseExecution: allOf: - $ref: "#/components/schemas/BaseExecutionCreate" - - type: object - $ref: "#/components/schemas/BaseResource" BaseExecutionCreate: allOf: - $ref: "#/components/schemas/BaseExecutionUpdate" - - type: object - $ref: "#/components/schemas/BaseResourceCreate" BaseExecutionUpdate: type: object @@ -1286,7 +1427,6 @@ components: type: integer ArtifactList: description: A list of Artifact entities. - type: object allOf: - type: object properties: @@ -1297,7 +1437,13 @@ components: $ref: "#/components/schemas/Artifact" - $ref: "#/components/schemas/BaseResourceList" ModelArtifactUpdate: - description: An ML model artifact. + description: An ML model artifact to be updated. + required: + - artifactType + properties: + artifactType: + type: string + default: "model-artifact" allOf: - $ref: "#/components/schemas/BaseArtifactUpdate" - type: object @@ -1319,7 +1465,12 @@ components: type: string ModelArtifactCreate: description: An ML model artifact. - type: object + required: + - artifactType + properties: + artifactType: + type: string + default: "model-artifact" allOf: - $ref: "#/components/schemas/BaseArtifactCreate" - $ref: "#/components/schemas/ModelArtifactUpdate" @@ -1363,9 +1514,28 @@ components: allOf: - $ref: "#/components/schemas/BaseArtifactCreate" - $ref: "#/components/schemas/BaseResource" + ArtifactCreate: + description: An Artifact to be created. + oneOf: + - $ref: "#/components/schemas/ModelArtifactCreate" + - $ref: "#/components/schemas/DocArtifactCreate" + discriminator: + propertyName: artifactType + mapping: + model-artifact: "#/components/schemas/ModelArtifactCreate" + doc-artifact: "#/components/schemas/DocArtifactCreate" + ArtifactUpdate: + description: An Artifact to be updated. + oneOf: + - $ref: "#/components/schemas/ModelArtifactUpdate" + - $ref: "#/components/schemas/DocArtifactUpdate" + discriminator: + propertyName: artifactType + mapping: + model-artifact: "#/components/schemas/ModelArtifactUpdate" + doc-artifact: "#/components/schemas/DocArtifactUpdate" ServingEnvironmentList: description: List of ServingEnvironments. - type: object allOf: - type: object properties: @@ -1378,7 +1548,6 @@ components: - $ref: "#/components/schemas/BaseResourceList" RegisteredModelList: description: List of RegisteredModels. - type: object allOf: - type: object properties: @@ -1402,7 +1571,6 @@ components: ServingEnvironmentCreate: description: A Model Serving environment for serving `RegisteredModels`. allOf: - - type: object - $ref: "#/components/schemas/BaseResourceCreate" - $ref: "#/components/schemas/ServingEnvironmentUpdate" InferenceService: @@ -1413,7 +1581,6 @@ components: - $ref: "#/components/schemas/InferenceServiceCreate" InferenceServiceList: description: List of InferenceServices. - type: object allOf: - type: object properties: @@ -1426,7 +1593,6 @@ components: - $ref: "#/components/schemas/BaseResourceList" ServeModelList: description: List of ServeModel entities. - type: object allOf: - type: object properties: @@ -1438,7 +1604,6 @@ components: - $ref: "#/components/schemas/BaseResourceList" ServeModel: description: An ML model serving action. - type: object allOf: - $ref: "#/components/schemas/BaseExecution" - $ref: "#/components/schemas/ServeModelCreate" @@ -1448,12 +1613,12 @@ components: - $ref: "#/components/schemas/BaseExecutionUpdate" ServeModelCreate: description: An ML model serving action. + required: + - modelVersionId allOf: - $ref: "#/components/schemas/BaseExecutionCreate" - $ref: "#/components/schemas/ServeModelUpdate" - - required: - - modelVersionId - type: object + - type: object properties: modelVersionId: description: ID of the `ModelVersion` that was served in `InferenceService`. @@ -1477,13 +1642,13 @@ components: InferenceServiceCreate: description: >- An `InferenceService` entity in a `ServingEnvironment` represents a deployed `ModelVersion` from a `RegisteredModel` created by Model Serving. + required: + - registeredModelId + - servingEnvironmentId allOf: - $ref: "#/components/schemas/BaseResourceCreate" - $ref: "#/components/schemas/InferenceServiceUpdate" - - required: - - registeredModelId - - servingEnvironmentId - type: object + - type: object properties: registeredModelId: description: ID of the `RegisteredModel` to serve. diff --git a/clients/python/src/.openapi-generator/FILES b/clients/python/src/.openapi-generator/FILES index 0c6091c8..1d6f28a4 100644 --- a/clients/python/src/.openapi-generator/FILES +++ b/clients/python/src/.openapi-generator/FILES @@ -7,8 +7,10 @@ mr_openapi/configuration.py mr_openapi/exceptions.py mr_openapi/models/__init__.py mr_openapi/models/artifact.py +mr_openapi/models/artifact_create.py mr_openapi/models/artifact_list.py mr_openapi/models/artifact_state.py +mr_openapi/models/artifact_update.py mr_openapi/models/base_artifact.py mr_openapi/models/base_artifact_create.py mr_openapi/models/base_artifact_update.py @@ -20,6 +22,8 @@ mr_openapi/models/base_resource_create.py mr_openapi/models/base_resource_list.py mr_openapi/models/base_resource_update.py mr_openapi/models/doc_artifact.py +mr_openapi/models/doc_artifact_create.py +mr_openapi/models/doc_artifact_update.py mr_openapi/models/error.py mr_openapi/models/execution_state.py mr_openapi/models/inference_service.py diff --git a/clients/python/src/model_registry/types/artifacts.py b/clients/python/src/model_registry/types/artifacts.py index 64f5a974..a8cfd793 100644 --- a/clients/python/src/model_registry/types/artifacts.py +++ b/clients/python/src/model_registry/types/artifacts.py @@ -152,6 +152,7 @@ def create(self, **kwargs) -> ModelArtifactCreate: return ModelArtifactCreate( customProperties=self._map_custom_properties(), **self._props_as_dict(exclude=("id", "custom_properties")), + artifactType="model-artifact", **kwargs, ) @@ -161,6 +162,7 @@ def update(self, **kwargs) -> ModelArtifactUpdate: return ModelArtifactUpdate( customProperties=self._map_custom_properties(), **self._props_as_dict(exclude=("id", "name", "custom_properties")), + artifactType="model-artifact", **kwargs, ) diff --git a/clients/python/src/mr_openapi/README.md b/clients/python/src/mr_openapi/README.md index ffee94f0..c7f687c7 100644 --- a/clients/python/src/mr_openapi/README.md +++ b/clients/python/src/mr_openapi/README.md @@ -55,16 +55,15 @@ configuration = mr_openapi.Configuration( async with mr_openapi.ApiClient(configuration) as api_client: # Create an instance of the API class api_instance = mr_openapi.ModelRegistryServiceApi(api_client) - servingenvironment_id = 'servingenvironment_id_example' # str | A unique identifier for a `ServingEnvironment`. - inference_service_create = mr_openapi.InferenceServiceCreate() # InferenceServiceCreate | A new `InferenceService` to be created. + artifact_create = mr_openapi.ArtifactCreate() # ArtifactCreate | A new `Artifact` to be created. try: - # Create a InferenceService in ServingEnvironment - api_response = await api_instance.create_environment_inference_service(servingenvironment_id, inference_service_create) - print("The response of ModelRegistryServiceApi->create_environment_inference_service:\n") + # Create an Artifact + api_response = await api_instance.create_artifact(artifact_create) + print("The response of ModelRegistryServiceApi->create_artifact:\n") pprint(api_response) except ApiException as e: - print("Exception when calling ModelRegistryServiceApi->create_environment_inference_service: %s\n" % e) + print("Exception when calling ModelRegistryServiceApi->create_artifact: %s\n" % e) ``` @@ -74,6 +73,7 @@ All URIs are relative to *https://localhost:8080* Class | Method | HTTP request | Description ------------ | ------------- | ------------- | ------------- +*ModelRegistryServiceApi* | [**create_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#create_artifact) | **POST** /api/model_registry/v1alpha3/artifacts | Create an Artifact *ModelRegistryServiceApi* | [**create_environment_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#create_environment_inference_service) | **POST** /api/model_registry/v1alpha3/serving_environments/{servingenvironmentId}/inference_services | Create a InferenceService in ServingEnvironment *ModelRegistryServiceApi* | [**create_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#create_inference_service) | **POST** /api/model_registry/v1alpha3/inference_services | Create a InferenceService *ModelRegistryServiceApi* | [**create_inference_service_serve**](mr_openapi/docs/ModelRegistryServiceApi.md#create_inference_service_serve) | **POST** /api/model_registry/v1alpha3/inference_services/{inferenceserviceId}/serves | Create a ServeModel action in a InferenceService @@ -82,11 +82,14 @@ Class | Method | HTTP request | Description *ModelRegistryServiceApi* | [**create_registered_model**](mr_openapi/docs/ModelRegistryServiceApi.md#create_registered_model) | **POST** /api/model_registry/v1alpha3/registered_models | Create a RegisteredModel *ModelRegistryServiceApi* | [**create_registered_model_version**](mr_openapi/docs/ModelRegistryServiceApi.md#create_registered_model_version) | **POST** /api/model_registry/v1alpha3/registered_models/{registeredmodelId}/versions | Create a ModelVersion in RegisteredModel *ModelRegistryServiceApi* | [**create_serving_environment**](mr_openapi/docs/ModelRegistryServiceApi.md#create_serving_environment) | **POST** /api/model_registry/v1alpha3/serving_environments | Create a ServingEnvironment +*ModelRegistryServiceApi* | [**find_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#find_artifact) | **GET** /api/model_registry/v1alpha3/artifact | Get an Artifact that matches search parameters. *ModelRegistryServiceApi* | [**find_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#find_inference_service) | **GET** /api/model_registry/v1alpha3/inference_service | Get an InferenceServices that matches search parameters. *ModelRegistryServiceApi* | [**find_model_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#find_model_artifact) | **GET** /api/model_registry/v1alpha3/model_artifact | Get a ModelArtifact that matches search parameters. *ModelRegistryServiceApi* | [**find_model_version**](mr_openapi/docs/ModelRegistryServiceApi.md#find_model_version) | **GET** /api/model_registry/v1alpha3/model_version | Get a ModelVersion that matches search parameters. *ModelRegistryServiceApi* | [**find_registered_model**](mr_openapi/docs/ModelRegistryServiceApi.md#find_registered_model) | **GET** /api/model_registry/v1alpha3/registered_model | Get a RegisteredModel that matches search parameters. *ModelRegistryServiceApi* | [**find_serving_environment**](mr_openapi/docs/ModelRegistryServiceApi.md#find_serving_environment) | **GET** /api/model_registry/v1alpha3/serving_environment | Find ServingEnvironment +*ModelRegistryServiceApi* | [**get_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#get_artifact) | **GET** /api/model_registry/v1alpha3/artifacts/{id} | Get an Artifact +*ModelRegistryServiceApi* | [**get_artifacts**](mr_openapi/docs/ModelRegistryServiceApi.md#get_artifacts) | **GET** /api/model_registry/v1alpha3/artifacts | List All Artifacts *ModelRegistryServiceApi* | [**get_environment_inference_services**](mr_openapi/docs/ModelRegistryServiceApi.md#get_environment_inference_services) | **GET** /api/model_registry/v1alpha3/serving_environments/{servingenvironmentId}/inference_services | List All ServingEnvironment's InferenceServices *ModelRegistryServiceApi* | [**get_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#get_inference_service) | **GET** /api/model_registry/v1alpha3/inference_services/{inferenceserviceId} | Get a InferenceService *ModelRegistryServiceApi* | [**get_inference_service_model**](mr_openapi/docs/ModelRegistryServiceApi.md#get_inference_service_model) | **GET** /api/model_registry/v1alpha3/inference_services/{inferenceserviceId}/model | Get InferenceService's RegisteredModel @@ -103,6 +106,7 @@ Class | Method | HTTP request | Description *ModelRegistryServiceApi* | [**get_registered_models**](mr_openapi/docs/ModelRegistryServiceApi.md#get_registered_models) | **GET** /api/model_registry/v1alpha3/registered_models | List All RegisteredModels *ModelRegistryServiceApi* | [**get_serving_environment**](mr_openapi/docs/ModelRegistryServiceApi.md#get_serving_environment) | **GET** /api/model_registry/v1alpha3/serving_environments/{servingenvironmentId} | Get a ServingEnvironment *ModelRegistryServiceApi* | [**get_serving_environments**](mr_openapi/docs/ModelRegistryServiceApi.md#get_serving_environments) | **GET** /api/model_registry/v1alpha3/serving_environments | List All ServingEnvironments +*ModelRegistryServiceApi* | [**update_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#update_artifact) | **PATCH** /api/model_registry/v1alpha3/artifacts/{id} | Update an Artifact *ModelRegistryServiceApi* | [**update_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#update_inference_service) | **PATCH** /api/model_registry/v1alpha3/inference_services/{inferenceserviceId} | Update a InferenceService *ModelRegistryServiceApi* | [**update_model_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#update_model_artifact) | **PATCH** /api/model_registry/v1alpha3/model_artifacts/{modelartifactId} | Update a ModelArtifact *ModelRegistryServiceApi* | [**update_model_version**](mr_openapi/docs/ModelRegistryServiceApi.md#update_model_version) | **PATCH** /api/model_registry/v1alpha3/model_versions/{modelversionId} | Update a ModelVersion @@ -114,8 +118,10 @@ Class | Method | HTTP request | Description ## Documentation For Models - [Artifact](mr_openapi/docs/Artifact.md) + - [ArtifactCreate](mr_openapi/docs/ArtifactCreate.md) - [ArtifactList](mr_openapi/docs/ArtifactList.md) - [ArtifactState](mr_openapi/docs/ArtifactState.md) + - [ArtifactUpdate](mr_openapi/docs/ArtifactUpdate.md) - [BaseArtifact](mr_openapi/docs/BaseArtifact.md) - [BaseArtifactCreate](mr_openapi/docs/BaseArtifactCreate.md) - [BaseArtifactUpdate](mr_openapi/docs/BaseArtifactUpdate.md) @@ -127,6 +133,8 @@ Class | Method | HTTP request | Description - [BaseResourceList](mr_openapi/docs/BaseResourceList.md) - [BaseResourceUpdate](mr_openapi/docs/BaseResourceUpdate.md) - [DocArtifact](mr_openapi/docs/DocArtifact.md) + - [DocArtifactCreate](mr_openapi/docs/DocArtifactCreate.md) + - [DocArtifactUpdate](mr_openapi/docs/DocArtifactUpdate.md) - [Error](mr_openapi/docs/Error.md) - [ExecutionState](mr_openapi/docs/ExecutionState.md) - [InferenceService](mr_openapi/docs/InferenceService.md) diff --git a/clients/python/src/mr_openapi/__init__.py b/clients/python/src/mr_openapi/__init__.py index 3a01c51d..b022ce3f 100644 --- a/clients/python/src/mr_openapi/__init__.py +++ b/clients/python/src/mr_openapi/__init__.py @@ -32,8 +32,10 @@ # import models into sdk package from mr_openapi.models.artifact import Artifact +from mr_openapi.models.artifact_create import ArtifactCreate from mr_openapi.models.artifact_list import ArtifactList from mr_openapi.models.artifact_state import ArtifactState +from mr_openapi.models.artifact_update import ArtifactUpdate from mr_openapi.models.base_artifact import BaseArtifact from mr_openapi.models.base_artifact_create import BaseArtifactCreate from mr_openapi.models.base_artifact_update import BaseArtifactUpdate @@ -45,6 +47,8 @@ from mr_openapi.models.base_resource_list import BaseResourceList from mr_openapi.models.base_resource_update import BaseResourceUpdate from mr_openapi.models.doc_artifact import DocArtifact +from mr_openapi.models.doc_artifact_create import DocArtifactCreate +from mr_openapi.models.doc_artifact_update import DocArtifactUpdate from mr_openapi.models.error import Error from mr_openapi.models.execution_state import ExecutionState from mr_openapi.models.inference_service import InferenceService diff --git a/clients/python/src/mr_openapi/api/model_registry_service_api.py b/clients/python/src/mr_openapi/api/model_registry_service_api.py index 0e563a10..423c782c 100644 --- a/clients/python/src/mr_openapi/api/model_registry_service_api.py +++ b/clients/python/src/mr_openapi/api/model_registry_service_api.py @@ -15,7 +15,9 @@ from mr_openapi.api_client import ApiClient, RequestSerialized from mr_openapi.api_response import ApiResponse from mr_openapi.models.artifact import Artifact +from mr_openapi.models.artifact_create import ArtifactCreate from mr_openapi.models.artifact_list import ArtifactList +from mr_openapi.models.artifact_update import ArtifactUpdate from mr_openapi.models.inference_service import InferenceService from mr_openapi.models.inference_service_create import InferenceServiceCreate from mr_openapi.models.inference_service_list import InferenceServiceList @@ -56,6 +58,245 @@ def __init__(self, api_client=None) -> None: api_client = ApiClient.get_default() self.api_client = api_client + @validate_call + async def create_artifact( + self, + artifact_create: Annotated[ArtifactCreate, Field(description="A new `Artifact` to be created.")], + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> Artifact: + """Create an Artifact. + + Creates a new instance of an `Artifact`. + + :param artifact_create: A new `Artifact` to be created. (required) + :type artifact_create: ArtifactCreate + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._create_artifact_serialize( + artifact_create=artifact_create, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "201": "Artifact", + "400": "Error", + "401": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ).data + + @validate_call + async def create_artifact_with_http_info( + self, + artifact_create: Annotated[ArtifactCreate, Field(description="A new `Artifact` to be created.")], + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> ApiResponse[Artifact]: + """Create an Artifact. + + Creates a new instance of an `Artifact`. + + :param artifact_create: A new `Artifact` to be created. (required) + :type artifact_create: ArtifactCreate + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._create_artifact_serialize( + artifact_create=artifact_create, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "201": "Artifact", + "400": "Error", + "401": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ) + + @validate_call + async def create_artifact_without_preload_content( + self, + artifact_create: Annotated[ArtifactCreate, Field(description="A new `Artifact` to be created.")], + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> RESTResponseType: + """Create an Artifact. + + Creates a new instance of an `Artifact`. + + :param artifact_create: A new `Artifact` to be created. (required) + :type artifact_create: ArtifactCreate + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._create_artifact_serialize( + artifact_create=artifact_create, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "201": "Artifact", + "400": "Error", + "401": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + return response_data.response + + def _create_artifact_serialize( + self, + artifact_create, + _request_auth, + _content_type, + _headers, + _host_index, + ) -> RequestSerialized: + + _host = None + + _collection_formats: dict[str, str] = {} + + _path_params: dict[str, str] = {} + _query_params: list[tuple[str, str]] = [] + _header_params: dict[str, Optional[str]] = _headers or {} + _form_params: list[tuple[str, str]] = [] + _files: dict[str, Union[str, bytes]] = {} + _body_params: Optional[bytes] = None + + # process the path parameters + # process the query parameters + # process the header parameters + # process the form parameters + # process the body parameter + if artifact_create is not None: + _body_params = artifact_create + + # set the HTTP header `Accept` + _header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) + + # set the HTTP header `Content-Type` + if _content_type: + _header_params["Content-Type"] = _content_type + else: + _default_content_type = self.api_client.select_header_content_type(["application/json"]) + if _default_content_type is not None: + _header_params["Content-Type"] = _default_content_type + + # authentication setting + _auth_settings: list[str] = ["Bearer"] + + return self.api_client.param_serialize( + method="POST", + resource_path="/api/model_registry/v1alpha3/artifacts", + path_params=_path_params, + query_params=_query_params, + header_params=_header_params, + body=_body_params, + post_params=_form_params, + files=_files, + auth_settings=_auth_settings, + collection_formats=_collection_formats, + _host=_host, + _request_auth=_request_auth, + ) + @validate_call async def create_environment_inference_service( self, @@ -2065,7 +2306,7 @@ def _create_serving_environment_serialize( ) @validate_call - async def find_inference_service( + async def find_artifact( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -2081,10 +2322,10 @@ async def find_inference_service( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> InferenceService: - """Get an InferenceServices that matches search parameters. + ) -> Artifact: + """Get an Artifact that matches search parameters. - Gets the details of a single instance of `InferenceService` that matches search parameters. + Gets the details of a single instance of an `Artifact` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -2113,7 +2354,7 @@ async def find_inference_service( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_inference_service_serialize( + _param = self._find_artifact_serialize( name=name, external_id=external_id, parent_resource_id=parent_resource_id, @@ -2124,7 +2365,7 @@ async def find_inference_service( ) _response_types_map: dict[str, Optional[str]] = { - "200": "InferenceService", + "200": "Artifact", "400": "Error", "401": "Error", "404": "Error", @@ -2138,7 +2379,7 @@ async def find_inference_service( ).data @validate_call - async def find_inference_service_with_http_info( + async def find_artifact_with_http_info( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -2154,10 +2395,10 @@ async def find_inference_service_with_http_info( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> ApiResponse[InferenceService]: - """Get an InferenceServices that matches search parameters. + ) -> ApiResponse[Artifact]: + """Get an Artifact that matches search parameters. - Gets the details of a single instance of `InferenceService` that matches search parameters. + Gets the details of a single instance of an `Artifact` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -2186,7 +2427,7 @@ async def find_inference_service_with_http_info( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_inference_service_serialize( + _param = self._find_artifact_serialize( name=name, external_id=external_id, parent_resource_id=parent_resource_id, @@ -2197,7 +2438,7 @@ async def find_inference_service_with_http_info( ) _response_types_map: dict[str, Optional[str]] = { - "200": "InferenceService", + "200": "Artifact", "400": "Error", "401": "Error", "404": "Error", @@ -2211,7 +2452,7 @@ async def find_inference_service_with_http_info( ) @validate_call - async def find_inference_service_without_preload_content( + async def find_artifact_without_preload_content( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -2228,9 +2469,9 @@ async def find_inference_service_without_preload_content( _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: - """Get an InferenceServices that matches search parameters. + """Get an Artifact that matches search parameters. - Gets the details of a single instance of `InferenceService` that matches search parameters. + Gets the details of a single instance of an `Artifact` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -2259,7 +2500,7 @@ async def find_inference_service_without_preload_content( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_inference_service_serialize( + _param = self._find_artifact_serialize( name=name, external_id=external_id, parent_resource_id=parent_resource_id, @@ -2270,7 +2511,7 @@ async def find_inference_service_without_preload_content( ) _response_types_map: dict[str, Optional[str]] = { - "200": "InferenceService", + "200": "Artifact", "400": "Error", "401": "Error", "404": "Error", @@ -2279,7 +2520,7 @@ async def find_inference_service_without_preload_content( response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) return response_data.response - def _find_inference_service_serialize( + def _find_artifact_serialize( self, name, external_id, @@ -2327,7 +2568,7 @@ def _find_inference_service_serialize( return self.api_client.param_serialize( method="GET", - resource_path="/api/model_registry/v1alpha3/inference_service", + resource_path="/api/model_registry/v1alpha3/artifact", path_params=_path_params, query_params=_query_params, header_params=_header_params, @@ -2341,7 +2582,7 @@ def _find_inference_service_serialize( ) @validate_call - async def find_model_artifact( + async def find_inference_service( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -2357,10 +2598,10 @@ async def find_model_artifact( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> ModelArtifact: - """Get a ModelArtifact that matches search parameters. + ) -> InferenceService: + """Get an InferenceServices that matches search parameters. - Gets the details of a single instance of a `ModelArtifact` that matches search parameters. + Gets the details of a single instance of `InferenceService` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -2389,7 +2630,7 @@ async def find_model_artifact( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_model_artifact_serialize( + _param = self._find_inference_service_serialize( name=name, external_id=external_id, parent_resource_id=parent_resource_id, @@ -2400,7 +2641,7 @@ async def find_model_artifact( ) _response_types_map: dict[str, Optional[str]] = { - "200": "ModelArtifact", + "200": "InferenceService", "400": "Error", "401": "Error", "404": "Error", @@ -2414,7 +2655,7 @@ async def find_model_artifact( ).data @validate_call - async def find_model_artifact_with_http_info( + async def find_inference_service_with_http_info( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -2430,10 +2671,10 @@ async def find_model_artifact_with_http_info( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> ApiResponse[ModelArtifact]: - """Get a ModelArtifact that matches search parameters. + ) -> ApiResponse[InferenceService]: + """Get an InferenceServices that matches search parameters. - Gets the details of a single instance of a `ModelArtifact` that matches search parameters. + Gets the details of a single instance of `InferenceService` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -2462,7 +2703,7 @@ async def find_model_artifact_with_http_info( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_model_artifact_serialize( + _param = self._find_inference_service_serialize( name=name, external_id=external_id, parent_resource_id=parent_resource_id, @@ -2473,7 +2714,7 @@ async def find_model_artifact_with_http_info( ) _response_types_map: dict[str, Optional[str]] = { - "200": "ModelArtifact", + "200": "InferenceService", "400": "Error", "401": "Error", "404": "Error", @@ -2487,7 +2728,7 @@ async def find_model_artifact_with_http_info( ) @validate_call - async def find_model_artifact_without_preload_content( + async def find_inference_service_without_preload_content( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -2504,9 +2745,9 @@ async def find_model_artifact_without_preload_content( _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: - """Get a ModelArtifact that matches search parameters. + """Get an InferenceServices that matches search parameters. - Gets the details of a single instance of a `ModelArtifact` that matches search parameters. + Gets the details of a single instance of `InferenceService` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -2535,7 +2776,7 @@ async def find_model_artifact_without_preload_content( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_model_artifact_serialize( + _param = self._find_inference_service_serialize( name=name, external_id=external_id, parent_resource_id=parent_resource_id, @@ -2546,7 +2787,7 @@ async def find_model_artifact_without_preload_content( ) _response_types_map: dict[str, Optional[str]] = { - "200": "ModelArtifact", + "200": "InferenceService", "400": "Error", "401": "Error", "404": "Error", @@ -2555,7 +2796,7 @@ async def find_model_artifact_without_preload_content( response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) return response_data.response - def _find_model_artifact_serialize( + def _find_inference_service_serialize( self, name, external_id, @@ -2603,7 +2844,7 @@ def _find_model_artifact_serialize( return self.api_client.param_serialize( method="GET", - resource_path="/api/model_registry/v1alpha3/model_artifact", + resource_path="/api/model_registry/v1alpha3/inference_service", path_params=_path_params, query_params=_query_params, header_params=_header_params, @@ -2617,7 +2858,7 @@ def _find_model_artifact_serialize( ) @validate_call - async def find_model_version( + async def find_model_artifact( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -2633,10 +2874,10 @@ async def find_model_version( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> ModelVersion: - """Get a ModelVersion that matches search parameters. + ) -> ModelArtifact: + """Get a ModelArtifact that matches search parameters. - Gets the details of a single instance of a `ModelVersion` that matches search parameters. + Gets the details of a single instance of a `ModelArtifact` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -2665,7 +2906,7 @@ async def find_model_version( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_model_version_serialize( + _param = self._find_model_artifact_serialize( name=name, external_id=external_id, parent_resource_id=parent_resource_id, @@ -2676,7 +2917,7 @@ async def find_model_version( ) _response_types_map: dict[str, Optional[str]] = { - "200": "ModelVersion", + "200": "ModelArtifact", "400": "Error", "401": "Error", "404": "Error", @@ -2690,7 +2931,7 @@ async def find_model_version( ).data @validate_call - async def find_model_version_with_http_info( + async def find_model_artifact_with_http_info( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -2706,10 +2947,10 @@ async def find_model_version_with_http_info( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> ApiResponse[ModelVersion]: - """Get a ModelVersion that matches search parameters. + ) -> ApiResponse[ModelArtifact]: + """Get a ModelArtifact that matches search parameters. - Gets the details of a single instance of a `ModelVersion` that matches search parameters. + Gets the details of a single instance of a `ModelArtifact` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -2738,7 +2979,7 @@ async def find_model_version_with_http_info( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_model_version_serialize( + _param = self._find_model_artifact_serialize( name=name, external_id=external_id, parent_resource_id=parent_resource_id, @@ -2749,7 +2990,7 @@ async def find_model_version_with_http_info( ) _response_types_map: dict[str, Optional[str]] = { - "200": "ModelVersion", + "200": "ModelArtifact", "400": "Error", "401": "Error", "404": "Error", @@ -2763,7 +3004,7 @@ async def find_model_version_with_http_info( ) @validate_call - async def find_model_version_without_preload_content( + async def find_model_artifact_without_preload_content( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -2780,9 +3021,9 @@ async def find_model_version_without_preload_content( _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: - """Get a ModelVersion that matches search parameters. + """Get a ModelArtifact that matches search parameters. - Gets the details of a single instance of a `ModelVersion` that matches search parameters. + Gets the details of a single instance of a `ModelArtifact` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -2811,7 +3052,7 @@ async def find_model_version_without_preload_content( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_model_version_serialize( + _param = self._find_model_artifact_serialize( name=name, external_id=external_id, parent_resource_id=parent_resource_id, @@ -2822,7 +3063,7 @@ async def find_model_version_without_preload_content( ) _response_types_map: dict[str, Optional[str]] = { - "200": "ModelVersion", + "200": "ModelArtifact", "400": "Error", "401": "Error", "404": "Error", @@ -2831,7 +3072,7 @@ async def find_model_version_without_preload_content( response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) return response_data.response - def _find_model_version_serialize( + def _find_model_artifact_serialize( self, name, external_id, @@ -2879,7 +3120,7 @@ def _find_model_version_serialize( return self.api_client.param_serialize( method="GET", - resource_path="/api/model_registry/v1alpha3/model_version", + resource_path="/api/model_registry/v1alpha3/model_artifact", path_params=_path_params, query_params=_query_params, header_params=_header_params, @@ -2893,10 +3134,13 @@ def _find_model_version_serialize( ) @validate_call - async def find_registered_model( + async def find_model_version( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, + parent_resource_id: Annotated[ + Optional[StrictStr], Field(description="ID of the parent resource to use for search.") + ] = None, _request_timeout: Union[ None, Annotated[StrictFloat, Field(gt=0)], @@ -2906,15 +3150,17 @@ async def find_registered_model( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> RegisteredModel: - """Get a RegisteredModel that matches search parameters. + ) -> ModelVersion: + """Get a ModelVersion that matches search parameters. - Gets the details of a single instance of a `RegisteredModel` that matches search parameters. + Gets the details of a single instance of a `ModelVersion` that matches search parameters. :param name: Name of entity to search. :type name: str :param external_id: External ID of entity to search. :type external_id: str + :param parent_resource_id: ID of the parent resource to use for search. + :type parent_resource_id: str :param _request_timeout: timeout setting for this request. If one number provided, it will be total request timeout. It can also be a pair (tuple) of @@ -2936,9 +3182,10 @@ async def find_registered_model( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_registered_model_serialize( + _param = self._find_model_version_serialize( name=name, external_id=external_id, + parent_resource_id=parent_resource_id, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, @@ -2946,7 +3193,8 @@ async def find_registered_model( ) _response_types_map: dict[str, Optional[str]] = { - "200": "RegisteredModel", + "200": "ModelVersion", + "400": "Error", "401": "Error", "404": "Error", "500": "Error", @@ -2959,10 +3207,13 @@ async def find_registered_model( ).data @validate_call - async def find_registered_model_with_http_info( + async def find_model_version_with_http_info( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, + parent_resource_id: Annotated[ + Optional[StrictStr], Field(description="ID of the parent resource to use for search.") + ] = None, _request_timeout: Union[ None, Annotated[StrictFloat, Field(gt=0)], @@ -2972,15 +3223,17 @@ async def find_registered_model_with_http_info( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> ApiResponse[RegisteredModel]: - """Get a RegisteredModel that matches search parameters. + ) -> ApiResponse[ModelVersion]: + """Get a ModelVersion that matches search parameters. - Gets the details of a single instance of a `RegisteredModel` that matches search parameters. + Gets the details of a single instance of a `ModelVersion` that matches search parameters. :param name: Name of entity to search. :type name: str :param external_id: External ID of entity to search. :type external_id: str + :param parent_resource_id: ID of the parent resource to use for search. + :type parent_resource_id: str :param _request_timeout: timeout setting for this request. If one number provided, it will be total request timeout. It can also be a pair (tuple) of @@ -3002,9 +3255,10 @@ async def find_registered_model_with_http_info( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_registered_model_serialize( + _param = self._find_model_version_serialize( name=name, external_id=external_id, + parent_resource_id=parent_resource_id, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, @@ -3012,7 +3266,8 @@ async def find_registered_model_with_http_info( ) _response_types_map: dict[str, Optional[str]] = { - "200": "RegisteredModel", + "200": "ModelVersion", + "400": "Error", "401": "Error", "404": "Error", "500": "Error", @@ -3025,10 +3280,13 @@ async def find_registered_model_with_http_info( ) @validate_call - async def find_registered_model_without_preload_content( + async def find_model_version_without_preload_content( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, + parent_resource_id: Annotated[ + Optional[StrictStr], Field(description="ID of the parent resource to use for search.") + ] = None, _request_timeout: Union[ None, Annotated[StrictFloat, Field(gt=0)], @@ -3039,14 +3297,16 @@ async def find_registered_model_without_preload_content( _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: - """Get a RegisteredModel that matches search parameters. + """Get a ModelVersion that matches search parameters. - Gets the details of a single instance of a `RegisteredModel` that matches search parameters. + Gets the details of a single instance of a `ModelVersion` that matches search parameters. :param name: Name of entity to search. :type name: str :param external_id: External ID of entity to search. :type external_id: str + :param parent_resource_id: ID of the parent resource to use for search. + :type parent_resource_id: str :param _request_timeout: timeout setting for this request. If one number provided, it will be total request timeout. It can also be a pair (tuple) of @@ -3068,9 +3328,10 @@ async def find_registered_model_without_preload_content( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_registered_model_serialize( + _param = self._find_model_version_serialize( name=name, external_id=external_id, + parent_resource_id=parent_resource_id, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, @@ -3078,7 +3339,8 @@ async def find_registered_model_without_preload_content( ) _response_types_map: dict[str, Optional[str]] = { - "200": "RegisteredModel", + "200": "ModelVersion", + "400": "Error", "401": "Error", "404": "Error", "500": "Error", @@ -3086,10 +3348,11 @@ async def find_registered_model_without_preload_content( response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) return response_data.response - def _find_registered_model_serialize( + def _find_model_version_serialize( self, name, external_id, + parent_resource_id, _request_auth, _content_type, _headers, @@ -3117,6 +3380,10 @@ def _find_registered_model_serialize( _query_params.append(("externalId", external_id)) + if parent_resource_id is not None: + + _query_params.append(("parentResourceId", parent_resource_id)) + # process the header parameters # process the form parameters # process the body parameter @@ -3129,7 +3396,7 @@ def _find_registered_model_serialize( return self.api_client.param_serialize( method="GET", - resource_path="/api/model_registry/v1alpha3/registered_model", + resource_path="/api/model_registry/v1alpha3/model_version", path_params=_path_params, query_params=_query_params, header_params=_header_params, @@ -3143,7 +3410,7 @@ def _find_registered_model_serialize( ) @validate_call - async def find_serving_environment( + async def find_registered_model( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -3156,10 +3423,10 @@ async def find_serving_environment( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> ServingEnvironment: - """Find ServingEnvironment. + ) -> RegisteredModel: + """Get a RegisteredModel that matches search parameters. - Finds a `ServingEnvironment` entity that matches query parameters. + Gets the details of a single instance of a `RegisteredModel` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -3186,7 +3453,7 @@ async def find_serving_environment( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_serving_environment_serialize( + _param = self._find_registered_model_serialize( name=name, external_id=external_id, _request_auth=_request_auth, @@ -3196,7 +3463,7 @@ async def find_serving_environment( ) _response_types_map: dict[str, Optional[str]] = { - "200": "ServingEnvironment", + "200": "RegisteredModel", "401": "Error", "404": "Error", "500": "Error", @@ -3209,7 +3476,7 @@ async def find_serving_environment( ).data @validate_call - async def find_serving_environment_with_http_info( + async def find_registered_model_with_http_info( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -3222,10 +3489,10 @@ async def find_serving_environment_with_http_info( _content_type: Optional[StrictStr] = None, _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, - ) -> ApiResponse[ServingEnvironment]: - """Find ServingEnvironment. + ) -> ApiResponse[RegisteredModel]: + """Get a RegisteredModel that matches search parameters. - Finds a `ServingEnvironment` entity that matches query parameters. + Gets the details of a single instance of a `RegisteredModel` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -3252,7 +3519,7 @@ async def find_serving_environment_with_http_info( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_serving_environment_serialize( + _param = self._find_registered_model_serialize( name=name, external_id=external_id, _request_auth=_request_auth, @@ -3262,7 +3529,7 @@ async def find_serving_environment_with_http_info( ) _response_types_map: dict[str, Optional[str]] = { - "200": "ServingEnvironment", + "200": "RegisteredModel", "401": "Error", "404": "Error", "500": "Error", @@ -3275,7 +3542,7 @@ async def find_serving_environment_with_http_info( ) @validate_call - async def find_serving_environment_without_preload_content( + async def find_registered_model_without_preload_content( self, name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, @@ -3289,9 +3556,9 @@ async def find_serving_environment_without_preload_content( _headers: Optional[dict[StrictStr, Any]] = None, _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, ) -> RESTResponseType: - """Find ServingEnvironment. + """Get a RegisteredModel that matches search parameters. - Finds a `ServingEnvironment` entity that matches query parameters. + Gets the details of a single instance of a `RegisteredModel` that matches search parameters. :param name: Name of entity to search. :type name: str @@ -3318,9 +3585,772 @@ async def find_serving_environment_without_preload_content( :type _host_index: int, optional :return: Returns the result object. """ # noqa: E501 - _param = self._find_serving_environment_serialize( - name=name, - external_id=external_id, + _param = self._find_registered_model_serialize( + name=name, + external_id=external_id, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "RegisteredModel", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + return response_data.response + + def _find_registered_model_serialize( + self, + name, + external_id, + _request_auth, + _content_type, + _headers, + _host_index, + ) -> RequestSerialized: + + _host = None + + _collection_formats: dict[str, str] = {} + + _path_params: dict[str, str] = {} + _query_params: list[tuple[str, str]] = [] + _header_params: dict[str, Optional[str]] = _headers or {} + _form_params: list[tuple[str, str]] = [] + _files: dict[str, Union[str, bytes]] = {} + _body_params: Optional[bytes] = None + + # process the path parameters + # process the query parameters + if name is not None: + + _query_params.append(("name", name)) + + if external_id is not None: + + _query_params.append(("externalId", external_id)) + + # process the header parameters + # process the form parameters + # process the body parameter + + # set the HTTP header `Accept` + _header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) + + # authentication setting + _auth_settings: list[str] = ["Bearer"] + + return self.api_client.param_serialize( + method="GET", + resource_path="/api/model_registry/v1alpha3/registered_model", + path_params=_path_params, + query_params=_query_params, + header_params=_header_params, + body=_body_params, + post_params=_form_params, + files=_files, + auth_settings=_auth_settings, + collection_formats=_collection_formats, + _host=_host, + _request_auth=_request_auth, + ) + + @validate_call + async def find_serving_environment( + self, + name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, + external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> ServingEnvironment: + """Find ServingEnvironment. + + Finds a `ServingEnvironment` entity that matches query parameters. + + :param name: Name of entity to search. + :type name: str + :param external_id: External ID of entity to search. + :type external_id: str + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._find_serving_environment_serialize( + name=name, + external_id=external_id, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "ServingEnvironment", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ).data + + @validate_call + async def find_serving_environment_with_http_info( + self, + name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, + external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> ApiResponse[ServingEnvironment]: + """Find ServingEnvironment. + + Finds a `ServingEnvironment` entity that matches query parameters. + + :param name: Name of entity to search. + :type name: str + :param external_id: External ID of entity to search. + :type external_id: str + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._find_serving_environment_serialize( + name=name, + external_id=external_id, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "ServingEnvironment", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ) + + @validate_call + async def find_serving_environment_without_preload_content( + self, + name: Annotated[Optional[StrictStr], Field(description="Name of entity to search.")] = None, + external_id: Annotated[Optional[StrictStr], Field(description="External ID of entity to search.")] = None, + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> RESTResponseType: + """Find ServingEnvironment. + + Finds a `ServingEnvironment` entity that matches query parameters. + + :param name: Name of entity to search. + :type name: str + :param external_id: External ID of entity to search. + :type external_id: str + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._find_serving_environment_serialize( + name=name, + external_id=external_id, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "ServingEnvironment", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + return response_data.response + + def _find_serving_environment_serialize( + self, + name, + external_id, + _request_auth, + _content_type, + _headers, + _host_index, + ) -> RequestSerialized: + + _host = None + + _collection_formats: dict[str, str] = {} + + _path_params: dict[str, str] = {} + _query_params: list[tuple[str, str]] = [] + _header_params: dict[str, Optional[str]] = _headers or {} + _form_params: list[tuple[str, str]] = [] + _files: dict[str, Union[str, bytes]] = {} + _body_params: Optional[bytes] = None + + # process the path parameters + # process the query parameters + if name is not None: + + _query_params.append(("name", name)) + + if external_id is not None: + + _query_params.append(("externalId", external_id)) + + # process the header parameters + # process the form parameters + # process the body parameter + + # set the HTTP header `Accept` + _header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) + + # authentication setting + _auth_settings: list[str] = ["Bearer"] + + return self.api_client.param_serialize( + method="GET", + resource_path="/api/model_registry/v1alpha3/serving_environment", + path_params=_path_params, + query_params=_query_params, + header_params=_header_params, + body=_body_params, + post_params=_form_params, + files=_files, + auth_settings=_auth_settings, + collection_formats=_collection_formats, + _host=_host, + _request_auth=_request_auth, + ) + + @validate_call + async def get_artifact( + self, + id: Annotated[StrictStr, Field(description="A unique identifier for an `Artifact`.")], + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> Artifact: + """Get an Artifact. + + Gets the details of a single instance of an `Artifact`. + + :param id: A unique identifier for an `Artifact`. (required) + :type id: str + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._get_artifact_serialize( + id=id, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "Artifact", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ).data + + @validate_call + async def get_artifact_with_http_info( + self, + id: Annotated[StrictStr, Field(description="A unique identifier for an `Artifact`.")], + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> ApiResponse[Artifact]: + """Get an Artifact. + + Gets the details of a single instance of an `Artifact`. + + :param id: A unique identifier for an `Artifact`. (required) + :type id: str + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._get_artifact_serialize( + id=id, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "Artifact", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ) + + @validate_call + async def get_artifact_without_preload_content( + self, + id: Annotated[StrictStr, Field(description="A unique identifier for an `Artifact`.")], + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> RESTResponseType: + """Get an Artifact. + + Gets the details of a single instance of an `Artifact`. + + :param id: A unique identifier for an `Artifact`. (required) + :type id: str + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._get_artifact_serialize( + id=id, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, _host_index=_host_index + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "Artifact", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + return response_data.response + + def _get_artifact_serialize( + self, + id, + _request_auth, + _content_type, + _headers, + _host_index, + ) -> RequestSerialized: + + _host = None + + _collection_formats: dict[str, str] = {} + + _path_params: dict[str, str] = {} + _query_params: list[tuple[str, str]] = [] + _header_params: dict[str, Optional[str]] = _headers or {} + _form_params: list[tuple[str, str]] = [] + _files: dict[str, Union[str, bytes]] = {} + _body_params: Optional[bytes] = None + + # process the path parameters + if id is not None: + _path_params["id"] = id + # process the query parameters + # process the header parameters + # process the form parameters + # process the body parameter + + # set the HTTP header `Accept` + _header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) + + # authentication setting + _auth_settings: list[str] = ["Bearer"] + + return self.api_client.param_serialize( + method="GET", + resource_path="/api/model_registry/v1alpha3/artifacts/{id}", + path_params=_path_params, + query_params=_query_params, + header_params=_header_params, + body=_body_params, + post_params=_form_params, + files=_files, + auth_settings=_auth_settings, + collection_formats=_collection_formats, + _host=_host, + _request_auth=_request_auth, + ) + + @validate_call + async def get_artifacts( + self, + page_size: Annotated[Optional[StrictStr], Field(description="Number of entities in each page.")] = None, + order_by: Annotated[ + Optional[OrderByField], Field(description="Specifies the order by criteria for listing entities.") + ] = None, + sort_order: Annotated[ + Optional[SortOrder], Field(description="Specifies the sort order for listing entities, defaults to ASC.") + ] = None, + next_page_token: Annotated[ + Optional[StrictStr], Field(description="Token to use to retrieve next page of results.") + ] = None, + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> ArtifactList: + """List All Artifacts. + + Gets a list of all `Artifact` entities. + + :param page_size: Number of entities in each page. + :type page_size: str + :param order_by: Specifies the order by criteria for listing entities. + :type order_by: OrderByField + :param sort_order: Specifies the sort order for listing entities, defaults to ASC. + :type sort_order: SortOrder + :param next_page_token: Token to use to retrieve next page of results. + :type next_page_token: str + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._get_artifacts_serialize( + page_size=page_size, + order_by=order_by, + sort_order=sort_order, + next_page_token=next_page_token, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "ArtifactList", + "400": "Error", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ).data + + @validate_call + async def get_artifacts_with_http_info( + self, + page_size: Annotated[Optional[StrictStr], Field(description="Number of entities in each page.")] = None, + order_by: Annotated[ + Optional[OrderByField], Field(description="Specifies the order by criteria for listing entities.") + ] = None, + sort_order: Annotated[ + Optional[SortOrder], Field(description="Specifies the sort order for listing entities, defaults to ASC.") + ] = None, + next_page_token: Annotated[ + Optional[StrictStr], Field(description="Token to use to retrieve next page of results.") + ] = None, + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> ApiResponse[ArtifactList]: + """List All Artifacts. + + Gets a list of all `Artifact` entities. + + :param page_size: Number of entities in each page. + :type page_size: str + :param order_by: Specifies the order by criteria for listing entities. + :type order_by: OrderByField + :param sort_order: Specifies the sort order for listing entities, defaults to ASC. + :type sort_order: SortOrder + :param next_page_token: Token to use to retrieve next page of results. + :type next_page_token: str + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._get_artifacts_serialize( + page_size=page_size, + order_by=order_by, + sort_order=sort_order, + next_page_token=next_page_token, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "ArtifactList", + "400": "Error", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ) + + @validate_call + async def get_artifacts_without_preload_content( + self, + page_size: Annotated[Optional[StrictStr], Field(description="Number of entities in each page.")] = None, + order_by: Annotated[ + Optional[OrderByField], Field(description="Specifies the order by criteria for listing entities.") + ] = None, + sort_order: Annotated[ + Optional[SortOrder], Field(description="Specifies the sort order for listing entities, defaults to ASC.") + ] = None, + next_page_token: Annotated[ + Optional[StrictStr], Field(description="Token to use to retrieve next page of results.") + ] = None, + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> RESTResponseType: + """List All Artifacts. + + Gets a list of all `Artifact` entities. + + :param page_size: Number of entities in each page. + :type page_size: str + :param order_by: Specifies the order by criteria for listing entities. + :type order_by: OrderByField + :param sort_order: Specifies the sort order for listing entities, defaults to ASC. + :type sort_order: SortOrder + :param next_page_token: Token to use to retrieve next page of results. + :type next_page_token: str + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._get_artifacts_serialize( + page_size=page_size, + order_by=order_by, + sort_order=sort_order, + next_page_token=next_page_token, _request_auth=_request_auth, _content_type=_content_type, _headers=_headers, @@ -3328,7 +4358,8 @@ async def find_serving_environment_without_preload_content( ) _response_types_map: dict[str, Optional[str]] = { - "200": "ServingEnvironment", + "200": "ArtifactList", + "400": "Error", "401": "Error", "404": "Error", "500": "Error", @@ -3336,10 +4367,12 @@ async def find_serving_environment_without_preload_content( response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) return response_data.response - def _find_serving_environment_serialize( + def _get_artifacts_serialize( self, - name, - external_id, + page_size, + order_by, + sort_order, + next_page_token, _request_auth, _content_type, _headers, @@ -3359,13 +4392,21 @@ def _find_serving_environment_serialize( # process the path parameters # process the query parameters - if name is not None: + if page_size is not None: - _query_params.append(("name", name)) + _query_params.append(("pageSize", page_size)) - if external_id is not None: + if order_by is not None: - _query_params.append(("externalId", external_id)) + _query_params.append(("orderBy", order_by.value)) + + if sort_order is not None: + + _query_params.append(("sortOrder", sort_order.value)) + + if next_page_token is not None: + + _query_params.append(("nextPageToken", next_page_token)) # process the header parameters # process the form parameters @@ -3379,7 +4420,7 @@ def _find_serving_environment_serialize( return self.api_client.param_serialize( method="GET", - resource_path="/api/model_registry/v1alpha3/serving_environment", + resource_path="/api/model_registry/v1alpha3/artifacts", path_params=_path_params, query_params=_query_params, header_params=_header_params, @@ -7926,6 +8967,263 @@ def _get_serving_environments_serialize( _request_auth=_request_auth, ) + @validate_call + async def update_artifact( + self, + id: Annotated[StrictStr, Field(description="A unique identifier for an `Artifact`.")], + artifact_update: Annotated[ArtifactUpdate, Field(description="Updated `Artifact` information.")], + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> Artifact: + """Update an Artifact. + + Updates an existing `Artifact`. + + :param id: A unique identifier for an `Artifact`. (required) + :type id: str + :param artifact_update: Updated `Artifact` information. (required) + :type artifact_update: ArtifactUpdate + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._update_artifact_serialize( + id=id, + artifact_update=artifact_update, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "Artifact", + "400": "Error", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ).data + + @validate_call + async def update_artifact_with_http_info( + self, + id: Annotated[StrictStr, Field(description="A unique identifier for an `Artifact`.")], + artifact_update: Annotated[ArtifactUpdate, Field(description="Updated `Artifact` information.")], + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> ApiResponse[Artifact]: + """Update an Artifact. + + Updates an existing `Artifact`. + + :param id: A unique identifier for an `Artifact`. (required) + :type id: str + :param artifact_update: Updated `Artifact` information. (required) + :type artifact_update: ArtifactUpdate + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._update_artifact_serialize( + id=id, + artifact_update=artifact_update, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "Artifact", + "400": "Error", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + await response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ) + + @validate_call + async def update_artifact_without_preload_content( + self, + id: Annotated[StrictStr, Field(description="A unique identifier for an `Artifact`.")], + artifact_update: Annotated[ArtifactUpdate, Field(description="Updated `Artifact` information.")], + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + tuple[Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]], + ] = None, + _request_auth: Optional[dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> RESTResponseType: + """Update an Artifact. + + Updates an existing `Artifact`. + + :param id: A unique identifier for an `Artifact`. (required) + :type id: str + :param artifact_update: Updated `Artifact` information. (required) + :type artifact_update: ArtifactUpdate + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + _param = self._update_artifact_serialize( + id=id, + artifact_update=artifact_update, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index, + ) + + _response_types_map: dict[str, Optional[str]] = { + "200": "Artifact", + "400": "Error", + "401": "Error", + "404": "Error", + "500": "Error", + } + response_data = await self.api_client.call_api(*_param, _request_timeout=_request_timeout) + return response_data.response + + def _update_artifact_serialize( + self, + id, + artifact_update, + _request_auth, + _content_type, + _headers, + _host_index, + ) -> RequestSerialized: + + _host = None + + _collection_formats: dict[str, str] = {} + + _path_params: dict[str, str] = {} + _query_params: list[tuple[str, str]] = [] + _header_params: dict[str, Optional[str]] = _headers or {} + _form_params: list[tuple[str, str]] = [] + _files: dict[str, Union[str, bytes]] = {} + _body_params: Optional[bytes] = None + + # process the path parameters + if id is not None: + _path_params["id"] = id + # process the query parameters + # process the header parameters + # process the form parameters + # process the body parameter + if artifact_update is not None: + _body_params = artifact_update + + # set the HTTP header `Accept` + _header_params["Accept"] = self.api_client.select_header_accept(["application/json"]) + + # set the HTTP header `Content-Type` + if _content_type: + _header_params["Content-Type"] = _content_type + else: + _default_content_type = self.api_client.select_header_content_type(["application/json"]) + if _default_content_type is not None: + _header_params["Content-Type"] = _default_content_type + + # authentication setting + _auth_settings: list[str] = ["Bearer"] + + return self.api_client.param_serialize( + method="PATCH", + resource_path="/api/model_registry/v1alpha3/artifacts/{id}", + path_params=_path_params, + query_params=_query_params, + header_params=_header_params, + body=_body_params, + post_params=_form_params, + files=_files, + auth_settings=_auth_settings, + collection_formats=_collection_formats, + _host=_host, + _request_auth=_request_auth, + ) + @validate_call async def update_inference_service( self, diff --git a/clients/python/src/mr_openapi/models/__init__.py b/clients/python/src/mr_openapi/models/__init__.py index ceb6da35..c4e8bd49 100644 --- a/clients/python/src/mr_openapi/models/__init__.py +++ b/clients/python/src/mr_openapi/models/__init__.py @@ -15,8 +15,10 @@ # import models into model package from mr_openapi.models.artifact import Artifact +from mr_openapi.models.artifact_create import ArtifactCreate from mr_openapi.models.artifact_list import ArtifactList from mr_openapi.models.artifact_state import ArtifactState +from mr_openapi.models.artifact_update import ArtifactUpdate from mr_openapi.models.base_artifact import BaseArtifact from mr_openapi.models.base_artifact_create import BaseArtifactCreate from mr_openapi.models.base_artifact_update import BaseArtifactUpdate @@ -28,6 +30,8 @@ from mr_openapi.models.base_resource_list import BaseResourceList from mr_openapi.models.base_resource_update import BaseResourceUpdate from mr_openapi.models.doc_artifact import DocArtifact +from mr_openapi.models.doc_artifact_create import DocArtifactCreate +from mr_openapi.models.doc_artifact_update import DocArtifactUpdate from mr_openapi.models.error import Error from mr_openapi.models.execution_state import ExecutionState from mr_openapi.models.inference_service import InferenceService diff --git a/clients/python/src/mr_openapi/models/artifact_create.py b/clients/python/src/mr_openapi/models/artifact_create.py new file mode 100644 index 00000000..dfae6989 --- /dev/null +++ b/clients/python/src/mr_openapi/models/artifact_create.py @@ -0,0 +1,174 @@ +"""Model Registry REST API. + +REST API for Model Registry to create and manage ML model metadata + +The version of the OpenAPI document: v1alpha3 +Generated by OpenAPI Generator (https://openapi-generator.tech) + +Do not edit the class manually. +""" # noqa: E501 + +from __future__ import annotations + +import json +import pprint +from typing import Any + +from pydantic import ( + BaseModel, + ConfigDict, + ValidationError, + field_validator, +) +from typing_extensions import Self + +from mr_openapi.models.doc_artifact_create import DocArtifactCreate +from mr_openapi.models.model_artifact_create import ModelArtifactCreate + +ARTIFACTCREATE_ONE_OF_SCHEMAS = ["DocArtifactCreate", "ModelArtifactCreate"] + + +class ArtifactCreate(BaseModel): + """An Artifact to be created.""" + + # data type: ModelArtifactCreate + oneof_schema_1_validator: ModelArtifactCreate | None = None + # data type: DocArtifactCreate + oneof_schema_2_validator: DocArtifactCreate | None = None + actual_instance: DocArtifactCreate | ModelArtifactCreate | None = None + one_of_schemas: set[str] = {"DocArtifactCreate", "ModelArtifactCreate"} + + model_config = ConfigDict( + validate_assignment=True, + protected_namespaces=(), + ) + + discriminator_value_class_map: dict[str, str] = {} + + def __init__(self, *args, **kwargs) -> None: + if args: + if len(args) > 1: + msg = "If a position argument is used, only 1 is allowed to set `actual_instance`" + raise ValueError(msg) + if kwargs: + msg = "If a position argument is used, keyword arguments cannot be used." + raise ValueError(msg) + super().__init__(actual_instance=args[0]) + else: + super().__init__(**kwargs) + + @field_validator("actual_instance") + def actual_instance_must_validate_oneof(cls, v): + ArtifactCreate.model_construct() + error_messages = [] + match = 0 + # validate data type: ModelArtifactCreate + if not isinstance(v, ModelArtifactCreate): + error_messages.append(f"Error! Input type `{type(v)}` is not `ModelArtifactCreate`") + else: + match += 1 + # validate data type: DocArtifactCreate + if not isinstance(v, DocArtifactCreate): + error_messages.append(f"Error! Input type `{type(v)}` is not `DocArtifactCreate`") + else: + match += 1 + if match > 1: + # more than 1 match + raise ValueError( + "Multiple matches found when setting `actual_instance` in ArtifactCreate with oneOf schemas: DocArtifactCreate, ModelArtifactCreate. Details: " + + ", ".join(error_messages) + ) + if match == 0: + # no match + raise ValueError( + "No match found when setting `actual_instance` in ArtifactCreate with oneOf schemas: DocArtifactCreate, ModelArtifactCreate. Details: " + + ", ".join(error_messages) + ) + return v + + @classmethod + def from_dict(cls, obj: str | dict[str, Any]) -> Self: + return cls.from_json(json.dumps(obj)) + + @classmethod + def from_json(cls, json_str: str) -> Self: + """Returns the object represented by the json string.""" + instance = cls.model_construct() + error_messages = [] + match = 0 + + # use oneOf discriminator to lookup the data type + _data_type = json.loads(json_str).get("artifactType") + if not _data_type: + msg = "Failed to lookup data type from the field `artifactType` in the input." + raise ValueError(msg) + + # check if data type is `DocArtifactCreate` + if _data_type == "doc-artifact": + instance.actual_instance = DocArtifactCreate.from_json(json_str) + return instance + + # check if data type is `ModelArtifactCreate` + if _data_type == "model-artifact": + instance.actual_instance = ModelArtifactCreate.from_json(json_str) + return instance + + # check if data type is `DocArtifactCreate` + if _data_type == "DocArtifactCreate": + instance.actual_instance = DocArtifactCreate.from_json(json_str) + return instance + + # check if data type is `ModelArtifactCreate` + if _data_type == "ModelArtifactCreate": + instance.actual_instance = ModelArtifactCreate.from_json(json_str) + return instance + + # deserialize data into ModelArtifactCreate + try: + instance.actual_instance = ModelArtifactCreate.from_json(json_str) + match += 1 + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + # deserialize data into DocArtifactCreate + try: + instance.actual_instance = DocArtifactCreate.from_json(json_str) + match += 1 + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + + if match > 1: + # more than 1 match + raise ValueError( + "Multiple matches found when deserializing the JSON string into ArtifactCreate with oneOf schemas: DocArtifactCreate, ModelArtifactCreate. Details: " + + ", ".join(error_messages) + ) + if match == 0: + # no match + raise ValueError( + "No match found when deserializing the JSON string into ArtifactCreate with oneOf schemas: DocArtifactCreate, ModelArtifactCreate. Details: " + + ", ".join(error_messages) + ) + return instance + + def to_json(self) -> str: + """Returns the JSON representation of the actual instance.""" + if self.actual_instance is None: + return "null" + + if hasattr(self.actual_instance, "to_json") and callable(self.actual_instance.to_json): + return self.actual_instance.to_json() + return json.dumps(self.actual_instance) + + def to_dict(self) -> dict[str, Any] | DocArtifactCreate | ModelArtifactCreate | None: + """Returns the dict representation of the actual instance.""" + if self.actual_instance is None: + return None + + if hasattr(self.actual_instance, "to_dict") and callable(self.actual_instance.to_dict): + return self.actual_instance.to_dict() + # primitive type + return self.actual_instance + + def to_str(self) -> str: + """Returns the string representation of the actual instance.""" + return pprint.pformat(self.model_dump()) diff --git a/clients/python/src/mr_openapi/models/artifact_update.py b/clients/python/src/mr_openapi/models/artifact_update.py new file mode 100644 index 00000000..8484d0a2 --- /dev/null +++ b/clients/python/src/mr_openapi/models/artifact_update.py @@ -0,0 +1,174 @@ +"""Model Registry REST API. + +REST API for Model Registry to create and manage ML model metadata + +The version of the OpenAPI document: v1alpha3 +Generated by OpenAPI Generator (https://openapi-generator.tech) + +Do not edit the class manually. +""" # noqa: E501 + +from __future__ import annotations + +import json +import pprint +from typing import Any + +from pydantic import ( + BaseModel, + ConfigDict, + ValidationError, + field_validator, +) +from typing_extensions import Self + +from mr_openapi.models.doc_artifact_update import DocArtifactUpdate +from mr_openapi.models.model_artifact_update import ModelArtifactUpdate + +ARTIFACTUPDATE_ONE_OF_SCHEMAS = ["DocArtifactUpdate", "ModelArtifactUpdate"] + + +class ArtifactUpdate(BaseModel): + """An Artifact to be updated.""" + + # data type: ModelArtifactUpdate + oneof_schema_1_validator: ModelArtifactUpdate | None = None + # data type: DocArtifactUpdate + oneof_schema_2_validator: DocArtifactUpdate | None = None + actual_instance: DocArtifactUpdate | ModelArtifactUpdate | None = None + one_of_schemas: set[str] = {"DocArtifactUpdate", "ModelArtifactUpdate"} + + model_config = ConfigDict( + validate_assignment=True, + protected_namespaces=(), + ) + + discriminator_value_class_map: dict[str, str] = {} + + def __init__(self, *args, **kwargs) -> None: + if args: + if len(args) > 1: + msg = "If a position argument is used, only 1 is allowed to set `actual_instance`" + raise ValueError(msg) + if kwargs: + msg = "If a position argument is used, keyword arguments cannot be used." + raise ValueError(msg) + super().__init__(actual_instance=args[0]) + else: + super().__init__(**kwargs) + + @field_validator("actual_instance") + def actual_instance_must_validate_oneof(cls, v): + ArtifactUpdate.model_construct() + error_messages = [] + match = 0 + # validate data type: ModelArtifactUpdate + if not isinstance(v, ModelArtifactUpdate): + error_messages.append(f"Error! Input type `{type(v)}` is not `ModelArtifactUpdate`") + else: + match += 1 + # validate data type: DocArtifactUpdate + if not isinstance(v, DocArtifactUpdate): + error_messages.append(f"Error! Input type `{type(v)}` is not `DocArtifactUpdate`") + else: + match += 1 + if match > 1: + # more than 1 match + raise ValueError( + "Multiple matches found when setting `actual_instance` in ArtifactUpdate with oneOf schemas: DocArtifactUpdate, ModelArtifactUpdate. Details: " + + ", ".join(error_messages) + ) + if match == 0: + # no match + raise ValueError( + "No match found when setting `actual_instance` in ArtifactUpdate with oneOf schemas: DocArtifactUpdate, ModelArtifactUpdate. Details: " + + ", ".join(error_messages) + ) + return v + + @classmethod + def from_dict(cls, obj: str | dict[str, Any]) -> Self: + return cls.from_json(json.dumps(obj)) + + @classmethod + def from_json(cls, json_str: str) -> Self: + """Returns the object represented by the json string.""" + instance = cls.model_construct() + error_messages = [] + match = 0 + + # use oneOf discriminator to lookup the data type + _data_type = json.loads(json_str).get("artifactType") + if not _data_type: + msg = "Failed to lookup data type from the field `artifactType` in the input." + raise ValueError(msg) + + # check if data type is `DocArtifactUpdate` + if _data_type == "doc-artifact": + instance.actual_instance = DocArtifactUpdate.from_json(json_str) + return instance + + # check if data type is `ModelArtifactUpdate` + if _data_type == "model-artifact": + instance.actual_instance = ModelArtifactUpdate.from_json(json_str) + return instance + + # check if data type is `DocArtifactUpdate` + if _data_type == "DocArtifactUpdate": + instance.actual_instance = DocArtifactUpdate.from_json(json_str) + return instance + + # check if data type is `ModelArtifactUpdate` + if _data_type == "ModelArtifactUpdate": + instance.actual_instance = ModelArtifactUpdate.from_json(json_str) + return instance + + # deserialize data into ModelArtifactUpdate + try: + instance.actual_instance = ModelArtifactUpdate.from_json(json_str) + match += 1 + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + # deserialize data into DocArtifactUpdate + try: + instance.actual_instance = DocArtifactUpdate.from_json(json_str) + match += 1 + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + + if match > 1: + # more than 1 match + raise ValueError( + "Multiple matches found when deserializing the JSON string into ArtifactUpdate with oneOf schemas: DocArtifactUpdate, ModelArtifactUpdate. Details: " + + ", ".join(error_messages) + ) + if match == 0: + # no match + raise ValueError( + "No match found when deserializing the JSON string into ArtifactUpdate with oneOf schemas: DocArtifactUpdate, ModelArtifactUpdate. Details: " + + ", ".join(error_messages) + ) + return instance + + def to_json(self) -> str: + """Returns the JSON representation of the actual instance.""" + if self.actual_instance is None: + return "null" + + if hasattr(self.actual_instance, "to_json") and callable(self.actual_instance.to_json): + return self.actual_instance.to_json() + return json.dumps(self.actual_instance) + + def to_dict(self) -> dict[str, Any] | DocArtifactUpdate | ModelArtifactUpdate | None: + """Returns the dict representation of the actual instance.""" + if self.actual_instance is None: + return None + + if hasattr(self.actual_instance, "to_dict") and callable(self.actual_instance.to_dict): + return self.actual_instance.to_dict() + # primitive type + return self.actual_instance + + def to_str(self) -> str: + """Returns the string representation of the actual instance.""" + return pprint.pformat(self.model_dump()) diff --git a/clients/python/src/mr_openapi/models/doc_artifact.py b/clients/python/src/mr_openapi/models/doc_artifact.py index a45fdd54..631a7e0e 100644 --- a/clients/python/src/mr_openapi/models/doc_artifact.py +++ b/clients/python/src/mr_openapi/models/doc_artifact.py @@ -67,6 +67,7 @@ class DocArtifact(BaseModel): "id", "createTimeSinceEpoch", "lastUpdateTimeSinceEpoch", + "artifactType", ] model_config = ConfigDict( @@ -145,5 +146,6 @@ def from_dict(cls, obj: dict[str, Any] | None) -> Self | None: "id": obj.get("id"), "createTimeSinceEpoch": obj.get("createTimeSinceEpoch"), "lastUpdateTimeSinceEpoch": obj.get("lastUpdateTimeSinceEpoch"), + "artifactType": obj.get("artifactType") if obj.get("artifactType") is not None else "doc-artifact", } ) diff --git a/clients/python/src/mr_openapi/models/doc_artifact_create.py b/clients/python/src/mr_openapi/models/doc_artifact_create.py new file mode 100644 index 00000000..6537960d --- /dev/null +++ b/clients/python/src/mr_openapi/models/doc_artifact_create.py @@ -0,0 +1,128 @@ +"""Model Registry REST API. + +REST API for Model Registry to create and manage ML model metadata + +The version of the OpenAPI document: v1alpha3 +Generated by OpenAPI Generator (https://openapi-generator.tech) + +Do not edit the class manually. +""" # noqa: E501 + +from __future__ import annotations + +import json +import pprint +import re # noqa: F401 +from typing import Any, ClassVar + +from pydantic import BaseModel, ConfigDict, Field, StrictStr +from typing_extensions import Self + +from mr_openapi.models.artifact_state import ArtifactState +from mr_openapi.models.metadata_value import MetadataValue + + +class DocArtifactCreate(BaseModel): + """A document artifact to be created.""" # noqa: E501 + + artifact_type: StrictStr = Field(alias="artifactType") + custom_properties: dict[str, MetadataValue] | None = Field( + default=None, + description="User provided custom properties which are not defined by its type.", + alias="customProperties", + ) + description: StrictStr | None = Field(default=None, description="An optional description about the resource.") + external_id: StrictStr | None = Field( + default=None, + description="The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance.", + alias="externalId", + ) + uri: StrictStr | None = Field( + default=None, + description="The uniform resource identifier of the physical artifact. May be empty if there is no physical artifact.", + ) + state: ArtifactState | None = None + name: StrictStr | None = Field( + default=None, + description="The client provided name of the artifact. This field is optional. If set, it must be unique among all the artifacts of the same artifact type within a database instance and cannot be changed once set.", + ) + __properties: ClassVar[list[str]] = [ + "customProperties", + "description", + "externalId", + "uri", + "state", + "name", + "artifactType", + ] + + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + protected_namespaces=(), + ) + + def to_str(self) -> str: + """Returns the string representation of the model using alias.""" + return pprint.pformat(self.model_dump(by_alias=True)) + + def to_json(self) -> str: + """Returns the JSON representation of the model using alias.""" + # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> Self | None: + """Create an instance of DocArtifactCreate from a JSON string.""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self) -> dict[str, Any]: + """Return the dictionary representation of the model using alias. + + This has the following differences from calling pydantic's + `self.model_dump(by_alias=True)`: + + * `None` is only added to the output dict for nullable fields that + were set at model initialization. Other fields with value `None` + are ignored. + """ + excluded_fields: set[str] = set() + + _dict = self.model_dump( + by_alias=True, + exclude=excluded_fields, + exclude_none=True, + ) + # override the default output from pydantic by calling `to_dict()` of each value in custom_properties (dict) + _field_dict = {} + if self.custom_properties: + for _key in self.custom_properties: + if self.custom_properties[_key]: + _field_dict[_key] = self.custom_properties[_key].to_dict() + _dict["customProperties"] = _field_dict + return _dict + + @classmethod + def from_dict(cls, obj: dict[str, Any] | None) -> Self | None: + """Create an instance of DocArtifactCreate from a dict.""" + if obj is None: + return None + + if not isinstance(obj, dict): + return cls.model_validate(obj) + + return cls.model_validate( + { + "customProperties": ( + {_k: MetadataValue.from_dict(_v) for _k, _v in obj["customProperties"].items()} + if obj.get("customProperties") is not None + else None + ), + "description": obj.get("description"), + "externalId": obj.get("externalId"), + "uri": obj.get("uri"), + "state": obj.get("state"), + "name": obj.get("name"), + "artifactType": obj.get("artifactType") if obj.get("artifactType") is not None else "doc-artifact", + } + ) diff --git a/clients/python/src/mr_openapi/models/doc_artifact_update.py b/clients/python/src/mr_openapi/models/doc_artifact_update.py new file mode 100644 index 00000000..682b2960 --- /dev/null +++ b/clients/python/src/mr_openapi/models/doc_artifact_update.py @@ -0,0 +1,114 @@ +"""Model Registry REST API. + +REST API for Model Registry to create and manage ML model metadata + +The version of the OpenAPI document: v1alpha3 +Generated by OpenAPI Generator (https://openapi-generator.tech) + +Do not edit the class manually. +""" # noqa: E501 + +from __future__ import annotations + +import json +import pprint +import re # noqa: F401 +from typing import Any, ClassVar + +from pydantic import BaseModel, ConfigDict, Field, StrictStr +from typing_extensions import Self + +from mr_openapi.models.artifact_state import ArtifactState +from mr_openapi.models.metadata_value import MetadataValue + + +class DocArtifactUpdate(BaseModel): + """A document artifact to be updated.""" # noqa: E501 + + artifact_type: StrictStr = Field(alias="artifactType") + custom_properties: dict[str, MetadataValue] | None = Field( + default=None, + description="User provided custom properties which are not defined by its type.", + alias="customProperties", + ) + description: StrictStr | None = Field(default=None, description="An optional description about the resource.") + external_id: StrictStr | None = Field( + default=None, + description="The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance.", + alias="externalId", + ) + uri: StrictStr | None = Field( + default=None, + description="The uniform resource identifier of the physical artifact. May be empty if there is no physical artifact.", + ) + state: ArtifactState | None = None + __properties: ClassVar[list[str]] = ["customProperties", "description", "externalId", "uri", "state"] + + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + protected_namespaces=(), + ) + + def to_str(self) -> str: + """Returns the string representation of the model using alias.""" + return pprint.pformat(self.model_dump(by_alias=True)) + + def to_json(self) -> str: + """Returns the JSON representation of the model using alias.""" + # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> Self | None: + """Create an instance of DocArtifactUpdate from a JSON string.""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self) -> dict[str, Any]: + """Return the dictionary representation of the model using alias. + + This has the following differences from calling pydantic's + `self.model_dump(by_alias=True)`: + + * `None` is only added to the output dict for nullable fields that + were set at model initialization. Other fields with value `None` + are ignored. + """ + excluded_fields: set[str] = set() + + _dict = self.model_dump( + by_alias=True, + exclude=excluded_fields, + exclude_none=True, + ) + # override the default output from pydantic by calling `to_dict()` of each value in custom_properties (dict) + _field_dict = {} + if self.custom_properties: + for _key in self.custom_properties: + if self.custom_properties[_key]: + _field_dict[_key] = self.custom_properties[_key].to_dict() + _dict["customProperties"] = _field_dict + return _dict + + @classmethod + def from_dict(cls, obj: dict[str, Any] | None) -> Self | None: + """Create an instance of DocArtifactUpdate from a dict.""" + if obj is None: + return None + + if not isinstance(obj, dict): + return cls.model_validate(obj) + + return cls.model_validate( + { + "customProperties": ( + {_k: MetadataValue.from_dict(_v) for _k, _v in obj["customProperties"].items()} + if obj.get("customProperties") is not None + else None + ), + "description": obj.get("description"), + "externalId": obj.get("externalId"), + "uri": obj.get("uri"), + "state": obj.get("state"), + } + ) diff --git a/clients/python/src/mr_openapi/models/model_artifact.py b/clients/python/src/mr_openapi/models/model_artifact.py index 353aa88b..6c0fa522 100644 --- a/clients/python/src/mr_openapi/models/model_artifact.py +++ b/clients/python/src/mr_openapi/models/model_artifact.py @@ -80,6 +80,7 @@ class ModelArtifact(BaseModel): "id", "createTimeSinceEpoch", "lastUpdateTimeSinceEpoch", + "artifactType", "modelFormatName", "storageKey", "storagePath", @@ -163,6 +164,7 @@ def from_dict(cls, obj: dict[str, Any] | None) -> Self | None: "id": obj.get("id"), "createTimeSinceEpoch": obj.get("createTimeSinceEpoch"), "lastUpdateTimeSinceEpoch": obj.get("lastUpdateTimeSinceEpoch"), + "artifactType": obj.get("artifactType") if obj.get("artifactType") is not None else "model-artifact", "modelFormatName": obj.get("modelFormatName"), "storageKey": obj.get("storageKey"), "storagePath": obj.get("storagePath"), diff --git a/clients/python/src/mr_openapi/models/model_artifact_create.py b/clients/python/src/mr_openapi/models/model_artifact_create.py index bc5a34d7..0ba57c0d 100644 --- a/clients/python/src/mr_openapi/models/model_artifact_create.py +++ b/clients/python/src/mr_openapi/models/model_artifact_create.py @@ -25,6 +25,7 @@ class ModelArtifactCreate(BaseModel): """An ML model artifact.""" # noqa: E501 + artifact_type: StrictStr = Field(alias="artifactType") custom_properties: dict[str, MetadataValue] | None = Field( default=None, description="User provided custom properties which are not defined by its type.", @@ -65,6 +66,7 @@ class ModelArtifactCreate(BaseModel): "uri", "state", "name", + "artifactType", "modelFormatName", "storageKey", "storagePath", @@ -139,6 +141,7 @@ def from_dict(cls, obj: dict[str, Any] | None) -> Self | None: "uri": obj.get("uri"), "state": obj.get("state"), "name": obj.get("name"), + "artifactType": obj.get("artifactType") if obj.get("artifactType") is not None else "model-artifact", "modelFormatName": obj.get("modelFormatName"), "storageKey": obj.get("storageKey"), "storagePath": obj.get("storagePath"), diff --git a/clients/python/src/mr_openapi/models/model_artifact_update.py b/clients/python/src/mr_openapi/models/model_artifact_update.py index 6df6896d..6210b62d 100644 --- a/clients/python/src/mr_openapi/models/model_artifact_update.py +++ b/clients/python/src/mr_openapi/models/model_artifact_update.py @@ -23,8 +23,9 @@ class ModelArtifactUpdate(BaseModel): - """An ML model artifact.""" # noqa: E501 + """An ML model artifact to be updated.""" # noqa: E501 + artifact_type: StrictStr = Field(alias="artifactType") custom_properties: dict[str, MetadataValue] | None = Field( default=None, description="User provided custom properties which are not defined by its type.", diff --git a/internal/converter/generated/openapi_converter.gen.go b/internal/converter/generated/openapi_converter.gen.go index 7d1c8c14..d98c0916 100644 --- a/internal/converter/generated/openapi_converter.gen.go +++ b/internal/converter/generated/openapi_converter.gen.go @@ -11,6 +11,120 @@ import ( type OpenAPIConverterImpl struct{} +func (c *OpenAPIConverterImpl) ConvertArtifactCreate(source *openapi.ArtifactCreate) (*openapi.Artifact, error) { + var pOpenapiArtifact *openapi.Artifact + if source != nil { + var openapiArtifact openapi.Artifact + pOpenapiDocArtifact, err := c.ConvertDocArtifactCreate((*source).DocArtifactCreate) + if err != nil { + return nil, fmt.Errorf("error setting field DocArtifact: %w", err) + } + openapiArtifact.DocArtifact = pOpenapiDocArtifact + pOpenapiModelArtifact, err := c.ConvertModelArtifactCreate((*source).ModelArtifactCreate) + if err != nil { + return nil, fmt.Errorf("error setting field ModelArtifact: %w", err) + } + openapiArtifact.ModelArtifact = pOpenapiModelArtifact + pOpenapiArtifact = &openapiArtifact + } + return pOpenapiArtifact, nil +} +func (c *OpenAPIConverterImpl) ConvertArtifactUpdate(source *openapi.ArtifactUpdate) (*openapi.Artifact, error) { + var pOpenapiArtifact *openapi.Artifact + if source != nil { + var openapiArtifact openapi.Artifact + pOpenapiDocArtifact, err := c.ConvertDocArtifactUpdate((*source).DocArtifactUpdate) + if err != nil { + return nil, fmt.Errorf("error setting field DocArtifact: %w", err) + } + openapiArtifact.DocArtifact = pOpenapiDocArtifact + pOpenapiModelArtifact, err := c.ConvertModelArtifactUpdate((*source).ModelArtifactUpdate) + if err != nil { + return nil, fmt.Errorf("error setting field ModelArtifact: %w", err) + } + openapiArtifact.ModelArtifact = pOpenapiModelArtifact + pOpenapiArtifact = &openapiArtifact + } + return pOpenapiArtifact, nil +} +func (c *OpenAPIConverterImpl) ConvertDocArtifactCreate(source *openapi.DocArtifactCreate) (*openapi.DocArtifact, error) { + var pOpenapiDocArtifact *openapi.DocArtifact + if source != nil { + var openapiDocArtifact openapi.DocArtifact + if (*source).CustomProperties != nil { + var mapStringOpenapiMetadataValue map[string]openapi.MetadataValue + if (*(*source).CustomProperties) != nil { + mapStringOpenapiMetadataValue = make(map[string]openapi.MetadataValue, len((*(*source).CustomProperties))) + for key, value := range *(*source).CustomProperties { + mapStringOpenapiMetadataValue[key] = c.openapiMetadataValueToOpenapiMetadataValue(value) + } + } + openapiDocArtifact.CustomProperties = &mapStringOpenapiMetadataValue + } + if (*source).Description != nil { + xstring := *(*source).Description + openapiDocArtifact.Description = &xstring + } + if (*source).ExternalId != nil { + xstring2 := *(*source).ExternalId + openapiDocArtifact.ExternalId = &xstring2 + } + if (*source).Uri != nil { + xstring3 := *(*source).Uri + openapiDocArtifact.Uri = &xstring3 + } + if (*source).State != nil { + openapiArtifactState, err := c.openapiArtifactStateToOpenapiArtifactState(*(*source).State) + if err != nil { + return nil, fmt.Errorf("error setting field State: %w", err) + } + openapiDocArtifact.State = &openapiArtifactState + } + if (*source).Name != nil { + xstring4 := *(*source).Name + openapiDocArtifact.Name = &xstring4 + } + pOpenapiDocArtifact = &openapiDocArtifact + } + return pOpenapiDocArtifact, nil +} +func (c *OpenAPIConverterImpl) ConvertDocArtifactUpdate(source *openapi.DocArtifactUpdate) (*openapi.DocArtifact, error) { + var pOpenapiDocArtifact *openapi.DocArtifact + if source != nil { + var openapiDocArtifact openapi.DocArtifact + if (*source).CustomProperties != nil { + var mapStringOpenapiMetadataValue map[string]openapi.MetadataValue + if (*(*source).CustomProperties) != nil { + mapStringOpenapiMetadataValue = make(map[string]openapi.MetadataValue, len((*(*source).CustomProperties))) + for key, value := range *(*source).CustomProperties { + mapStringOpenapiMetadataValue[key] = c.openapiMetadataValueToOpenapiMetadataValue(value) + } + } + openapiDocArtifact.CustomProperties = &mapStringOpenapiMetadataValue + } + if (*source).Description != nil { + xstring := *(*source).Description + openapiDocArtifact.Description = &xstring + } + if (*source).ExternalId != nil { + xstring2 := *(*source).ExternalId + openapiDocArtifact.ExternalId = &xstring2 + } + if (*source).Uri != nil { + xstring3 := *(*source).Uri + openapiDocArtifact.Uri = &xstring3 + } + if (*source).State != nil { + openapiArtifactState, err := c.openapiArtifactStateToOpenapiArtifactState(*(*source).State) + if err != nil { + return nil, fmt.Errorf("error setting field State: %w", err) + } + openapiDocArtifact.State = &openapiArtifactState + } + pOpenapiDocArtifact = &openapiDocArtifact + } + return pOpenapiDocArtifact, nil +} func (c *OpenAPIConverterImpl) ConvertInferenceServiceCreate(source *openapi.InferenceServiceCreate) (*openapi.InferenceService, error) { var pOpenapiInferenceService *openapi.InferenceService if source != nil { @@ -495,6 +609,11 @@ func (c *OpenAPIConverterImpl) ConvertServingEnvironmentUpdate(source *openapi.S } return pOpenapiServingEnvironment, nil } +func (c *OpenAPIConverterImpl) OverrideNotEditableForArtifact(source converter.OpenapiUpdateWrapper[openapi.Artifact]) (openapi.Artifact, error) { + openapiArtifact := converter.InitWithUpdate(source) + _ = source + return openapiArtifact, nil +} func (c *OpenAPIConverterImpl) OverrideNotEditableForDocArtifact(source converter.OpenapiUpdateWrapper[openapi.DocArtifact]) (openapi.DocArtifact, error) { openapiDocArtifact := converter.InitWithUpdate(source) _ = source diff --git a/internal/converter/openapi_converter.go b/internal/converter/openapi_converter.go index 60a4c41d..df704ddd 100644 --- a/internal/converter/openapi_converter.go +++ b/internal/converter/openapi_converter.go @@ -26,6 +26,20 @@ type OpenAPIConverter interface { // goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Name RegisteredModelId ConvertModelVersionUpdate(source *openapi.ModelVersionUpdate) (*openapi.ModelVersion, error) + // goverter:map DocArtifactCreate DocArtifact + // goverter:map ModelArtifactCreate ModelArtifact + ConvertArtifactCreate(source *openapi.ArtifactCreate) (*openapi.Artifact, error) + + // goverter:map DocArtifactUpdate DocArtifact + // goverter:map ModelArtifactUpdate ModelArtifact + ConvertArtifactUpdate(source *openapi.ArtifactUpdate) (*openapi.Artifact, error) + + // goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch ArtifactType + ConvertDocArtifactCreate(source *openapi.DocArtifactCreate) (*openapi.DocArtifact, error) + + // goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch ArtifactType Name + ConvertDocArtifactUpdate(source *openapi.DocArtifactUpdate) (*openapi.DocArtifact, error) + // goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch ArtifactType ConvertModelArtifactCreate(source *openapi.ModelArtifactCreate) (*openapi.ModelArtifact, error) @@ -62,6 +76,12 @@ type OpenAPIConverter interface { // goverter:ignore Id CreateTimeSinceEpoch LastUpdateTimeSinceEpoch Description ExternalId CustomProperties State Author OverrideNotEditableForModelVersion(source OpenapiUpdateWrapper[openapi.ModelVersion]) (openapi.ModelVersion, error) + // Ignore all fields that ARE editable + // goverter:default InitWithUpdate + // goverter:autoMap Existing + // goverter:ignore DocArtifact ModelArtifact + OverrideNotEditableForArtifact(source OpenapiUpdateWrapper[openapi.Artifact]) (openapi.Artifact, error) + // Ignore all fields that ARE editable // goverter:default InitWithUpdate // goverter:autoMap Existing diff --git a/internal/converter/openapi_converter_test.go b/internal/converter/openapi_converter_test.go index 4c33c27e..e9869e11 100644 --- a/internal/converter/openapi_converter_test.go +++ b/internal/converter/openapi_converter_test.go @@ -20,7 +20,7 @@ type visitor struct { entities map[string]*oapiEntity } -func newVisitor(t *testing.T, f *ast.File) visitor { +func newVisitor(t *testing.T, _ *ast.File) visitor { return visitor{ t: t, entities: map[string]*oapiEntity{ @@ -45,6 +45,9 @@ func newVisitor(t *testing.T, f *ast.File) visitor { "ServeModel": { obj: openapi.ServeModel{}, }, + "Artifact": { + obj: openapi.Artifact{}, + }, }, } } diff --git a/internal/converter/openapi_converter_util.go b/internal/converter/openapi_converter_util.go index 816fe111..b7a8964f 100644 --- a/internal/converter/openapi_converter_util.go +++ b/internal/converter/openapi_converter_util.go @@ -3,7 +3,8 @@ package converter import "github.com/kubeflow/model-registry/pkg/openapi" type OpenAPIModel interface { - openapi.RegisteredModel | + openapi.Artifact | + openapi.RegisteredModel | openapi.ModelVersion | openapi.ModelArtifact | openapi.DocArtifact | diff --git a/internal/converter/openapi_reconciler_util.go b/internal/converter/openapi_reconciler_util.go new file mode 100644 index 00000000..e2b8544c --- /dev/null +++ b/internal/converter/openapi_reconciler_util.go @@ -0,0 +1,23 @@ +package converter + +import ( + "github.com/kubeflow/model-registry/pkg/openapi" +) + +func UpdateExistingArtifact(genc OpenAPIReconciler, source OpenapiUpdateWrapper[openapi.Artifact]) (openapi.Artifact, error) { + art := InitWithExisting(source) + if source.Update == nil { + return art, nil + } + ma, err := genc.UpdateExistingModelArtifact(OpenapiUpdateWrapper[openapi.ModelArtifact]{Existing: art.ModelArtifact, Update: source.Update.ModelArtifact}) + if err != nil { + return art, err + } + da, err := genc.UpdateExistingDocArtifact(OpenapiUpdateWrapper[openapi.DocArtifact]{Existing: art.DocArtifact, Update: source.Update.DocArtifact}) + if err != nil { + return art, err + } + art.DocArtifact = &da + art.ModelArtifact = &ma + return art, nil +} diff --git a/internal/server/openapi/api.go b/internal/server/openapi/api.go index c62155d2..cc6be125 100644 --- a/internal/server/openapi/api.go +++ b/internal/server/openapi/api.go @@ -65,12 +65,14 @@ type ModelRegistryServiceAPIServicer interface { CreateEnvironmentInferenceService(context.Context, string, model.InferenceServiceCreate) (ImplResponse, error) CreateInferenceService(context.Context, model.InferenceServiceCreate) (ImplResponse, error) CreateInferenceServiceServe(context.Context, string, model.ServeModelCreate) (ImplResponse, error) + CreateArtifact(context.Context, model.ArtifactCreate) (ImplResponse, error) CreateModelArtifact(context.Context, model.ModelArtifactCreate) (ImplResponse, error) CreateModelVersion(context.Context, model.ModelVersionCreate) (ImplResponse, error) CreateRegisteredModel(context.Context, model.RegisteredModelCreate) (ImplResponse, error) CreateRegisteredModelVersion(context.Context, string, model.ModelVersion) (ImplResponse, error) CreateServingEnvironment(context.Context, model.ServingEnvironmentCreate) (ImplResponse, error) FindInferenceService(context.Context, string, string, string) (ImplResponse, error) + FindArtifact(context.Context, string, string, string) (ImplResponse, error) FindModelArtifact(context.Context, string, string, string) (ImplResponse, error) FindModelVersion(context.Context, string, string, string) (ImplResponse, error) FindRegisteredModel(context.Context, string, string) (ImplResponse, error) @@ -81,6 +83,8 @@ type ModelRegistryServiceAPIServicer interface { GetInferenceServiceServes(context.Context, string, string, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) GetInferenceServiceVersion(context.Context, string) (ImplResponse, error) GetInferenceServices(context.Context, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) + GetArtifact(context.Context, string) (ImplResponse, error) + GetArtifacts(context.Context, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) GetModelArtifact(context.Context, string) (ImplResponse, error) GetModelArtifacts(context.Context, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) GetModelVersion(context.Context, string) (ImplResponse, error) @@ -92,6 +96,7 @@ type ModelRegistryServiceAPIServicer interface { GetServingEnvironment(context.Context, string) (ImplResponse, error) GetServingEnvironments(context.Context, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) UpdateInferenceService(context.Context, string, model.InferenceServiceUpdate) (ImplResponse, error) + UpdateArtifact(context.Context, string, model.ArtifactUpdate) (ImplResponse, error) UpdateModelArtifact(context.Context, string, model.ModelArtifactUpdate) (ImplResponse, error) UpdateModelVersion(context.Context, string, model.ModelVersionUpdate) (ImplResponse, error) UpdateRegisteredModel(context.Context, string, model.RegisteredModelUpdate) (ImplResponse, error) diff --git a/internal/server/openapi/api_model_registry_service.go b/internal/server/openapi/api_model_registry_service.go index e4134e30..58448254 100644 --- a/internal/server/openapi/api_model_registry_service.go +++ b/internal/server/openapi/api_model_registry_service.go @@ -52,6 +52,11 @@ func NewModelRegistryServiceAPIController(s ModelRegistryServiceAPIServicer, opt // Routes returns all the api routes for the ModelRegistryServiceAPIController func (c *ModelRegistryServiceAPIController) Routes() Routes { return Routes{ + "CreateArtifact": Route{ + strings.ToUpper("Post"), + "/api/model_registry/v1alpha3/artifacts", + c.CreateArtifact, + }, "CreateEnvironmentInferenceService": Route{ strings.ToUpper("Post"), "/api/model_registry/v1alpha3/serving_environments/{servingenvironmentId}/inference_services", @@ -92,6 +97,11 @@ func (c *ModelRegistryServiceAPIController) Routes() Routes { "/api/model_registry/v1alpha3/serving_environments", c.CreateServingEnvironment, }, + "FindArtifact": Route{ + strings.ToUpper("Get"), + "/api/model_registry/v1alpha3/artifact", + c.FindArtifact, + }, "FindInferenceService": Route{ strings.ToUpper("Get"), "/api/model_registry/v1alpha3/inference_service", @@ -117,6 +127,16 @@ func (c *ModelRegistryServiceAPIController) Routes() Routes { "/api/model_registry/v1alpha3/serving_environment", c.FindServingEnvironment, }, + "GetArtifact": Route{ + strings.ToUpper("Get"), + "/api/model_registry/v1alpha3/artifacts/{id}", + c.GetArtifact, + }, + "GetArtifacts": Route{ + strings.ToUpper("Get"), + "/api/model_registry/v1alpha3/artifacts", + c.GetArtifacts, + }, "GetEnvironmentInferenceServices": Route{ strings.ToUpper("Get"), "/api/model_registry/v1alpha3/serving_environments/{servingenvironmentId}/inference_services", @@ -197,6 +217,11 @@ func (c *ModelRegistryServiceAPIController) Routes() Routes { "/api/model_registry/v1alpha3/serving_environments", c.GetServingEnvironments, }, + "UpdateArtifact": Route{ + strings.ToUpper("Patch"), + "/api/model_registry/v1alpha3/artifacts/{id}", + c.UpdateArtifact, + }, "UpdateInferenceService": Route{ strings.ToUpper("Patch"), "/api/model_registry/v1alpha3/inference_services/{inferenceserviceId}", @@ -230,6 +255,33 @@ func (c *ModelRegistryServiceAPIController) Routes() Routes { } } +// CreateArtifact - Create an Artifact +func (c *ModelRegistryServiceAPIController) CreateArtifact(w http.ResponseWriter, r *http.Request) { + artifactCreateParam := model.ArtifactCreate{} + d := json.NewDecoder(r.Body) + d.DisallowUnknownFields() + if err := d.Decode(&artifactCreateParam); err != nil { + c.errorHandler(w, r, &ParsingError{Err: err}, nil) + return + } + if err := AssertArtifactCreateRequired(artifactCreateParam); err != nil { + c.errorHandler(w, r, err, nil) + return + } + if err := AssertArtifactCreateConstraints(artifactCreateParam); err != nil { + c.errorHandler(w, r, err, nil) + return + } + result, err := c.service.CreateArtifact(r.Context(), artifactCreateParam) + // If an error occurred, encode the error with the status code + if err != nil { + c.errorHandler(w, r, err, &result) + return + } + // If no error, encode the body and the result code + EncodeJSONResponse(result.Body, &result.Code, w) +} + // CreateEnvironmentInferenceService - Create a InferenceService in ServingEnvironment func (c *ModelRegistryServiceAPIController) CreateEnvironmentInferenceService(w http.ResponseWriter, r *http.Request) { servingenvironmentIdParam := chi.URLParam(r, "servingenvironmentId") @@ -449,6 +501,22 @@ func (c *ModelRegistryServiceAPIController) CreateServingEnvironment(w http.Resp EncodeJSONResponse(result.Body, &result.Code, w) } +// FindArtifact - Get an Artifact that matches search parameters. +func (c *ModelRegistryServiceAPIController) FindArtifact(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + nameParam := query.Get("name") + externalIdParam := query.Get("externalId") + parentResourceIdParam := query.Get("parentResourceId") + result, err := c.service.FindArtifact(r.Context(), nameParam, externalIdParam, parentResourceIdParam) + // If an error occurred, encode the error with the status code + if err != nil { + c.errorHandler(w, r, err, &result) + return + } + // If no error, encode the body and the result code + EncodeJSONResponse(result.Body, &result.Code, w) +} + // FindInferenceService - Get an InferenceServices that matches search parameters. func (c *ModelRegistryServiceAPIController) FindInferenceService(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() @@ -527,6 +595,36 @@ func (c *ModelRegistryServiceAPIController) FindServingEnvironment(w http.Respon EncodeJSONResponse(result.Body, &result.Code, w) } +// GetArtifact - Get an Artifact +func (c *ModelRegistryServiceAPIController) GetArtifact(w http.ResponseWriter, r *http.Request) { + idParam := chi.URLParam(r, "id") + result, err := c.service.GetArtifact(r.Context(), idParam) + // If an error occurred, encode the error with the status code + if err != nil { + c.errorHandler(w, r, err, &result) + return + } + // If no error, encode the body and the result code + EncodeJSONResponse(result.Body, &result.Code, w) +} + +// GetArtifacts - List All Artifacts +func (c *ModelRegistryServiceAPIController) GetArtifacts(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + pageSizeParam := query.Get("pageSize") + orderByParam := query.Get("orderBy") + sortOrderParam := query.Get("sortOrder") + nextPageTokenParam := query.Get("nextPageToken") + result, err := c.service.GetArtifacts(r.Context(), pageSizeParam, model.OrderByField(orderByParam), model.SortOrder(sortOrderParam), nextPageTokenParam) + // If an error occurred, encode the error with the status code + if err != nil { + c.errorHandler(w, r, err, &result) + return + } + // If no error, encode the body and the result code + EncodeJSONResponse(result.Body, &result.Code, w) +} + // GetEnvironmentInferenceServices - List All ServingEnvironment's InferenceServices func (c *ModelRegistryServiceAPIController) GetEnvironmentInferenceServices(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() @@ -783,6 +881,34 @@ func (c *ModelRegistryServiceAPIController) GetServingEnvironments(w http.Respon EncodeJSONResponse(result.Body, &result.Code, w) } +// UpdateArtifact - Update an Artifact +func (c *ModelRegistryServiceAPIController) UpdateArtifact(w http.ResponseWriter, r *http.Request) { + idParam := chi.URLParam(r, "id") + artifactUpdateParam := model.ArtifactUpdate{} + d := json.NewDecoder(r.Body) + d.DisallowUnknownFields() + if err := d.Decode(&artifactUpdateParam); err != nil { + c.errorHandler(w, r, &ParsingError{Err: err}, nil) + return + } + if err := AssertArtifactUpdateRequired(artifactUpdateParam); err != nil { + c.errorHandler(w, r, err, nil) + return + } + if err := AssertArtifactUpdateConstraints(artifactUpdateParam); err != nil { + c.errorHandler(w, r, err, nil) + return + } + result, err := c.service.UpdateArtifact(r.Context(), idParam, artifactUpdateParam) + // If an error occurred, encode the error with the status code + if err != nil { + c.errorHandler(w, r, err, &result) + return + } + // If no error, encode the body and the result code + EncodeJSONResponse(result.Body, &result.Code, w) +} + // UpdateInferenceService - Update a InferenceService func (c *ModelRegistryServiceAPIController) UpdateInferenceService(w http.ResponseWriter, r *http.Request) { inferenceserviceIdParam := chi.URLParam(r, "inferenceserviceId") diff --git a/internal/server/openapi/api_model_registry_service_service.go b/internal/server/openapi/api_model_registry_service_service.go index c0d0e6bd..d22b38c2 100644 --- a/internal/server/openapi/api_model_registry_service_service.go +++ b/internal/server/openapi/api_model_registry_service_service.go @@ -75,6 +75,21 @@ func (s *ModelRegistryServiceAPIService) CreateInferenceServiceServe(ctx context // TODO: return Response(http.StatusUnauthorized, Error{}), nil } +// CreateArtifact - Create an Artifact +func (s *ModelRegistryServiceAPIService) CreateArtifact(ctx context.Context, artifactCreate model.ArtifactCreate) (ImplResponse, error) { + entity, err := s.converter.ConvertArtifactCreate(&artifactCreate) + if err != nil { + return ErrorResponse(http.StatusBadRequest, err), err + } + + result, err := s.coreApi.UpsertArtifact(entity) + if err != nil { + return ErrorResponse(api.ErrToStatus(err), err), err + } + return Response(http.StatusCreated, result), nil + // TODO: return Response(http.StatusUnauthorized, Error{}), nil +} + // CreateModelArtifact - Create a ModelArtifact func (s *ModelRegistryServiceAPIService) CreateModelArtifact(ctx context.Context, modelArtifactCreate model.ModelArtifactCreate) (ImplResponse, error) { entity, err := s.converter.ConvertModelArtifactCreate(&modelArtifactCreate) @@ -172,6 +187,16 @@ func (s *ModelRegistryServiceAPIService) FindInferenceService(ctx context.Contex // TODO return Response(http.StatusUnauthorized, Error{}), nil } +// FindArtifact - Get an Artifact that matches search parameters. +func (s *ModelRegistryServiceAPIService) FindArtifact(ctx context.Context, name string, externalId string, parentResourceId string) (ImplResponse, error) { + result, err := s.coreApi.GetArtifactByParams(apiutils.StrPtr(name), apiutils.StrPtr(parentResourceId), apiutils.StrPtr(externalId)) + if err != nil { + return ErrorResponse(api.ErrToStatus(err), err), err + } + return Response(http.StatusOK, result), nil + // TODO return Response(http.StatusUnauthorized, Error{}), nil +} + // FindModelArtifact - Get a ModelArtifact that matches search parameters. func (s *ModelRegistryServiceAPIService) FindModelArtifact(ctx context.Context, name string, externalId string, parentResourceId string) (ImplResponse, error) { result, err := s.coreApi.GetModelArtifactByParams(apiutils.StrPtr(name), apiutils.StrPtr(parentResourceId), apiutils.StrPtr(externalId)) @@ -284,6 +309,30 @@ func (s *ModelRegistryServiceAPIService) GetInferenceServices(ctx context.Contex // TODO return Response(http.StatusUnauthorized, Error{}), nil } +// GetArtifact - Get a Artifact +func (s *ModelRegistryServiceAPIService) GetArtifact(ctx context.Context, artifactId string) (ImplResponse, error) { + result, err := s.coreApi.GetArtifactById(artifactId) + if err != nil { + return ErrorResponse(api.ErrToStatus(err), err), err + } + return Response(http.StatusOK, result), nil + // TODO: return Response(http.StatusUnauthorized, Error{}), nil +} + +// GetArtifacts - List All Artifacts +func (s *ModelRegistryServiceAPIService) GetArtifacts(ctx context.Context, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { + listOpts, err := apiutils.BuildListOption(pageSize, orderBy, sortOrder, nextPageToken) + if err != nil { + return ErrorResponse(api.ErrToStatus(err), err), err + } + result, err := s.coreApi.GetArtifacts(listOpts, nil) + if err != nil { + return ErrorResponse(api.ErrToStatus(err), err), err + } + return Response(http.StatusOK, result), nil + // TODO return Response(http.StatusUnauthorized, Error{}), nil +} + // GetModelArtifact - Get a ModelArtifact func (s *ModelRegistryServiceAPIService) GetModelArtifact(ctx context.Context, modelartifactId string) (ImplResponse, error) { result, err := s.coreApi.GetModelArtifactById(modelartifactId) @@ -435,6 +484,33 @@ func (s *ModelRegistryServiceAPIService) UpdateInferenceService(ctx context.Cont // TODO return Response(http.StatusUnauthorized, Error{}), nil } +// UpdateArtifact - Update a Artifact +func (s *ModelRegistryServiceAPIService) UpdateArtifact(ctx context.Context, artifactId string, artifactUpdate model.ArtifactUpdate) (ImplResponse, error) { + entity, err := s.converter.ConvertArtifactUpdate(&artifactUpdate) + if err != nil { + return ErrorResponse(http.StatusBadRequest, err), err + } + if artifactUpdate.DocArtifactUpdate != nil { + entity.DocArtifact.Id = &artifactId + } else { + entity.ModelArtifact.Id = &artifactId + } + existing, err := s.coreApi.GetArtifactById(artifactId) + if err != nil { + return ErrorResponse(api.ErrToStatus(err), err), err + } + update, err := converter.UpdateExistingArtifact(s.reconciler, converter.NewOpenapiUpdateWrapper(existing, entity)) + if err != nil { + return ErrorResponse(http.StatusBadRequest, err), err + } + result, err := s.coreApi.UpsertArtifact(&update) + if err != nil { + return ErrorResponse(api.ErrToStatus(err), err), err + } + return Response(http.StatusOK, result), nil + // TODO return Response(http.StatusUnauthorized, Error{}), nil +} + // UpdateModelArtifact - Update a ModelArtifact func (s *ModelRegistryServiceAPIService) UpdateModelArtifact(ctx context.Context, modelartifactId string, modelArtifactUpdate model.ModelArtifactUpdate) (ImplResponse, error) { modelArtifact, err := s.converter.ConvertModelArtifactUpdate(&modelArtifactUpdate) diff --git a/internal/server/openapi/type_asserts.go b/internal/server/openapi/type_asserts.go index fa7cde46..63395c6a 100644 --- a/internal/server/openapi/type_asserts.go +++ b/internal/server/openapi/type_asserts.go @@ -21,6 +21,16 @@ func AssertArtifactConstraints(obj model.Artifact) error { return nil } +// AssertArtifactCreateConstraints checks if the values respects the defined constraints +func AssertArtifactCreateConstraints(obj model.ArtifactCreate) error { + return nil +} + +// AssertArtifactCreateRequired checks if the required fields are not zero-ed +func AssertArtifactCreateRequired(obj model.ArtifactCreate) error { + return nil +} + // AssertArtifactListConstraints checks if the values respects the defined constraints func AssertArtifactListConstraints(obj model.ArtifactList) error { return nil @@ -62,6 +72,16 @@ func AssertArtifactStateRequired(obj model.ArtifactState) error { return nil } +// AssertArtifactUpdateConstraints checks if the values respects the defined constraints +func AssertArtifactUpdateConstraints(obj model.ArtifactUpdate) error { + return nil +} + +// AssertArtifactUpdateRequired checks if the required fields are not zero-ed +func AssertArtifactUpdateRequired(obj model.ArtifactUpdate) error { + return nil +} + // AssertBaseArtifactConstraints checks if the values respects the defined constraints func AssertBaseArtifactConstraints(obj model.BaseArtifact) error { return nil @@ -178,6 +198,25 @@ func AssertDocArtifactConstraints(obj model.DocArtifact) error { return nil } +// AssertDocArtifactCreateConstraints checks if the values respects the defined constraints +func AssertDocArtifactCreateConstraints(obj model.DocArtifactCreate) error { + return nil +} + +// AssertDocArtifactCreateRequired checks if the required fields are not zero-ed +func AssertDocArtifactCreateRequired(obj model.DocArtifactCreate) error { + elements := map[string]interface{}{ + "artifactType": obj.ArtifactType, + } + for name, el := range elements { + if isZero := IsZeroValue(el); isZero { + return &RequiredError{Field: name} + } + } + + return nil +} + // AssertDocArtifactRequired checks if the required fields are not zero-ed func AssertDocArtifactRequired(obj model.DocArtifact) error { elements := map[string]interface{}{ @@ -192,6 +231,25 @@ func AssertDocArtifactRequired(obj model.DocArtifact) error { return nil } +// AssertDocArtifactUpdateConstraints checks if the values respects the defined constraints +func AssertDocArtifactUpdateConstraints(obj model.DocArtifactUpdate) error { + return nil +} + +// AssertDocArtifactUpdateRequired checks if the required fields are not zero-ed +func AssertDocArtifactUpdateRequired(obj model.DocArtifactUpdate) error { + elements := map[string]interface{}{ + "artifactType": obj.ArtifactType, + } + for name, el := range elements { + if isZero := IsZeroValue(el); isZero { + return &RequiredError{Field: name} + } + } + + return nil +} + // AssertErrorConstraints checks if the values respects the defined constraints func AssertErrorConstraints(obj model.Error) error { return nil @@ -468,6 +526,15 @@ func AssertModelArtifactCreateConstraints(obj model.ModelArtifactCreate) error { // AssertModelArtifactCreateRequired checks if the required fields are not zero-ed func AssertModelArtifactCreateRequired(obj model.ModelArtifactCreate) error { + elements := map[string]interface{}{ + "artifactType": obj.ArtifactType, + } + for name, el := range elements { + if isZero := IsZeroValue(el); isZero { + return &RequiredError{Field: name} + } + } + return nil } @@ -518,6 +585,15 @@ func AssertModelArtifactUpdateConstraints(obj model.ModelArtifactUpdate) error { // AssertModelArtifactUpdateRequired checks if the required fields are not zero-ed func AssertModelArtifactUpdateRequired(obj model.ModelArtifactUpdate) error { + elements := map[string]interface{}{ + "artifactType": obj.ArtifactType, + } + for name, el := range elements { + if isZero := IsZeroValue(el); isZero { + return &RequiredError{Field: name} + } + } + return nil } diff --git a/pkg/api/api.go b/pkg/api/api.go index 96662da6..a3dd3325 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -58,6 +58,8 @@ type ModelRegistryApi interface { GetArtifactById(id string) (*openapi.Artifact, error) + GetArtifactByParams(artifactName *string, modelVersionId *string, externalId *string) (*openapi.Artifact, error) + GetArtifacts(listOptions ListOptions, modelVersionId *string) (*openapi.ArtifactList, error) // MODEL ARTIFACT diff --git a/pkg/core/artifact.go b/pkg/core/artifact.go index 3c06a25b..ad5ccc07 100644 --- a/pkg/core/artifact.go +++ b/pkg/core/artifact.go @@ -18,6 +18,9 @@ import ( // ID is provided. // Upon creation, new artifacts will be associated with their corresponding model version. func (serv *ModelRegistryService) UpsertModelVersionArtifact(artifact *openapi.Artifact, modelVersionId string) (*openapi.Artifact, error) { + if artifact == nil { + return nil, fmt.Errorf("invalid artifact pointer, can't upsert nil: %w", api.ErrBadRequest) + } art, err := serv.upsertArtifact(artifact, &modelVersionId) if err != nil { return nil, err @@ -153,6 +156,48 @@ func (serv *ModelRegistryService) GetArtifactById(id string) (*openapi.Artifact, return serv.mapper.MapToArtifact(artifactsResp.Artifacts[0]) } +// GetArtifactByParams retrieves an artifact based on specified parameters, such as (artifact name and model version ID), or external ID. +// If multiple or no model artifacts are found, an error is returned. +func (serv *ModelRegistryService) GetArtifactByParams(artifactName *string, modelVersionId *string, externalId *string) (*openapi.Artifact, error) { + var artifact0 *proto.Artifact + + filterQuery := "" + if externalId != nil { + filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) + } else if artifactName != nil && modelVersionId != nil { + filterQuery = fmt.Sprintf("name = \"%s\"", converter.PrefixWhenOwned(modelVersionId, *artifactName)) + } else { + return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and modelVersionId), or externalId: %w", api.ErrBadRequest) + } + glog.Info("filterQuery ", filterQuery) + + artifactsResponse, err := serv.mlmdClient.GetArtifacts(context.Background(), &proto.GetArtifactsRequest{ + Options: &proto.ListOperationOptions{ + FilterQuery: &filterQuery, + }, + }) + if err != nil { + return nil, err + } + + if len(artifactsResponse.Artifacts) > 1 { + return nil, fmt.Errorf("multiple model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v: %w", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + if len(artifactsResponse.Artifacts) == 0 { + return nil, fmt.Errorf("no model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v: %w", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) + } + + artifact0 = artifactsResponse.Artifacts[0] + + result, err := serv.mapper.MapToArtifact(artifact0) + if err != nil { + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) + } + + return result, nil +} + // GetArtifacts retrieves a list of artifacts based on the provided list options and optional model version ID. func (serv *ModelRegistryService) GetArtifacts(listOptions api.ListOptions, modelVersionId *string) (*openapi.ArtifactList, error) { listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) diff --git a/pkg/core/artifact_test.go b/pkg/core/artifact_test.go index 12b7f90d..df78ec0c 100644 --- a/pkg/core/artifact_test.go +++ b/pkg/core/artifact_test.go @@ -412,6 +412,59 @@ func (suite *CoreTestSuite) TestGetArtifactById() { suite.Equal(*createdArtifact, *getById, "artifacts returned during creation and on get by id should be equal") } +func (suite *CoreTestSuite) TestGetArtifactByParams() { + // create mode registry service + service := suite.setupModelRegistryService() + + modelVersionId := suite.registerModelVersion(service, nil, nil, nil, nil) + + docArtifact := &openapi.DocArtifact{ + Name: &artifactName, + State: (*openapi.ArtifactState)(&artifactState), + Uri: &artifactUri, + ExternalId: &artifactExtId, + CustomProperties: &map[string]openapi.MetadataValue{ + "custom_string_prop": { + MetadataStringValue: converter.NewMetadataStringValue(customString), + }, + }, + } + + art, err := service.UpsertModelVersionArtifact(&openapi.Artifact{DocArtifact: docArtifact}, modelVersionId) + suite.Nilf(err, "error creating new model artifact: %v", err) + da := art.DocArtifact + + createdArtifactId, _ := converter.StringToInt64(da.Id) + + state, _ := openapi.NewArtifactStateFromValue(artifactState) + + artByName, err := service.GetArtifactByParams(&artifactName, &modelVersionId, nil) + suite.Nilf(err, "error getting model artifact by id %s: %v", *createdArtifactId, err) + daByName := artByName.DocArtifact + + suite.NotNil(da.Id, "created artifact id should not be nil") + suite.Equal(artifactName, *daByName.Name) + suite.Equal(artifactExtId, *daByName.ExternalId) + suite.Equal(*state, *daByName.State) + suite.Equal(artifactUri, *daByName.Uri) + suite.Equal(customString, (*daByName.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) + + suite.Equal(*da, *daByName, "artifacts returned during creation and on get by name should be equal") + + getByExtId, err := service.GetArtifactByParams(nil, nil, &artifactExtId) + suite.Nilf(err, "error getting model artifact by id %s: %v", *createdArtifactId, err) + daByExtId := getByExtId.DocArtifact + + suite.NotNil(da.Id, "created artifact id should not be nil") + suite.Equal(artifactName, *daByExtId.Name) + suite.Equal(artifactExtId, *daByExtId.ExternalId) + suite.Equal(*state, *daByExtId.State) + suite.Equal(artifactUri, *daByExtId.Uri) + suite.Equal(customString, (*daByExtId.CustomProperties)["custom_string_prop"].MetadataStringValue.StringValue) + + suite.Equal(*da, *daByExtId, "artifacts returned during creation and on get by ext id should be equal") +} + func (suite *CoreTestSuite) TestGetArtifacts() { // create mode registry service service := suite.setupModelRegistryService() diff --git a/pkg/openapi/.openapi-generator/FILES b/pkg/openapi/.openapi-generator/FILES index ba72484a..8a56f21b 100644 --- a/pkg/openapi/.openapi-generator/FILES +++ b/pkg/openapi/.openapi-generator/FILES @@ -2,8 +2,10 @@ api_model_registry_service.go client.go configuration.go model_artifact.go +model_artifact_create.go model_artifact_list.go model_artifact_state.go +model_artifact_update.go model_base_artifact.go model_base_artifact_create.go model_base_artifact_update.go @@ -15,6 +17,8 @@ model_base_resource_create.go model_base_resource_list.go model_base_resource_update.go model_doc_artifact.go +model_doc_artifact_create.go +model_doc_artifact_update.go model_error.go model_execution_state.go model_inference_service.go diff --git a/pkg/openapi/api_model_registry_service.go b/pkg/openapi/api_model_registry_service.go index 646fb530..4904ad21 100644 --- a/pkg/openapi/api_model_registry_service.go +++ b/pkg/openapi/api_model_registry_service.go @@ -22,6 +22,150 @@ import ( // ModelRegistryServiceAPIService ModelRegistryServiceAPI service type ModelRegistryServiceAPIService service +type ApiCreateArtifactRequest struct { + ctx context.Context + ApiService *ModelRegistryServiceAPIService + artifactCreate *ArtifactCreate +} + +// A new `Artifact` to be created. +func (r ApiCreateArtifactRequest) ArtifactCreate(artifactCreate ArtifactCreate) ApiCreateArtifactRequest { + r.artifactCreate = &artifactCreate + return r +} + +func (r ApiCreateArtifactRequest) Execute() (*Artifact, *http.Response, error) { + return r.ApiService.CreateArtifactExecute(r) +} + +/* +CreateArtifact Create an Artifact + +Creates a new instance of an `Artifact`. + + @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). + @return ApiCreateArtifactRequest +*/ +func (a *ModelRegistryServiceAPIService) CreateArtifact(ctx context.Context) ApiCreateArtifactRequest { + return ApiCreateArtifactRequest{ + ApiService: a, + ctx: ctx, + } +} + +// Execute executes the request +// +// @return Artifact +func (a *ModelRegistryServiceAPIService) CreateArtifactExecute(r ApiCreateArtifactRequest) (*Artifact, *http.Response, error) { + var ( + localVarHTTPMethod = http.MethodPost + localVarPostBody interface{} + formFiles []formFile + localVarReturnValue *Artifact + ) + + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.CreateArtifact") + if err != nil { + return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} + } + + localVarPath := localBasePath + "/api/model_registry/v1alpha3/artifacts" + + localVarHeaderParams := make(map[string]string) + localVarQueryParams := url.Values{} + localVarFormParams := url.Values{} + if r.artifactCreate == nil { + return localVarReturnValue, nil, reportError("artifactCreate is required and must be specified") + } + + // to determine the Content-Type header + localVarHTTPContentTypes := []string{"application/json"} + + // set Content-Type header + localVarHTTPContentType := selectHeaderContentType(localVarHTTPContentTypes) + if localVarHTTPContentType != "" { + localVarHeaderParams["Content-Type"] = localVarHTTPContentType + } + + // to determine the Accept header + localVarHTTPHeaderAccepts := []string{"application/json"} + + // set Accept header + localVarHTTPHeaderAccept := selectHeaderAccept(localVarHTTPHeaderAccepts) + if localVarHTTPHeaderAccept != "" { + localVarHeaderParams["Accept"] = localVarHTTPHeaderAccept + } + // body params + localVarPostBody = r.artifactCreate + req, err := a.client.prepareRequest(r.ctx, localVarPath, localVarHTTPMethod, localVarPostBody, localVarHeaderParams, localVarQueryParams, localVarFormParams, formFiles) + if err != nil { + return localVarReturnValue, nil, err + } + + localVarHTTPResponse, err := a.client.callAPI(req) + if err != nil || localVarHTTPResponse == nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + localVarBody, err := io.ReadAll(localVarHTTPResponse.Body) + localVarHTTPResponse.Body.Close() + localVarHTTPResponse.Body = io.NopCloser(bytes.NewBuffer(localVarBody)) + if err != nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + if localVarHTTPResponse.StatusCode >= 300 { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: localVarHTTPResponse.Status, + } + if localVarHTTPResponse.StatusCode == 400 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 401 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 500 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + err = a.client.decode(&localVarReturnValue, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: err.Error(), + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + return localVarReturnValue, localVarHTTPResponse, nil +} + type ApiCreateEnvironmentInferenceServiceRequest struct { ctx context.Context ApiService *ModelRegistryServiceAPIService @@ -1219,7 +1363,7 @@ func (a *ModelRegistryServiceAPIService) CreateServingEnvironmentExecute(r ApiCr return localVarReturnValue, localVarHTTPResponse, nil } -type ApiFindInferenceServiceRequest struct { +type ApiFindArtifactRequest struct { ctx context.Context ApiService *ModelRegistryServiceAPIService name *string @@ -1228,37 +1372,37 @@ type ApiFindInferenceServiceRequest struct { } // Name of entity to search. -func (r ApiFindInferenceServiceRequest) Name(name string) ApiFindInferenceServiceRequest { +func (r ApiFindArtifactRequest) Name(name string) ApiFindArtifactRequest { r.name = &name return r } // External ID of entity to search. -func (r ApiFindInferenceServiceRequest) ExternalId(externalId string) ApiFindInferenceServiceRequest { +func (r ApiFindArtifactRequest) ExternalId(externalId string) ApiFindArtifactRequest { r.externalId = &externalId return r } // ID of the parent resource to use for search. -func (r ApiFindInferenceServiceRequest) ParentResourceId(parentResourceId string) ApiFindInferenceServiceRequest { +func (r ApiFindArtifactRequest) ParentResourceId(parentResourceId string) ApiFindArtifactRequest { r.parentResourceId = &parentResourceId return r } -func (r ApiFindInferenceServiceRequest) Execute() (*InferenceService, *http.Response, error) { - return r.ApiService.FindInferenceServiceExecute(r) +func (r ApiFindArtifactRequest) Execute() (*Artifact, *http.Response, error) { + return r.ApiService.FindArtifactExecute(r) } /* -FindInferenceService Get an InferenceServices that matches search parameters. +FindArtifact Get an Artifact that matches search parameters. -Gets the details of a single instance of `InferenceService` that matches search parameters. +Gets the details of a single instance of an `Artifact` that matches search parameters. @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). - @return ApiFindInferenceServiceRequest + @return ApiFindArtifactRequest */ -func (a *ModelRegistryServiceAPIService) FindInferenceService(ctx context.Context) ApiFindInferenceServiceRequest { - return ApiFindInferenceServiceRequest{ +func (a *ModelRegistryServiceAPIService) FindArtifact(ctx context.Context) ApiFindArtifactRequest { + return ApiFindArtifactRequest{ ApiService: a, ctx: ctx, } @@ -1266,21 +1410,21 @@ func (a *ModelRegistryServiceAPIService) FindInferenceService(ctx context.Contex // Execute executes the request // -// @return InferenceService -func (a *ModelRegistryServiceAPIService) FindInferenceServiceExecute(r ApiFindInferenceServiceRequest) (*InferenceService, *http.Response, error) { +// @return Artifact +func (a *ModelRegistryServiceAPIService) FindArtifactExecute(r ApiFindArtifactRequest) (*Artifact, *http.Response, error) { var ( localVarHTTPMethod = http.MethodGet localVarPostBody interface{} formFiles []formFile - localVarReturnValue *InferenceService + localVarReturnValue *Artifact ) - localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindInferenceService") + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindArtifact") if err != nil { return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} } - localVarPath := localBasePath + "/api/model_registry/v1alpha3/inference_service" + localVarPath := localBasePath + "/api/model_registry/v1alpha3/artifact" localVarHeaderParams := make(map[string]string) localVarQueryParams := url.Values{} @@ -1392,7 +1536,7 @@ func (a *ModelRegistryServiceAPIService) FindInferenceServiceExecute(r ApiFindIn return localVarReturnValue, localVarHTTPResponse, nil } -type ApiFindModelArtifactRequest struct { +type ApiFindInferenceServiceRequest struct { ctx context.Context ApiService *ModelRegistryServiceAPIService name *string @@ -1401,37 +1545,37 @@ type ApiFindModelArtifactRequest struct { } // Name of entity to search. -func (r ApiFindModelArtifactRequest) Name(name string) ApiFindModelArtifactRequest { +func (r ApiFindInferenceServiceRequest) Name(name string) ApiFindInferenceServiceRequest { r.name = &name return r } // External ID of entity to search. -func (r ApiFindModelArtifactRequest) ExternalId(externalId string) ApiFindModelArtifactRequest { +func (r ApiFindInferenceServiceRequest) ExternalId(externalId string) ApiFindInferenceServiceRequest { r.externalId = &externalId return r } // ID of the parent resource to use for search. -func (r ApiFindModelArtifactRequest) ParentResourceId(parentResourceId string) ApiFindModelArtifactRequest { +func (r ApiFindInferenceServiceRequest) ParentResourceId(parentResourceId string) ApiFindInferenceServiceRequest { r.parentResourceId = &parentResourceId return r } -func (r ApiFindModelArtifactRequest) Execute() (*ModelArtifact, *http.Response, error) { - return r.ApiService.FindModelArtifactExecute(r) +func (r ApiFindInferenceServiceRequest) Execute() (*InferenceService, *http.Response, error) { + return r.ApiService.FindInferenceServiceExecute(r) } /* -FindModelArtifact Get a ModelArtifact that matches search parameters. +FindInferenceService Get an InferenceServices that matches search parameters. -Gets the details of a single instance of a `ModelArtifact` that matches search parameters. +Gets the details of a single instance of `InferenceService` that matches search parameters. @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). - @return ApiFindModelArtifactRequest + @return ApiFindInferenceServiceRequest */ -func (a *ModelRegistryServiceAPIService) FindModelArtifact(ctx context.Context) ApiFindModelArtifactRequest { - return ApiFindModelArtifactRequest{ +func (a *ModelRegistryServiceAPIService) FindInferenceService(ctx context.Context) ApiFindInferenceServiceRequest { + return ApiFindInferenceServiceRequest{ ApiService: a, ctx: ctx, } @@ -1439,21 +1583,21 @@ func (a *ModelRegistryServiceAPIService) FindModelArtifact(ctx context.Context) // Execute executes the request // -// @return ModelArtifact -func (a *ModelRegistryServiceAPIService) FindModelArtifactExecute(r ApiFindModelArtifactRequest) (*ModelArtifact, *http.Response, error) { +// @return InferenceService +func (a *ModelRegistryServiceAPIService) FindInferenceServiceExecute(r ApiFindInferenceServiceRequest) (*InferenceService, *http.Response, error) { var ( localVarHTTPMethod = http.MethodGet localVarPostBody interface{} formFiles []formFile - localVarReturnValue *ModelArtifact + localVarReturnValue *InferenceService ) - localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindModelArtifact") + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindInferenceService") if err != nil { return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} } - localVarPath := localBasePath + "/api/model_registry/v1alpha3/model_artifact" + localVarPath := localBasePath + "/api/model_registry/v1alpha3/inference_service" localVarHeaderParams := make(map[string]string) localVarQueryParams := url.Values{} @@ -1565,7 +1709,7 @@ func (a *ModelRegistryServiceAPIService) FindModelArtifactExecute(r ApiFindModel return localVarReturnValue, localVarHTTPResponse, nil } -type ApiFindModelVersionRequest struct { +type ApiFindModelArtifactRequest struct { ctx context.Context ApiService *ModelRegistryServiceAPIService name *string @@ -1574,37 +1718,37 @@ type ApiFindModelVersionRequest struct { } // Name of entity to search. -func (r ApiFindModelVersionRequest) Name(name string) ApiFindModelVersionRequest { +func (r ApiFindModelArtifactRequest) Name(name string) ApiFindModelArtifactRequest { r.name = &name return r } // External ID of entity to search. -func (r ApiFindModelVersionRequest) ExternalId(externalId string) ApiFindModelVersionRequest { +func (r ApiFindModelArtifactRequest) ExternalId(externalId string) ApiFindModelArtifactRequest { r.externalId = &externalId return r } // ID of the parent resource to use for search. -func (r ApiFindModelVersionRequest) ParentResourceId(parentResourceId string) ApiFindModelVersionRequest { +func (r ApiFindModelArtifactRequest) ParentResourceId(parentResourceId string) ApiFindModelArtifactRequest { r.parentResourceId = &parentResourceId return r } -func (r ApiFindModelVersionRequest) Execute() (*ModelVersion, *http.Response, error) { - return r.ApiService.FindModelVersionExecute(r) +func (r ApiFindModelArtifactRequest) Execute() (*ModelArtifact, *http.Response, error) { + return r.ApiService.FindModelArtifactExecute(r) } /* -FindModelVersion Get a ModelVersion that matches search parameters. +FindModelArtifact Get a ModelArtifact that matches search parameters. -Gets the details of a single instance of a `ModelVersion` that matches search parameters. +Gets the details of a single instance of a `ModelArtifact` that matches search parameters. @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). - @return ApiFindModelVersionRequest + @return ApiFindModelArtifactRequest */ -func (a *ModelRegistryServiceAPIService) FindModelVersion(ctx context.Context) ApiFindModelVersionRequest { - return ApiFindModelVersionRequest{ +func (a *ModelRegistryServiceAPIService) FindModelArtifact(ctx context.Context) ApiFindModelArtifactRequest { + return ApiFindModelArtifactRequest{ ApiService: a, ctx: ctx, } @@ -1612,21 +1756,512 @@ func (a *ModelRegistryServiceAPIService) FindModelVersion(ctx context.Context) A // Execute executes the request // -// @return ModelVersion -func (a *ModelRegistryServiceAPIService) FindModelVersionExecute(r ApiFindModelVersionRequest) (*ModelVersion, *http.Response, error) { +// @return ModelArtifact +func (a *ModelRegistryServiceAPIService) FindModelArtifactExecute(r ApiFindModelArtifactRequest) (*ModelArtifact, *http.Response, error) { var ( localVarHTTPMethod = http.MethodGet localVarPostBody interface{} formFiles []formFile - localVarReturnValue *ModelVersion + localVarReturnValue *ModelArtifact ) - localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindModelVersion") + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindModelArtifact") if err != nil { return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} } - localVarPath := localBasePath + "/api/model_registry/v1alpha3/model_version" + localVarPath := localBasePath + "/api/model_registry/v1alpha3/model_artifact" + + localVarHeaderParams := make(map[string]string) + localVarQueryParams := url.Values{} + localVarFormParams := url.Values{} + + if r.name != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "name", r.name, "") + } + if r.externalId != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "externalId", r.externalId, "") + } + if r.parentResourceId != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "parentResourceId", r.parentResourceId, "") + } + // to determine the Content-Type header + localVarHTTPContentTypes := []string{} + + // set Content-Type header + localVarHTTPContentType := selectHeaderContentType(localVarHTTPContentTypes) + if localVarHTTPContentType != "" { + localVarHeaderParams["Content-Type"] = localVarHTTPContentType + } + + // to determine the Accept header + localVarHTTPHeaderAccepts := []string{"application/json"} + + // set Accept header + localVarHTTPHeaderAccept := selectHeaderAccept(localVarHTTPHeaderAccepts) + if localVarHTTPHeaderAccept != "" { + localVarHeaderParams["Accept"] = localVarHTTPHeaderAccept + } + req, err := a.client.prepareRequest(r.ctx, localVarPath, localVarHTTPMethod, localVarPostBody, localVarHeaderParams, localVarQueryParams, localVarFormParams, formFiles) + if err != nil { + return localVarReturnValue, nil, err + } + + localVarHTTPResponse, err := a.client.callAPI(req) + if err != nil || localVarHTTPResponse == nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + localVarBody, err := io.ReadAll(localVarHTTPResponse.Body) + localVarHTTPResponse.Body.Close() + localVarHTTPResponse.Body = io.NopCloser(bytes.NewBuffer(localVarBody)) + if err != nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + if localVarHTTPResponse.StatusCode >= 300 { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: localVarHTTPResponse.Status, + } + if localVarHTTPResponse.StatusCode == 400 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 401 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 404 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 500 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + err = a.client.decode(&localVarReturnValue, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: err.Error(), + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + return localVarReturnValue, localVarHTTPResponse, nil +} + +type ApiFindModelVersionRequest struct { + ctx context.Context + ApiService *ModelRegistryServiceAPIService + name *string + externalId *string + parentResourceId *string +} + +// Name of entity to search. +func (r ApiFindModelVersionRequest) Name(name string) ApiFindModelVersionRequest { + r.name = &name + return r +} + +// External ID of entity to search. +func (r ApiFindModelVersionRequest) ExternalId(externalId string) ApiFindModelVersionRequest { + r.externalId = &externalId + return r +} + +// ID of the parent resource to use for search. +func (r ApiFindModelVersionRequest) ParentResourceId(parentResourceId string) ApiFindModelVersionRequest { + r.parentResourceId = &parentResourceId + return r +} + +func (r ApiFindModelVersionRequest) Execute() (*ModelVersion, *http.Response, error) { + return r.ApiService.FindModelVersionExecute(r) +} + +/* +FindModelVersion Get a ModelVersion that matches search parameters. + +Gets the details of a single instance of a `ModelVersion` that matches search parameters. + + @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). + @return ApiFindModelVersionRequest +*/ +func (a *ModelRegistryServiceAPIService) FindModelVersion(ctx context.Context) ApiFindModelVersionRequest { + return ApiFindModelVersionRequest{ + ApiService: a, + ctx: ctx, + } +} + +// Execute executes the request +// +// @return ModelVersion +func (a *ModelRegistryServiceAPIService) FindModelVersionExecute(r ApiFindModelVersionRequest) (*ModelVersion, *http.Response, error) { + var ( + localVarHTTPMethod = http.MethodGet + localVarPostBody interface{} + formFiles []formFile + localVarReturnValue *ModelVersion + ) + + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindModelVersion") + if err != nil { + return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} + } + + localVarPath := localBasePath + "/api/model_registry/v1alpha3/model_version" + + localVarHeaderParams := make(map[string]string) + localVarQueryParams := url.Values{} + localVarFormParams := url.Values{} + + if r.name != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "name", r.name, "") + } + if r.externalId != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "externalId", r.externalId, "") + } + if r.parentResourceId != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "parentResourceId", r.parentResourceId, "") + } + // to determine the Content-Type header + localVarHTTPContentTypes := []string{} + + // set Content-Type header + localVarHTTPContentType := selectHeaderContentType(localVarHTTPContentTypes) + if localVarHTTPContentType != "" { + localVarHeaderParams["Content-Type"] = localVarHTTPContentType + } + + // to determine the Accept header + localVarHTTPHeaderAccepts := []string{"application/json"} + + // set Accept header + localVarHTTPHeaderAccept := selectHeaderAccept(localVarHTTPHeaderAccepts) + if localVarHTTPHeaderAccept != "" { + localVarHeaderParams["Accept"] = localVarHTTPHeaderAccept + } + req, err := a.client.prepareRequest(r.ctx, localVarPath, localVarHTTPMethod, localVarPostBody, localVarHeaderParams, localVarQueryParams, localVarFormParams, formFiles) + if err != nil { + return localVarReturnValue, nil, err + } + + localVarHTTPResponse, err := a.client.callAPI(req) + if err != nil || localVarHTTPResponse == nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + localVarBody, err := io.ReadAll(localVarHTTPResponse.Body) + localVarHTTPResponse.Body.Close() + localVarHTTPResponse.Body = io.NopCloser(bytes.NewBuffer(localVarBody)) + if err != nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + if localVarHTTPResponse.StatusCode >= 300 { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: localVarHTTPResponse.Status, + } + if localVarHTTPResponse.StatusCode == 400 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 401 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 404 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 500 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + err = a.client.decode(&localVarReturnValue, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: err.Error(), + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + return localVarReturnValue, localVarHTTPResponse, nil +} + +type ApiFindRegisteredModelRequest struct { + ctx context.Context + ApiService *ModelRegistryServiceAPIService + name *string + externalId *string +} + +// Name of entity to search. +func (r ApiFindRegisteredModelRequest) Name(name string) ApiFindRegisteredModelRequest { + r.name = &name + return r +} + +// External ID of entity to search. +func (r ApiFindRegisteredModelRequest) ExternalId(externalId string) ApiFindRegisteredModelRequest { + r.externalId = &externalId + return r +} + +func (r ApiFindRegisteredModelRequest) Execute() (*RegisteredModel, *http.Response, error) { + return r.ApiService.FindRegisteredModelExecute(r) +} + +/* +FindRegisteredModel Get a RegisteredModel that matches search parameters. + +Gets the details of a single instance of a `RegisteredModel` that matches search parameters. + + @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). + @return ApiFindRegisteredModelRequest +*/ +func (a *ModelRegistryServiceAPIService) FindRegisteredModel(ctx context.Context) ApiFindRegisteredModelRequest { + return ApiFindRegisteredModelRequest{ + ApiService: a, + ctx: ctx, + } +} + +// Execute executes the request +// +// @return RegisteredModel +func (a *ModelRegistryServiceAPIService) FindRegisteredModelExecute(r ApiFindRegisteredModelRequest) (*RegisteredModel, *http.Response, error) { + var ( + localVarHTTPMethod = http.MethodGet + localVarPostBody interface{} + formFiles []formFile + localVarReturnValue *RegisteredModel + ) + + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindRegisteredModel") + if err != nil { + return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} + } + + localVarPath := localBasePath + "/api/model_registry/v1alpha3/registered_model" + + localVarHeaderParams := make(map[string]string) + localVarQueryParams := url.Values{} + localVarFormParams := url.Values{} + + if r.name != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "name", r.name, "") + } + if r.externalId != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "externalId", r.externalId, "") + } + // to determine the Content-Type header + localVarHTTPContentTypes := []string{} + + // set Content-Type header + localVarHTTPContentType := selectHeaderContentType(localVarHTTPContentTypes) + if localVarHTTPContentType != "" { + localVarHeaderParams["Content-Type"] = localVarHTTPContentType + } + + // to determine the Accept header + localVarHTTPHeaderAccepts := []string{"application/json"} + + // set Accept header + localVarHTTPHeaderAccept := selectHeaderAccept(localVarHTTPHeaderAccepts) + if localVarHTTPHeaderAccept != "" { + localVarHeaderParams["Accept"] = localVarHTTPHeaderAccept + } + req, err := a.client.prepareRequest(r.ctx, localVarPath, localVarHTTPMethod, localVarPostBody, localVarHeaderParams, localVarQueryParams, localVarFormParams, formFiles) + if err != nil { + return localVarReturnValue, nil, err + } + + localVarHTTPResponse, err := a.client.callAPI(req) + if err != nil || localVarHTTPResponse == nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + localVarBody, err := io.ReadAll(localVarHTTPResponse.Body) + localVarHTTPResponse.Body.Close() + localVarHTTPResponse.Body = io.NopCloser(bytes.NewBuffer(localVarBody)) + if err != nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + if localVarHTTPResponse.StatusCode >= 300 { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: localVarHTTPResponse.Status, + } + if localVarHTTPResponse.StatusCode == 401 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 404 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 500 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + err = a.client.decode(&localVarReturnValue, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: err.Error(), + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + return localVarReturnValue, localVarHTTPResponse, nil +} + +type ApiFindServingEnvironmentRequest struct { + ctx context.Context + ApiService *ModelRegistryServiceAPIService + name *string + externalId *string +} + +// Name of entity to search. +func (r ApiFindServingEnvironmentRequest) Name(name string) ApiFindServingEnvironmentRequest { + r.name = &name + return r +} + +// External ID of entity to search. +func (r ApiFindServingEnvironmentRequest) ExternalId(externalId string) ApiFindServingEnvironmentRequest { + r.externalId = &externalId + return r +} + +func (r ApiFindServingEnvironmentRequest) Execute() (*ServingEnvironment, *http.Response, error) { + return r.ApiService.FindServingEnvironmentExecute(r) +} + +/* +FindServingEnvironment Find ServingEnvironment + +Finds a `ServingEnvironment` entity that matches query parameters. + + @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). + @return ApiFindServingEnvironmentRequest +*/ +func (a *ModelRegistryServiceAPIService) FindServingEnvironment(ctx context.Context) ApiFindServingEnvironmentRequest { + return ApiFindServingEnvironmentRequest{ + ApiService: a, + ctx: ctx, + } +} + +// Execute executes the request +// +// @return ServingEnvironment +func (a *ModelRegistryServiceAPIService) FindServingEnvironmentExecute(r ApiFindServingEnvironmentRequest) (*ServingEnvironment, *http.Response, error) { + var ( + localVarHTTPMethod = http.MethodGet + localVarPostBody interface{} + formFiles []formFile + localVarReturnValue *ServingEnvironment + ) + + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindServingEnvironment") + if err != nil { + return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} + } + + localVarPath := localBasePath + "/api/model_registry/v1alpha3/serving_environment" localVarHeaderParams := make(map[string]string) localVarQueryParams := url.Values{} @@ -1638,9 +2273,6 @@ func (a *ModelRegistryServiceAPIService) FindModelVersionExecute(r ApiFindModelV if r.externalId != nil { parameterAddToHeaderOrQuery(localVarQueryParams, "externalId", r.externalId, "") } - if r.parentResourceId != nil { - parameterAddToHeaderOrQuery(localVarQueryParams, "parentResourceId", r.parentResourceId, "") - } // to determine the Content-Type header localVarHTTPContentTypes := []string{} @@ -1680,17 +2312,6 @@ func (a *ModelRegistryServiceAPIService) FindModelVersionExecute(r ApiFindModelV body: localVarBody, error: localVarHTTPResponse.Status, } - if localVarHTTPResponse.StatusCode == 400 { - var v Error - err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) - if err != nil { - newErr.error = err.Error() - return localVarReturnValue, localVarHTTPResponse, newErr - } - newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) - newErr.model = v - return localVarReturnValue, localVarHTTPResponse, newErr - } if localVarHTTPResponse.StatusCode == 401 { var v Error err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) @@ -1738,72 +2359,56 @@ func (a *ModelRegistryServiceAPIService) FindModelVersionExecute(r ApiFindModelV return localVarReturnValue, localVarHTTPResponse, nil } -type ApiFindRegisteredModelRequest struct { +type ApiGetArtifactRequest struct { ctx context.Context ApiService *ModelRegistryServiceAPIService - name *string - externalId *string -} - -// Name of entity to search. -func (r ApiFindRegisteredModelRequest) Name(name string) ApiFindRegisteredModelRequest { - r.name = &name - return r -} - -// External ID of entity to search. -func (r ApiFindRegisteredModelRequest) ExternalId(externalId string) ApiFindRegisteredModelRequest { - r.externalId = &externalId - return r + id string } -func (r ApiFindRegisteredModelRequest) Execute() (*RegisteredModel, *http.Response, error) { - return r.ApiService.FindRegisteredModelExecute(r) +func (r ApiGetArtifactRequest) Execute() (*Artifact, *http.Response, error) { + return r.ApiService.GetArtifactExecute(r) } /* -FindRegisteredModel Get a RegisteredModel that matches search parameters. +GetArtifact Get an Artifact -Gets the details of a single instance of a `RegisteredModel` that matches search parameters. +Gets the details of a single instance of an `Artifact`. @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). - @return ApiFindRegisteredModelRequest + @param id A unique identifier for an `Artifact`. + @return ApiGetArtifactRequest */ -func (a *ModelRegistryServiceAPIService) FindRegisteredModel(ctx context.Context) ApiFindRegisteredModelRequest { - return ApiFindRegisteredModelRequest{ +func (a *ModelRegistryServiceAPIService) GetArtifact(ctx context.Context, id string) ApiGetArtifactRequest { + return ApiGetArtifactRequest{ ApiService: a, ctx: ctx, + id: id, } } // Execute executes the request // -// @return RegisteredModel -func (a *ModelRegistryServiceAPIService) FindRegisteredModelExecute(r ApiFindRegisteredModelRequest) (*RegisteredModel, *http.Response, error) { +// @return Artifact +func (a *ModelRegistryServiceAPIService) GetArtifactExecute(r ApiGetArtifactRequest) (*Artifact, *http.Response, error) { var ( localVarHTTPMethod = http.MethodGet localVarPostBody interface{} formFiles []formFile - localVarReturnValue *RegisteredModel + localVarReturnValue *Artifact ) - localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindRegisteredModel") + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.GetArtifact") if err != nil { return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} } - localVarPath := localBasePath + "/api/model_registry/v1alpha3/registered_model" + localVarPath := localBasePath + "/api/model_registry/v1alpha3/artifacts/{id}" + localVarPath = strings.Replace(localVarPath, "{"+"id"+"}", url.PathEscape(parameterValueToString(r.id, "id")), -1) localVarHeaderParams := make(map[string]string) localVarQueryParams := url.Values{} localVarFormParams := url.Values{} - if r.name != nil { - parameterAddToHeaderOrQuery(localVarQueryParams, "name", r.name, "") - } - if r.externalId != nil { - parameterAddToHeaderOrQuery(localVarQueryParams, "externalId", r.externalId, "") - } // to determine the Content-Type header localVarHTTPContentTypes := []string{} @@ -1890,39 +2495,53 @@ func (a *ModelRegistryServiceAPIService) FindRegisteredModelExecute(r ApiFindReg return localVarReturnValue, localVarHTTPResponse, nil } -type ApiFindServingEnvironmentRequest struct { - ctx context.Context - ApiService *ModelRegistryServiceAPIService - name *string - externalId *string +type ApiGetArtifactsRequest struct { + ctx context.Context + ApiService *ModelRegistryServiceAPIService + pageSize *string + orderBy *OrderByField + sortOrder *SortOrder + nextPageToken *string } -// Name of entity to search. -func (r ApiFindServingEnvironmentRequest) Name(name string) ApiFindServingEnvironmentRequest { - r.name = &name +// Number of entities in each page. +func (r ApiGetArtifactsRequest) PageSize(pageSize string) ApiGetArtifactsRequest { + r.pageSize = &pageSize return r } -// External ID of entity to search. -func (r ApiFindServingEnvironmentRequest) ExternalId(externalId string) ApiFindServingEnvironmentRequest { - r.externalId = &externalId +// Specifies the order by criteria for listing entities. +func (r ApiGetArtifactsRequest) OrderBy(orderBy OrderByField) ApiGetArtifactsRequest { + r.orderBy = &orderBy return r } -func (r ApiFindServingEnvironmentRequest) Execute() (*ServingEnvironment, *http.Response, error) { - return r.ApiService.FindServingEnvironmentExecute(r) +// Specifies the sort order for listing entities, defaults to ASC. +func (r ApiGetArtifactsRequest) SortOrder(sortOrder SortOrder) ApiGetArtifactsRequest { + r.sortOrder = &sortOrder + return r +} + +// Token to use to retrieve next page of results. +func (r ApiGetArtifactsRequest) NextPageToken(nextPageToken string) ApiGetArtifactsRequest { + r.nextPageToken = &nextPageToken + return r +} + +func (r ApiGetArtifactsRequest) Execute() (*ArtifactList, *http.Response, error) { + return r.ApiService.GetArtifactsExecute(r) } /* -FindServingEnvironment Find ServingEnvironment +GetArtifacts List All Artifacts -Finds a `ServingEnvironment` entity that matches query parameters. +Gets a list of all `Artifact` entities. @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). - @return ApiFindServingEnvironmentRequest + @return ApiGetArtifactsRequest */ -func (a *ModelRegistryServiceAPIService) FindServingEnvironment(ctx context.Context) ApiFindServingEnvironmentRequest { - return ApiFindServingEnvironmentRequest{ +func (a *ModelRegistryServiceAPIService) GetArtifacts(ctx context.Context) ApiGetArtifactsRequest { + return ApiGetArtifactsRequest{ ApiService: a, ctx: ctx, } @@ -1930,31 +2549,37 @@ func (a *ModelRegistryServiceAPIService) FindServingEnvironment(ctx context.Cont // Execute executes the request // -// @return ServingEnvironment -func (a *ModelRegistryServiceAPIService) FindServingEnvironmentExecute(r ApiFindServingEnvironmentRequest) (*ServingEnvironment, *http.Response, error) { +// @return ArtifactList +func (a *ModelRegistryServiceAPIService) GetArtifactsExecute(r ApiGetArtifactsRequest) (*ArtifactList, *http.Response, error) { var ( localVarHTTPMethod = http.MethodGet localVarPostBody interface{} formFiles []formFile - localVarReturnValue *ServingEnvironment + localVarReturnValue *ArtifactList ) - localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.FindServingEnvironment") + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.GetArtifacts") if err != nil { return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} } - localVarPath := localBasePath + "/api/model_registry/v1alpha3/serving_environment" + localVarPath := localBasePath + "/api/model_registry/v1alpha3/artifacts" localVarHeaderParams := make(map[string]string) localVarQueryParams := url.Values{} localVarFormParams := url.Values{} - if r.name != nil { - parameterAddToHeaderOrQuery(localVarQueryParams, "name", r.name, "") + if r.pageSize != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "pageSize", r.pageSize, "") } - if r.externalId != nil { - parameterAddToHeaderOrQuery(localVarQueryParams, "externalId", r.externalId, "") + if r.orderBy != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "orderBy", r.orderBy, "") + } + if r.sortOrder != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "sortOrder", r.sortOrder, "") + } + if r.nextPageToken != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "nextPageToken", r.nextPageToken, "") } // to determine the Content-Type header localVarHTTPContentTypes := []string{} @@ -1995,6 +2620,17 @@ func (a *ModelRegistryServiceAPIService) FindServingEnvironmentExecute(r ApiFind body: localVarBody, error: localVarHTTPResponse.Status, } + if localVarHTTPResponse.StatusCode == 400 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } if localVarHTTPResponse.StatusCode == 401 { var v Error err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) @@ -4625,6 +5261,165 @@ func (a *ModelRegistryServiceAPIService) GetServingEnvironmentsExecute(r ApiGetS return localVarReturnValue, localVarHTTPResponse, nil } +type ApiUpdateArtifactRequest struct { + ctx context.Context + ApiService *ModelRegistryServiceAPIService + id string + artifactUpdate *ArtifactUpdate +} + +// Updated `Artifact` information. +func (r ApiUpdateArtifactRequest) ArtifactUpdate(artifactUpdate ArtifactUpdate) ApiUpdateArtifactRequest { + r.artifactUpdate = &artifactUpdate + return r +} + +func (r ApiUpdateArtifactRequest) Execute() (*Artifact, *http.Response, error) { + return r.ApiService.UpdateArtifactExecute(r) +} + +/* +UpdateArtifact Update an Artifact + +Updates an existing `Artifact`. + + @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). + @param id A unique identifier for an `Artifact`. + @return ApiUpdateArtifactRequest +*/ +func (a *ModelRegistryServiceAPIService) UpdateArtifact(ctx context.Context, id string) ApiUpdateArtifactRequest { + return ApiUpdateArtifactRequest{ + ApiService: a, + ctx: ctx, + id: id, + } +} + +// Execute executes the request +// +// @return Artifact +func (a *ModelRegistryServiceAPIService) UpdateArtifactExecute(r ApiUpdateArtifactRequest) (*Artifact, *http.Response, error) { + var ( + localVarHTTPMethod = http.MethodPatch + localVarPostBody interface{} + formFiles []formFile + localVarReturnValue *Artifact + ) + + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelRegistryServiceAPIService.UpdateArtifact") + if err != nil { + return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} + } + + localVarPath := localBasePath + "/api/model_registry/v1alpha3/artifacts/{id}" + localVarPath = strings.Replace(localVarPath, "{"+"id"+"}", url.PathEscape(parameterValueToString(r.id, "id")), -1) + + localVarHeaderParams := make(map[string]string) + localVarQueryParams := url.Values{} + localVarFormParams := url.Values{} + if r.artifactUpdate == nil { + return localVarReturnValue, nil, reportError("artifactUpdate is required and must be specified") + } + + // to determine the Content-Type header + localVarHTTPContentTypes := []string{"application/json"} + + // set Content-Type header + localVarHTTPContentType := selectHeaderContentType(localVarHTTPContentTypes) + if localVarHTTPContentType != "" { + localVarHeaderParams["Content-Type"] = localVarHTTPContentType + } + + // to determine the Accept header + localVarHTTPHeaderAccepts := []string{"application/json"} + + // set Accept header + localVarHTTPHeaderAccept := selectHeaderAccept(localVarHTTPHeaderAccepts) + if localVarHTTPHeaderAccept != "" { + localVarHeaderParams["Accept"] = localVarHTTPHeaderAccept + } + // body params + localVarPostBody = r.artifactUpdate + req, err := a.client.prepareRequest(r.ctx, localVarPath, localVarHTTPMethod, localVarPostBody, localVarHeaderParams, localVarQueryParams, localVarFormParams, formFiles) + if err != nil { + return localVarReturnValue, nil, err + } + + localVarHTTPResponse, err := a.client.callAPI(req) + if err != nil || localVarHTTPResponse == nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + localVarBody, err := io.ReadAll(localVarHTTPResponse.Body) + localVarHTTPResponse.Body.Close() + localVarHTTPResponse.Body = io.NopCloser(bytes.NewBuffer(localVarBody)) + if err != nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + if localVarHTTPResponse.StatusCode >= 300 { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: localVarHTTPResponse.Status, + } + if localVarHTTPResponse.StatusCode == 400 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 401 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 404 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 500 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + err = a.client.decode(&localVarReturnValue, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: err.Error(), + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + return localVarReturnValue, localVarHTTPResponse, nil +} + type ApiUpdateInferenceServiceRequest struct { ctx context.Context ApiService *ModelRegistryServiceAPIService diff --git a/pkg/openapi/model_artifact_create.go b/pkg/openapi/model_artifact_create.go new file mode 100644 index 00000000..88755d3b --- /dev/null +++ b/pkg/openapi/model_artifact_create.go @@ -0,0 +1,163 @@ +/* +Model Registry REST API + +REST API for Model Registry to create and manage ML model metadata + +API version: v1alpha3 +*/ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package openapi + +import ( + "encoding/json" + "fmt" +) + +// ArtifactCreate - An Artifact to be created. +type ArtifactCreate struct { + DocArtifactCreate *DocArtifactCreate + ModelArtifactCreate *ModelArtifactCreate +} + +// DocArtifactCreateAsArtifactCreate is a convenience function that returns DocArtifactCreate wrapped in ArtifactCreate +func DocArtifactCreateAsArtifactCreate(v *DocArtifactCreate) ArtifactCreate { + return ArtifactCreate{ + DocArtifactCreate: v, + } +} + +// ModelArtifactCreateAsArtifactCreate is a convenience function that returns ModelArtifactCreate wrapped in ArtifactCreate +func ModelArtifactCreateAsArtifactCreate(v *ModelArtifactCreate) ArtifactCreate { + return ArtifactCreate{ + ModelArtifactCreate: v, + } +} + +// Unmarshal JSON data into one of the pointers in the struct +func (dst *ArtifactCreate) UnmarshalJSON(data []byte) error { + var err error + // use discriminator value to speed up the lookup + var jsonDict map[string]interface{} + err = newStrictDecoder(data).Decode(&jsonDict) + if err != nil { + return fmt.Errorf("failed to unmarshal JSON into map for the discriminator lookup") + } + + // check if the discriminator value is 'DocArtifactCreate' + if jsonDict["artifactType"] == "DocArtifactCreate" { + // try to unmarshal JSON data into DocArtifactCreate + err = json.Unmarshal(data, &dst.DocArtifactCreate) + if err == nil { + return nil // data stored in dst.DocArtifactCreate, return on the first match + } else { + dst.DocArtifactCreate = nil + return fmt.Errorf("failed to unmarshal ArtifactCreate as DocArtifactCreate: %s", err.Error()) + } + } + + // check if the discriminator value is 'ModelArtifactCreate' + if jsonDict["artifactType"] == "ModelArtifactCreate" { + // try to unmarshal JSON data into ModelArtifactCreate + err = json.Unmarshal(data, &dst.ModelArtifactCreate) + if err == nil { + return nil // data stored in dst.ModelArtifactCreate, return on the first match + } else { + dst.ModelArtifactCreate = nil + return fmt.Errorf("failed to unmarshal ArtifactCreate as ModelArtifactCreate: %s", err.Error()) + } + } + + // check if the discriminator value is 'doc-artifact' + if jsonDict["artifactType"] == "doc-artifact" { + // try to unmarshal JSON data into DocArtifactCreate + err = json.Unmarshal(data, &dst.DocArtifactCreate) + if err == nil { + return nil // data stored in dst.DocArtifactCreate, return on the first match + } else { + dst.DocArtifactCreate = nil + return fmt.Errorf("failed to unmarshal ArtifactCreate as DocArtifactCreate: %s", err.Error()) + } + } + + // check if the discriminator value is 'model-artifact' + if jsonDict["artifactType"] == "model-artifact" { + // try to unmarshal JSON data into ModelArtifactCreate + err = json.Unmarshal(data, &dst.ModelArtifactCreate) + if err == nil { + return nil // data stored in dst.ModelArtifactCreate, return on the first match + } else { + dst.ModelArtifactCreate = nil + return fmt.Errorf("failed to unmarshal ArtifactCreate as ModelArtifactCreate: %s", err.Error()) + } + } + + return nil +} + +// Marshal data from the first non-nil pointers in the struct to JSON +func (src ArtifactCreate) MarshalJSON() ([]byte, error) { + if src.DocArtifactCreate != nil { + return json.Marshal(&src.DocArtifactCreate) + } + + if src.ModelArtifactCreate != nil { + return json.Marshal(&src.ModelArtifactCreate) + } + + return nil, nil // no data in oneOf schemas +} + +// Get the actual instance +func (obj *ArtifactCreate) GetActualInstance() interface{} { + if obj == nil { + return nil + } + if obj.DocArtifactCreate != nil { + return obj.DocArtifactCreate + } + + if obj.ModelArtifactCreate != nil { + return obj.ModelArtifactCreate + } + + // all schemas are nil + return nil +} + +type NullableArtifactCreate struct { + value *ArtifactCreate + isSet bool +} + +func (v NullableArtifactCreate) Get() *ArtifactCreate { + return v.value +} + +func (v *NullableArtifactCreate) Set(val *ArtifactCreate) { + v.value = val + v.isSet = true +} + +func (v NullableArtifactCreate) IsSet() bool { + return v.isSet +} + +func (v *NullableArtifactCreate) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableArtifactCreate(val *ArtifactCreate) *NullableArtifactCreate { + return &NullableArtifactCreate{value: val, isSet: true} +} + +func (v NullableArtifactCreate) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableArtifactCreate) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/pkg/openapi/model_artifact_update.go b/pkg/openapi/model_artifact_update.go new file mode 100644 index 00000000..2c9b1a03 --- /dev/null +++ b/pkg/openapi/model_artifact_update.go @@ -0,0 +1,163 @@ +/* +Model Registry REST API + +REST API for Model Registry to create and manage ML model metadata + +API version: v1alpha3 +*/ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package openapi + +import ( + "encoding/json" + "fmt" +) + +// ArtifactUpdate - An Artifact to be updated. +type ArtifactUpdate struct { + DocArtifactUpdate *DocArtifactUpdate + ModelArtifactUpdate *ModelArtifactUpdate +} + +// DocArtifactUpdateAsArtifactUpdate is a convenience function that returns DocArtifactUpdate wrapped in ArtifactUpdate +func DocArtifactUpdateAsArtifactUpdate(v *DocArtifactUpdate) ArtifactUpdate { + return ArtifactUpdate{ + DocArtifactUpdate: v, + } +} + +// ModelArtifactUpdateAsArtifactUpdate is a convenience function that returns ModelArtifactUpdate wrapped in ArtifactUpdate +func ModelArtifactUpdateAsArtifactUpdate(v *ModelArtifactUpdate) ArtifactUpdate { + return ArtifactUpdate{ + ModelArtifactUpdate: v, + } +} + +// Unmarshal JSON data into one of the pointers in the struct +func (dst *ArtifactUpdate) UnmarshalJSON(data []byte) error { + var err error + // use discriminator value to speed up the lookup + var jsonDict map[string]interface{} + err = newStrictDecoder(data).Decode(&jsonDict) + if err != nil { + return fmt.Errorf("failed to unmarshal JSON into map for the discriminator lookup") + } + + // check if the discriminator value is 'DocArtifactUpdate' + if jsonDict["artifactType"] == "DocArtifactUpdate" { + // try to unmarshal JSON data into DocArtifactUpdate + err = json.Unmarshal(data, &dst.DocArtifactUpdate) + if err == nil { + return nil // data stored in dst.DocArtifactUpdate, return on the first match + } else { + dst.DocArtifactUpdate = nil + return fmt.Errorf("failed to unmarshal ArtifactUpdate as DocArtifactUpdate: %s", err.Error()) + } + } + + // check if the discriminator value is 'ModelArtifactUpdate' + if jsonDict["artifactType"] == "ModelArtifactUpdate" { + // try to unmarshal JSON data into ModelArtifactUpdate + err = json.Unmarshal(data, &dst.ModelArtifactUpdate) + if err == nil { + return nil // data stored in dst.ModelArtifactUpdate, return on the first match + } else { + dst.ModelArtifactUpdate = nil + return fmt.Errorf("failed to unmarshal ArtifactUpdate as ModelArtifactUpdate: %s", err.Error()) + } + } + + // check if the discriminator value is 'doc-artifact' + if jsonDict["artifactType"] == "doc-artifact" { + // try to unmarshal JSON data into DocArtifactUpdate + err = json.Unmarshal(data, &dst.DocArtifactUpdate) + if err == nil { + return nil // data stored in dst.DocArtifactUpdate, return on the first match + } else { + dst.DocArtifactUpdate = nil + return fmt.Errorf("failed to unmarshal ArtifactUpdate as DocArtifactUpdate: %s", err.Error()) + } + } + + // check if the discriminator value is 'model-artifact' + if jsonDict["artifactType"] == "model-artifact" { + // try to unmarshal JSON data into ModelArtifactUpdate + err = json.Unmarshal(data, &dst.ModelArtifactUpdate) + if err == nil { + return nil // data stored in dst.ModelArtifactUpdate, return on the first match + } else { + dst.ModelArtifactUpdate = nil + return fmt.Errorf("failed to unmarshal ArtifactUpdate as ModelArtifactUpdate: %s", err.Error()) + } + } + + return nil +} + +// Marshal data from the first non-nil pointers in the struct to JSON +func (src ArtifactUpdate) MarshalJSON() ([]byte, error) { + if src.DocArtifactUpdate != nil { + return json.Marshal(&src.DocArtifactUpdate) + } + + if src.ModelArtifactUpdate != nil { + return json.Marshal(&src.ModelArtifactUpdate) + } + + return nil, nil // no data in oneOf schemas +} + +// Get the actual instance +func (obj *ArtifactUpdate) GetActualInstance() interface{} { + if obj == nil { + return nil + } + if obj.DocArtifactUpdate != nil { + return obj.DocArtifactUpdate + } + + if obj.ModelArtifactUpdate != nil { + return obj.ModelArtifactUpdate + } + + // all schemas are nil + return nil +} + +type NullableArtifactUpdate struct { + value *ArtifactUpdate + isSet bool +} + +func (v NullableArtifactUpdate) Get() *ArtifactUpdate { + return v.value +} + +func (v *NullableArtifactUpdate) Set(val *ArtifactUpdate) { + v.value = val + v.isSet = true +} + +func (v NullableArtifactUpdate) IsSet() bool { + return v.isSet +} + +func (v *NullableArtifactUpdate) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableArtifactUpdate(val *ArtifactUpdate) *NullableArtifactUpdate { + return &NullableArtifactUpdate{value: val, isSet: true} +} + +func (v NullableArtifactUpdate) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableArtifactUpdate) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/pkg/openapi/model_doc_artifact_create.go b/pkg/openapi/model_doc_artifact_create.go new file mode 100644 index 00000000..07e4240a --- /dev/null +++ b/pkg/openapi/model_doc_artifact_create.go @@ -0,0 +1,341 @@ +/* +Model Registry REST API + +REST API for Model Registry to create and manage ML model metadata + +API version: v1alpha3 +*/ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package openapi + +import ( + "encoding/json" +) + +// checks if the DocArtifactCreate type satisfies the MappedNullable interface at compile time +var _ MappedNullable = &DocArtifactCreate{} + +// DocArtifactCreate A document artifact to be created. +type DocArtifactCreate struct { + ArtifactType string `json:"artifactType"` + // User provided custom properties which are not defined by its type. + CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` + // An optional description about the resource. + Description *string `json:"description,omitempty"` + // The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance. + ExternalId *string `json:"externalId,omitempty"` + // The uniform resource identifier of the physical artifact. May be empty if there is no physical artifact. + Uri *string `json:"uri,omitempty"` + State *ArtifactState `json:"state,omitempty"` + // The client provided name of the artifact. This field is optional. If set, it must be unique among all the artifacts of the same artifact type within a database instance and cannot be changed once set. + Name *string `json:"name,omitempty"` +} + +// NewDocArtifactCreate instantiates a new DocArtifactCreate object +// This constructor will assign default values to properties that have it defined, +// and makes sure properties required by API are set, but the set of arguments +// will change when the set of required properties is changed +func NewDocArtifactCreate(artifactType string) *DocArtifactCreate { + this := DocArtifactCreate{} + var state ArtifactState = ARTIFACTSTATE_UNKNOWN + this.State = &state + return &this +} + +// NewDocArtifactCreateWithDefaults instantiates a new DocArtifactCreate object +// This constructor will only assign default values to properties that have it defined, +// but it doesn't guarantee that properties required by API are set +func NewDocArtifactCreateWithDefaults() *DocArtifactCreate { + this := DocArtifactCreate{} + var artifactType string = "doc-artifact" + this.ArtifactType = artifactType + var state ArtifactState = ARTIFACTSTATE_UNKNOWN + this.State = &state + return &this +} + +// GetArtifactType returns the ArtifactType field value +func (o *DocArtifactCreate) GetArtifactType() string { + if o == nil { + var ret string + return ret + } + + return o.ArtifactType +} + +// GetArtifactTypeOk returns a tuple with the ArtifactType field value +// and a boolean to check if the value has been set. +func (o *DocArtifactCreate) GetArtifactTypeOk() (*string, bool) { + if o == nil { + return nil, false + } + return &o.ArtifactType, true +} + +// SetArtifactType sets field value +func (o *DocArtifactCreate) SetArtifactType(v string) { + o.ArtifactType = v +} + +// GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. +func (o *DocArtifactCreate) GetCustomProperties() map[string]MetadataValue { + if o == nil || IsNil(o.CustomProperties) { + var ret map[string]MetadataValue + return ret + } + return *o.CustomProperties +} + +// GetCustomPropertiesOk returns a tuple with the CustomProperties field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactCreate) GetCustomPropertiesOk() (*map[string]MetadataValue, bool) { + if o == nil || IsNil(o.CustomProperties) { + return nil, false + } + return o.CustomProperties, true +} + +// HasCustomProperties returns a boolean if a field has been set. +func (o *DocArtifactCreate) HasCustomProperties() bool { + if o != nil && !IsNil(o.CustomProperties) { + return true + } + + return false +} + +// SetCustomProperties gets a reference to the given map[string]MetadataValue and assigns it to the CustomProperties field. +func (o *DocArtifactCreate) SetCustomProperties(v map[string]MetadataValue) { + o.CustomProperties = &v +} + +// GetDescription returns the Description field value if set, zero value otherwise. +func (o *DocArtifactCreate) GetDescription() string { + if o == nil || IsNil(o.Description) { + var ret string + return ret + } + return *o.Description +} + +// GetDescriptionOk returns a tuple with the Description field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactCreate) GetDescriptionOk() (*string, bool) { + if o == nil || IsNil(o.Description) { + return nil, false + } + return o.Description, true +} + +// HasDescription returns a boolean if a field has been set. +func (o *DocArtifactCreate) HasDescription() bool { + if o != nil && !IsNil(o.Description) { + return true + } + + return false +} + +// SetDescription gets a reference to the given string and assigns it to the Description field. +func (o *DocArtifactCreate) SetDescription(v string) { + o.Description = &v +} + +// GetExternalId returns the ExternalId field value if set, zero value otherwise. +func (o *DocArtifactCreate) GetExternalId() string { + if o == nil || IsNil(o.ExternalId) { + var ret string + return ret + } + return *o.ExternalId +} + +// GetExternalIdOk returns a tuple with the ExternalId field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactCreate) GetExternalIdOk() (*string, bool) { + if o == nil || IsNil(o.ExternalId) { + return nil, false + } + return o.ExternalId, true +} + +// HasExternalId returns a boolean if a field has been set. +func (o *DocArtifactCreate) HasExternalId() bool { + if o != nil && !IsNil(o.ExternalId) { + return true + } + + return false +} + +// SetExternalId gets a reference to the given string and assigns it to the ExternalId field. +func (o *DocArtifactCreate) SetExternalId(v string) { + o.ExternalId = &v +} + +// GetUri returns the Uri field value if set, zero value otherwise. +func (o *DocArtifactCreate) GetUri() string { + if o == nil || IsNil(o.Uri) { + var ret string + return ret + } + return *o.Uri +} + +// GetUriOk returns a tuple with the Uri field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactCreate) GetUriOk() (*string, bool) { + if o == nil || IsNil(o.Uri) { + return nil, false + } + return o.Uri, true +} + +// HasUri returns a boolean if a field has been set. +func (o *DocArtifactCreate) HasUri() bool { + if o != nil && !IsNil(o.Uri) { + return true + } + + return false +} + +// SetUri gets a reference to the given string and assigns it to the Uri field. +func (o *DocArtifactCreate) SetUri(v string) { + o.Uri = &v +} + +// GetState returns the State field value if set, zero value otherwise. +func (o *DocArtifactCreate) GetState() ArtifactState { + if o == nil || IsNil(o.State) { + var ret ArtifactState + return ret + } + return *o.State +} + +// GetStateOk returns a tuple with the State field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactCreate) GetStateOk() (*ArtifactState, bool) { + if o == nil || IsNil(o.State) { + return nil, false + } + return o.State, true +} + +// HasState returns a boolean if a field has been set. +func (o *DocArtifactCreate) HasState() bool { + if o != nil && !IsNil(o.State) { + return true + } + + return false +} + +// SetState gets a reference to the given ArtifactState and assigns it to the State field. +func (o *DocArtifactCreate) SetState(v ArtifactState) { + o.State = &v +} + +// GetName returns the Name field value if set, zero value otherwise. +func (o *DocArtifactCreate) GetName() string { + if o == nil || IsNil(o.Name) { + var ret string + return ret + } + return *o.Name +} + +// GetNameOk returns a tuple with the Name field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactCreate) GetNameOk() (*string, bool) { + if o == nil || IsNil(o.Name) { + return nil, false + } + return o.Name, true +} + +// HasName returns a boolean if a field has been set. +func (o *DocArtifactCreate) HasName() bool { + if o != nil && !IsNil(o.Name) { + return true + } + + return false +} + +// SetName gets a reference to the given string and assigns it to the Name field. +func (o *DocArtifactCreate) SetName(v string) { + o.Name = &v +} + +func (o DocArtifactCreate) MarshalJSON() ([]byte, error) { + toSerialize, err := o.ToMap() + if err != nil { + return []byte{}, err + } + return json.Marshal(toSerialize) +} + +func (o DocArtifactCreate) ToMap() (map[string]interface{}, error) { + toSerialize := map[string]interface{}{} + toSerialize["artifactType"] = o.ArtifactType + if !IsNil(o.CustomProperties) { + toSerialize["customProperties"] = o.CustomProperties + } + if !IsNil(o.Description) { + toSerialize["description"] = o.Description + } + if !IsNil(o.ExternalId) { + toSerialize["externalId"] = o.ExternalId + } + if !IsNil(o.Uri) { + toSerialize["uri"] = o.Uri + } + if !IsNil(o.State) { + toSerialize["state"] = o.State + } + if !IsNil(o.Name) { + toSerialize["name"] = o.Name + } + return toSerialize, nil +} + +type NullableDocArtifactCreate struct { + value *DocArtifactCreate + isSet bool +} + +func (v NullableDocArtifactCreate) Get() *DocArtifactCreate { + return v.value +} + +func (v *NullableDocArtifactCreate) Set(val *DocArtifactCreate) { + v.value = val + v.isSet = true +} + +func (v NullableDocArtifactCreate) IsSet() bool { + return v.isSet +} + +func (v *NullableDocArtifactCreate) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableDocArtifactCreate(val *DocArtifactCreate) *NullableDocArtifactCreate { + return &NullableDocArtifactCreate{value: val, isSet: true} +} + +func (v NullableDocArtifactCreate) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableDocArtifactCreate) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/pkg/openapi/model_doc_artifact_update.go b/pkg/openapi/model_doc_artifact_update.go new file mode 100644 index 00000000..cee15785 --- /dev/null +++ b/pkg/openapi/model_doc_artifact_update.go @@ -0,0 +1,304 @@ +/* +Model Registry REST API + +REST API for Model Registry to create and manage ML model metadata + +API version: v1alpha3 +*/ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package openapi + +import ( + "encoding/json" +) + +// checks if the DocArtifactUpdate type satisfies the MappedNullable interface at compile time +var _ MappedNullable = &DocArtifactUpdate{} + +// DocArtifactUpdate A document artifact to be updated. +type DocArtifactUpdate struct { + ArtifactType string `json:"artifactType"` + // User provided custom properties which are not defined by its type. + CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` + // An optional description about the resource. + Description *string `json:"description,omitempty"` + // The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance. + ExternalId *string `json:"externalId,omitempty"` + // The uniform resource identifier of the physical artifact. May be empty if there is no physical artifact. + Uri *string `json:"uri,omitempty"` + State *ArtifactState `json:"state,omitempty"` +} + +// NewDocArtifactUpdate instantiates a new DocArtifactUpdate object +// This constructor will assign default values to properties that have it defined, +// and makes sure properties required by API are set, but the set of arguments +// will change when the set of required properties is changed +func NewDocArtifactUpdate(artifactType string) *DocArtifactUpdate { + this := DocArtifactUpdate{} + var state ArtifactState = ARTIFACTSTATE_UNKNOWN + this.State = &state + return &this +} + +// NewDocArtifactUpdateWithDefaults instantiates a new DocArtifactUpdate object +// This constructor will only assign default values to properties that have it defined, +// but it doesn't guarantee that properties required by API are set +func NewDocArtifactUpdateWithDefaults() *DocArtifactUpdate { + this := DocArtifactUpdate{} + var artifactType string = "doc-artifact" + this.ArtifactType = artifactType + var state ArtifactState = ARTIFACTSTATE_UNKNOWN + this.State = &state + return &this +} + +// GetArtifactType returns the ArtifactType field value +func (o *DocArtifactUpdate) GetArtifactType() string { + if o == nil { + var ret string + return ret + } + + return o.ArtifactType +} + +// GetArtifactTypeOk returns a tuple with the ArtifactType field value +// and a boolean to check if the value has been set. +func (o *DocArtifactUpdate) GetArtifactTypeOk() (*string, bool) { + if o == nil { + return nil, false + } + return &o.ArtifactType, true +} + +// SetArtifactType sets field value +func (o *DocArtifactUpdate) SetArtifactType(v string) { + o.ArtifactType = v +} + +// GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. +func (o *DocArtifactUpdate) GetCustomProperties() map[string]MetadataValue { + if o == nil || IsNil(o.CustomProperties) { + var ret map[string]MetadataValue + return ret + } + return *o.CustomProperties +} + +// GetCustomPropertiesOk returns a tuple with the CustomProperties field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactUpdate) GetCustomPropertiesOk() (*map[string]MetadataValue, bool) { + if o == nil || IsNil(o.CustomProperties) { + return nil, false + } + return o.CustomProperties, true +} + +// HasCustomProperties returns a boolean if a field has been set. +func (o *DocArtifactUpdate) HasCustomProperties() bool { + if o != nil && !IsNil(o.CustomProperties) { + return true + } + + return false +} + +// SetCustomProperties gets a reference to the given map[string]MetadataValue and assigns it to the CustomProperties field. +func (o *DocArtifactUpdate) SetCustomProperties(v map[string]MetadataValue) { + o.CustomProperties = &v +} + +// GetDescription returns the Description field value if set, zero value otherwise. +func (o *DocArtifactUpdate) GetDescription() string { + if o == nil || IsNil(o.Description) { + var ret string + return ret + } + return *o.Description +} + +// GetDescriptionOk returns a tuple with the Description field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactUpdate) GetDescriptionOk() (*string, bool) { + if o == nil || IsNil(o.Description) { + return nil, false + } + return o.Description, true +} + +// HasDescription returns a boolean if a field has been set. +func (o *DocArtifactUpdate) HasDescription() bool { + if o != nil && !IsNil(o.Description) { + return true + } + + return false +} + +// SetDescription gets a reference to the given string and assigns it to the Description field. +func (o *DocArtifactUpdate) SetDescription(v string) { + o.Description = &v +} + +// GetExternalId returns the ExternalId field value if set, zero value otherwise. +func (o *DocArtifactUpdate) GetExternalId() string { + if o == nil || IsNil(o.ExternalId) { + var ret string + return ret + } + return *o.ExternalId +} + +// GetExternalIdOk returns a tuple with the ExternalId field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactUpdate) GetExternalIdOk() (*string, bool) { + if o == nil || IsNil(o.ExternalId) { + return nil, false + } + return o.ExternalId, true +} + +// HasExternalId returns a boolean if a field has been set. +func (o *DocArtifactUpdate) HasExternalId() bool { + if o != nil && !IsNil(o.ExternalId) { + return true + } + + return false +} + +// SetExternalId gets a reference to the given string and assigns it to the ExternalId field. +func (o *DocArtifactUpdate) SetExternalId(v string) { + o.ExternalId = &v +} + +// GetUri returns the Uri field value if set, zero value otherwise. +func (o *DocArtifactUpdate) GetUri() string { + if o == nil || IsNil(o.Uri) { + var ret string + return ret + } + return *o.Uri +} + +// GetUriOk returns a tuple with the Uri field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactUpdate) GetUriOk() (*string, bool) { + if o == nil || IsNil(o.Uri) { + return nil, false + } + return o.Uri, true +} + +// HasUri returns a boolean if a field has been set. +func (o *DocArtifactUpdate) HasUri() bool { + if o != nil && !IsNil(o.Uri) { + return true + } + + return false +} + +// SetUri gets a reference to the given string and assigns it to the Uri field. +func (o *DocArtifactUpdate) SetUri(v string) { + o.Uri = &v +} + +// GetState returns the State field value if set, zero value otherwise. +func (o *DocArtifactUpdate) GetState() ArtifactState { + if o == nil || IsNil(o.State) { + var ret ArtifactState + return ret + } + return *o.State +} + +// GetStateOk returns a tuple with the State field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *DocArtifactUpdate) GetStateOk() (*ArtifactState, bool) { + if o == nil || IsNil(o.State) { + return nil, false + } + return o.State, true +} + +// HasState returns a boolean if a field has been set. +func (o *DocArtifactUpdate) HasState() bool { + if o != nil && !IsNil(o.State) { + return true + } + + return false +} + +// SetState gets a reference to the given ArtifactState and assigns it to the State field. +func (o *DocArtifactUpdate) SetState(v ArtifactState) { + o.State = &v +} + +func (o DocArtifactUpdate) MarshalJSON() ([]byte, error) { + toSerialize, err := o.ToMap() + if err != nil { + return []byte{}, err + } + return json.Marshal(toSerialize) +} + +func (o DocArtifactUpdate) ToMap() (map[string]interface{}, error) { + toSerialize := map[string]interface{}{} + toSerialize["artifactType"] = o.ArtifactType + if !IsNil(o.CustomProperties) { + toSerialize["customProperties"] = o.CustomProperties + } + if !IsNil(o.Description) { + toSerialize["description"] = o.Description + } + if !IsNil(o.ExternalId) { + toSerialize["externalId"] = o.ExternalId + } + if !IsNil(o.Uri) { + toSerialize["uri"] = o.Uri + } + if !IsNil(o.State) { + toSerialize["state"] = o.State + } + return toSerialize, nil +} + +type NullableDocArtifactUpdate struct { + value *DocArtifactUpdate + isSet bool +} + +func (v NullableDocArtifactUpdate) Get() *DocArtifactUpdate { + return v.value +} + +func (v *NullableDocArtifactUpdate) Set(val *DocArtifactUpdate) { + v.value = val + v.isSet = true +} + +func (v NullableDocArtifactUpdate) IsSet() bool { + return v.isSet +} + +func (v *NullableDocArtifactUpdate) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableDocArtifactUpdate(val *DocArtifactUpdate) *NullableDocArtifactUpdate { + return &NullableDocArtifactUpdate{value: val, isSet: true} +} + +func (v NullableDocArtifactUpdate) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableDocArtifactUpdate) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/pkg/openapi/model_model_artifact_create.go b/pkg/openapi/model_model_artifact_create.go index 7a07fdb1..b4498baa 100644 --- a/pkg/openapi/model_model_artifact_create.go +++ b/pkg/openapi/model_model_artifact_create.go @@ -19,6 +19,7 @@ var _ MappedNullable = &ModelArtifactCreate{} // ModelArtifactCreate An ML model artifact. type ModelArtifactCreate struct { + ArtifactType string `json:"artifactType"` // User provided custom properties which are not defined by its type. CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` // An optional description about the resource. @@ -46,7 +47,7 @@ type ModelArtifactCreate struct { // This constructor will assign default values to properties that have it defined, // and makes sure properties required by API are set, but the set of arguments // will change when the set of required properties is changed -func NewModelArtifactCreate() *ModelArtifactCreate { +func NewModelArtifactCreate(artifactType string) *ModelArtifactCreate { this := ModelArtifactCreate{} var state ArtifactState = ARTIFACTSTATE_UNKNOWN this.State = &state @@ -58,11 +59,37 @@ func NewModelArtifactCreate() *ModelArtifactCreate { // but it doesn't guarantee that properties required by API are set func NewModelArtifactCreateWithDefaults() *ModelArtifactCreate { this := ModelArtifactCreate{} + var artifactType string = "model-artifact" + this.ArtifactType = artifactType var state ArtifactState = ARTIFACTSTATE_UNKNOWN this.State = &state return &this } +// GetArtifactType returns the ArtifactType field value +func (o *ModelArtifactCreate) GetArtifactType() string { + if o == nil { + var ret string + return ret + } + + return o.ArtifactType +} + +// GetArtifactTypeOk returns a tuple with the ArtifactType field value +// and a boolean to check if the value has been set. +func (o *ModelArtifactCreate) GetArtifactTypeOk() (*string, bool) { + if o == nil { + return nil, false + } + return &o.ArtifactType, true +} + +// SetArtifactType sets field value +func (o *ModelArtifactCreate) SetArtifactType(v string) { + o.ArtifactType = v +} + // GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. func (o *ModelArtifactCreate) GetCustomProperties() map[string]MetadataValue { if o == nil || IsNil(o.CustomProperties) { @@ -425,6 +452,7 @@ func (o ModelArtifactCreate) MarshalJSON() ([]byte, error) { func (o ModelArtifactCreate) ToMap() (map[string]interface{}, error) { toSerialize := map[string]interface{}{} + toSerialize["artifactType"] = o.ArtifactType if !IsNil(o.CustomProperties) { toSerialize["customProperties"] = o.CustomProperties } diff --git a/pkg/openapi/model_model_artifact_update.go b/pkg/openapi/model_model_artifact_update.go index dd154aa0..a555d498 100644 --- a/pkg/openapi/model_model_artifact_update.go +++ b/pkg/openapi/model_model_artifact_update.go @@ -17,8 +17,9 @@ import ( // checks if the ModelArtifactUpdate type satisfies the MappedNullable interface at compile time var _ MappedNullable = &ModelArtifactUpdate{} -// ModelArtifactUpdate An ML model artifact. +// ModelArtifactUpdate An ML model artifact to be updated. type ModelArtifactUpdate struct { + ArtifactType string `json:"artifactType"` // User provided custom properties which are not defined by its type. CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` // An optional description about the resource. @@ -44,7 +45,7 @@ type ModelArtifactUpdate struct { // This constructor will assign default values to properties that have it defined, // and makes sure properties required by API are set, but the set of arguments // will change when the set of required properties is changed -func NewModelArtifactUpdate() *ModelArtifactUpdate { +func NewModelArtifactUpdate(artifactType string) *ModelArtifactUpdate { this := ModelArtifactUpdate{} var state ArtifactState = ARTIFACTSTATE_UNKNOWN this.State = &state @@ -56,11 +57,37 @@ func NewModelArtifactUpdate() *ModelArtifactUpdate { // but it doesn't guarantee that properties required by API are set func NewModelArtifactUpdateWithDefaults() *ModelArtifactUpdate { this := ModelArtifactUpdate{} + var artifactType string = "model-artifact" + this.ArtifactType = artifactType var state ArtifactState = ARTIFACTSTATE_UNKNOWN this.State = &state return &this } +// GetArtifactType returns the ArtifactType field value +func (o *ModelArtifactUpdate) GetArtifactType() string { + if o == nil { + var ret string + return ret + } + + return o.ArtifactType +} + +// GetArtifactTypeOk returns a tuple with the ArtifactType field value +// and a boolean to check if the value has been set. +func (o *ModelArtifactUpdate) GetArtifactTypeOk() (*string, bool) { + if o == nil { + return nil, false + } + return &o.ArtifactType, true +} + +// SetArtifactType sets field value +func (o *ModelArtifactUpdate) SetArtifactType(v string) { + o.ArtifactType = v +} + // GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. func (o *ModelArtifactUpdate) GetCustomProperties() map[string]MetadataValue { if o == nil || IsNil(o.CustomProperties) { @@ -391,6 +418,7 @@ func (o ModelArtifactUpdate) MarshalJSON() ([]byte, error) { func (o ModelArtifactUpdate) ToMap() (map[string]interface{}, error) { toSerialize := map[string]interface{}{} + toSerialize["artifactType"] = o.ArtifactType if !IsNil(o.CustomProperties) { toSerialize["customProperties"] = o.CustomProperties } diff --git a/pkg/openapi/model_registered_model.go b/pkg/openapi/model_registered_model.go index 372fb3e8..1806e314 100644 --- a/pkg/openapi/model_registered_model.go +++ b/pkg/openapi/model_registered_model.go @@ -25,7 +25,7 @@ type RegisteredModel struct { Description *string `json:"description,omitempty"` // The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance. ExternalId *string `json:"externalId,omitempty"` - // The client provided name of the artifact. This field is optional. If set, it must be unique among all the artifacts of the same artifact type within a database instance and cannot be changed once set. + // The client provided name of the model. It must be unique among all the RegisteredModels of the same type within a Model Registry instance and cannot be changed once set. Name string `json:"name"` // The unique server generated id of the resource. Id *string `json:"id,omitempty"` diff --git a/pkg/openapi/model_registered_model_create.go b/pkg/openapi/model_registered_model_create.go index 49edf663..8594df57 100644 --- a/pkg/openapi/model_registered_model_create.go +++ b/pkg/openapi/model_registered_model_create.go @@ -25,7 +25,7 @@ type RegisteredModelCreate struct { Description *string `json:"description,omitempty"` // The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance. ExternalId *string `json:"externalId,omitempty"` - // The client provided name of the artifact. This field is optional. If set, it must be unique among all the artifacts of the same artifact type within a database instance and cannot be changed once set. + // The client provided name of the model. It must be unique among all the RegisteredModels of the same type within a Model Registry instance and cannot be changed once set. Name string `json:"name"` Owner *string `json:"owner,omitempty"` State *RegisteredModelState `json:"state,omitempty"` From 68b5dc7ef88935183bc0043b70216343bc0da5be Mon Sep 17 00:00:00 2001 From: Eder Ignatowicz Date: Wed, 25 Sep 2024 12:03:34 -0400 Subject: [PATCH 10/13] temporarly bump on path-to-regexp (#427) Signed-off-by: Eder Ignatowicz --- clients/ui/frontend/package-lock.json | 6 +++--- clients/ui/frontend/package.json | 5 +++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/clients/ui/frontend/package-lock.json b/clients/ui/frontend/package-lock.json index 3975f834..971998f4 100644 --- a/clients/ui/frontend/package-lock.json +++ b/clients/ui/frontend/package-lock.json @@ -17425,9 +17425,9 @@ } }, "node_modules/path-to-regexp": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-2.2.1.tgz", - "integrity": "sha512-gu9bD6Ta5bwGrrU8muHzVOBFFREpp2iRkVfhBJahwJ6p6Xw20SjT0MxLnwkjOibQmGSYhiUnf2FLe7k+jcFmGQ==", + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-3.3.0.tgz", + "integrity": "sha512-qyCH421YQPS2WFDxDjftfc1ZR5WKQzVzqsp4n9M2kQhVOo/ByahFoUNJfl58kOcEGfQ//7weFTDhm+ss8Ecxgw==", "dev": true, "license": "MIT" }, diff --git a/clients/ui/frontend/package.json b/clients/ui/frontend/package.json index 8748f1b6..271f1b30 100644 --- a/clients/ui/frontend/package.json +++ b/clients/ui/frontend/package.json @@ -119,5 +119,10 @@ "eslint-plugin-prettier": "^5.0.0", "eslint-plugin-react": "^7.32.2", "eslint-plugin-react-hooks": "^4.6.0" + }, + "overrides": { + "serve": { + "path-to-regexp": "3.3.0" + } } } From 3761a6ac04b0c4c56ac36234c88feec799116b27 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 18:40:34 +0000 Subject: [PATCH 11/13] build(deps-dev): bump @types/jest from 29.5.12 to 29.5.13 in /clients/ui/frontend (#419) Bumps [@types/jest](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/jest) from 29.5.12 to 29.5.13. - [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases) - [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/jest) --- updated-dependencies: - dependency-name: "@types/jest" dependency-type: direct:development update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- clients/ui/frontend/package-lock.json | 9 ++++----- clients/ui/frontend/package.json | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/clients/ui/frontend/package-lock.json b/clients/ui/frontend/package-lock.json index 971998f4..57bb275f 100644 --- a/clients/ui/frontend/package-lock.json +++ b/clients/ui/frontend/package-lock.json @@ -35,7 +35,7 @@ "@testing-library/user-event": "14.5.2", "@types/classnames": "^2.3.1", "@types/dompurify": "^3.0.5", - "@types/jest": "^29.5.12", + "@types/jest": "^29.5.13", "@types/lodash-es": "^4.17.8", "@types/react-router-dom": "^5.3.3", "@types/showdown": "^2.0.3", @@ -4258,11 +4258,10 @@ } }, "node_modules/@types/jest": { - "version": "29.5.12", - "resolved": "https://registry.npmjs.org/@types/jest/-/jest-29.5.12.tgz", - "integrity": "sha512-eDC8bTvT/QhYdxJAulQikueigY5AsdBRH2yDKW3yveW7svY3+DzN84/2NUgkw10RTiJbWqZrTtoGVdYlvFJdLw==", + "version": "29.5.13", + "resolved": "https://registry.npmjs.org/@types/jest/-/jest-29.5.13.tgz", + "integrity": "sha512-wd+MVEZCHt23V0/L642O5APvspWply/rGY5BcW4SUETo2UzPU3Z26qr8jC2qxpimI2jjx9h7+2cj2FwIr01bXg==", "dev": true, - "license": "MIT", "dependencies": { "expect": "^29.0.0", "pretty-format": "^29.0.0" diff --git a/clients/ui/frontend/package.json b/clients/ui/frontend/package.json index 271f1b30..99b0c619 100644 --- a/clients/ui/frontend/package.json +++ b/clients/ui/frontend/package.json @@ -42,7 +42,7 @@ "@testing-library/user-event": "14.5.2", "@types/classnames": "^2.3.1", "@types/dompurify": "^3.0.5", - "@types/jest": "^29.5.12", + "@types/jest": "^29.5.13", "@types/lodash-es": "^4.17.8", "@types/react-router-dom": "^5.3.3", "@types/showdown": "^2.0.3", From 97f575137408b5407275b41270365b98c8ccbb2d Mon Sep 17 00:00:00 2001 From: Griffin Sullivan <48397354+Griffin-Sullivan@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:47:36 -0400 Subject: [PATCH 12/13] Add the Model Version Details and Model Version Archive pages (#428) Signed-off-by: Griffin-Sullivan --- .../ui/bff/internal/mocks/static_data_mock.go | 15 +- .../src/__mocks__/mockModelArtifact.ts | 45 +-- .../src/__mocks__/mockModelArtifactList.ts | 11 + .../cypress/cypress/pages/modelRegistry.ts | 4 + .../modelRegistryView/modelVersionArchive.ts | 128 +++++++ .../modelRegistryView/modelVersionDetails.ts | 73 ++++ .../registeredModelArchive.ts | 128 +++++++ .../cypress/cypress/support/commands/api.ts | 9 +- .../mocked/modelRegistry/modelRegistry.cy.ts | 3 +- .../tests/mocked/modelVersionArchive.cy.ts | 303 +++++++++++++++ .../tests/mocked/modelVersionDetails.cy.ts | 185 +++++++++ .../cypress/tests/mocked/modelVersions.cy.ts | 21 +- .../tests/mocked/registeredModelArchive.cy.ts | 355 ++++++++++++++++++ .../modelRegistry/ModelRegistryRoutes.tsx | 54 +++ .../ModelPropertiesDescriptionListGroup.tsx | 27 +- .../screens/ModelPropertiesTableRow.tsx | 78 ++-- .../ModelVersionDetails.tsx | 108 ++++++ .../ModelVersionDetailsHeaderActions.tsx | 86 +++++ .../ModelVersionDetailsTabs.tsx | 52 +++ .../ModelVersionDetailsView.tsx | 193 ++++++++++ .../ModelVersionSelector.tsx | 111 ++++++ .../screens/ModelVersionDetails/const.ts | 7 + .../ModelVersions/ModelDetailsView.tsx | 10 +- .../ModelVersions/ModelVersionListView.tsx | 216 ++++++----- .../screens/ModelVersions/ModelVersions.tsx | 13 +- .../ModelVersions/ModelVersionsTable.tsx | 9 +- .../ModelVersions/ModelVersionsTableRow.tsx | 103 ++--- .../ModelVersions/ModelVersionsTabs.tsx | 9 +- .../ArchiveModelVersionDetails.tsx | 84 +++++ .../ArchiveModelVersionDetailsBreadcrumb.tsx | 41 ++ .../ModelVersionArchiveDetails.tsx | 114 ++++++ .../ModelVersionArchiveDetailsBreadcrumb.tsx | 43 +++ .../ModelVersionsArchive.tsx | 60 +++ .../ModelVersionsArchiveListView.tsx | 95 +++++ .../ModelVersionsArchiveTable.tsx | 34 ++ .../RegisteredModelListView.tsx | 2 +- .../RegisteredModelTableRow.tsx | 18 +- .../RegisteredModelArchiveDetails.tsx | 109 ++++++ ...egisteredModelArchiveDetailsBreadcrumb.tsx | 28 ++ .../RegisteredModelsArchiveListView.tsx | 2 +- .../pages/modelRegistry/screens/routeUtils.ts | 9 +- .../app/pages/modelRegistry/screens/utils.ts | 27 ++ .../DashboardDescriptionListGroup.tsx | 4 +- .../src/components/DashboardHelpTooltip.tsx | 15 + .../EditableLabelsDescriptionListGroup.tsx | 4 +- .../EditableTextDescriptionListGroup.tsx | 4 +- .../InlineTruncatedClipboardCopy.scss | 4 + .../InlineTruncatedClipboardCopy.tsx | 31 ++ 48 files changed, 2835 insertions(+), 249 deletions(-) create mode 100644 clients/ui/frontend/src/__mocks__/mockModelArtifactList.ts create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionArchive.ts create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionDetails.ts create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registeredModelArchive.ts create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersionArchive.cy.ts create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersionDetails.cy.ts create mode 100644 clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/registeredModelArchive.cy.ts create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetails.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsHeaderActions.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsTabs.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsView.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionSelector.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/const.ts create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ArchiveModelVersionDetails.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ArchiveModelVersionDetailsBreadcrumb.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionArchiveDetails.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionArchiveDetailsBreadcrumb.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchive.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchiveListView.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchiveTable.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelArchiveDetails.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelArchiveDetailsBreadcrumb.tsx create mode 100644 clients/ui/frontend/src/components/DashboardHelpTooltip.tsx create mode 100644 clients/ui/frontend/src/components/InlineTruncatedClipboardCopy.scss create mode 100644 clients/ui/frontend/src/components/InlineTruncatedClipboardCopy.tsx diff --git a/clients/ui/bff/internal/mocks/static_data_mock.go b/clients/ui/bff/internal/mocks/static_data_mock.go index a156c8d5..05474fd8 100644 --- a/clients/ui/bff/internal/mocks/static_data_mock.go +++ b/clients/ui/bff/internal/mocks/static_data_mock.go @@ -82,7 +82,20 @@ func GetModelVersionMocks() []openapi.ModelVersion { State: stateToPointer(openapi.MODELVERSIONSTATE_LIVE), } - return []openapi.ModelVersion{model1, model2} + model3 := openapi.ModelVersion{ + CustomProperties: newCustomProperties(), + Name: "Version Three", + Description: stringToPointer("This version didn't improve stuff and things"), + ExternalId: stringToPointer("934589791"), + Id: stringToPointer("3"), + CreateTimeSinceEpoch: stringToPointer("1725282249921"), + LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"), + RegisteredModelId: "3", + Author: stringToPointer("Sherlock Holmes"), + State: stateToPointer(openapi.MODELVERSIONSTATE_ARCHIVED), + } + + return []openapi.ModelVersion{model1, model2, model3} } func GetModelVersionListMock() openapi.ModelVersionList { diff --git a/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts b/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts index 8f2bb628..8f5ac45b 100644 --- a/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts +++ b/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts @@ -1,34 +1,17 @@ -import { ModelArtifact, ModelArtifactState } from '~/app/types'; +import { ModelArtifact } from '~/app/types'; -type MockModelArtifact = { - id?: string; - name?: string; - uri?: string; - state?: ModelArtifactState; - author?: string; -}; - -export const mockModelArtifact = ({ - id = '1', - name = 'test', - uri = 'test', - state = ModelArtifactState.LIVE, - author = 'Author 1', -}: MockModelArtifact): ModelArtifact => ({ - id, - name, - externalID: '1234132asdfasdf', - description: '', - createTimeSinceEpoch: '1710404288975', - lastUpdateTimeSinceEpoch: '1710404288975', +export const mockModelArtifact = (partial?: Partial): ModelArtifact => ({ + createTimeSinceEpoch: '1712234877179', + id: '1', + lastUpdateTimeSinceEpoch: '1712234877179', + name: 'fraud detection model version 1', + description: 'Description of model version', + artifactType: 'model-artifact', customProperties: {}, - uri, - state, - author, - modelFormatName: 'test', - storageKey: 'test', - storagePath: 'test', - modelFormatVersion: 'test', - serviceAccountName: 'test', - artifactType: 'test', + storageKey: 'test storage key', + storagePath: 'test path', + uri: 's3://test-bucket/demo-models/test-path?endpoint=test-endpoint&defaultRegion=test-region', + modelFormatName: 'test model format', + modelFormatVersion: 'test version 1', + ...partial, }); diff --git a/clients/ui/frontend/src/__mocks__/mockModelArtifactList.ts b/clients/ui/frontend/src/__mocks__/mockModelArtifactList.ts new file mode 100644 index 00000000..8bf2ca20 --- /dev/null +++ b/clients/ui/frontend/src/__mocks__/mockModelArtifactList.ts @@ -0,0 +1,11 @@ +/* eslint-disable camelcase */ +import { ModelArtifactList } from '~/app/types'; + +export const mockModelArtifactList = ({ + items = [], +}: Partial): ModelArtifactList => ({ + items, + nextPageToken: '', + pageSize: 0, + size: 1, +}); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts index bffc77fe..8af39893 100644 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts @@ -103,6 +103,10 @@ class ModelRegistry { cy.findByTestId('registered-models-table-toolbar').should('exist'); } + shouldArchiveModelVersionsEmpty() { + cy.findByTestId('empty-archive-model-versions').should('exist'); + } + tabEnabled() { appChrome.findNavItem('Model Registry').should('exist'); return this; diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionArchive.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionArchive.ts new file mode 100644 index 00000000..81afa063 --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionArchive.ts @@ -0,0 +1,128 @@ +import { TableRow } from '~/__tests__/cypress/cypress/pages/components/table'; +import { Modal } from '~/__tests__/cypress/cypress/pages/components/Modal'; + +class ArchiveVersionTableRow extends TableRow { + findName() { + return this.find().findByTestId('model-version-name'); + } + + findDescription() { + return this.find().findByTestId('model-version-description'); + } + + findLabelPopoverText() { + return this.find().findByTestId('popover-label-text'); + } + + findLabelModalText() { + return this.find().findByTestId('modal-label-text'); + } + + shouldContainsPopoverLabels(labels: string[]) { + cy.findByTestId('popover-label-group').within(() => labels.map((label) => cy.contains(label))); + return this; + } +} + +class RestoreVersionModal extends Modal { + constructor() { + super('Restore version?'); + } + + findRestoreButton() { + return cy.findByTestId('modal-submit-button'); + } +} + +class ArchiveVersionModal extends Modal { + constructor() { + super('Archive version?'); + } + + findArchiveButton() { + return cy.findByTestId('modal-submit-button'); + } + + findModalTextInput() { + return cy.findByTestId('confirm-archive-input'); + } +} + +class ModelVersionArchive { + private wait() { + cy.findByTestId('app-page-title').should('exist'); + cy.testA11y(); + } + + visit() { + const rmId = '1'; + const preferredModelRegistry = 'modelregistry-sample'; + cy.visit(`/modelRegistry/${preferredModelRegistry}/registeredModels/${rmId}/versions/archive`); + this.wait(); + } + + visitArchiveVersionDetail() { + const mvId = '2'; + const rmId = '1'; + const preferredModelRegistry = 'modelregistry-sample'; + cy.visit( + `/modelRegistry/${preferredModelRegistry}/registeredModels/${rmId}/versions/archive/${mvId}`, + ); + } + + visitModelVersionList() { + const rmId = '1'; + const preferredModelRegistry = 'modelregistry-sample'; + cy.visit(`/modelRegistry/${preferredModelRegistry}/registeredModels/${rmId}/versions`); + this.wait(); + } + + visitModelVersionDetails() { + const mvId = '3'; + const rmId = '1'; + const preferredModelRegistry = 'modelregistry-sample'; + cy.visit(`/modelRegistry/${preferredModelRegistry}/registeredModels/${rmId}/versions/${mvId}`); + this.wait(); + } + + findModelVersionsTableKebab() { + return cy.findByTestId('model-versions-table-kebab-action'); + } + + shouldArchiveVersionsEmpty() { + cy.findByTestId('empty-archive-state').should('exist'); + } + + findArchiveVersionBreadcrumbItem() { + return cy.findByTestId('archive-version-page-breadcrumb'); + } + + findArchiveVersionTable() { + return cy.findByTestId('model-versions-archive-table'); + } + + findArchiveVersionsTableRows() { + return this.findArchiveVersionTable().find('tbody tr'); + } + + findRestoreButton() { + return cy.findByTestId('restore-button'); + } + + getRow(name: string) { + return new ArchiveVersionTableRow(() => + this.findArchiveVersionTable() + .find(`[data-label="Version name"]`) + .contains(name) + .parents('tr'), + ); + } + + findModelVersionsDetailsHeaderAction() { + return cy.findByTestId('model-version-details-action-button'); + } +} + +export const modelVersionArchive = new ModelVersionArchive(); +export const restoreVersionModal = new RestoreVersionModal(); +export const archiveVersionModal = new ArchiveVersionModal(); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionDetails.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionDetails.ts new file mode 100644 index 00000000..b311fa89 --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionDetails.ts @@ -0,0 +1,73 @@ +class ModelVersionDetails { + visit() { + const preferredModelRegistry = 'modelregistry-sample'; + const rmId = '1'; + const mvId = '1'; + cy.visit(`/modelRegistry/${preferredModelRegistry}/registeredModels/${rmId}/versions/${mvId}`); + this.wait(); + } + + private wait() { + cy.findByTestId('app-page-title').should('exist'); + cy.testA11y(); + } + + findVersionId() { + return cy.findByTestId('model-version-id'); + } + + findDescription() { + return cy.findByTestId('model-version-description'); + } + + findMoreLabelsButton() { + return cy.findByTestId('label-group').find('button'); + } + + findStorageURI() { + return cy.findByTestId('storage-uri'); + } + + findStorageEndpoint() { + return cy.findByTestId('storage-endpoint'); + } + + findStorageRegion() { + return cy.findByTestId('storage-region'); + } + + findStorageBucket() { + return cy.findByTestId('storage-bucket'); + } + + findStoragePath() { + return cy.findByTestId('storage-path'); + } + + shouldContainsModalLabels(labels: string[]) { + cy.findByTestId('label-group').within(() => labels.map((label) => cy.contains(label))); + return this; + } + + findModelVersionDropdownButton() { + return cy.findByTestId('model-version-toggle-button'); + } + + findModelVersionDropdownSearch() { + return cy.findByTestId('search-input'); + } + + findModelVersionDropdownItem(name: string) { + return cy.findByTestId('model-version-selector-list').find('li').contains(name); + } + + findDetailsTab() { + return cy.findByTestId('model-versions-details-tab'); + } + + findRegisteredDeploymentsTab() { + return cy.findByTestId('deployments-tab'); + } +} + +export const modelVersionDetails = new ModelVersionDetails(); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registeredModelArchive.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registeredModelArchive.ts new file mode 100644 index 00000000..05b1186e --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registeredModelArchive.ts @@ -0,0 +1,128 @@ +import { TableRow } from '~/__tests__/cypress/cypress/pages/components/table'; +import { Modal } from '~/__tests__/cypress/cypress/pages/components/Modal'; + +class ArchiveModelTableRow extends TableRow { + findName() { + return this.find().findByTestId('model-name'); + } + + findDescription() { + return this.find().findByTestId('description'); + } + + findLabelPopoverText() { + return this.find().findByTestId('popover-label-text'); + } + + findLabelModalText() { + return this.find().findByTestId('modal-label-text'); + } + + shouldContainsPopoverLabels(labels: string[]) { + cy.findByTestId('popover-label-group').within(() => labels.map((label) => cy.contains(label))); + return this; + } +} + +class RestoreModelModal extends Modal { + constructor() { + super('Restore model?'); + } + + findRestoreButton() { + return cy.findByTestId('modal-submit-button'); + } +} + +class ArchiveModelModal extends Modal { + constructor() { + super('Archive model?'); + } + + findArchiveButton() { + return cy.findByTestId('modal-submit-button'); + } + + findModalTextInput() { + return cy.findByTestId('confirm-archive-input'); + } +} + +class ModelArchive { + private wait() { + cy.findByTestId('app-page-title').should('exist'); + cy.testA11y(); + } + + visit() { + const preferredModelRegistry = 'modelregistry-sample'; + cy.visit(`/modelRegistry/${preferredModelRegistry}/registeredModels/archive`); + this.wait(); + } + + visitArchiveModelDetail() { + const rmId = '2'; + const preferredModelRegistry = 'modelregistry-sample'; + cy.visit(`/modelRegistry/${preferredModelRegistry}/registeredModels/archive/${rmId}`); + } + + visitArchiveModelVersionList() { + const rmId = '2'; + const preferredModelRegistry = 'modelregistry-sample'; + cy.visit(`/modelRegistry/${preferredModelRegistry}/registeredModels/archive/${rmId}/versions`); + } + + visitModelList() { + cy.visit('/modelRegistry/modelregistry-sample'); + this.wait(); + } + + visitModelDetails() { + const rmId = '2'; + const preferredModelRegistry = 'modelregistry-sample'; + cy.visit(`/modelRegistry/${preferredModelRegistry}/registeredModels/${rmId}`); + this.wait(); + } + + findTableKebabMenu() { + return cy.findByTestId('registered-models-table-kebab-action'); + } + + shouldArchiveVersionsEmpty() { + cy.findByTestId('empty-archive-model-state').should('exist'); + } + + findArchiveModelBreadcrumbItem() { + return cy.findByTestId('archive-model-page-breadcrumb'); + } + + findRegisteredModelsArchiveTableHeaderButton(name: string) { + return this.findArchiveModelTable().find('thead').findByRole('button', { name }); + } + + findArchiveModelTable() { + return cy.findByTestId('registered-models-archive-table'); + } + + findArchiveModelsTableRows() { + return this.findArchiveModelTable().find('tbody tr'); + } + + findRestoreButton() { + return cy.findByTestId('restore-button'); + } + + getRow(name: string) { + return new ArchiveModelTableRow(() => + this.findArchiveModelTable().find(`[data-label="Model name"]`).contains(name).parents('tr'), + ); + } + + findModelVersionsDetailsHeaderAction() { + return cy.findByTestId('model-version-action-toggle'); + } +} + +export const registeredModelArchive = new ModelArchive(); +export const restoreModelModal = new RestoreModelModal(); +export const archiveModelModal = new ArchiveModelModal(); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts index edf7c007..1b07dfb8 100644 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts @@ -10,6 +10,9 @@ import type { RegisteredModelList, } from '~/app/types'; +const MODEL_REGISTRY_API_VERSION = 'v1'; +export { MODEL_REGISTRY_API_VERSION }; + type SuccessErrorResponse = { success: boolean; error?: string; @@ -65,21 +68,21 @@ declare global { options: { path: { modelRegistryName: string; apiVersion: string; registeredModelId: number }; }, - response: ApiResponse, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId', options: { path: { modelRegistryName: string; apiVersion: string; modelVersionId: number }; }, - response: ApiResponse, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId/artifacts', options: { path: { modelRegistryName: string; apiVersion: string; modelVersionId: number }; }, - response: ApiResponse, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'POST /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId/artifacts', diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts index 91592ae1..dc9940b2 100644 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts @@ -8,8 +8,7 @@ import { labelModal, modelRegistry } from '~/__tests__/cypress/cypress/pages/mod import { mockBFFResponse } from '~/__mocks__/mockBFFResponse'; import type { ModelRegistry, ModelVersion, RegisteredModel } from '~/app/types'; import { be } from '~/__tests__/cypress/cypress/utils/should'; - -const MODEL_REGISTRY_API_VERSION = 'v1'; +import { MODEL_REGISTRY_API_VERSION } from '~/__tests__/cypress/cypress/support/commands/api'; type HandlersProps = { modelRegistries?: ModelRegistry[]; diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersionArchive.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersionArchive.cy.ts new file mode 100644 index 00000000..a1a3cef4 --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersionArchive.cy.ts @@ -0,0 +1,303 @@ +/* eslint-disable camelcase */ +import { mockRegisteredModelList } from '~/__mocks__/mockRegisteredModelsList'; +import { mockModelVersionList } from '~/__mocks__/mockModelVersionList'; +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; +import { verifyRelativeURL } from '~/__tests__/cypress/cypress/utils/url'; +import { labelModal } from '~/__tests__/cypress/cypress/pages/modelRegistry'; +import type { ModelRegistry, ModelVersion } from '~/app/types'; +import { ModelState } from '~/app/types'; +import { mockModelRegistry } from '~/__mocks__/mockModelRegistry'; +import { mockBFFResponse } from '~/__mocks__/utils'; +import { modelVersionArchive } from '~/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionArchive'; +import { MODEL_REGISTRY_API_VERSION } from '~/__tests__/cypress/cypress/support/commands/api'; + +type HandlersProps = { + registeredModelsSize?: number; + modelVersions?: ModelVersion[]; + modelRegistries?: ModelRegistry[]; +}; + +const initIntercepts = ({ + registeredModelsSize = 4, + modelVersions = [ + mockModelVersion({ + name: 'model version 1', + author: 'Author 1', + id: '1', + labels: [ + 'Financial data', + 'Fraud detection', + 'Test label', + 'Machine learning', + 'Next data to be overflow', + 'Test label x', + 'Test label y', + 'Test label z', + ], + state: ModelState.ARCHIVED, + }), + mockModelVersion({ id: '2', name: 'model version 2', state: ModelState.ARCHIVED }), + mockModelVersion({ id: '3', name: 'model version 3' }), + ], + modelRegistries = [ + mockModelRegistry({ + name: 'modelregistry-sample', + description: 'New model registry', + displayName: 'Model Registry Sample', + }), + mockModelRegistry({ + name: 'modelregistry-sample-2', + description: 'New model registry 2', + displayName: 'Model Registry Sample 2', + }), + ], +}: HandlersProps) => { + cy.interceptApi( + `GET /api/:apiVersion/model_registry`, + { + path: { apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockBFFResponse(modelRegistries), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models`, + { + path: { modelRegistryName: 'modelregistry-sample', apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockBFFResponse(mockRegisteredModelList({ size: registeredModelsSize })), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId/versions`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockBFFResponse( + mockModelVersionList({ + items: modelVersions, + }), + ), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockBFFResponse(mockRegisteredModel({ name: 'test-1' })), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 2, + }, + }, + mockBFFResponse( + mockModelVersion({ id: '2', name: 'model version 2', state: ModelState.ARCHIVED }), + ), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 3, + }, + }, + mockBFFResponse(mockModelVersion({ id: '3', name: 'model version 3', state: ModelState.LIVE })), + ); +}; + +describe('Model version archive list', () => { + it('No archive versions in the selected registered model', () => { + initIntercepts({ modelVersions: [mockModelVersion({ id: '3', name: 'model version 2' })] }); + modelVersionArchive.visitModelVersionList(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions'); + // TODO: Uncomment when dropdowns are fixed and remove the visit after the comments + // modelVersionArchive + // .findModelVersionsTableKebab() + // .findDropdownItem('View archived versions') + // .click(); + modelVersionArchive.visit(); + modelVersionArchive.shouldArchiveVersionsEmpty(); + }); + + it('Archived version details browser back button should lead to archived versions table', () => { + initIntercepts({}); + modelVersionArchive.visit(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions/archive'); + modelVersionArchive.findArchiveVersionBreadcrumbItem().contains('Archived version'); + const archiveVersionRow = modelVersionArchive.getRow('model version 2'); + archiveVersionRow.findName().contains('model version 2').click(); + verifyRelativeURL( + '/modelRegistry/modelregistry-sample/registeredModels/1/versions/archive/2/details', + ); + cy.go('back'); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions/archive'); + modelVersionArchive.findArchiveVersionBreadcrumbItem().contains('Archived version'); + archiveVersionRow.findName().contains('model version 2').should('exist'); + }); + + it('Archive version list', () => { + initIntercepts({}); + modelVersionArchive.visit(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions/archive'); + + //breadcrumb + modelVersionArchive.findArchiveVersionBreadcrumbItem().contains('Archived version'); + + // name, last modified, owner, labels modal + modelVersionArchive.findArchiveVersionTable().should('be.visible'); + modelVersionArchive.findArchiveVersionsTableRows().should('have.length', 2); + + const archiveVersionRow = modelVersionArchive.getRow('model version 1'); + + archiveVersionRow.findLabelModalText().contains('5 more'); + archiveVersionRow.findLabelModalText().click(); + labelModal.shouldContainsModalLabels([ + 'Financial', + 'Financial data', + 'Fraud detection', + 'Test label', + 'Machine learning', + 'Next data to be overflow', + 'Test label x', + 'Test label y', + 'Test label y', + ]); + labelModal.findCloseModal().click(); + }); +}); + +// TODO: Uncomment when we have restoring and archiving mocked +// describe('Restoring archive version', () => { +// it('Restore from archive table', () => { +// cy.interceptApi( +// 'PATCH /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId', +// { +// path: { +// modelRegistryName: 'modelregistry-sample', +// apiVersion: MODEL_REGISTRY_API_VERSION, +// modelVersionId: 2, +// }, +// }, +// mockModelVersion({}), +// ).as('versionRestored'); + +// initIntercepts({}); +// modelVersionArchive.visit(); + +// const archiveVersionRow = modelVersionArchive.getRow('model version 2'); +// archiveVersionRow.findKebabAction('Restore version').click(); + +// restoreVersionModal.findRestoreButton().click(); + +// cy.wait('@versionRestored').then((interception) => { +// expect(interception.request.body).to.eql({ +// state: 'LIVE', +// }); +// }); +// }); + +// it('Restore from archive version details', () => { +// cy.interceptApi( +// 'PATCH /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId', +// { +// path: { +// modelRegistryName: 'modelregistry-sample', +// apiVersion: MODEL_REGISTRY_API_VERSION, +// modelVersionId: 2, +// }, +// }, +// mockModelVersion({}), +// ).as('versionRestored'); + +// initIntercepts({}); +// modelVersionArchive.visitArchiveVersionDetail(); + +// modelVersionArchive.findRestoreButton().click(); +// restoreVersionModal.findRestoreButton().click(); + +// cy.wait('@versionRestored').then((interception) => { +// expect(interception.request.body).to.eql({ +// state: 'LIVE', +// }); +// }); +// }); +// }); + +// describe('Archiving version', () => { +// it('Archive version from versions table', () => { +// cy.interceptApi( +// 'PATCH /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId', +// { +// path: { +// serviceName: 'modelregistry-sample', +// apiVersion: MODEL_REGISTRY_API_VERSION, +// modelVersionId: 3, +// }, +// }, +// mockModelVersion({}), +// ).as('versionArchived'); + +// initIntercepts({}); +// modelVersionArchive.visitModelVersionList(); + +// const modelVersionRow = modelRegistry.getModelVersionRow('model version 3'); +// modelVersionRow.findKebabAction('Archive model version').click(); +// archiveVersionModal.findArchiveButton().should('be.disabled'); +// archiveVersionModal.findModalTextInput().fill('model version 3'); +// archiveVersionModal.findArchiveButton().should('be.enabled').click(); +// cy.wait('@versionArchived').then((interception) => { +// expect(interception.request.body).to.eql({ +// state: 'ARCHIVED', +// }); +// }); +// }); + +// it('Archive version from versions details', () => { +// cy.interceptApi( +// 'PATCH /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId', +// { +// path: { +// serviceName: 'modelregistry-sample', +// apiVersion: MODEL_REGISTRY_API_VERSION, +// modelVersionId: 3, +// }, +// }, +// mockModelVersion({}), +// ).as('versionArchived'); + +// initIntercepts({}); +// modelVersionArchive.visitModelVersionDetails(); +// modelVersionArchive +// .findModelVersionsDetailsHeaderAction() +// .findDropdownItem('Archive version') +// .click(); + +// archiveVersionModal.findArchiveButton().should('be.disabled'); +// archiveVersionModal.findModalTextInput().fill('model version 3'); +// archiveVersionModal.findArchiveButton().should('be.enabled').click(); +// cy.wait('@versionArchived').then((interception) => { +// expect(interception.request.body).to.eql({ +// state: 'ARCHIVED', +// }); +// }); +// }); +// }); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersionDetails.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersionDetails.cy.ts new file mode 100644 index 00000000..3bcf9ece --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersionDetails.cy.ts @@ -0,0 +1,185 @@ +/* eslint-disable camelcase */ +import { verifyRelativeURL } from '~/__tests__/cypress/cypress/utils/url'; +import { mockModelRegistry } from '~/__mocks__/mockModelRegistry'; +import { mockBFFResponse } from '~/__mocks__/utils'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; +import { mockModelVersionList } from '~/__mocks__/mockModelVersionList'; +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import { mockModelArtifactList } from '~/__mocks__/mockModelArtifactList'; +import { mockModelArtifact } from '~/__mocks__/mockModelArtifact'; +import type { ModelRegistry } from '~/app/types'; +import { MODEL_REGISTRY_API_VERSION } from '~/__tests__/cypress/cypress/support/commands/api'; +import { modelVersionDetails } from '~/__tests__/cypress/cypress/pages/modelRegistryView/modelVersionDetails'; + +type HandlersProps = { + modelRegistries?: ModelRegistry[]; +}; + +const initIntercepts = ({ + modelRegistries = [ + mockModelRegistry({ + name: 'modelregistry-sample', + description: 'New model registry', + displayName: 'Model Registry Sample', + }), + mockModelRegistry({ + name: 'modelregistry-sample-2', + description: 'New model registry 2', + displayName: 'Model Registry Sample 2', + }), + ], +}: HandlersProps) => { + cy.interceptApi( + `GET /api/:apiVersion/model_registry`, + { + path: { apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockBFFResponse(modelRegistries), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockBFFResponse(mockRegisteredModel({})), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId/versions`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockBFFResponse( + mockModelVersionList({ + items: [ + mockModelVersion({ name: 'Version 1', author: 'Author 1', registeredModelId: '1' }), + mockModelVersion({ + author: 'Author 2', + registeredModelId: '1', + id: '2', + name: 'Version 2', + }), + ], + }), + ), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 1, + }, + }, + mockBFFResponse( + mockModelVersion({ + id: '1', + name: 'Version 1', + labels: [ + 'Testing label', + 'Financial data', + 'Fraud detection', + 'Long label data to be truncated abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc abc', + 'Machine learning', + 'Next data to be overflow', + 'Label x', + 'Label y', + 'Label z', + ], + }), + ), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 2, + }, + }, + mockBFFResponse(mockModelVersion({ id: '2', name: 'Version 2' })), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId/artifacts`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 1, + }, + }, + mockBFFResponse( + mockModelArtifactList({ + items: [ + mockModelArtifact({}), + mockModelArtifact({ + author: 'Author 2', + id: '2', + name: 'Artifact 2', + }), + ], + }), + ), + ); +}; + +describe('Model version details', () => { + describe('Details tab', () => { + beforeEach(() => { + initIntercepts({}); + modelVersionDetails.visit(); + }); + + it('Model version details page header', () => { + verifyRelativeURL( + '/modelRegistry/modelregistry-sample/registeredModels/1/versions/1/details', + ); + cy.findByTestId('app-page-title').should('have.text', 'Version 1'); + cy.findByTestId('breadcrumb-version-name').should('have.text', 'Version 1'); + }); + + it('Model version details tab', () => { + modelVersionDetails.findVersionId().contains('1'); + modelVersionDetails.findDescription().should('have.text', 'Description of model version'); + modelVersionDetails.findMoreLabelsButton().contains('6 more'); + modelVersionDetails.findMoreLabelsButton().click(); + modelVersionDetails.shouldContainsModalLabels([ + 'Testing label', + 'Financial', + 'Financial data', + 'Fraud detection', + 'Machine learning', + 'Next data to be overflow', + 'Label x', + 'Label y', + 'Label z', + ]); + modelVersionDetails.findStorageEndpoint().contains('test-endpoint'); + modelVersionDetails.findStorageRegion().contains('test-region'); + modelVersionDetails.findStorageBucket().contains('test-bucket'); + modelVersionDetails.findStoragePath().contains('demo-models/test-path'); + }); + + it('Switching model versions', () => { + modelVersionDetails.findVersionId().contains('1'); + modelVersionDetails.findModelVersionDropdownButton().click(); + modelVersionDetails.findModelVersionDropdownSearch().fill('Version 2'); + modelVersionDetails.findModelVersionDropdownItem('Version 2').click(); + modelVersionDetails.findVersionId().contains('2'); + }); + }); +}); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersions.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersions.cy.ts index 31a7ab37..4fca0caa 100644 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersions.cy.ts +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelVersions.cy.ts @@ -9,8 +9,7 @@ import { verifyRelativeURL } from '~/__tests__/cypress/cypress/utils/url'; import { mockModelRegistry } from '~/__mocks__/mockModelRegistry'; import { mockModelVersion } from '~/__mocks__/mockModelVersion'; import { mockBFFResponse } from '~/__mocks__/utils'; - -const MODEL_REGISTRY_API_VERSION = 'v1'; +import { MODEL_REGISTRY_API_VERSION } from '~/__tests__/cypress/cypress/support/commands/api'; type HandlersProps = { registeredModelsSize?: number; @@ -99,7 +98,7 @@ const initIntercepts = ({ modelVersionId: 1, }, }, - mockModelVersion({ id: '1', name: 'model version' }), + mockBFFResponse(mockModelVersion({ id: '1', name: 'model version' })), ); }; @@ -209,14 +208,12 @@ describe('Model Versions', () => { modelRegistry.visit(); const registeredModelRow = modelRegistry.getRow('Fraud detection model'); registeredModelRow.findName().contains('Fraud detection model').click(); - verifyRelativeURL(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); - // TODO: Uncomment when we have model version details - // const modelVersionRow = modelRegistry.getModelVersionRow('model version'); - // modelVersionRow.findModelVersionName().contains('model version').click(); - // verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions/1/details'); - // cy.findByTestId('app-page-title').should('have.text', 'model version'); - // cy.findByTestId('breadcrumb-version-name').should('have.text', 'model version'); - // cy.go('back'); - // verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions'); + const modelVersionRow = modelRegistry.getModelVersionRow('model version'); + modelVersionRow.findModelVersionName().contains('model version').click(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions/1/details'); + cy.findByTestId('app-page-title').should('have.text', 'model version'); + cy.findByTestId('breadcrumb-version-name').should('have.text', 'model version'); + cy.go('back'); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/1/versions'); }); }); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/registeredModelArchive.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/registeredModelArchive.cy.ts new file mode 100644 index 00000000..5abd877d --- /dev/null +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/registeredModelArchive.cy.ts @@ -0,0 +1,355 @@ +/* eslint-disable camelcase */ +import { mockRegisteredModelList } from '~/__mocks__/mockRegisteredModelsList'; +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; +import { verifyRelativeURL } from '~/__tests__/cypress/cypress/utils/url'; +import { labelModal, modelRegistry } from '~/__tests__/cypress/cypress/pages/modelRegistry'; +import { mockModelVersionList } from '~/__mocks__/mockModelVersionList'; +import { be } from '~/__tests__/cypress/cypress/utils/should'; +import type { ModelRegistry, ModelVersion, RegisteredModel } from '~/app/types'; +import { ModelState } from '~/app/types'; +import { mockBFFResponse } from '~/__mocks__/utils'; +import { mockModelRegistry } from '~/__mocks__/mockModelRegistry'; +import { MODEL_REGISTRY_API_VERSION } from '~/__tests__/cypress/cypress/support/commands/api'; +import { registeredModelArchive } from '~/__tests__/cypress/cypress/pages/modelRegistryView/registeredModelArchive'; + +type HandlersProps = { + registeredModels?: RegisteredModel[]; + modelVersions?: ModelVersion[]; + modelRegistries?: ModelRegistry[]; +}; + +const initIntercepts = ({ + registeredModels = [ + mockRegisteredModel({ + name: 'model 1', + id: '1', + labels: [ + 'Financial data', + 'Fraud detection', + 'Test label', + 'Machine learning', + 'Next data to be overflow', + 'Test label x', + 'Test label y', + 'Test label z', + ], + state: ModelState.ARCHIVED, + }), + mockRegisteredModel({ id: '2', name: 'model 2', state: ModelState.ARCHIVED }), + mockRegisteredModel({ id: '3', name: 'model 3' }), + mockRegisteredModel({ id: '4', name: 'model 4' }), + ], + modelVersions = [ + mockModelVersion({ author: 'Author 1', registeredModelId: '2' }), + mockModelVersion({ name: 'model version' }), + ], + modelRegistries = [ + mockModelRegistry({ + name: 'modelregistry-sample', + description: 'New model registry', + displayName: 'Model Registry Sample', + }), + mockModelRegistry({ + name: 'modelregistry-sample-2', + description: 'New model registry 2', + displayName: 'Model Registry Sample 2', + }), + ], +}: HandlersProps) => { + cy.interceptApi( + `GET /api/:apiVersion/model_registry`, + { + path: { apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockBFFResponse(modelRegistries), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models`, + { + path: { modelRegistryName: 'modelregistry-sample', apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockBFFResponse(mockRegisteredModelList({ items: registeredModels })), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 1, + }, + }, + mockBFFResponse(mockModelVersion({ id: '1', name: 'Version 2' })), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId/versions`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 2, + }, + }, + mockBFFResponse(mockModelVersionList({ items: modelVersions })), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 2, + }, + }, + mockBFFResponse(mockRegisteredModel({ id: '2', name: 'model 2', state: ModelState.ARCHIVED })), + ); + + cy.interceptApi( + `GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId`, + { + path: { + modelRegistryName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 3, + }, + }, + mockBFFResponse(mockRegisteredModel({ id: '3', name: 'model 3' })), + ); +}; + +describe('Model archive list', () => { + it('No archive models in the selected model registry', () => { + initIntercepts({ + registeredModels: [], + }); + registeredModelArchive.visit(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/archive'); + registeredModelArchive.shouldArchiveVersionsEmpty(); + }); + + it('Archived model details browser back button should lead to archived models table', () => { + initIntercepts({}); + registeredModelArchive.visit(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/archive'); + registeredModelArchive.findArchiveModelBreadcrumbItem().contains('Archived models'); + const archiveModelRow = registeredModelArchive.getRow('model 2'); + archiveModelRow.findName().contains('model 2').click(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/archive/2/versions'); + cy.findByTestId('app-page-title').should('have.text', 'model 2Archived'); + cy.go('back'); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/archive'); + registeredModelArchive.findArchiveModelTable().should('be.visible'); + }); + + it('Archived model with no versions', () => { + initIntercepts({ modelVersions: [] }); + registeredModelArchive.visit(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/archive'); + registeredModelArchive.findArchiveModelBreadcrumbItem().contains('Archived models'); + const archiveModelRow = registeredModelArchive.getRow('model 2'); + archiveModelRow.findName().contains('model 2').click(); + modelRegistry.shouldArchiveModelVersionsEmpty(); + }); + + it('Archived model flow', () => { + initIntercepts({}); + registeredModelArchive.visitArchiveModelVersionList(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/archive/2/versions'); + + modelRegistry.findModelVersionsTable().should('be.visible'); + modelRegistry.findModelVersionsTableRows().should('have.length', 2); + const version = modelRegistry.getModelVersionRow('model version'); + version.findModelVersionName().contains('model version').click(); + verifyRelativeURL( + '/modelRegistry/modelregistry-sample/registeredModels/archive/2/versions/1/details', + ); + cy.go('back'); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/archive/2/versions'); + }); + + it('Archive models list', () => { + initIntercepts({}); + registeredModelArchive.visit(); + verifyRelativeURL('/modelRegistry/modelregistry-sample/registeredModels/archive'); + + //breadcrumb + registeredModelArchive.findArchiveModelBreadcrumbItem().contains('Archived models'); + + // name, last modified, owner, labels modal + registeredModelArchive.findArchiveModelTable().should('be.visible'); + registeredModelArchive.findArchiveModelsTableRows().should('have.length', 2); + + const archiveModelRow = registeredModelArchive.getRow('model 1'); + + archiveModelRow.findLabelModalText().contains('5 more'); + archiveModelRow.findLabelModalText().click(); + labelModal.shouldContainsModalLabels([ + 'Financial', + 'Financial data', + 'Fraud detection', + 'Test label', + 'Machine learning', + 'Next data to be overflow', + 'Test label x', + 'Test label y', + 'Test label y', + ]); + labelModal.findCloseModal().click(); + + // sort by Last modified + registeredModelArchive + .findRegisteredModelsArchiveTableHeaderButton('Last modified') + .should(be.sortAscending); + registeredModelArchive.findRegisteredModelsArchiveTableHeaderButton('Last modified').click(); + registeredModelArchive + .findRegisteredModelsArchiveTableHeaderButton('Last modified') + .should(be.sortDescending); + + // sort by Model name + registeredModelArchive.findRegisteredModelsArchiveTableHeaderButton('Model name').click(); + registeredModelArchive + .findRegisteredModelsArchiveTableHeaderButton('Model name') + .should(be.sortAscending); + registeredModelArchive.findRegisteredModelsArchiveTableHeaderButton('Model name').click(); + registeredModelArchive + .findRegisteredModelsArchiveTableHeaderButton('Model name') + .should(be.sortDescending); + }); +}); + +// TODO: Uncomment when dropdowns are fixed +// it('Opens the detail page when we select "View Details" from action menu', () => { +// initIntercepts({}); +// registeredModelArchive.visit(); +// const archiveModelRow = registeredModelArchive.getRow('model 2'); +// archiveModelRow.findKebabAction('View details').click(); +// cy.location('pathname').should( +// 'be.equals', +// '/modelRegistry/modelregistry-sample/registeredModels/archive/2/details', +// ); +// }); + +// TODO: Uncomment when we have mock data for restoring and archiving +// describe('Restoring archive model', () => { +// it('Restore from archive models table', () => { +// cy.interceptApi( +// 'PATCH /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId', +// { +// path: { +// modelRegistryName: 'modelregistry-sample', +// apiVersion: MODEL_REGISTRY_API_VERSION, +// registeredModelId: 2, +// }, +// }, +// mockBFFResponse(mockRegisteredModel({ id: '2', name: 'model 2', state: ModelState.LIVE })), +// ).as('modelRestored'); + +// initIntercepts({}); +// registeredModelArchive.visit(); + +// const archiveModelRow = registeredModelArchive.getRow('model 2'); +// archiveModelRow.findKebabAction('Restore model').click(); + +// restoreModelModal.findRestoreButton().click(); + +// cy.wait('@modelRestored').then((interception) => { +// expect(interception.request.body).to.eql({ +// state: 'LIVE', +// }); +// }); +// }); + +// it('Restore from archive model details', () => { +// cy.interceptApi( +// 'PATCH /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId', +// { +// path: { +// serviceName: 'modelregistry-sample', +// apiVersion: MODEL_REGISTRY_API_VERSION, +// registeredModelId: 2, +// }, +// }, +// mockRegisteredModel({ id: '2', name: 'model 2', state: ModelState.LIVE }), +// ).as('modelRestored'); + +// initIntercepts({}); +// registeredModelArchive.visitArchiveModelDetail(); + +// registeredModelArchive.findRestoreButton().click(); +// restoreModelModal.findRestoreButton().click(); + +// cy.wait('@modelRestored').then((interception) => { +// expect(interception.request.body).to.eql({ +// state: 'LIVE', +// }); +// }); +// }); +// }); + +// describe('Archiving model', () => { +// it('Archive model from registered models table', () => { +// cy.interceptApi( +// 'PATCH /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId', +// { +// path: { +// serviceName: 'modelregistry-sample', +// apiVersion: MODEL_REGISTRY_API_VERSION, +// registeredModelId: 3, +// }, +// }, +// mockRegisteredModel({ id: '3', name: 'model 3', state: ModelState.ARCHIVED }), +// ).as('modelArchived'); + +// initIntercepts({}); +// registeredModelArchive.visitModelList(); + +// const modelRow = modelRegistry.getRow('model 3'); +// modelRow.findKebabAction('Archive model').click(); +// archiveModelModal.findArchiveButton().should('be.disabled'); +// archiveModelModal.findModalTextInput().fill('model 3'); +// archiveModelModal.findArchiveButton().should('be.enabled').click(); +// cy.wait('@modelArchived').then((interception) => { +// expect(interception.request.body).to.eql({ +// state: 'ARCHIVED', +// }); +// }); +// }); + +// it('Archive model from model details', () => { +// cy.interceptApi( +// 'PATCH /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId', +// { +// path: { +// serviceName: 'modelregistry-sample', +// apiVersion: MODEL_REGISTRY_API_VERSION, +// registeredModelId: 3, +// }, +// }, +// mockRegisteredModel({ id: '3', name: 'model 3', state: ModelState.ARCHIVED }), +// ).as('modelArchived'); + +// initIntercepts({}); +// registeredModelArchive.visitModelList(); + +// const modelRow = modelRegistry.getRow('model 3'); +// modelRow.findName().contains('model 3').click(); +// registeredModelArchive +// .findModelVersionsDetailsHeaderAction() +// .findDropdownItem('Archive model') +// .click(); + +// archiveModelModal.findArchiveButton().should('be.disabled'); +// archiveModelModal.findModalTextInput().fill('model 3'); +// archiveModelModal.findArchiveButton().should('be.enabled').click(); +// cy.wait('@modelArchived').then((interception) => { +// expect(interception.request.body).to.eql({ +// state: 'ARCHIVED', +// }); +// }); +// }); +//}); diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx index c7b78d10..c64a58c7 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx @@ -6,6 +6,12 @@ import { modelRegistryUrl } from './screens/routeUtils'; import RegisteredModelsArchive from './screens/RegisteredModelsArchive/RegisteredModelsArchive'; import { ModelVersionsTab } from './screens/ModelVersions/const'; import ModelVersions from './screens/ModelVersions/ModelVersions'; +import { ModelVersionDetailsTab } from './screens/ModelVersionDetails/const'; +import ModelVersionsDetails from './screens/ModelVersionDetails/ModelVersionDetails'; +import ModelVersionsArchive from './screens/ModelVersionsArchive/ModelVersionsArchive'; +import ModelVersionsArchiveDetails from './screens/ModelVersionsArchive/ModelVersionArchiveDetails'; +import ArchiveModelVersionDetails from './screens/ModelVersionsArchive/ArchiveModelVersionDetails'; +import RegisteredModelsArchiveDetails from './screens/RegisteredModelsArchive/RegisteredModelArchiveDetails'; const ModelRegistryRoutes: React.FC = () => ( @@ -28,10 +34,58 @@ const ModelRegistryRoutes: React.FC = () => ( path={ModelVersionsTab.DETAILS} element={} /> + + } /> + } + /> + } /> + + + } /> + + } /> + + } + /> + } /> + + } /> + } /> } /> + + } /> + + } + /> + + } + /> + + } /> + + } + /> + } /> + + } /> + } /> } /> diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup.tsx index aeb98464..8416f393 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup.tsx @@ -11,11 +11,13 @@ import ModelPropertiesTableRow from '~/app/pages/modelRegistry/screens/ModelProp type ModelPropertiesDescriptionListGroupProps = { customProperties: ModelRegistryCustomProperties; + isArchive?: boolean; saveEditedCustomProperties: (properties: ModelRegistryCustomProperties) => Promise; }; const ModelPropertiesDescriptionListGroup: React.FC = ({ customProperties = {}, + isArchive, saveEditedCustomProperties, }) => { const [editingPropertyKeys, setEditingPropertyKeys] = React.useState([]); @@ -51,16 +53,18 @@ const ModelPropertiesDescriptionListGroup: React.FC} - iconPosition="start" - isDisabled={isAdding || isSavingEdits} - onClick={() => setIsAdding(true)} - > - Add property - + !isArchive && ( + + ) } isEmpty={!isAdding && keys.length === 0} contentWhenEmpty="No properties" @@ -70,13 +74,14 @@ const ModelPropertiesDescriptionListGroup: React.FC - {shownKeys.map((key) => ( void; isSavingEdits: boolean; + isArchive?: boolean; setIsSavingEdits: (isSaving: boolean) => void; saveEditedProperty: (oldKey: string, newPair: KeyValuePair) => Promise; } & EitherNotBoth< @@ -38,6 +39,7 @@ const ModelPropertiesTableRow: React.FC = ({ setIsEditing, isSavingEdits, setIsSavingEdits, + isArchive, saveEditedProperty, }) => { const { key, value } = keyValuePair; @@ -143,43 +145,45 @@ const ModelPropertiesTableRow: React.FC = ({ )} - + {!isArchive && ( + + )} ); }; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetails.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetails.tsx new file mode 100644 index 00000000..fe60a7b6 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetails.tsx @@ -0,0 +1,108 @@ +import React, { useEffect } from 'react'; +import { useNavigate, useParams } from 'react-router'; +import { Breadcrumb, BreadcrumbItem, Flex, FlexItem, Truncate } from '@patternfly/react-core'; +import { Link } from 'react-router-dom'; +import ApplicationsPage from '~/app/components/ApplicationsPage'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import useRegisteredModelById from '~/app/hooks/useRegisteredModelById'; +import useModelVersionById from '~/app/hooks/useModelVersionById'; +import { ModelState } from '~/app/types'; +import { + archiveModelVersionDetailsUrl, + modelVersionArchiveDetailsUrl, + modelVersionUrl, + registeredModelUrl, +} from '~/app/pages/modelRegistry/screens/routeUtils'; +import { ModelVersionDetailsTab } from './const'; +import ModelVersionSelector from './ModelVersionSelector'; +import ModelVersionDetailsTabs from './ModelVersionDetailsTabs'; +import ModelVersionsDetailsHeaderActions from './ModelVersionDetailsHeaderActions'; + +type ModelVersionsDetailProps = { + tab: ModelVersionDetailsTab; +} & Omit< + React.ComponentProps, + 'breadcrumb' | 'title' | 'description' | 'loadError' | 'loaded' | 'provideChildrenPadding' +>; + +const ModelVersionsDetails: React.FC = ({ tab, ...pageProps }) => { + const navigate = useNavigate(); + + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + + const { modelVersionId: mvId, registeredModelId: rmId } = useParams(); + const [rm] = useRegisteredModelById(rmId); + const [mv, mvLoaded, mvLoadError, refreshModelVersion] = useModelVersionById(mvId); + + const refresh = React.useCallback(() => { + refreshModelVersion(); + }, [refreshModelVersion]); + + useEffect(() => { + if (rm?.state === ModelState.ARCHIVED && mv?.id) { + navigate( + archiveModelVersionDetailsUrl(mv.id, mv.registeredModelId, preferredModelRegistry?.name), + ); + } else if (mv?.state === ModelState.ARCHIVED) { + navigate( + modelVersionArchiveDetailsUrl(mv.id, mv.registeredModelId, preferredModelRegistry?.name), + ); + } + }, [rm?.state, mv?.id, mv?.state, mv?.registeredModelId, preferredModelRegistry?.name, navigate]); + + return ( + + ( + Model registry - {preferredModelRegistry?.name} + )} + /> + ( + + {rm?.name || 'Loading...'} + + )} + /> + + {mv?.name || 'Loading...'} + + + } + title={mv?.name} + headerAction={ + mvLoaded && + mv && ( + + + + navigate(modelVersionUrl(modelVersionId, rmId, preferredModelRegistry?.name)) + } + /> + + + + + + ) + } + description={} + loadError={mvLoadError} + loaded={mvLoaded} + provideChildrenPadding + > + {mv !== null && } + + ); +}; + +export default ModelVersionsDetails; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsHeaderActions.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsHeaderActions.tsx new file mode 100644 index 00000000..e969c36f --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsHeaderActions.tsx @@ -0,0 +1,86 @@ +import * as React from 'react'; +import { Dropdown, DropdownList, MenuToggle, DropdownItem } from '@patternfly/react-core'; +import { useNavigate } from 'react-router'; +import { ModelState, ModelVersion } from '~/app/types'; +import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import { ArchiveModelVersionModal } from '~/app/pages/modelRegistry/screens/components/ArchiveModelVersionModal'; +import { modelVersionArchiveDetailsUrl } from '~/app/pages/modelRegistry/screens/routeUtils'; + +interface ModelVersionsDetailsHeaderActionsProps { + mv: ModelVersion; + refresh: () => void; +} + +const ModelVersionsDetailsHeaderActions: React.FC = ({ + mv, +}) => { + const { apiState } = React.useContext(ModelRegistryContext); + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + + const navigate = useNavigate(); + const [isOpenActionDropdown, setOpenActionDropdown] = React.useState(false); + const [isArchiveModalOpen, setIsArchiveModalOpen] = React.useState(false); + const tooltipRef = React.useRef(null); + + return ( + <> + setOpenActionDropdown(false)} + onOpenChange={(open) => setOpenActionDropdown(open)} + popperProps={{ position: 'right' }} + toggle={(toggleRef) => ( + setOpenActionDropdown(!isOpenActionDropdown)} + isExpanded={isOpenActionDropdown} + aria-label="Model version details action toggle" + data-testid="model-version-details-action-button" + > + Actions + + )} + > + + setIsArchiveModalOpen(true)} + ref={tooltipRef} + > + Archive version + + + + setIsArchiveModalOpen(false)} + onSubmit={() => + apiState.api + .patchModelVersion( + {}, + { + state: ModelState.ARCHIVED, + }, + mv.id, + ) + .then(() => + navigate( + modelVersionArchiveDetailsUrl( + mv.id, + mv.registeredModelId, + preferredModelRegistry?.name, + ), + ), + ) + } + isOpen={isArchiveModalOpen} + modelVersionName={mv.name} + /> + + ); +}; + +export default ModelVersionsDetailsHeaderActions; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsTabs.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsTabs.tsx new file mode 100644 index 00000000..2747ce08 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsTabs.tsx @@ -0,0 +1,52 @@ +import * as React from 'react'; +import { useNavigate } from 'react-router-dom'; +import { PageSection, Tab, Tabs, TabTitleText } from '@patternfly/react-core'; +import { ModelVersion } from '~/app/types'; +import { ModelVersionDetailsTabTitle, ModelVersionDetailsTab } from './const'; +import ModelVersionDetailsView from './ModelVersionDetailsView'; + +type ModelVersionDetailTabsProps = { + tab: ModelVersionDetailsTab; + modelVersion: ModelVersion; + isArchiveVersion?: boolean; + refresh: () => void; +}; + +const ModelVersionDetailsTabs: React.FC = ({ + tab, + modelVersion: mv, + isArchiveVersion, + refresh, +}) => { + const navigate = useNavigate(); + return ( + navigate(`../${eventKey}`, { relative: 'path' })} + > + {ModelVersionDetailsTabTitle.DETAILS}} + aria-label="Model versions details tab" + data-testid="model-versions-details-tab" + > + + + + + + ); +}; + +export default ModelVersionDetailsTabs; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsView.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsView.tsx new file mode 100644 index 00000000..c541ea4f --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsView.tsx @@ -0,0 +1,193 @@ +import * as React from 'react'; +import { DescriptionList, Flex, FlexItem, ContentVariants, Title } from '@patternfly/react-core'; +import DashboardDescriptionListGroup from '~/components/DashboardDescriptionListGroup'; +import EditableTextDescriptionListGroup from '~/components/EditableTextDescriptionListGroup'; +import EditableLabelsDescriptionListGroup from '~/components/EditableLabelsDescriptionListGroup'; +import { ModelVersion } from '~/app/types'; +import useModelArtifactsByVersionId from '~/app/hooks/useModelArtifactsByVersionId'; +import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; +import InlineTruncatedClipboardCopy from '~/components/InlineTruncatedClipboardCopy'; +import DashboardHelpTooltip from '~/components/DashboardHelpTooltip'; +import { + getLabels, + mergeUpdatedLabels, + uriToObjectStorageFields, +} from '~/app/pages/modelRegistry/screens/utils'; +import ModelPropertiesDescriptionListGroup from '~/app/pages/modelRegistry/screens/ModelPropertiesDescriptionListGroup'; +import ModelTimestamp from '~/app/pages/modelRegistry/screens/components/ModelTimestamp'; + +type ModelVersionDetailsViewProps = { + modelVersion: ModelVersion; + isArchiveVersion?: boolean; + refresh: () => void; +}; + +const ModelVersionDetailsView: React.FC = ({ + modelVersion: mv, + isArchiveVersion, + refresh, +}) => { + const [modelArtifact] = useModelArtifactsByVersionId(mv.id); + const { apiState } = React.useContext(ModelRegistryContext); + const storageFields = uriToObjectStorageFields(modelArtifact.items[0]?.uri || ''); + + return ( + + + + + apiState.api + .patchModelVersion( + {}, + { + description: value, + }, + mv.id, + ) + .then(refresh) + } + /> + + apiState.api + .patchModelVersion( + {}, + { + customProperties: mergeUpdatedLabels(mv.customProperties, editedLabels), + }, + mv.id, + ) + .then(refresh) + } + /> + + apiState.api + .patchModelVersion({}, { customProperties: editedProperties }, mv.id) + .then(refresh) + } + /> + + + + + + + + + + Model location + + + {storageFields && ( + <> + + + + + + + + + + + + + + )} + {!storageFields && ( + <> + + + + + )} + + {modelArtifact.items[0]?.modelFormatName} + + + } + > + {mv.author} + + + + + + + + + + + ); +}; +export default ModelVersionDetailsView; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionSelector.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionSelector.tsx new file mode 100644 index 00000000..119b9a84 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionSelector.tsx @@ -0,0 +1,111 @@ +import * as React from 'react'; +import { + HelperText, + HelperTextItem, + Menu, + MenuContainer, + MenuContent, + MenuItem, + MenuList, + MenuSearch, + MenuSearchInput, + MenuToggle, + SearchInput, +} from '@patternfly/react-core'; +import { ModelVersion } from '~/app/types'; +import useModelVersionsByRegisteredModel from '~/app/hooks/useModelVersionsByRegisteredModel'; + +type ModelVersionSelectorProps = { + rmId?: string; + selection: ModelVersion; + onSelect: (versionId: string) => void; +}; + +const ModelVersionSelector: React.FC = ({ + rmId, + selection, + onSelect, +}) => { + const [isOpen, setOpen] = React.useState(false); + const [input, setInput] = React.useState(''); + + const toggleRef = React.useRef(null); + const menuRef = React.useRef(null); + + const [modelVersions] = useModelVersionsByRegisteredModel(rmId); + + const menuListItems = modelVersions.items + .filter((item) => !input || item.name.toLowerCase().includes(input.toString().toLowerCase())) + .map((mv, index) => ( + + {mv.name} + + )); + + if (input && modelVersions.size === 0) { + menuListItems.push( + + No results found + , + ); + } + + const menu = ( + { + if (typeof itemId === 'string') { + onSelect(itemId); + setOpen(false); + } + }} + data-id="model-version-selector-menu" + ref={menuRef} + isScrollable + activeItemId={selection.id} + > + + + + setInput(value)} + /> + + + + {`Type a name to search your ${modelVersions.size} versions.`} + + + + {menuListItems} + + + ); + + return ( + setOpen(!isOpen)} + isExpanded={isOpen} + isFullWidth + data-testid="model-version-toggle-button" + > + {selection.name} + + } + menu={menu} + menuRef={menuRef} + popperProps={{ maxWidth: 'trigger' }} + onOpenChange={(open) => setOpen(open)} + /> + ); +}; + +export default ModelVersionSelector; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/const.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/const.ts new file mode 100644 index 00000000..ded505d1 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/const.ts @@ -0,0 +1,7 @@ +export enum ModelVersionDetailsTab { + DETAILS = 'details', +} + +export enum ModelVersionDetailsTabTitle { + DETAILS = 'Details', +} diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsView.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsView.tsx index ec145465..f0360589 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsView.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsView.tsx @@ -12,9 +12,14 @@ import ModelTimestamp from '~/app/pages/modelRegistry/screens/components/ModelTi type ModelDetailsViewProps = { registeredModel: RegisteredModel; refresh: () => void; + isArchiveModel?: boolean; }; -const ModelDetailsView: React.FC = ({ registeredModel: rm, refresh }) => { +const ModelDetailsView: React.FC = ({ + registeredModel: rm, + refresh, + isArchiveModel, +}) => { const { apiState } = React.useContext(ModelRegistryContext); return ( = ({ registeredModel: rm @@ -42,6 +48,7 @@ const ModelDetailsView: React.FC = ({ registeredModel: rm /> apiState.api @@ -56,6 +63,7 @@ const ModelDetailsView: React.FC = ({ registeredModel: rm } /> apiState.api diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx index 0f9db46c..4373e935 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx @@ -1,5 +1,6 @@ import * as React from 'react'; import { + Alert, Button, Dropdown, DropdownItem, @@ -35,12 +36,14 @@ import SimpleSelect from '~/app/components/SimpleSelect'; type ModelVersionListViewProps = { modelVersions: ModelVersion[]; registeredModel?: RegisteredModel; + isArchiveModel?: boolean; refresh: () => void; }; const ModelVersionListView: React.FC = ({ modelVersions: unfilteredModelVersions, registeredModel: rm, + isArchiveModel, refresh, }) => { const navigate = useNavigate(); @@ -55,8 +58,24 @@ const ModelVersionListView: React.FC = ({ React.useState(false); const filteredModelVersions = filterModelVersions(unfilteredModelVersions, search, searchType); + const date = rm?.lastUpdateTimeSinceEpoch && new Date(parseInt(rm.lastUpdateTimeSinceEpoch)); if (unfilteredModelVersions.length === 0) { + if (isArchiveModel) { + return ( + ( + missing version + )} + description={`${rm?.name} has no registered versions.`} + /> + ); + } return ( = ({ } return ( - setSearch('')} - modelVersions={sortModelVersionsByCreateTime(filteredModelVersions)} - toolbarContent={ - - } breakpoint="xl"> - - setSearch('')} - deleteLabelGroup={() => setSearch('')} - categoryName={searchType} - > - ({ - key, - label: key, - }))} - value={searchType} - onChange={(newSearchType) => { - const enumMember = asEnumMember(newSearchType, SearchType); - if (enumMember !== null) { - setSearchType(enumMember); - } - }} - icon={} - /> - - - { - setSearch(searchValue); - }} - onClear={() => setSearch('')} - style={{ minWidth: '200px' }} - data-testid="model-versions-table-search" - /> - - - - - - - - setIsArchivedModelVersionKebabOpen(false)} - onOpenChange={(isOpen: boolean) => setIsArchivedModelVersionKebabOpen(isOpen)} - toggle={(tr: React.Ref) => ( - - setIsArchivedModelVersionKebabOpen(!isArchivedModelVersionKebabOpen) - } - isExpanded={isArchivedModelVersionKebabOpen} - aria-label="View archived versions" + <> + {isArchiveModel && ( + + )} + setSearch('')} + modelVersions={sortModelVersionsByCreateTime(filteredModelVersions)} + toolbarContent={ + + } breakpoint="xl"> + + setSearch('')} + deleteLabelGroup={() => setSearch('')} + categoryName={searchType} > - - - )} - shouldFocusToggleOnSelect - > - - - navigate(modelVersionArchiveUrl(rm?.id, preferredModelRegistry?.name)) - } - > - View archived versions - - - - - - } - /> + ({ + key, + label: key, + }))} + value={searchType} + onChange={(newSearchType) => { + const enumMember = asEnumMember(newSearchType, SearchType); + if (enumMember !== null) { + setSearchType(enumMember); + } + }} + icon={} + /> + + + { + setSearch(searchValue); + }} + onClear={() => setSearch('')} + style={{ minWidth: '200px' }} + data-testid="model-versions-table-search" + /> + + + + {!isArchiveModel && ( + <> + + + + + setIsArchivedModelVersionKebabOpen(false)} + onOpenChange={(isOpen: boolean) => setIsArchivedModelVersionKebabOpen(isOpen)} + toggle={(tr: React.Ref) => ( + + setIsArchivedModelVersionKebabOpen(!isArchivedModelVersionKebabOpen) + } + isExpanded={isArchivedModelVersionKebabOpen} + aria-label="View archived versions" + > + + + )} + shouldFocusToggleOnSelect + > + + + navigate(modelVersionArchiveUrl(rm?.id, preferredModelRegistry?.name)) + } + > + View archived versions + + + + + + )} + + } + /> + ); }; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersions.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersions.tsx index 9e471aca..1e1d4872 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersions.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersions.tsx @@ -1,5 +1,5 @@ -import React from 'react'; -import { useParams } from 'react-router'; +import React, { useEffect } from 'react'; +import { useNavigate, useParams } from 'react-router'; import { Breadcrumb, BreadcrumbItem, Truncate } from '@patternfly/react-core'; import { Link } from 'react-router-dom'; import { ModelVersionsTab } from '~/app/pages/modelRegistry/screens/ModelVersions/const'; @@ -9,6 +9,8 @@ import useRegisteredModelById from '~/app/hooks/useRegisteredModelById'; import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; import { filterLiveVersions } from '~/app/pages/modelRegistry/screens/utils'; import ModelVersionsHeaderActions from '~/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsHeaderActions'; +import { ModelState } from '~/app/types'; +import { registeredModelArchiveDetailsUrl } from '~/app/pages/modelRegistry/screens/routeUtils'; import ModelVersionsTabs from './ModelVersionsTabs'; type ModelVersionsProps = { @@ -25,6 +27,13 @@ const ModelVersions: React.FC = ({ tab, ...pageProps }) => { const [rm, rmLoaded, rmLoadError, rmRefresh] = useRegisteredModelById(rmId); const loadError = mvLoadError || rmLoadError; const loaded = mvLoaded && rmLoaded; + const navigate = useNavigate(); + + useEffect(() => { + if (rm?.state === ModelState.ARCHIVED) { + navigate(registeredModelArchiveDetailsUrl(rm.id, preferredModelRegistry?.name)); + } + }, [rm?.state, rm?.id, preferredModelRegistry?.name, navigate]); return ( void; modelVersions: ModelVersion[]; + isArchiveModel?: boolean; refresh: () => void; } & Partial, 'toolbarContent'>>; @@ -15,6 +16,7 @@ const ModelVersionsTable: React.FC = ({ clearFilters, modelVersions, toolbarContent, + isArchiveModel, refresh, }) => (
+ {mr.displayName || mr.name} + {mr.description &&

{mr.description}

} +
Key {isEditingSomeRow && requiredAsterisk} Value {isEditingSomeRow && requiredAsterisk} +
- {isEditing ? ( - - - - - - - - - ) : ( - - )} - + {isEditing ? ( + + + + + + + + + ) : ( + + )} +
= ({ enablePagination emptyTableView={} rowRenderer={(mv) => ( - + )} /> ); diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx index 55a4d85b..743adc4c 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx @@ -1,11 +1,12 @@ import * as React from 'react'; -import { ActionsColumn, Td, Tr } from '@patternfly/react-table'; +import { ActionsColumn, IAction, Td, Tr } from '@patternfly/react-table'; import { Content, ContentVariants, Truncate, FlexItem } from '@patternfly/react-core'; import { Link, useNavigate } from 'react-router-dom'; import { ModelState, ModelVersion } from '~/app/types'; import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; import { + archiveModelVersionDetailsUrl, modelVersionArchiveDetailsUrl, modelVersionUrl, } from '~/app/pages/modelRegistry/screens/routeUtils'; @@ -17,12 +18,14 @@ import { RestoreModelVersionModal } from '~/app/pages/modelRegistry/screens/comp type ModelVersionsTableRowProps = { modelVersion: ModelVersion; isArchiveRow?: boolean; + isArchiveModel?: boolean; refresh: () => void; }; const ModelVersionsTableRow: React.FC = ({ modelVersion: mv, isArchiveRow, + isArchiveModel, refresh, }) => { const navigate = useNavigate(); @@ -31,7 +34,7 @@ const ModelVersionsTableRow: React.FC = ({ const [isRestoreModalOpen, setIsRestoreModalOpen] = React.useState(false); const { apiState } = React.useContext(ModelRegistryContext); - const actions = isArchiveRow + const actions: IAction[] = isArchiveRow ? [ { title: 'Restore version', @@ -39,10 +42,6 @@ const ModelVersionsTableRow: React.FC = ({ }, ] : [ - { - title: 'Deploy', - onClick: () => setIsDeployModalOpen(true), - }, { title: 'Archive model version', onClick: () => setIsArchiveModalOpen(true), @@ -56,13 +55,19 @@ const ModelVersionsTableRow: React.FC = ({ @@ -82,45 +87,47 @@ const ModelVersionsTableRow: React.FC = ({ - + {!isArchiveModel && ( + + )} ); }; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTabs.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTabs.tsx index 7460e663..2bf57bd0 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTabs.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTabs.tsx @@ -13,6 +13,7 @@ type ModelVersionsTabProps = { tab: ModelVersionsTab; registeredModel: RegisteredModel; modelVersions: ModelVersion[]; + isArchiveModel?: boolean; refresh: () => void; mvRefresh: () => void; }; @@ -22,6 +23,7 @@ const ModelVersionsTabs: React.FC = ({ registeredModel: rm, modelVersions, refresh, + isArchiveModel, mvRefresh, }) => { const navigate = useNavigate(); @@ -41,6 +43,7 @@ const ModelVersionsTabs: React.FC = ({ > = ({ data-testid="model-details-tab" > - + diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ArchiveModelVersionDetails.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ArchiveModelVersionDetails.tsx new file mode 100644 index 00000000..8a595c47 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ArchiveModelVersionDetails.tsx @@ -0,0 +1,84 @@ +import React, { useEffect } from 'react'; +import { useNavigate, useParams } from 'react-router'; +import { Button, Flex, FlexItem, Label, Content, Tooltip, Truncate } from '@patternfly/react-core'; + +import ApplicationsPage from '~/app/components/ApplicationsPage'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import useRegisteredModelById from '~/app/hooks/useRegisteredModelById'; +import useModelVersionById from '~/app/hooks/useModelVersionById'; +import { ModelState } from '~/app/types'; +import { modelVersionUrl } from '~/app/pages/modelRegistry/screens/routeUtils'; +import ModelVersionDetailsTabs from '~/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsTabs'; +import { ModelVersionDetailsTab } from '~/app/pages/modelRegistry/screens/ModelVersionDetails/const'; +import ArchiveModelVersionDetailsBreadcrumb from './ArchiveModelVersionDetailsBreadcrumb'; + +type ArchiveModelVersionDetailsProps = { + tab: ModelVersionDetailsTab; +} & Omit< + React.ComponentProps, + 'breadcrumb' | 'title' | 'description' | 'loadError' | 'loaded' | 'provideChildrenPadding' +>; + +const ArchiveModelVersionDetails: React.FC = ({ + tab, + ...pageProps +}) => { + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + const { modelVersionId: mvId, registeredModelId: rmId } = useParams(); + const [rm] = useRegisteredModelById(rmId); + const [mv, mvLoaded, mvLoadError, refreshModelVersion] = useModelVersionById(mvId); + const navigate = useNavigate(); + + useEffect(() => { + if (rm?.state === ModelState.LIVE && mv?.id) { + navigate(modelVersionUrl(mv.id, mv.registeredModelId, preferredModelRegistry?.name)); + } + }, [rm?.state, mv?.id, mv?.registeredModelId, preferredModelRegistry?.name, navigate]); + + return ( + + } + title={ + mv && ( + + + {mv.name} + + + + + + ) + } + headerAction={ + + + + } + description={} + loadError={mvLoadError} + loaded={mvLoaded} + provideChildrenPadding + > + {mv !== null && ( + + )} + + ); +}; + +export default ArchiveModelVersionDetails; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ArchiveModelVersionDetailsBreadcrumb.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ArchiveModelVersionDetailsBreadcrumb.tsx new file mode 100644 index 00000000..8356589b --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ArchiveModelVersionDetailsBreadcrumb.tsx @@ -0,0 +1,41 @@ +import { Breadcrumb, BreadcrumbItem } from '@patternfly/react-core'; +import React from 'react'; +import { Link } from 'react-router-dom'; +import { RegisteredModel } from '~/app/types'; +import { + registeredModelArchiveDetailsUrl, + registeredModelArchiveUrl, +} from '~/app/pages/modelRegistry/screens/routeUtils'; + +type ArchiveModelVersionDetailsBreadcrumbProps = { + preferredModelRegistry?: string; + registeredModel: RegisteredModel | null; + modelVersionName?: string; +}; + +const ArchiveModelVersionDetailsBreadcrumb: React.FC = ({ + preferredModelRegistry, + registeredModel, + modelVersionName, +}) => ( + + Model registry - {preferredModelRegistry}} + /> + ( + Archived models + )} + /> + ( + + {registeredModel?.name || 'Loading...'} + + )} + /> + {modelVersionName || 'Loading...'} + +); + +export default ArchiveModelVersionDetailsBreadcrumb; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionArchiveDetails.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionArchiveDetails.tsx new file mode 100644 index 00000000..fdb1f80f --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionArchiveDetails.tsx @@ -0,0 +1,114 @@ +import React, { useEffect } from 'react'; +import { useNavigate, useParams } from 'react-router'; +import { Button, Flex, FlexItem, Label, Content, Truncate } from '@patternfly/react-core'; +import ApplicationsPage from '~/app/components/ApplicationsPage'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; +import useRegisteredModelById from '~/app/hooks/useRegisteredModelById'; +import useModelVersionById from '~/app/hooks/useModelVersionById'; +import { ModelState } from '~/app/types'; +import { + archiveModelVersionDetailsUrl, + modelVersionUrl, +} from '~/app/pages/modelRegistry/screens/routeUtils'; +import ModelVersionDetailsTabs from '~/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsTabs'; +import { RestoreModelVersionModal } from '~/app/pages/modelRegistry/screens/components/RestoreModelVersionModal'; +import { ModelVersionDetailsTab } from '~/app/pages/modelRegistry/screens/ModelVersionDetails/const'; +import ModelVersionArchiveDetailsBreadcrumb from './ModelVersionArchiveDetailsBreadcrumb'; + +type ModelVersionsArchiveDetailsProps = { + tab: ModelVersionDetailsTab; +} & Omit< + React.ComponentProps, + 'breadcrumb' | 'title' | 'description' | 'loadError' | 'loaded' | 'provideChildrenPadding' +>; + +const ModelVersionsArchiveDetails: React.FC = ({ + tab, + ...pageProps +}) => { + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + const { apiState } = React.useContext(ModelRegistryContext); + + const navigate = useNavigate(); + + const { modelVersionId: mvId, registeredModelId: rmId } = useParams(); + const [rm] = useRegisteredModelById(rmId); + const [mv, mvLoaded, mvLoadError, refreshModelVersion] = useModelVersionById(mvId); + const [isRestoreModalOpen, setIsRestoreModalOpen] = React.useState(false); + + useEffect(() => { + if (rm?.state === ModelState.ARCHIVED && mv?.id) { + navigate( + archiveModelVersionDetailsUrl(mv.id, mv.registeredModelId, preferredModelRegistry?.name), + ); + } else if (mv?.state === ModelState.LIVE) { + navigate(modelVersionUrl(mv.id, mv.registeredModelId, preferredModelRegistry?.name)); + } + }, [rm?.state, mv?.state, mv?.id, mv?.registeredModelId, preferredModelRegistry?.name, navigate]); + + return ( + <> + + } + title={ + mv && ( + + + {mv.name} + + + + + + ) + } + headerAction={ + + } + description={} + loadError={mvLoadError} + loaded={mvLoaded} + provideChildrenPadding + > + {mv !== null && ( + + )} + + {mv !== null && ( + setIsRestoreModalOpen(false)} + onSubmit={() => + apiState.api + .patchModelVersion( + {}, + { + state: ModelState.LIVE, + }, + mv.id, + ) + .then(() => navigate(modelVersionUrl(mv.id, rm?.id, preferredModelRegistry?.name))) + } + isOpen={isRestoreModalOpen} + modelVersionName={mv.name} + /> + )} + + ); +}; + +export default ModelVersionsArchiveDetails; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionArchiveDetailsBreadcrumb.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionArchiveDetailsBreadcrumb.tsx new file mode 100644 index 00000000..c706fef0 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionArchiveDetailsBreadcrumb.tsx @@ -0,0 +1,43 @@ +import { Breadcrumb, BreadcrumbItem } from '@patternfly/react-core'; +import React from 'react'; +import { Link } from 'react-router-dom'; +import { RegisteredModel } from '~/app/types'; +import { + modelVersionArchiveUrl, + registeredModelUrl, +} from '~/app/pages/modelRegistry/screens/routeUtils'; + +type ModelVersionArchiveDetailsBreadcrumbProps = { + preferredModelRegistry?: string; + registeredModel: RegisteredModel | null; + modelVersionName?: string; +}; + +const ModelVersionArchiveDetailsBreadcrumb: React.FC = ({ + preferredModelRegistry, + registeredModel, + modelVersionName, +}) => ( + + Model registry - {preferredModelRegistry}} + /> + ( + + {registeredModel?.name || 'Loading...'} + + )} + /> + ( + + Archived versions + + )} + /> + {modelVersionName || 'Loading...'} + +); + +export default ModelVersionArchiveDetailsBreadcrumb; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchive.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchive.tsx new file mode 100644 index 00000000..2245187a --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchive.tsx @@ -0,0 +1,60 @@ +import React from 'react'; +import { useParams } from 'react-router'; +import { Breadcrumb, BreadcrumbItem } from '@patternfly/react-core'; +import { Link } from 'react-router-dom'; +import ApplicationsPage from '~/app/components/ApplicationsPage'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import useRegisteredModelById from '~/app/hooks/useRegisteredModelById'; +import useModelVersionsByRegisteredModel from '~/app/hooks/useModelVersionsByRegisteredModel'; +import { registeredModelUrl } from '~/app/pages/modelRegistry/screens/routeUtils'; +import { filterArchiveVersions } from '~/app/pages/modelRegistry/screens/utils'; +import ModelVersionsArchiveListView from './ModelVersionsArchiveListView'; + +type ModelVersionsArchiveProps = Omit< + React.ComponentProps, + 'breadcrumb' | 'title' | 'description' | 'loadError' | 'loaded' | 'provideChildrenPadding' +>; + +const ModelVersionsArchive: React.FC = ({ ...pageProps }) => { + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + + const { registeredModelId: rmId } = useParams(); + const [rm] = useRegisteredModelById(rmId); + const [modelVersions, mvLoaded, mvLoadError, refresh] = useModelVersionsByRegisteredModel(rmId); + + return ( + + ( + Model registry - {preferredModelRegistry?.name} + )} + /> + ( + + {rm?.name || 'Loading...'} + + )} + /> + + Archived versions + + + } + title={rm ? `Archived versions of ${rm.name}` : 'Archived versions'} + loadError={mvLoadError} + loaded={mvLoaded} + provideChildrenPadding + > + + + ); +}; + +export default ModelVersionsArchive; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchiveListView.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchiveListView.tsx new file mode 100644 index 00000000..b8e419e4 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchiveListView.tsx @@ -0,0 +1,95 @@ +import * as React from 'react'; +import { + SearchInput, + ToolbarContent, + ToolbarFilter, + ToolbarGroup, + ToolbarItem, + ToolbarToggleGroup, +} from '@patternfly/react-core'; +import { FilterIcon } from '@patternfly/react-icons'; +import { ModelVersion } from '~/app/types'; +import { SearchType } from '~/app/components/DashboardSearchField'; +import SimpleSelect from '~/app/components/SimpleSelect'; +import { asEnumMember } from '~/app/utils'; +import { filterModelVersions } from '~/app/pages/modelRegistry/screens/utils'; +import EmptyModelRegistryState from '~/app/pages/modelRegistry/screens/components/EmptyModelRegistryState'; +import ModelVersionsArchiveTable from './ModelVersionsArchiveTable'; + +type ModelVersionsArchiveListViewProps = { + modelVersions: ModelVersion[]; + refresh: () => void; +}; + +const ModelVersionsArchiveListView: React.FC = ({ + modelVersions: unfilteredmodelVersions, + refresh, +}) => { + const [searchType, setSearchType] = React.useState(SearchType.KEYWORD); + const [search, setSearch] = React.useState(''); + + const searchTypes = [SearchType.KEYWORD, SearchType.AUTHOR]; + + const filteredModelVersions = filterModelVersions(unfilteredmodelVersions, search, searchType); + + if (unfilteredmodelVersions.length === 0) { + return ( + + ); + } + + return ( + setSearch('')} + modelVersions={filteredModelVersions} + toolbarContent={ + + } breakpoint="xl"> + + setSearch('')} + deleteLabelGroup={() => setSearch('')} + categoryName="Keyword" + > + ({ + key, + label: key, + }))} + value={searchType} + onChange={(newSearchType) => { + const enumMember = asEnumMember(newSearchType, SearchType); + if (enumMember) { + setSearchType(enumMember); + } + }} + icon={} + /> + + + { + setSearch(searchValue); + }} + onClear={() => setSearch('')} + style={{ minWidth: '200px' }} + data-testid="model-versions-archive-table-search" + /> + + + + + } + /> + ); +}; + +export default ModelVersionsArchiveListView; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchiveTable.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchiveTable.tsx new file mode 100644 index 00000000..c611ede0 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionsArchive/ModelVersionsArchiveTable.tsx @@ -0,0 +1,34 @@ +import * as React from 'react'; +import { Table } from '~/app/components/table'; +import { ModelVersion } from '~/app/types'; +import DashboardEmptyTableView from '~/app/components/DashboardEmptyTableView'; +import ModelVersionsTableRow from '~/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow'; +import { mvColumns } from '~/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableColumns'; + +type ModelVersionsArchiveTableProps = { + clearFilters: () => void; + modelVersions: ModelVersion[]; + refresh: () => void; +} & Partial, 'toolbarContent'>>; + +const ModelVersionsArchiveTable: React.FC = ({ + clearFilters, + modelVersions, + toolbarContent, + refresh, +}) => ( +
- - setIsArchiveModalOpen(false)} - onSubmit={() => - apiState.api - .patchModelVersion( - {}, - { - state: ModelState.ARCHIVED, - }, - mv.id, - ) - .then(refresh) - } - isOpen={isArchiveModalOpen} - modelVersionName={mv.name} - /> - setIsRestoreModalOpen(false)} - onSubmit={() => - apiState.api - .patchModelVersion( - {}, - { - state: ModelState.LIVE, - }, - mv.id, - ) - .then(() => - navigate( - modelVersionUrl(mv.id, mv.registeredModelId, preferredModelRegistry?.name), - ), - ) - } - isOpen={isRestoreModalOpen} - modelVersionName={mv.name} - /> - + + setIsArchiveModalOpen(false)} + onSubmit={() => + apiState.api + .patchModelVersion( + {}, + { + state: ModelState.ARCHIVED, + }, + mv.id, + ) + .then(refresh) + } + isOpen={isArchiveModalOpen} + modelVersionName={mv.name} + /> + setIsRestoreModalOpen(false)} + onSubmit={() => + apiState.api + .patchModelVersion( + {}, + { + state: ModelState.LIVE, + }, + mv.id, + ) + .then(() => + navigate( + modelVersionUrl(mv.id, mv.registeredModelId, preferredModelRegistry?.name), + ), + ) + } + isOpen={isRestoreModalOpen} + modelVersionName={mv.name} + /> +
} + defaultSortColumn={1} + rowRenderer={(mv: ModelVersion) => ( + + )} + /> +); + +export default ModelVersionsArchiveTable; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModels/RegisteredModelListView.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModels/RegisteredModelListView.tsx index d2a6367d..52f676fe 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModels/RegisteredModelListView.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModels/RegisteredModelListView.tsx @@ -89,7 +89,7 @@ const RegisteredModelListView: React.FC = ({ icon={} /> - + = ({ const [isRestoreModalOpen, setIsRestoreModalOpen] = React.useState(false); const rmUrl = registeredModelUrl(rm.id, preferredModelRegistry?.name); - const actions = [ + const actions: IAction[] = [ { title: 'View details', - // eslint-disable-next-line @typescript-eslint/no-empty-function - onClick: () => navigate(`${rmUrl}/${ModelVersionsTab.DETAILS}`), + onClick: () => { + if (isArchiveRow) { + navigate( + `${registeredModelArchiveUrl(preferredModelRegistry?.name)}/${rm.id}/${ + ModelVersionsTab.DETAILS + }`, + ); + } else { + navigate(`${rmUrl}/${ModelVersionsTab.DETAILS}`); + } + }, }, isArchiveRow ? { diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelArchiveDetails.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelArchiveDetails.tsx new file mode 100644 index 00000000..711469e5 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelArchiveDetails.tsx @@ -0,0 +1,109 @@ +import React, { useEffect } from 'react'; +import { useNavigate, useParams } from 'react-router'; +import { Button, Flex, FlexItem, Label, Content, Truncate } from '@patternfly/react-core'; +import ApplicationsPage from '~/app/components/ApplicationsPage'; +import { ModelRegistrySelectorContext } from '~/app/context/ModelRegistrySelectorContext'; +import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; +import useRegisteredModelById from '~/app/hooks/useRegisteredModelById'; +import useModelVersionsByRegisteredModel from '~/app/hooks/useModelVersionsByRegisteredModel'; +import { ModelState } from '~/app/types'; +import { ModelVersionsTab } from '~/app/pages/modelRegistry/screens/ModelVersions/const'; +import ModelVersionsTabs from '~/app/pages/modelRegistry/screens/ModelVersions/ModelVersionsTabs'; +import { RestoreRegisteredModelModal } from '~/app/pages/modelRegistry/screens/components/RestoreRegisteredModel'; +import { registeredModelUrl } from '~/app/pages/modelRegistry/screens/routeUtils'; +import RegisteredModelArchiveDetailsBreadcrumb from './RegisteredModelArchiveDetailsBreadcrumb'; + +type RegisteredModelsArchiveDetailsProps = { + tab: ModelVersionsTab; +} & Omit< + React.ComponentProps, + 'breadcrumb' | 'title' | 'description' | 'loadError' | 'loaded' | 'provideChildrenPadding' +>; + +const RegisteredModelsArchiveDetails: React.FC = ({ + tab, + ...pageProps +}) => { + const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); + const { apiState } = React.useContext(ModelRegistryContext); + + const navigate = useNavigate(); + + const { registeredModelId: rmId } = useParams(); + const [rm, rmLoaded, rmLoadError, rmRefresh] = useRegisteredModelById(rmId); + const [modelVersions, mvLoaded, mvLoadError, refresh] = useModelVersionsByRegisteredModel(rmId); + const [isRestoreModalOpen, setIsRestoreModalOpen] = React.useState(false); + + useEffect(() => { + if (rm?.state === ModelState.LIVE) { + navigate(registeredModelUrl(rm.id, preferredModelRegistry?.name)); + } + }, [rm?.state, preferredModelRegistry?.name, rm?.id, navigate]); + + return ( + <> + + } + title={ + rm && ( + + + {rm.name} + + + + + + ) + } + headerAction={ + + } + description={} + loadError={rmLoadError} + loaded={rmLoaded} + provideChildrenPadding + > + {rm !== null && mvLoaded && !mvLoadError && ( + + )} + + + {rm !== null && ( + setIsRestoreModalOpen(false)} + onSubmit={() => + apiState.api + .patchRegisteredModel( + {}, + { + state: ModelState.LIVE, + }, + rm.id, + ) + .then(() => navigate(registeredModelUrl(rm.id, preferredModelRegistry?.name))) + } + isOpen={isRestoreModalOpen} + registeredModelName={rm.name} + /> + )} + + ); +}; + +export default RegisteredModelsArchiveDetails; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelArchiveDetailsBreadcrumb.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelArchiveDetailsBreadcrumb.tsx new file mode 100644 index 00000000..e72161ec --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelArchiveDetailsBreadcrumb.tsx @@ -0,0 +1,28 @@ +import { Breadcrumb, BreadcrumbItem } from '@patternfly/react-core'; +import React from 'react'; +import { Link } from 'react-router-dom'; +import { RegisteredModel } from '~/app/types'; +import { registeredModelArchiveUrl } from '~/app/pages/modelRegistry/screens/routeUtils'; + +type RegisteredModelArchiveDetailsBreadcrumbProps = { + preferredModelRegistry?: string; + registeredModel: RegisteredModel | null; +}; + +const RegisteredModelArchiveDetailsBreadcrumb: React.FC< + RegisteredModelArchiveDetailsBreadcrumbProps +> = ({ preferredModelRegistry, registeredModel }) => ( + + Model registry - {preferredModelRegistry}} + /> + ( + Archived models + )} + /> + {registeredModel?.name || 'Loading...'} + +); + +export default RegisteredModelArchiveDetailsBreadcrumb; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchiveListView.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchiveListView.tsx index cd419d03..d3a8acda 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchiveListView.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchiveListView.tsx @@ -77,7 +77,7 @@ const RegisteredModelsArchiveListView: React.FC} /> - + `${registeredModelUrl(rmId, preferredModelRegistry)}/registerVersion`; -export const modelVersionDeploymentsUrl = ( +export const archiveModelVersionListUrl = ( + rmId?: string, + preferredModelRegistry?: string, +): string => `${registeredModelArchiveDetailsUrl(rmId, preferredModelRegistry)}/versions`; + +export const archiveModelVersionDetailsUrl = ( mvId: string, rmId?: string, preferredModelRegistry?: string, -): string => `${modelVersionUrl(mvId, rmId, preferredModelRegistry)}/deployments`; +): string => `${archiveModelVersionListUrl(rmId, preferredModelRegistry)}/${mvId}`; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/utils.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/utils.ts index ff6d4ef4..6c04ab1d 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/utils.ts +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/utils.ts @@ -9,6 +9,13 @@ import { } from '~/app/types'; import { KeyValuePair } from '~/types'; +export type ObjectStorageFields = { + endpoint: string; + bucket: string; + region?: string; + path: string; +}; + // Retrieves the labels from customProperties that have non-empty string_value. export const getLabels = (customProperties: T): string[] => Object.keys(customProperties).filter((key) => { @@ -148,3 +155,23 @@ export const filterArchiveModels = (registeredModels: RegisteredModel[]): Regist export const filterLiveModels = (registeredModels: RegisteredModel[]): RegisteredModel[] => registeredModels.filter((rm) => rm.state === ModelState.LIVE); + +export const uriToObjectStorageFields = (uri: string): ObjectStorageFields | null => { + try { + const urlObj = new URL(uri); + // Some environments include the first token after the protocol (our bucket) in the pathname and some have it as the hostname + const [bucket, ...pathSplit] = `${urlObj.hostname}/${urlObj.pathname}` + .split('/') + .filter(Boolean); + const path = pathSplit.join('/'); + const searchParams = new URLSearchParams(urlObj.search); + const endpoint = searchParams.get('endpoint'); + const region = searchParams.get('defaultRegion'); + if (endpoint && bucket && path) { + return { endpoint, bucket, region: region || undefined, path }; + } + return null; + } catch { + return null; + } +}; diff --git a/clients/ui/frontend/src/components/DashboardDescriptionListGroup.tsx b/clients/ui/frontend/src/components/DashboardDescriptionListGroup.tsx index 4216c86b..7fb6ae63 100644 --- a/clients/ui/frontend/src/components/DashboardDescriptionListGroup.tsx +++ b/clients/ui/frontend/src/components/DashboardDescriptionListGroup.tsx @@ -110,7 +110,9 @@ const DashboardDescriptionListGroup: React.FC )} - + {/* The text color below is a hack for a11y. + PF6 team needs to update their disabled color to work for white backgrounds */} + {isEditing ? contentWhenEditing : isEmpty ? contentWhenEmpty : children} diff --git a/clients/ui/frontend/src/components/DashboardHelpTooltip.tsx b/clients/ui/frontend/src/components/DashboardHelpTooltip.tsx new file mode 100644 index 00000000..033d1d65 --- /dev/null +++ b/clients/ui/frontend/src/components/DashboardHelpTooltip.tsx @@ -0,0 +1,15 @@ +import * as React from 'react'; +import { Tooltip } from '@patternfly/react-core'; +import { OutlinedQuestionCircleIcon } from '@patternfly/react-icons'; + +type DashboardHelpTooltipProps = { + content: React.ReactNode; +}; + +const DashboardHelpTooltip: React.FC = ({ content }) => ( + + + +); + +export default DashboardHelpTooltip; diff --git a/clients/ui/frontend/src/components/EditableLabelsDescriptionListGroup.tsx b/clients/ui/frontend/src/components/EditableLabelsDescriptionListGroup.tsx index 42cc9091..095bfdd7 100644 --- a/clients/ui/frontend/src/components/EditableLabelsDescriptionListGroup.tsx +++ b/clients/ui/frontend/src/components/EditableLabelsDescriptionListGroup.tsx @@ -22,6 +22,7 @@ type EditableTextDescriptionListGroupProps = Partial< labels: string[]; saveEditedLabels: (labels: string[]) => Promise; allExistingKeys?: string[]; + isArchive?: boolean; }; const EditableLabelsDescriptionListGroup: React.FC = ({ @@ -29,6 +30,7 @@ const EditableLabelsDescriptionListGroup: React.FC { const [isEditing, setIsEditing] = React.useState(false); @@ -98,7 +100,7 @@ const EditableLabelsDescriptionListGroup: React.FC Promise; testid?: string; + isArchive?: boolean; }; const EditableTextDescriptionListGroup: React.FC = ({ title, contentWhenEmpty, value, + isArchive, saveEditedValue, testid, }) => { @@ -29,7 +31,7 @@ const EditableTextDescriptionListGroup: React.FC = ({ textToCopy, testId }) => ( + // @ts-expect-error ClipboardCopy expects children of type string in PF v6 + { + navigator.clipboard.writeText(textToCopy); + }} + data-testid={testId} + > + + +); + +export default InlineTruncatedClipboardCopy; From 7990b69b9c545cc97f1aa5281ceeabd501be35c7 Mon Sep 17 00:00:00 2001 From: Lucas Fernandez Date: Fri, 27 Sep 2024 15:18:36 +0200 Subject: [PATCH 13/13] Add Register Model and Register Model Version form (#431) Signed-off-by: lucferbux --- clients/ui/frontend/package-lock.json | 18 ++ clients/ui/frontend/package.json | 1 + .../frontend/src/__mocks__/mockBFFResponse.ts | 4 +- clients/ui/frontend/src/__mocks__/utils.ts | 4 +- .../cypress/cypress/support/commands/api.ts | 16 +- .../src/app/api/__tests__/service.spec.ts | 17 +- clients/ui/frontend/src/app/api/apiUtils.ts | 14 +- clients/ui/frontend/src/app/api/service.ts | 109 +++++--- .../src/app/components/MarkdownView.scss | 46 ++-- .../src/app/components/SimpleSelect.scss | 4 +- .../app/components/design/DividedGallery.scss | 16 +- .../app/components/design/InfoGalleryItem.tsx | 4 +- .../app/components/design/ScrolledGallery.tsx | 4 +- .../app/components/design/TypeBorderCard.scss | 2 +- .../components/pf-overrides/FormSection.scss | 8 + .../components/pf-overrides/FormSection.tsx | 35 +++ .../modelRegistry/ModelRegistryRoutes.tsx | 5 + .../modelRegistry/screens/ModelRegistry.tsx | 2 +- .../PrefilledModelRegistryField.tsx | 14 + .../screens/RegisterModel/RegisterModel.tsx | 116 ++++++++ .../screens/RegisterModel/RegisterVersion.tsx | 146 ++++++++++ .../RegisterModel/RegisteredModelSelector.tsx | 61 +++++ .../RegistrationCommonFormSections.tsx | 249 ++++++++++++++++++ .../RegisterModel/RegistrationFormFooter.tsx | 66 +++++ .../usePrefillRegisterVersionFields.ts | 88 +++++++ .../RegisterModel/useRegisterModelData.ts | 63 +++++ .../useRegistrationCommonState.ts | 40 +++ .../screens/RegisterModel/utils.ts | 111 ++++++++ .../app/pages/modelRegistry/screens/utils.ts | 25 ++ clients/ui/frontend/src/app/types.ts | 3 +- clients/ui/frontend/src/app/utils.ts | 5 +- clients/ui/frontend/src/types.ts | 4 + .../src/utilities/useGenericObjectState.ts | 25 ++ 33 files changed, 1217 insertions(+), 108 deletions(-) create mode 100644 clients/ui/frontend/src/app/components/pf-overrides/FormSection.scss create mode 100644 clients/ui/frontend/src/app/components/pf-overrides/FormSection.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/PrefilledModelRegistryField.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegisterVersion.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegisteredModelSelector.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegistrationCommonFormSections.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegistrationFormFooter.tsx create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/usePrefillRegisterVersionFields.ts create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/useRegisterModelData.ts create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/useRegistrationCommonState.ts create mode 100644 clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/utils.ts create mode 100644 clients/ui/frontend/src/utilities/useGenericObjectState.ts diff --git a/clients/ui/frontend/package-lock.json b/clients/ui/frontend/package-lock.json index 57bb275f..6f7bfc21 100644 --- a/clients/ui/frontend/package-lock.json +++ b/clients/ui/frontend/package-lock.json @@ -13,6 +13,7 @@ "@patternfly/react-icons": "6.0.0-alpha.37", "@patternfly/react-styles": "6.0.0-alpha.35", "@patternfly/react-table": "6.0.0-alpha.101", + "@patternfly/react-templates": "6.0.0-alpha.50", "classnames": "^2.2.6", "dompurify": "^3.1.6", "lodash-es": "^4.17.15", @@ -3647,6 +3648,23 @@ "react-dom": "^17 || ^18" } }, + "node_modules/@patternfly/react-templates": { + "version": "6.0.0-alpha.50", + "resolved": "https://registry.npmjs.org/@patternfly/react-templates/-/react-templates-6.0.0-alpha.50.tgz", + "integrity": "sha512-YmP9iYcejDrnGPadi5Y/qZWG4xmANZe3fB8HMhSWI+CewOAWCifkWC/gv0oaB3eDXCopAnsf0Y6oXR7CNdWgyQ==", + "license": "MIT", + "dependencies": { + "@patternfly/react-core": "^6.0.0-alpha.100", + "@patternfly/react-icons": "^6.0.0-alpha.35", + "@patternfly/react-styles": "^6.0.0-alpha.34", + "@patternfly/react-tokens": "^6.0.0-alpha.34", + "tslib": "^2.6.3" + }, + "peerDependencies": { + "react": "^17 || ^18", + "react-dom": "^17 || ^18" + } + }, "node_modules/@patternfly/react-tokens": { "version": "6.0.0-prerelease.4", "resolved": "https://registry.npmjs.org/@patternfly/react-tokens/-/react-tokens-6.0.0-prerelease.4.tgz", diff --git a/clients/ui/frontend/package.json b/clients/ui/frontend/package.json index 99b0c619..56b47de5 100644 --- a/clients/ui/frontend/package.json +++ b/clients/ui/frontend/package.json @@ -94,6 +94,7 @@ "@patternfly/react-icons": "6.0.0-alpha.37", "@patternfly/react-styles": "6.0.0-alpha.35", "@patternfly/react-table": "6.0.0-alpha.101", + "@patternfly/react-templates": "6.0.0-alpha.50", "lodash-es": "^4.17.15", "npm-run-all": "^4.1.5", "react": "^18", diff --git a/clients/ui/frontend/src/__mocks__/mockBFFResponse.ts b/clients/ui/frontend/src/__mocks__/mockBFFResponse.ts index 8b2f910b..f1599462 100644 --- a/clients/ui/frontend/src/__mocks__/mockBFFResponse.ts +++ b/clients/ui/frontend/src/__mocks__/mockBFFResponse.ts @@ -1,5 +1,5 @@ -import { ModelRegistryResponse } from '~/app/types'; +import { ModelRegistryBody } from '~/app/types'; -export const mockBFFResponse = (data: T): ModelRegistryResponse => ({ +export const mockBFFResponse = (data: T): ModelRegistryBody => ({ data, }); diff --git a/clients/ui/frontend/src/__mocks__/utils.ts b/clients/ui/frontend/src/__mocks__/utils.ts index 1500da7d..43d1768e 100644 --- a/clients/ui/frontend/src/__mocks__/utils.ts +++ b/clients/ui/frontend/src/__mocks__/utils.ts @@ -1,6 +1,6 @@ import { ModelRegistryMetadataType, - ModelRegistryResponse, + ModelRegistryBody, ModelRegistryStringCustomProperties, } from '~/app/types'; @@ -16,6 +16,6 @@ export const createModelRegistryLabelsObject = ( return acc; }, {} as ModelRegistryStringCustomProperties); -export const mockBFFResponse = (data: T): ModelRegistryResponse => ({ +export const mockBFFResponse = (data: T): ModelRegistryBody => ({ data, }); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts index 1b07dfb8..ceaef6fd 100644 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/support/commands/api.ts @@ -3,7 +3,7 @@ import type { ModelArtifact, ModelArtifactList, ModelRegistry, - ModelRegistryResponse, + ModelRegistryBody, ModelVersion, ModelVersionList, RegisteredModel, @@ -35,7 +35,7 @@ declare global { interceptApi: (( type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/registered_models', options: { path: { modelRegistryName: string; apiVersion: string } }, - response: ApiResponse>, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'POST /api/:apiVersion/model_registry/:modelRegistryName/registered_models', @@ -47,7 +47,7 @@ declare global { options: { path: { modelRegistryName: string; apiVersion: string; registeredModelId: number }; }, - response: ApiResponse>, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'POST /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId/versions', @@ -61,28 +61,28 @@ declare global { options: { path: { modelRegistryName: string; apiVersion: string; registeredModelId: number }; }, - response: ApiResponse>, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'PATCH /api/:apiVersion/model_registry/:modelRegistryName/registered_models/:registeredModelId', options: { path: { modelRegistryName: string; apiVersion: string; registeredModelId: number }; }, - response: ApiResponse>, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId', options: { path: { modelRegistryName: string; apiVersion: string; modelVersionId: number }; }, - response: ApiResponse>, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'GET /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId/artifacts', options: { path: { modelRegistryName: string; apiVersion: string; modelVersionId: number }; }, - response: ApiResponse>, + response: ApiResponse>, ) => Cypress.Chainable) & (( type: 'POST /api/:apiVersion/model_registry/:modelRegistryName/model_versions/:modelVersionId/artifacts', @@ -101,7 +101,7 @@ declare global { (( type: 'GET /api/:apiVersion/model_registry', options: { path: { apiVersion: string } }, - response: ApiResponse>, + response: ApiResponse>, ) => Cypress.Chainable); } } diff --git a/clients/ui/frontend/src/app/api/__tests__/service.spec.ts b/clients/ui/frontend/src/app/api/__tests__/service.spec.ts index 7da6df75..85be9763 100644 --- a/clients/ui/frontend/src/app/api/__tests__/service.spec.ts +++ b/clients/ui/frontend/src/app/api/__tests__/service.spec.ts @@ -28,6 +28,7 @@ jest.mock('~/app/api/apiUtils', () => ({ restCREATE: jest.fn(() => mockRestPromise), restGET: jest.fn(() => mockRestPromise), restPATCH: jest.fn(() => mockRestPromise), + assembleModelRegistryBody: jest.fn(() => ({})), isModelRegistryResponse: jest.fn(() => true), })); @@ -61,7 +62,7 @@ describe('createRegisteredModel', () => { expect(restCREATEMock).toHaveBeenCalledWith( `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, `/registered_models`, - mockData, + {}, {}, APIOptionsMock, ); @@ -89,7 +90,7 @@ describe('createModelVersion', () => { expect(restCREATEMock).toHaveBeenCalledWith( `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, `/model_versions`, - mockData, + {}, {}, APIOptionsMock, ); @@ -117,7 +118,7 @@ describe('createModelVersionForRegisteredModel', () => { expect(restCREATEMock).toHaveBeenCalledWith( `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, `/registered_models/1/versions`, - mockData, + {}, {}, APIOptionsMock, ); @@ -150,7 +151,7 @@ describe('createModelArtifact', () => { expect(restCREATEMock).toHaveBeenCalledWith( `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, `/model_artifacts`, - mockData, + {}, {}, APIOptionsMock, ); @@ -183,7 +184,7 @@ describe('createModelArtifactForModelVersion', () => { expect(restCREATEMock).toHaveBeenCalledWith( `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, `/model_versions/2/artifacts`, - mockData, + {}, {}, APIOptionsMock, ); @@ -347,7 +348,7 @@ describe('patchRegisteredModel', () => { expect(restPATCHMock).toHaveBeenCalledWith( `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, `/registered_models/1`, - mockData, + {}, APIOptionsMock, ); expect(handleRestFailuresMock).toHaveBeenCalledTimes(1); @@ -366,7 +367,7 @@ describe('patchModelVersion', () => { expect(restPATCHMock).toHaveBeenCalledWith( `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, `/model_versions/1`, - mockData, + {}, APIOptionsMock, ); expect(handleRestFailuresMock).toHaveBeenCalledTimes(1); @@ -385,7 +386,7 @@ describe('patchModelArtifact', () => { expect(restPATCHMock).toHaveBeenCalledWith( `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, `/model_artifacts/1`, - mockData, + {}, APIOptionsMock, ); expect(handleRestFailuresMock).toHaveBeenCalledTimes(1); diff --git a/clients/ui/frontend/src/app/api/apiUtils.ts b/clients/ui/frontend/src/app/api/apiUtils.ts index 93e16681..d14733c7 100644 --- a/clients/ui/frontend/src/app/api/apiUtils.ts +++ b/clients/ui/frontend/src/app/api/apiUtils.ts @@ -1,6 +1,6 @@ import { APIOptions } from '~/app/api/types'; import { EitherOrNone } from '~/typeHelpers'; -import { ModelRegistryResponse } from '~/app/types'; +import { ModelRegistryBody } from '~/app/types'; export const mergeRequestInit = ( opts: APIOptions = {}, @@ -163,14 +163,16 @@ export const restDELETE = ( parseJSON: options?.parseJSON, }); -export const isModelRegistryResponse = ( - response: unknown, -): response is ModelRegistryResponse => { +export const isModelRegistryResponse = (response: unknown): response is ModelRegistryBody => { if (typeof response === 'object' && response !== null) { // eslint-disable-next-line @typescript-eslint/consistent-type-assertions - const modelRegistryResponse = response as { data?: T }; + const modelRegistryBody = response as { data?: T }; // TODO: Check if data is conforming any type so we have a proper check - return modelRegistryResponse.data !== undefined; + return modelRegistryBody.data !== undefined; } return false; }; + +export const assembleModelRegistryBody = (data: T): ModelRegistryBody => ({ + data, +}); diff --git a/clients/ui/frontend/src/app/api/service.ts b/clients/ui/frontend/src/app/api/service.ts index b894c6df..3579877d 100644 --- a/clients/ui/frontend/src/app/api/service.ts +++ b/clients/ui/frontend/src/app/api/service.ts @@ -9,26 +9,34 @@ import { RegisteredModelList, RegisteredModel, } from '~/app/types'; -import { isModelRegistryResponse, restCREATE, restGET, restPATCH } from '~/app/api/apiUtils'; +import { + assembleModelRegistryBody, + isModelRegistryResponse, + restCREATE, + restGET, + restPATCH, +} from '~/app/api/apiUtils'; import { APIOptions } from '~/app/api/types'; import { handleRestFailures } from '~/app/api/errorUtils'; export const createRegisteredModel = (hostPath: string) => (opts: APIOptions, data: CreateRegisteredModelData): Promise => - handleRestFailures(restCREATE(hostPath, `/registered_models`, data, {}, opts)).then( - (response) => { - if (isModelRegistryResponse(response)) { - return response.data; - } - throw new Error('Invalid response format'); - }, - ); + handleRestFailures( + restCREATE(hostPath, `/registered_models`, assembleModelRegistryBody(data), {}, opts), + ).then((response) => { + if (isModelRegistryResponse(response)) { + return response.data; + } + throw new Error('Invalid response format'); + }); export const createModelVersion = (hostPath: string) => (opts: APIOptions, data: CreateModelVersionData): Promise => - handleRestFailures(restCREATE(hostPath, `/model_versions`, data, {}, opts)).then((response) => { + handleRestFailures( + restCREATE(hostPath, `/model_versions`, assembleModelRegistryBody(data), {}, opts), + ).then((response) => { if (isModelRegistryResponse(response)) { return response.data; } @@ -43,7 +51,13 @@ export const createModelVersionForRegisteredModel = data: CreateModelVersionData, ): Promise => handleRestFailures( - restCREATE(hostPath, `/registered_models/${registeredModelId}/versions`, data, {}, opts), + restCREATE( + hostPath, + `/registered_models/${registeredModelId}/versions`, + assembleModelRegistryBody(data), + {}, + opts, + ), ).then((response) => { if (isModelRegistryResponse(response)) { return response.data; @@ -54,14 +68,14 @@ export const createModelVersionForRegisteredModel = export const createModelArtifact = (hostPath: string) => (opts: APIOptions, data: CreateModelArtifactData): Promise => - handleRestFailures(restCREATE(hostPath, `/model_artifacts`, data, {}, opts)).then( - (response) => { - if (isModelRegistryResponse(response)) { - return response.data; - } - throw new Error('Invalid response format'); - }, - ); + handleRestFailures( + restCREATE(hostPath, `/model_artifacts`, assembleModelRegistryBody(data), {}, opts), + ).then((response) => { + if (isModelRegistryResponse(response)) { + return response.data; + } + throw new Error('Invalid response format'); + }); export const createModelArtifactForModelVersion = (hostPath: string) => @@ -71,7 +85,13 @@ export const createModelArtifactForModelVersion = data: CreateModelArtifactData, ): Promise => handleRestFailures( - restCREATE(hostPath, `/model_versions/${modelVersionId}/artifacts`, data, {}, opts), + restCREATE( + hostPath, + `/model_versions/${modelVersionId}/artifacts`, + assembleModelRegistryBody(data), + {}, + opts, + ), ).then((response) => { if (isModelRegistryResponse(response)) { return response.data; @@ -177,7 +197,12 @@ export const patchRegisteredModel = registeredModelId: string, ): Promise => handleRestFailures( - restPATCH(hostPath, `/registered_models/${registeredModelId}`, data, opts), + restPATCH( + hostPath, + `/registered_models/${registeredModelId}`, + assembleModelRegistryBody(data), + opts, + ), ).then((response) => { if (isModelRegistryResponse(response)) { return response.data; @@ -188,14 +213,19 @@ export const patchRegisteredModel = export const patchModelVersion = (hostPath: string) => (opts: APIOptions, data: Partial, modelversionId: string): Promise => - handleRestFailures(restPATCH(hostPath, `/model_versions/${modelversionId}`, data, opts)).then( - (response) => { - if (isModelRegistryResponse(response)) { - return response.data; - } - throw new Error('Invalid response format'); - }, - ); + handleRestFailures( + restPATCH( + hostPath, + `/model_versions/${modelversionId}`, + assembleModelRegistryBody(data), + opts, + ), + ).then((response) => { + if (isModelRegistryResponse(response)) { + return response.data; + } + throw new Error('Invalid response format'); + }); export const patchModelArtifact = (hostPath: string) => @@ -204,11 +234,16 @@ export const patchModelArtifact = data: Partial, modelartifactId: string, ): Promise => - handleRestFailures(restPATCH(hostPath, `/model_artifacts/${modelartifactId}`, data, opts)).then( - (response) => { - if (isModelRegistryResponse(response)) { - return response.data; - } - throw new Error('Invalid response format'); - }, - ); + handleRestFailures( + restPATCH( + hostPath, + `/model_artifacts/${modelartifactId}`, + assembleModelRegistryBody(data), + opts, + ), + ).then((response) => { + if (isModelRegistryResponse(response)) { + return response.data; + } + throw new Error('Invalid response format'); + }); diff --git a/clients/ui/frontend/src/app/components/MarkdownView.scss b/clients/ui/frontend/src/app/components/MarkdownView.scss index aa367b23..7192c5c2 100644 --- a/clients/ui/frontend/src/app/components/MarkdownView.scss +++ b/clients/ui/frontend/src/app/components/MarkdownView.scss @@ -2,10 +2,10 @@ word-break: break-word; &--with-padding { - padding-bottom: var(--pf-v5-global--spacer--md); + padding-bottom: var(--pf-v6-global--spacer--md); p { - margin-bottom: var(--pf-v5-global--spacer--sm); + margin-bottom: var(--pf-v6-global--spacer--sm); } } @@ -15,35 +15,35 @@ h4, h5, h6 { - font-family: var(--pf-v5-global--FontFamily--heading--sans-serif); - font-weight: var(--pf-v5-global--FontWeight--normal); - margin-top: var(--pf-v5-global--spacer--md); - margin-bottom: var(--pf-v5-global--spacer--sm); + font-family: var(--pf-v6-global--FontFamily--heading--sans-serif); + font-weight: var(--pf-v6-global--FontWeight--normal); + margin-top: var(--pf-v6-global--spacer--md); + margin-bottom: var(--pf-v6-global--spacer--sm); } h1 { - font-size: var(--pf-v5-global--FontSize--2xl); + font-size: var(--pf-v6-global--FontSize--2xl); } h2 { - font-size: var(--pf-v5-global--FontSize--xl); + font-size: var(--pf-v6-global--FontSize--xl); } h3 { - font-size: var(--pf-v5-global--FontSize--lg); + font-size: var(--pf-v6-global--FontSize--lg); } h4, h5, h6 { - font-size: var(--pf-v5-global--FontSize--md); - margin-top: var(--pf-v5-global--spacer--sm); + font-size: var(--pf-v6-global--FontSize--md); + margin-top: var(--pf-v6-global--spacer--sm); } ul, ol { margin-top: 0; - margin-bottom: var(--pf-v5-global--spacer--sm); + margin-bottom: var(--pf-v6-global--spacer--sm); } ul { @@ -51,33 +51,33 @@ } li { - margin-left: var(--pf-v5-global--spacer--lg); + margin-left: var(--pf-v6-global--spacer--lg); } code, pre { font-family: Menlo, Monaco, Consolas, 'Courier New', monospace; - background-color: var(--pf-v5-global--BackgroundColor--200); - border-radius: var(--pf-v5-global--BorderRadius--sm); + background-color: var(--pf-v6-global--BackgroundColor--200); + border-radius: var(--pf-v6-global--BorderRadius--sm); } code { padding: 2px 4px; font-size: 85%; - color: var(--pf-v5-global--danger-color--100); + color: var(--pf-v6-global--danger-color--100); } pre { display: block; - padding: var(--pf-v5-global--spacer--sm); - margin-bottom: var(--pf-v5-global--spacer--sm); - font-size: var(--pf-v5-global--FontSize--sm); - color: var(--pf-v5-global--Color--300); + padding: var(--pf-v6-global--spacer--sm); + margin-bottom: var(--pf-v6-global--spacer--sm); + font-size: var(--pf-v6-global--FontSize--sm); + color: var(--pf-v6-global--Color--300); word-break: break-all; word-wrap: break-word; - background-color: var(--pf-v5-global--BackgroundColor--200); - border: var(--pf-v5-global--BorderWidth--sm) solid var(--pf-v5-global--Color--light-300); - border-radius: var(--pf-v5-global--BorderRadius--sm); + background-color: var(--pf-v6-global--BackgroundColor--200); + border: var(--pf-v6-global--BorderWidth--sm) solid var(--pf-v6-global--Color--light-300); + border-radius: var(--pf-v6-global--BorderRadius--sm); code { padding: 0; font-size: inherit; diff --git a/clients/ui/frontend/src/app/components/SimpleSelect.scss b/clients/ui/frontend/src/app/components/SimpleSelect.scss index 30dd3051..31c2011c 100644 --- a/clients/ui/frontend/src/app/components/SimpleSelect.scss +++ b/clients/ui/frontend/src/app/components/SimpleSelect.scss @@ -4,6 +4,6 @@ // remove this file when https://github.com/patternfly/patternfly/issues/6062 is solved .truncate-no-min-width { - --pf-v5-c-truncate--MinWidth: 0; - --pf-v5-c-truncate__start--MinWidth: 0; + --pf-v6-c-truncate--MinWidth: 0; + --pf-v6-c-truncate__start--MinWidth: 0; } diff --git a/clients/ui/frontend/src/app/components/design/DividedGallery.scss b/clients/ui/frontend/src/app/components/design/DividedGallery.scss index 002cf987..31662f09 100644 --- a/clients/ui/frontend/src/app/components/design/DividedGallery.scss +++ b/clients/ui/frontend/src/app/components/design/DividedGallery.scss @@ -1,27 +1,27 @@ .kubeflowdivided-gallery { position: relative; - background-color: var(--pf-v5-global--BackgroundColor--100); + background-color: var(--pf-v6-global--BackgroundColor--100); &__border { width: 2px; position: absolute; - top: var(--pf-v5-global--spacer--md); - bottom: var(--pf-v5-global--spacer--md); + top: var(--pf-v6-global--spacer--md); + bottom: var(--pf-v6-global--spacer--md); left: 0; background-color: white; content: ' '; } &__item { - border-left: 1px solid var(--pf-v5-global--BorderColor--100); - padding: 0 var(--pf-v5-global--spacer--xl); - margin: var(--pf-v5-global--spacer--md) 0; + border-left: 1px solid var(--pf-v6-global--BorderColor--100); + padding: 0 var(--pf-v6-global--spacer--xl); + margin: var(--pf-v6-global--spacer--md) 0; } - .pf-v5-l-gallery { + .pf-v6-l-gallery { position: relative; } &__close { position: absolute; - top: var(--pf-v5-global--spacer--xs); + top: var(--pf-v6-global--spacer--xs); right: 0; } } diff --git a/clients/ui/frontend/src/app/components/design/InfoGalleryItem.tsx b/clients/ui/frontend/src/app/components/design/InfoGalleryItem.tsx index 3ce5f3f1..eaf8d9f9 100644 --- a/clients/ui/frontend/src/app/components/design/InfoGalleryItem.tsx +++ b/clients/ui/frontend/src/app/components/design/InfoGalleryItem.tsx @@ -66,8 +66,8 @@ const InfoGalleryItem: React.FC = ({ isInline onClick={onClick} style={{ - fontSize: 'var(--pf-v5-global--FontSize--md)', - fontWeight: 'var(--pf-v5-global--FontWeight--bold)', + fontSize: 'var(--pf-v6-global--FontSize--md)', + fontWeight: 'var(--pf-v6-global--FontWeight--bold)', }} > {title} diff --git a/clients/ui/frontend/src/app/components/design/ScrolledGallery.tsx b/clients/ui/frontend/src/app/components/design/ScrolledGallery.tsx index 3d619970..4c85f141 100644 --- a/clients/ui/frontend/src/app/components/design/ScrolledGallery.tsx +++ b/clients/ui/frontend/src/app/components/design/ScrolledGallery.tsx @@ -22,8 +22,8 @@ const ScrolledGallery: React.FC = ({ display: 'grid', gridAutoFlow: 'column', overflowY: 'auto', - gap: 'var(--pf-v5-global--spacer--md)', - paddingBottom: 'var(--pf-v5-global--spacer--sm)', + gap: 'var(--pf-v6-global--spacer--md)', + paddingBottom: 'var(--pf-v6-global--spacer--sm)', }} {...rest} > diff --git a/clients/ui/frontend/src/app/components/design/TypeBorderCard.scss b/clients/ui/frontend/src/app/components/design/TypeBorderCard.scss index 799ac8c3..ea20bbf8 100644 --- a/clients/ui/frontend/src/app/components/design/TypeBorderCard.scss +++ b/clients/ui/frontend/src/app/components/design/TypeBorderCard.scss @@ -1,7 +1,7 @@ .kubeflowtype-bordered-card { position: relative; border-radius: 16px; - border: 1px solid var(--pf-v5-global--BorderColor--100); + border: 1px solid var(--pf-v6-global--BorderColor--100); padding: 1px; &:after { diff --git a/clients/ui/frontend/src/app/components/pf-overrides/FormSection.scss b/clients/ui/frontend/src/app/components/pf-overrides/FormSection.scss new file mode 100644 index 00000000..351e535b --- /dev/null +++ b/clients/ui/frontend/src/app/components/pf-overrides/FormSection.scss @@ -0,0 +1,8 @@ +.kf-form-section { + &__desc { + margin-top: var(--pf-v6-global--spacer--sm); + font-size: var(--pf-v6-global--FontSize--sm); + color: var(--pf-v6-global--Color--200); + font-weight: initial; + } +} diff --git a/clients/ui/frontend/src/app/components/pf-overrides/FormSection.tsx b/clients/ui/frontend/src/app/components/pf-overrides/FormSection.tsx new file mode 100644 index 00000000..17fc71b1 --- /dev/null +++ b/clients/ui/frontend/src/app/components/pf-overrides/FormSection.tsx @@ -0,0 +1,35 @@ +import * as React from 'react'; +import { FormSection as PFFormSection, FormSectionProps, Content } from '@patternfly/react-core'; + +import './FormSection.scss'; + +type Props = FormSectionProps & { + description?: React.ReactNode; +}; + +// Remove once https://github.com/patternfly/patternfly/issues/6663 is fixed +const FormSection: React.FC = ({ + description, + title, + titleElement: TitleElement = 'div', + ...props +}) => ( + + {title} + + {description} + + + ) : ( + title + ) + } + /> +); + +export default FormSection; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx index c64a58c7..c37d93a2 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/ModelRegistryRoutes.tsx @@ -12,6 +12,8 @@ import ModelVersionsArchive from './screens/ModelVersionsArchive/ModelVersionsAr import ModelVersionsArchiveDetails from './screens/ModelVersionsArchive/ModelVersionArchiveDetails'; import ArchiveModelVersionDetails from './screens/ModelVersionsArchive/ArchiveModelVersionDetails'; import RegisteredModelsArchiveDetails from './screens/RegisteredModelsArchive/RegisteredModelArchiveDetails'; +import RegisterModel from './screens/RegisterModel/RegisterModel'; +import RegisterVersion from './screens/RegisterModel/RegisterVersion'; const ModelRegistryRoutes: React.FC = () => ( @@ -34,6 +36,7 @@ const ModelRegistryRoutes: React.FC = () => ( path={ModelVersionsTab.DETAILS} element={} /> + } /> } /> ( } /> + } /> + } /> } /> diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistry.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistry.tsx index ddc33484..75cf9f11 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistry.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelRegistry.tsx @@ -3,7 +3,7 @@ import ApplicationsPage from '~/app/components/ApplicationsPage'; import TitleWithIcon from '~/app/components/design/TitleWithIcon'; import { ProjectObjectType } from '~/app/components/design/utils'; import useRegisteredModels from '~/app/hooks/useRegisteredModels'; -import { filterLiveModels } from '~/app/utils'; +import { filterLiveModels } from '~/app/pages/modelRegistry/screens/utils'; import ModelRegistrySelectorNavigator from './ModelRegistrySelectorNavigator'; import RegisteredModelListView from './RegisteredModels/RegisteredModelListView'; import { modelRegistryUrl } from './routeUtils'; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/PrefilledModelRegistryField.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/PrefilledModelRegistryField.tsx new file mode 100644 index 00000000..980da8d3 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/PrefilledModelRegistryField.tsx @@ -0,0 +1,14 @@ +import React from 'react'; +import { FormGroup, TextInput } from '@patternfly/react-core'; + +type PrefilledModelRegistryFieldProps = { + mrName?: string; +}; + +const PrefilledModelRegistryField: React.FC = ({ mrName }) => ( + + + +); + +export default PrefilledModelRegistryField; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx new file mode 100644 index 00000000..c6cffb9d --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx @@ -0,0 +1,116 @@ +import React from 'react'; +import { + Breadcrumb, + BreadcrumbItem, + Form, + FormGroup, + PageSection, + Stack, + StackItem, + TextArea, + TextInput, +} from '@patternfly/react-core'; +import spacing from '@patternfly/react-styles/css/utilities/Spacing/spacing'; +import { useParams, useNavigate } from 'react-router'; +import { Link } from 'react-router-dom'; +import FormSection from '~/app/components/pf-overrides/FormSection'; +import ApplicationsPage from '~/app/components/ApplicationsPage'; +import { modelRegistryUrl, registeredModelUrl } from '~/app/pages/modelRegistry/screens/routeUtils'; +import { ValueOf } from '~/typeHelpers'; +import { useRegisterModelData, RegistrationCommonFormData } from './useRegisterModelData'; +import { isRegisterModelSubmitDisabled, registerModel } from './utils'; +import { useRegistrationCommonState } from './useRegistrationCommonState'; +import RegistrationCommonFormSections from './RegistrationCommonFormSections'; +import PrefilledModelRegistryField from './PrefilledModelRegistryField'; +import RegistrationFormFooter from './RegistrationFormFooter'; + +const RegisterModel: React.FC = () => { + const { modelRegistry: mrName } = useParams(); + const navigate = useNavigate(); + + const { isSubmitting, submitError, setSubmitError, handleSubmit, apiState, author } = + useRegistrationCommonState(); + + const [formData, setData] = useRegisterModelData(); + const isSubmitDisabled = isSubmitting || isRegisterModelSubmitDisabled(formData); + const { modelName, modelDescription } = formData; + + const onSubmit = () => + handleSubmit(async () => { + const { registeredModel } = await registerModel(apiState, formData, author); + navigate(registeredModelUrl(registeredModel.id, mrName)); + }); + const onCancel = () => navigate(modelRegistryUrl(mrName)); + + return ( + + Model registry - {mrName}} + /> + Register model + + } + loaded + empty={false} + > + +
+ + + + + + + + setData('modelName', value)} + /> + + +