Skip to content

Commit

Permalink
Merge sync remote-tracking branch 'upstream/main' into tarilabs-20240…
Browse files Browse the repository at this point in the history
…909-sync

Signed-off-by: Matteo Mortari <matteo.mortari@gmail.com>
  • Loading branch information
tarilabs committed Sep 9, 2024
2 parents 56a2713 + 12f6cb9 commit d15c643
Show file tree
Hide file tree
Showing 77 changed files with 3,564 additions and 415 deletions.
33 changes: 33 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"Area/UI":
- changed-files:
- any-glob-to-any-file: "clients/ui/**"

"Area/MR Python client":
- changed-files:
- any-glob-to-any-file: "clients/python/**"

"Area/Go REST server":
- changed-files:
- any-glob-to-any-file:
- "api/**"
- "cmd/**"
- "internal/**"
- "patches/**"
- "pkg/**"
- "templates/go-server/**"

"Area/CSI":
- changed-files:
- any-glob-to-any-file: "csi/**"

"Area/Manifests":
- changed-files:
- any-glob-to-any-file: "manifests/**"

"Area/Documentation":
- changed-files:
- any-glob-to-any-file: "docs/**"

"Area/GitHub":
- changed-files:
- any-glob-to-any-file: ".github/**"
12 changes: 12 additions & 0 deletions .github/workflows/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: "Pull Request Labeler"
on:
- pull_request_target

jobs:
labeler:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v5
9 changes: 8 additions & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
if [[ ${{ matrix.session }} == "tests" ]]; then
make build-mr
nox --python=${{ matrix.python }} -- --cov-report=xml
poetry build
elif [[ ${{ matrix.session }} == "mypy" ]]; then
nox --python=${{ matrix.python }} ||\
echo "::error title='mypy failure'::Check the logs for more details"
Expand All @@ -80,9 +81,15 @@ jobs:
files: coverage.xml
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload dist
if: matrix.session == 'tests' && matrix.python == '3.12'
uses: actions/upload-artifact@v4
with:
name: py-dist
path: clients/python/dist
- name: Upload documentation
if: matrix.session == 'docs-build'
uses: actions/upload-artifact@v4
with:
name: docs
name: py-docs
path: clients/python/docs/_build
4 changes: 2 additions & 2 deletions api/openapi/model-registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ components:
enum:
- CREATE_TIME
- LAST_UPDATE_TIME
- Id
- ID
type: string
Artifact:
oneOf:
Expand Down Expand Up @@ -1661,7 +1661,7 @@ components:
explode: true
examples:
orderBy:
value: Id
value: ID
name: orderBy
description: Specifies the order by criteria for listing entities.
schema:
Expand Down
6 changes: 6 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ model = registry.get_registered_model("my-model")
version = registry.get_model_version("my-model", "2.0.0")

experiment = registry.get_model_artifact("my-model", "2.0.0")

# change is not reflected on pushed model version
version.description = "Updated model version"

# you can update it using
registry.update(version)
```

### Importing from S3
Expand Down
5 changes: 4 additions & 1 deletion clients/python/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def lint(session: Session) -> None:
def mypy(session: Session) -> None:
"""Type check using mypy."""
session.install(".")
session.install("mypy")
session.install(
"mypy",
"types-python-dateutil",
)

session.run("mypy", "src/model_registry")

Expand Down
14 changes: 12 additions & 2 deletions clients/python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mypy = "^1.7.0"
pytest-asyncio = ">=0.23.7,<0.25.0"
requests = "^2.32.2"
black = "^24.4.2"
types-python-dateutil = "^2.9.0.20240906"

[tool.coverage.run]
branch = true
Expand Down
27 changes: 22 additions & 5 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from pathlib import Path
from typing import Any, get_args
from typing import Any, TypeVar, Union, get_args
from warnings import warn

from .core import ModelRegistryAPIClient
Expand All @@ -18,6 +18,9 @@
SupportedTypes,
)

ModelTypes = Union[RegisteredModel, ModelVersion, ModelArtifact]
TModel = TypeVar("TModel", bound=ModelTypes)


class ModelRegistry:
"""Model registry client."""
Expand All @@ -29,7 +32,7 @@ def __init__(
*,
author: str,
is_secure: bool = True,
user_token: bytes | None = None,
user_token: str | None = None,
custom_ca: str | None = None,
):
"""Constructor.
Expand All @@ -41,8 +44,8 @@ def __init__(
Keyword Args:
author: Name of the author.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a byte string. Defaults to path on envvar CERT.
user_token: The PEM-encoded user token as a string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a string. Defaults to path on envvar CERT.
"""
import nest_asyncio

Expand All @@ -55,7 +58,7 @@ def __init__(
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if sa_token:
user_token = Path(sa_token).read_bytes()
user_token = Path(sa_token).read_text()
else:
warn("User access token is missing", stacklevel=2)

Expand Down Expand Up @@ -191,6 +194,20 @@ def register_model(

return rm

def update(self, model: TModel) -> TModel:
"""Update a model."""
if not model.id:
msg = "Model must have an ID"
raise StoreError(msg)
if not isinstance(model, get_args(ModelTypes)):
msg = f"Model must be one of {get_args(ModelTypes)}"
raise StoreError(msg)
if isinstance(model, RegisteredModel):
return self.async_runner(self._api.upsert_registered_model(model))
if isinstance(model, ModelVersion):
return self.async_runner(self._api.upsert_model_version(model, model.id))
return self.async_runner(self._api.upsert_model_artifact(model, model.id))

def register_hf_model(
self,
repo: str,
Expand Down
8 changes: 4 additions & 4 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def secure_connection(
server_address: str,
port: int = 443,
*,
user_token: bytes,
user_token: str,
custom_ca: str | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.
Expand All @@ -52,7 +52,7 @@ def secure_connection(
port: Server port. Defaults to 443.
Keyword Args:
user_token: The PEM-encoded user token as a byte string.
user_token: The PEM-encoded user token as a string.
custom_ca: The path to a PEM-
"""
return cls(
Expand All @@ -68,14 +68,14 @@ def insecure_connection(
cls,
server_address: str,
port: int,
user_token: bytes | None = None,
user_token: str | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.
Args:
server_address: Server address.
port: Server port.
user_token: The PEM-encoded user token as a byte string.
user_token: The PEM-encoded user token as a string.
"""
return cls(
Configuration(host=f"{server_address}:{port}", access_token=user_token)
Expand Down
16 changes: 11 additions & 5 deletions clients/python/src/model_registry/types/pager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def restart(self) -> Pager[T]:
This keeps the current options and page function, but resets the internal state.
"""
# as MLMD loops over pages, we need to keep track of the first page or we'll loop forever
self._start = None
self._current_page = None
self._start: str | None = None
self._current_page: list[T] | None = None
# tracks the next item on the current page
self._i = 0
self.options.next_page_token = None
Expand Down Expand Up @@ -112,7 +112,9 @@ async def _anext_page(self) -> list[T]:
return await cast(Awaitable[list[T]], self.page_fn(self.options))

def _needs_fetch(self) -> bool:
return not self._current_page or self._i >= len(self._current_page)
return not self._current_page or (
self._i >= len(self._current_page) and self._start is not None
)

def _next_item(self) -> T:
"""Get the next item in the pager.
Expand All @@ -126,6 +128,8 @@ def _next_item(self) -> T:
self._current_page = self._next_page()
self._i = 0
assert self._current_page
if self._i >= len(self._current_page):
raise StopIteration

item = self._current_page[self._i]
self._i += 1
Expand All @@ -143,6 +147,8 @@ async def _anext_item(self) -> T:
self._current_page = await self._anext_page()
self._i = 0
assert self._current_page
if self._i >= len(self._current_page):
raise StopIteration

item = self._current_page[self._i]
self._i += 1
Expand All @@ -153,7 +159,7 @@ def __next__(self) -> T:

item = self._next_item()

if not self._start:
if self._start is None:
self._start = self.options.next_page_token
elif check_looping and self.options.next_page_token == self._start:
raise StopIteration
Expand All @@ -165,7 +171,7 @@ async def __anext__(self) -> T:

item = await self._anext_item()

if not self._start:
if self._start is None:
self._start = self.options.next_page_token
elif check_looping and self.options.next_page_token == self._start:
raise StopAsyncIteration
Expand Down
2 changes: 1 addition & 1 deletion clients/python/src/mr_openapi/models/order_by_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class OrderByField(str, Enum):
"""
CREATE_TIME = "CREATE_TIME"
LAST_UPDATE_TIME = "LAST_UPDATE_TIME"
ID = "Id"
ID = "ID"

@classmethod
def from_json(cls, json_str: str) -> Self:
Expand Down
16 changes: 16 additions & 0 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import os
import subprocess
import tempfile
import time
from contextlib import asynccontextmanager
from pathlib import Path
Expand Down Expand Up @@ -133,3 +134,18 @@ def event_loop():
@cleanup
def client() -> ModelRegistry:
return ModelRegistry(REGISTRY_HOST, REGISTRY_PORT, author="author", is_secure=False)

@pytest.fixture(scope="module")
def setup_env_user_token():
with tempfile.NamedTemporaryFile(delete=False) as token_file:
token_file.write(b"Token")
old_token_path = os.getenv("KF_PIPELINES_SA_TOKEN_PATH")
os.environ["KF_PIPELINES_SA_TOKEN_PATH"] = token_file.name

yield token_file.name

if old_token_path is None:
del os.environ["KF_PIPELINES_SA_TOKEN_PATH"]
else:
os.environ["KF_PIPELINES_SA_TOKEN_PATH"] = old_token_path
os.remove(token_file.name)
Loading

0 comments on commit d15c643

Please sign in to comment.