Skip to content

Commit

Permalink
Merge pull request #121 from kubeflow/main
Browse files Browse the repository at this point in the history
[pull] main from kubeflow:main
  • Loading branch information
openshift-merge-bot[bot] committed Sep 13, 2024
2 parents 108f697 + 95e6b7f commit 0cd206f
Show file tree
Hide file tree
Showing 63 changed files with 2,476 additions and 379 deletions.
15 changes: 10 additions & 5 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,23 @@ for version in registry.get_model_versions("my-model"):
...
```

To customize sorting order or query limits you can also use
You can also use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order

```py
latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending().limit(20)
latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending()
for version in latest_updates:
...
```

You can use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order.
By default, all queries will be `ascending`, but this method is also available for explicitness.

> Note that the `limit()` method only limits the query size, not the actual loop boundaries -- even if your limit is 1
> you will still get all the models, with one query each.
> Note: You can also set the `page_size()` that you want the Pager to use when invoking the Model Registry backend.
> When using it as an iterator, it will automatically manage pages for you.
#### Implementation notes

The pager will manage pages for you in order to prevent infinite looping.
Currently, the Model Registry backend treats model lists as a circular buffer, and **will not end iteration** for you.

## Development

Expand Down
3 changes: 2 additions & 1 deletion clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
from collections.abc import Mapping
from pathlib import Path
from typing import Any, TypeVar, Union, get_args
from warnings import warn
Expand Down Expand Up @@ -138,7 +139,7 @@ def register_model(
author: str | None = None,
owner: str | None = None,
description: str | None = None,
metadata: dict[str, SupportedTypes] | None = None,
metadata: Mapping[str, SupportedTypes] | None = None,
) -> RegisteredModel:
"""Register a model.
Expand Down
4 changes: 2 additions & 2 deletions clients/python/src/model_registry/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Union, get_args

from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -35,7 +35,7 @@ class BaseResourceModel(BaseModel, ABC):
external_id: str | None = None
create_time_since_epoch: str | None = None
last_update_time_since_epoch: str | None = None
custom_properties: dict[str, SupportedTypes] | None = None
custom_properties: Mapping[str, SupportedTypes] | None = None

@abstractmethod
def create(self, **kwargs) -> Any:
Expand Down
9 changes: 6 additions & 3 deletions clients/python/src/model_registry/types/pager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,15 @@ def order_by_id(self) -> Pager[T]:
self.options.order_by = OrderByField.ID
return self.restart()

def limit(self, limit: int) -> Pager[T]:
"""Limit the number of items to return.
def page_size(self, n: int) -> Pager[T]:
"""Set the page size for each request.
This resets the pager.
"""
self.options.limit = limit
if n < 1:
msg = f"Page size must be at least 1, got {n}"
raise ValueError(msg)
self.options.limit = n
return self.restart()

def ascending(self) -> Pager[T]:
Expand Down
49 changes: 40 additions & 9 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def test_register_existing_version(client: ModelRegistry):
"model_format_version": "test_version",
"version": "1.0.0",
}
client.register_model(**params)
client.register_model(**params, metadata=None)

with pytest.raises(StoreError):
client.register_model(**params)
client.register_model(**params, metadata=None)


@pytest.mark.e2e
Expand Down Expand Up @@ -124,8 +124,10 @@ async def test_update_logical_model_with_labels(client: ModelRegistry):
)
assert rm.id
mv = client.get_model_version(name, version)
assert mv
assert mv.id
ma = client.get_model_artifact(name, version)
assert ma
assert ma.id

rm_labels = {
Expand All @@ -149,9 +151,15 @@ async def test_update_logical_model_with_labels(client: ModelRegistry):
ma.custom_properties = ma_labels
client.update(ma)

assert client.get_registered_model(name).custom_properties == rm_labels
assert client.get_model_version(name, version).custom_properties == mv_labels
assert client.get_model_artifact(name, version).custom_properties == ma_labels
rm = client.get_registered_model(name)
assert rm
assert rm.custom_properties == rm_labels
mv = client.get_model_version(name, version)
assert mv
assert mv.custom_properties == mv_labels
ma = client.get_model_artifact(name, version)
assert ma
assert ma.custom_properties == ma_labels


@pytest.mark.e2e
Expand Down Expand Up @@ -232,7 +240,7 @@ def test_get_registered_models(client: ModelRegistry):
version="1.0.0",
)

rm_iter = client.get_registered_models().limit(10)
rm_iter = client.get_registered_models().page_size(10)
i = 0
prev_tok = None
changes = 0
Expand Down Expand Up @@ -315,6 +323,17 @@ def test_get_registered_models_order_by(client: ModelRegistry):

assert i == models

# or if descending is explicitly set
i = 0
for rm, by_update in zip(
rms,
client.get_registered_models().order_by_update_time().descending(),
):
assert rm.id == by_update.id
i += 1

assert i == models


@pytest.mark.e2e
def test_get_registered_models_and_reset(client: ModelRegistry):
Expand All @@ -330,7 +349,7 @@ def test_get_registered_models_and_reset(client: ModelRegistry):
version="1.0.0",
)

rm_iter = client.get_registered_models().limit(model_count - 1)
rm_iter = client.get_registered_models().page_size(model_count - 1)
models = []
for rm in islice(rm_iter, page):
models.append(rm)
Expand All @@ -355,7 +374,7 @@ def test_get_model_versions(client: ModelRegistry):
version=v,
)

mv_iter = client.get_model_versions(name).limit(10)
mv_iter = client.get_model_versions(name).page_size(10)
i = 0
prev_tok = None
changes = 0
Expand Down Expand Up @@ -430,6 +449,18 @@ def test_get_model_versions_order_by(client: ModelRegistry):
assert mv.id == by_update.id
i += 1

assert i == models

i = 0
for mv, by_update in zip(
mvs,
client.get_model_versions(name).order_by_update_time().descending(),
):
assert mv.id == by_update.id
i += 1

assert i == models


@pytest.mark.e2e
def test_get_model_versions_and_reset(client: ModelRegistry):
Expand All @@ -447,7 +478,7 @@ def test_get_model_versions_and_reset(client: ModelRegistry):
version=v,
)

mv_iter = client.get_model_versions(name).limit(model_count - 1)
mv_iter = client.get_model_versions(name).page_size(model_count - 1)
models = []
for rm in islice(mv_iter, page):
models.append(rm)
Expand Down
10 changes: 6 additions & 4 deletions clients/python/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async def test_get_registered_model_by_external_id(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
):
assert registered_model.external_id
assert (
rm := await client.get_registered_model_by_params(
external_id=registered_model.external_id
Expand All @@ -99,7 +100,7 @@ async def test_page_through_registered_models(client: ModelRegistryAPIClient):
models = 6
for i in range(models):
await client.upsert_registered_model(RegisteredModel(name=f"rm{i}"))
pager = Pager(client.get_registered_models).limit(5)
pager = Pager(client.get_registered_models).page_size(5)
total = 0
async for _ in pager:
total += 1
Expand Down Expand Up @@ -205,7 +206,7 @@ async def test_page_through_model_versions(
)
pager = Pager(
lambda o: client.get_model_versions(str(registered_model.id), o)
).limit(5)
).page_size(5)
total = 0
async for _ in pager:
total += 1
Expand All @@ -227,7 +228,8 @@ async def test_insert_model_artifact(
"service_account_name": "test service account",
}
ma = await client.upsert_model_artifact(
ModelArtifact(**props), str(model_version.id)
ModelArtifact(**props), # type: ignore
str(model_version.id),
)
assert ma.id
assert ma.name == "test model"
Expand Down Expand Up @@ -340,7 +342,7 @@ async def test_page_through_model_version_artifacts(
await client.create_model_version_artifact(art, str(model_version.id))
pager = Pager(
lambda o: client.get_model_version_artifacts(str(model_version.id), o)
).limit(5)
).page_size(5)
total = 0
async for _ in pager:
total += 1
Expand Down
28 changes: 14 additions & 14 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,29 @@ make docker-build
| URL Pattern | Handler | Action |
|------------------------------------------------------------------------------------|-------------------------|----------------------------------------------|
| GET /v1/healthcheck | HealthcheckHandler | Show application information. |
| GET /v1/model-registry | ModelRegistryHandler | Get all model registries, |
| GET /v1/model-registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Gets a list of all RegisteredModel entities. |
| POST /v1/model-registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Create a RegisteredModel entity. |
| GET /v1/model-registry/{model_registry_id}/registered_models/{registered_model_id} | RegisteredModelHandler | Get a RegisteredModel entity by ID |
| GET /v1/model_registry | ModelRegistryHandler | Get all model registries, |
| GET /v1/model_registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Gets a list of all RegisteredModel entities. |
| POST /v1/model_registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Create a RegisteredModel entity. |
| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | RegisteredModelHandler | Get a RegisteredModel entity by ID |

### Sample local calls
```
# GET /v1/healthcheck
curl -i localhost:4000/api/v1/healthcheck
```
```
# GET /v1/model-registry
curl -i localhost:4000/api/v1/model-registry
# GET /v1/model_registry
curl -i localhost:4000/api/v1/model_registry
```
```
# GET /v1/model-registry/{model_registry_id}/registered_models
curl -i localhost:4000/api/v1/model-registry/model-registry/registered_models
# GET /v1/model_registry/{model_registry_id}/registered_models
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models
```
```
#POST /v1/model-registry/{model_registry_id}/registered_models
curl -i -X POST "http://localhost:4000/api/v1/model-registry/model-registry/registered_models" \
#POST /v1/model_registry/{model_registry_id}/registered_models
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \
-H "Content-Type: application/json" \
-d '{
-d '{ "data": {
"customProperties": {
"my-label9": {
"metadataType": "MetadataStringValue",
Expand All @@ -91,9 +91,9 @@ curl -i -X POST "http://localhost:4000/api/v1/model-registry/model-registry/regi
"name": "bella",
"owner": "eder",
"state": "LIVE"
}'
}}'
```
```
# GET /v1/model-registry/{model_registry_id}/registered_models/{registered_model_id}
curl -i localhost:4000/api/v1/model-registry/model-registry/registered_models/1
# GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1
```
7 changes: 4 additions & 3 deletions clients/ui/bff/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package api

import (
"fmt"
"log/slog"
"net/http"

"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/ui/bff/config"
"github.com/kubeflow/model-registry/ui/bff/data"
"github.com/kubeflow/model-registry/ui/bff/integrations"
"github.com/kubeflow/model-registry/ui/bff/internals/mocks"
"log/slog"
"net/http"
)

const (
Expand All @@ -17,7 +18,7 @@ const (
ModelRegistryId = "model_registry_id"
RegisteredModelId = "registered_model_id"
HealthCheckPath = PathPrefix + "/healthcheck"
ModelRegistry = PathPrefix + "/model-registry"
ModelRegistry = PathPrefix + "/model_registry"
RegisteredModelsPath = ModelRegistry + "/:" + ModelRegistryId + "/registered_models"
RegisteredModelPath = RegisteredModelsPath + "/:" + RegisteredModelId
)
Expand Down
6 changes: 5 additions & 1 deletion clients/ui/bff/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ type ErrorResponse struct {
Message string `json:"message"`
}

type ErrorEnvelope struct {
Error *integrations.HTTPError `json:"error"`
}

func (app *App) LogError(r *http.Request, err error) {
var (
method = r.Method
Expand All @@ -40,7 +44,7 @@ func (app *App) badRequestResponse(w http.ResponseWriter, r *http.Request, err e

func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *integrations.HTTPError) {

env := Envelope{"error": error}
env := ErrorEnvelope{Error: error}

err := app.WriteJSON(w, error.StatusCode, env, nil)

Expand Down
13 changes: 10 additions & 3 deletions clients/ui/bff/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (
"strings"
)

type Envelope map[string]interface{}
type Envelope[D any, M any] struct {
Data D `json:"data,omitempty"`
Metadata M `json:"metadata,omitempty"`
}

type TypedEnvelope[T any] map[string]T
type None *struct{}

func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers http.Header) error {

Expand All @@ -29,7 +32,11 @@ func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers h

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
w.Write(js)
_, err = w.Write(js)

if err != nil {
return err
}

return nil
}
Expand Down
7 changes: 5 additions & 2 deletions clients/ui/bff/api/model_registry_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package api

import (
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/ui/bff/data"
"net/http"
)

type ModelRegistryListEnvelope Envelope[[]data.ModelRegistryModel, None]

func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {

registries, err := app.models.ModelRegistry.FetchAllModelRegistries(app.kubernetesClient)
Expand All @@ -13,8 +16,8 @@ func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps
return
}

modelRegistryRes := Envelope{
"model_registry": registries,
modelRegistryRes := ModelRegistryListEnvelope{
Data: registries,
}

err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil)
Expand Down
Loading

0 comments on commit 0cd206f

Please sign in to comment.