Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from kubeflow:main #121

Merged
merged 3 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,23 @@ for version in registry.get_model_versions("my-model"):
...
```

To customize sorting order or query limits you can also use
You can also use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order

```py
latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending().limit(20)
latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending()
for version in latest_updates:
...
```

You can use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order.
By default, all queries will be `ascending`, but this method is also available for explicitness.

> Note that the `limit()` method only limits the query size, not the actual loop boundaries -- even if your limit is 1
> you will still get all the models, with one query each.
> Note: You can also set the `page_size()` that you want the Pager to use when invoking the Model Registry backend.
> When using it as an iterator, it will automatically manage pages for you.

#### Implementation notes

The pager will manage pages for you in order to prevent infinite looping.
Currently, the Model Registry backend treats model lists as a circular buffer, and **will not end iteration** for you.

## Development

Expand Down
3 changes: 2 additions & 1 deletion clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
from collections.abc import Mapping
from pathlib import Path
from typing import Any, TypeVar, Union, get_args
from warnings import warn
Expand Down Expand Up @@ -138,7 +139,7 @@ def register_model(
author: str | None = None,
owner: str | None = None,
description: str | None = None,
metadata: dict[str, SupportedTypes] | None = None,
metadata: Mapping[str, SupportedTypes] | None = None,
) -> RegisteredModel:
"""Register a model.

Expand Down
4 changes: 2 additions & 2 deletions clients/python/src/model_registry/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Union, get_args

from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -35,7 +35,7 @@ class BaseResourceModel(BaseModel, ABC):
external_id: str | None = None
create_time_since_epoch: str | None = None
last_update_time_since_epoch: str | None = None
custom_properties: dict[str, SupportedTypes] | None = None
custom_properties: Mapping[str, SupportedTypes] | None = None

@abstractmethod
def create(self, **kwargs) -> Any:
Expand Down
9 changes: 6 additions & 3 deletions clients/python/src/model_registry/types/pager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,15 @@ def order_by_id(self) -> Pager[T]:
self.options.order_by = OrderByField.ID
return self.restart()

def limit(self, limit: int) -> Pager[T]:
"""Limit the number of items to return.
def page_size(self, n: int) -> Pager[T]:
"""Set the page size for each request.

This resets the pager.
"""
self.options.limit = limit
if n < 1:
msg = f"Page size must be at least 1, got {n}"
raise ValueError(msg)
self.options.limit = n
return self.restart()

def ascending(self) -> Pager[T]:
Expand Down
49 changes: 40 additions & 9 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def test_register_existing_version(client: ModelRegistry):
"model_format_version": "test_version",
"version": "1.0.0",
}
client.register_model(**params)
client.register_model(**params, metadata=None)

with pytest.raises(StoreError):
client.register_model(**params)
client.register_model(**params, metadata=None)


@pytest.mark.e2e
Expand Down Expand Up @@ -124,8 +124,10 @@ async def test_update_logical_model_with_labels(client: ModelRegistry):
)
assert rm.id
mv = client.get_model_version(name, version)
assert mv
assert mv.id
ma = client.get_model_artifact(name, version)
assert ma
assert ma.id

rm_labels = {
Expand All @@ -149,9 +151,15 @@ async def test_update_logical_model_with_labels(client: ModelRegistry):
ma.custom_properties = ma_labels
client.update(ma)

assert client.get_registered_model(name).custom_properties == rm_labels
assert client.get_model_version(name, version).custom_properties == mv_labels
assert client.get_model_artifact(name, version).custom_properties == ma_labels
rm = client.get_registered_model(name)
assert rm
assert rm.custom_properties == rm_labels
mv = client.get_model_version(name, version)
assert mv
assert mv.custom_properties == mv_labels
ma = client.get_model_artifact(name, version)
assert ma
assert ma.custom_properties == ma_labels


@pytest.mark.e2e
Expand Down Expand Up @@ -232,7 +240,7 @@ def test_get_registered_models(client: ModelRegistry):
version="1.0.0",
)

rm_iter = client.get_registered_models().limit(10)
rm_iter = client.get_registered_models().page_size(10)
i = 0
prev_tok = None
changes = 0
Expand Down Expand Up @@ -315,6 +323,17 @@ def test_get_registered_models_order_by(client: ModelRegistry):

assert i == models

# or if descending is explicitly set
i = 0
for rm, by_update in zip(
rms,
client.get_registered_models().order_by_update_time().descending(),
):
assert rm.id == by_update.id
i += 1

assert i == models


@pytest.mark.e2e
def test_get_registered_models_and_reset(client: ModelRegistry):
Expand All @@ -330,7 +349,7 @@ def test_get_registered_models_and_reset(client: ModelRegistry):
version="1.0.0",
)

rm_iter = client.get_registered_models().limit(model_count - 1)
rm_iter = client.get_registered_models().page_size(model_count - 1)
models = []
for rm in islice(rm_iter, page):
models.append(rm)
Expand All @@ -355,7 +374,7 @@ def test_get_model_versions(client: ModelRegistry):
version=v,
)

mv_iter = client.get_model_versions(name).limit(10)
mv_iter = client.get_model_versions(name).page_size(10)
i = 0
prev_tok = None
changes = 0
Expand Down Expand Up @@ -430,6 +449,18 @@ def test_get_model_versions_order_by(client: ModelRegistry):
assert mv.id == by_update.id
i += 1

assert i == models

i = 0
for mv, by_update in zip(
mvs,
client.get_model_versions(name).order_by_update_time().descending(),
):
assert mv.id == by_update.id
i += 1

assert i == models


@pytest.mark.e2e
def test_get_model_versions_and_reset(client: ModelRegistry):
Expand All @@ -447,7 +478,7 @@ def test_get_model_versions_and_reset(client: ModelRegistry):
version=v,
)

mv_iter = client.get_model_versions(name).limit(model_count - 1)
mv_iter = client.get_model_versions(name).page_size(model_count - 1)
models = []
for rm in islice(mv_iter, page):
models.append(rm)
Expand Down
10 changes: 6 additions & 4 deletions clients/python/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async def test_get_registered_model_by_external_id(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
):
assert registered_model.external_id
assert (
rm := await client.get_registered_model_by_params(
external_id=registered_model.external_id
Expand All @@ -99,7 +100,7 @@ async def test_page_through_registered_models(client: ModelRegistryAPIClient):
models = 6
for i in range(models):
await client.upsert_registered_model(RegisteredModel(name=f"rm{i}"))
pager = Pager(client.get_registered_models).limit(5)
pager = Pager(client.get_registered_models).page_size(5)
total = 0
async for _ in pager:
total += 1
Expand Down Expand Up @@ -205,7 +206,7 @@ async def test_page_through_model_versions(
)
pager = Pager(
lambda o: client.get_model_versions(str(registered_model.id), o)
).limit(5)
).page_size(5)
total = 0
async for _ in pager:
total += 1
Expand All @@ -227,7 +228,8 @@ async def test_insert_model_artifact(
"service_account_name": "test service account",
}
ma = await client.upsert_model_artifact(
ModelArtifact(**props), str(model_version.id)
ModelArtifact(**props), # type: ignore
str(model_version.id),
)
assert ma.id
assert ma.name == "test model"
Expand Down Expand Up @@ -340,7 +342,7 @@ async def test_page_through_model_version_artifacts(
await client.create_model_version_artifact(art, str(model_version.id))
pager = Pager(
lambda o: client.get_model_version_artifacts(str(model_version.id), o)
).limit(5)
).page_size(5)
total = 0
async for _ in pager:
total += 1
Expand Down
28 changes: 14 additions & 14 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,29 @@ make docker-build
| URL Pattern | Handler | Action |
|------------------------------------------------------------------------------------|-------------------------|----------------------------------------------|
| GET /v1/healthcheck | HealthcheckHandler | Show application information. |
| GET /v1/model-registry | ModelRegistryHandler | Get all model registries, |
| GET /v1/model-registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Gets a list of all RegisteredModel entities. |
| POST /v1/model-registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Create a RegisteredModel entity. |
| GET /v1/model-registry/{model_registry_id}/registered_models/{registered_model_id} | RegisteredModelHandler | Get a RegisteredModel entity by ID |
| GET /v1/model_registry | ModelRegistryHandler | Get all model registries, |
| GET /v1/model_registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Gets a list of all RegisteredModel entities. |
| POST /v1/model_registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Create a RegisteredModel entity. |
| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | RegisteredModelHandler | Get a RegisteredModel entity by ID |

### Sample local calls
```
# GET /v1/healthcheck
curl -i localhost:4000/api/v1/healthcheck
```
```
# GET /v1/model-registry
curl -i localhost:4000/api/v1/model-registry
# GET /v1/model_registry
curl -i localhost:4000/api/v1/model_registry
```
```
# GET /v1/model-registry/{model_registry_id}/registered_models
curl -i localhost:4000/api/v1/model-registry/model-registry/registered_models
# GET /v1/model_registry/{model_registry_id}/registered_models
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models
```
```
#POST /v1/model-registry/{model_registry_id}/registered_models
curl -i -X POST "http://localhost:4000/api/v1/model-registry/model-registry/registered_models" \
#POST /v1/model_registry/{model_registry_id}/registered_models
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \
-H "Content-Type: application/json" \
-d '{
-d '{ "data": {
"customProperties": {
"my-label9": {
"metadataType": "MetadataStringValue",
Expand All @@ -91,9 +91,9 @@ curl -i -X POST "http://localhost:4000/api/v1/model-registry/model-registry/regi
"name": "bella",
"owner": "eder",
"state": "LIVE"
}'
}}'
```
```
# GET /v1/model-registry/{model_registry_id}/registered_models/{registered_model_id}
curl -i localhost:4000/api/v1/model-registry/model-registry/registered_models/1
# GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1
```
7 changes: 4 additions & 3 deletions clients/ui/bff/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package api

import (
"fmt"
"log/slog"
"net/http"

"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/ui/bff/config"
"github.com/kubeflow/model-registry/ui/bff/data"
"github.com/kubeflow/model-registry/ui/bff/integrations"
"github.com/kubeflow/model-registry/ui/bff/internals/mocks"
"log/slog"
"net/http"
)

const (
Expand All @@ -17,7 +18,7 @@ const (
ModelRegistryId = "model_registry_id"
RegisteredModelId = "registered_model_id"
HealthCheckPath = PathPrefix + "/healthcheck"
ModelRegistry = PathPrefix + "/model-registry"
ModelRegistry = PathPrefix + "/model_registry"
RegisteredModelsPath = ModelRegistry + "/:" + ModelRegistryId + "/registered_models"
RegisteredModelPath = RegisteredModelsPath + "/:" + RegisteredModelId
)
Expand Down
6 changes: 5 additions & 1 deletion clients/ui/bff/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ type ErrorResponse struct {
Message string `json:"message"`
}

type ErrorEnvelope struct {
Error *integrations.HTTPError `json:"error"`
}

func (app *App) LogError(r *http.Request, err error) {
var (
method = r.Method
Expand All @@ -40,7 +44,7 @@ func (app *App) badRequestResponse(w http.ResponseWriter, r *http.Request, err e

func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *integrations.HTTPError) {

env := Envelope{"error": error}
env := ErrorEnvelope{Error: error}

err := app.WriteJSON(w, error.StatusCode, env, nil)

Expand Down
13 changes: 10 additions & 3 deletions clients/ui/bff/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (
"strings"
)

type Envelope map[string]interface{}
type Envelope[D any, M any] struct {
Data D `json:"data,omitempty"`
Metadata M `json:"metadata,omitempty"`
}

type TypedEnvelope[T any] map[string]T
type None *struct{}

func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers http.Header) error {

Expand All @@ -29,7 +32,11 @@ func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers h

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
w.Write(js)
_, err = w.Write(js)

if err != nil {
return err
}

return nil
}
Expand Down
7 changes: 5 additions & 2 deletions clients/ui/bff/api/model_registry_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package api

import (
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/ui/bff/data"
"net/http"
)

type ModelRegistryListEnvelope Envelope[[]data.ModelRegistryModel, None]

func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {

registries, err := app.models.ModelRegistry.FetchAllModelRegistries(app.kubernetesClient)
Expand All @@ -13,8 +16,8 @@ func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps
return
}

modelRegistryRes := Envelope{
"model_registry": registries,
modelRegistryRes := ModelRegistryListEnvelope{
Data: registries,
}

err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil)
Expand Down
Loading