-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e7a3e0a
commit 86fbc4c
Showing
14 changed files
with
420 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import sys | ||
import uuid | ||
from typing import Any, Iterable | ||
|
||
import click | ||
|
||
from ai.backend.cli.main import main | ||
from ai.backend.cli.types import ExitCode | ||
from ai.backend.client.cli.extensions import pass_ctx_obj | ||
from ai.backend.client.cli.types import CLIContext | ||
from ai.backend.client.exceptions import BackendAPIError | ||
from ai.backend.client.session import Session | ||
|
||
from ..output.fields import network_fields | ||
from .pretty import print_done | ||
|
||
_default_list_fields = ( | ||
network_fields["name"], | ||
network_fields["ref_name"], | ||
network_fields["driver"], | ||
network_fields["created_at"], | ||
) | ||
|
||
|
||
@main.group() | ||
def network(): | ||
"""Set of inter-container network operations""" | ||
|
||
|
||
@network.command() | ||
@pass_ctx_obj | ||
@click.argument("project", type=str, metavar="PROJECT_ID_OR_NAME") | ||
@click.argument("name", type=str, metavar="NAME") | ||
@click.option("-d", "--driver", default=None, help="Set the network driver.") | ||
def create(ctx: CLIContext, project, name, driver): | ||
"""Create a new network interface.""" | ||
|
||
with Session() as session: | ||
proj_id: str | None = None | ||
|
||
try: | ||
uuid.UUID(project) | ||
except ValueError: | ||
pass | ||
else: | ||
if session.Group.detail(project): | ||
proj_id = project | ||
|
||
if not proj_id: | ||
projects = session.Group.from_name(project) | ||
if not projects: | ||
ctx.output.print_fail(f"Project '{project}' not found.") | ||
sys.exit(ExitCode.FAILURE) | ||
proj_id = projects[0]["id"] | ||
|
||
try: | ||
network = session.Network.create(proj_id, name, driver=driver) | ||
print_done(f"Network {name} (ID {network.network_id}) created.") | ||
except Exception as e: | ||
ctx.output.print_error(e) | ||
sys.exit(ExitCode.FAILURE) | ||
|
||
|
||
@network.command() | ||
@pass_ctx_obj | ||
@click.option( | ||
"-f", | ||
"--format", | ||
default=None, | ||
help="Display only specified fields. When specifying multiple fields separate them with comma (,).", | ||
) | ||
@click.option("--filter", "filter_", default=None, help="Set the query filter expression.") | ||
@click.option("--order", default=None, help="Set the query ordering expression.") | ||
@click.option("--offset", default=0, help="The index of the current page start for pagination.") | ||
@click.option("--limit", type=int, default=None, help="The page size for pagination.") | ||
def list(ctx: CLIContext, format, filter_, order, offset, limit): | ||
"""List all available network interfaces.""" | ||
|
||
if format: | ||
try: | ||
fields = [network_fields[f.strip()] for f in format.split(",")] | ||
except KeyError as e: | ||
ctx.output.print_fail(f"Field {str(e)} not found") | ||
sys.exit(ExitCode.FAILURE) | ||
else: | ||
fields = None | ||
with Session() as session: | ||
try: | ||
fetch_func = lambda pg_offset, pg_size: session.Network.paginated_list( | ||
page_offset=pg_offset, | ||
page_size=pg_size, | ||
filter=filter_, | ||
order=order, | ||
fields=fields, | ||
) | ||
ctx.output.print_paginated_list( | ||
fetch_func, | ||
initial_page_offset=offset, | ||
page_size=limit, | ||
) | ||
except Exception as e: | ||
ctx.output.print_error(e) | ||
sys.exit(ExitCode.FAILURE) | ||
|
||
|
||
@network.command() | ||
@pass_ctx_obj | ||
@click.argument("network", type=str, metavar="NETWORK_ID_OR_NAME") | ||
@click.option( | ||
"-f", | ||
"--format", | ||
default=None, | ||
help="Display only specified fields. When specifying multiple fields separate them with comma (,).", | ||
) | ||
def get(ctx: CLIContext, network, format): | ||
fields: Iterable[Any] | ||
if format: | ||
try: | ||
fields = [network_fields[f.strip()] for f in format.split(",")] | ||
except KeyError as e: | ||
ctx.output.print_fail(f"Field {str(e)} not found") | ||
sys.exit(ExitCode.FAILURE) | ||
else: | ||
fields = _default_list_fields | ||
|
||
with Session() as session: | ||
try: | ||
network_info = session.Network(uuid.UUID(network)).get(fields=fields) | ||
except (ValueError, BackendAPIError): | ||
networks = session.Network.paginated_list(filter=f'name == "{network}"', fields=fields) | ||
if networks.total_count == 0: | ||
ctx.output.print_fail(f"Network {network} not found.") | ||
sys.exit(ExitCode.FAILURE) | ||
if networks.total_count > 1: | ||
ctx.output.print_fail( | ||
f"One or more networks found with name {network}. Try mentioning network ID instead of name to resolve the issue." | ||
) | ||
sys.exit(ExitCode.FAILURE) | ||
network_info = networks.items[0] | ||
|
||
ctx.output.print_item(network_info, fields) | ||
|
||
|
||
@network.command() | ||
@pass_ctx_obj | ||
@click.argument("network", type=str, metavar="NETWORK_ID_OR_NAME") | ||
def delete(ctx: CLIContext, network): | ||
with Session() as session: | ||
try: | ||
network_info = session.Network(uuid.UUID(network)).get(fields=[network_fields["id"]]) | ||
except (ValueError, BackendAPIError): | ||
networks = session.Network.paginated_list( | ||
filter=f'name == "{network}"', fields=[network_fields["id"]] | ||
) | ||
if networks.total_count == 0: | ||
ctx.output.print_fail(f"Network {network} not found.") | ||
sys.exit(ExitCode.FAILURE) | ||
if networks.total_count > 1: | ||
ctx.output.print_fail( | ||
f"One or more networks found with name {network}. Try mentioning network ID instead of name to resolve the issue." | ||
) | ||
sys.exit(ExitCode.FAILURE) | ||
network_info = networks.items[0] | ||
|
||
try: | ||
session.Network(uuid.UUID(network_info["row_id"])).delete() | ||
print_done(f"Network {network} has been deleted.") | ||
except BackendAPIError as e: | ||
ctx.output.print_fail(f"Failed to delete network {network}:") | ||
ctx.output.print_error(e) | ||
sys.exit(ExitCode.FAILURE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from typing import Sequence | ||
from uuid import UUID | ||
|
||
from ..output.fields import network_fields | ||
from ..output.types import FieldSpec, RelayPaginatedResult | ||
from ..pagination import execute_paginated_relay_query | ||
from ..session import api_session | ||
from .base import BaseFunction, api_function | ||
|
||
__all__ = ("Network",) | ||
|
||
_default_list_fields = ( | ||
network_fields["name"], | ||
network_fields["ref_name"], | ||
network_fields["driver"], | ||
network_fields["created_at"], | ||
) | ||
|
||
|
||
class Network(BaseFunction): | ||
@api_function | ||
@classmethod | ||
async def paginated_list( | ||
cls, | ||
*, | ||
fields: Sequence[FieldSpec] | None = None, | ||
page_offset: int = 0, | ||
page_size: int = 20, | ||
filter: str | None = None, | ||
order: str | None = None, | ||
) -> RelayPaginatedResult[dict]: | ||
""" | ||
Fetches the list of created networks in this cluster. | ||
""" | ||
return await execute_paginated_relay_query( | ||
"networks", | ||
{ | ||
"filter": (filter, "String"), | ||
"order": (order, "String"), | ||
}, | ||
fields or _default_list_fields, | ||
limit=page_size, | ||
offset=page_offset, | ||
) | ||
|
||
@api_function | ||
@classmethod | ||
async def create( | ||
cls, | ||
project_id: str, | ||
name: str, | ||
*, | ||
driver: str | None = None, | ||
) -> "Network": | ||
""" | ||
Creates a new network. | ||
:param project_id: The ID of the project to which the network belongs. | ||
:param name: The name of the network. | ||
:param driver: (Optional) The driver of the network. If not specified, the default driver will be used. | ||
:return: The created network. | ||
""" | ||
q = ( | ||
"mutation($name: String!, $project_id: UUID!, $driver: String) {" | ||
" create_network(name: $name, project_id: $project_id, driver: $driver) {" | ||
" network {" | ||
" row_id" | ||
" }" | ||
" }" | ||
"}" | ||
) | ||
data = await api_session.get().Admin._query( | ||
q, | ||
{ | ||
"name": name, | ||
"project_id": project_id, | ||
"driver": driver, | ||
}, | ||
) | ||
return cls(network_id=UUID(data["create_network"]["network"]["row_id"])) | ||
|
||
def __init__(self, network_id: UUID) -> None: | ||
""" | ||
:param network_id: The ID of the network. Pass `row_id` value (not `id`) of the network info fetched by `paginated_list`. | ||
""" | ||
super().__init__() | ||
self.network_id = network_id | ||
|
||
@api_function | ||
async def get( | ||
self, | ||
fields: Sequence[FieldSpec] | None = None, | ||
) -> dict: | ||
""" | ||
Fetches the information of the network. | ||
""" | ||
q = "query($id: String!) {" " network(id: $id) {" " $fields" " }" "}" | ||
q = q.replace("$fields", " ".join(f.field_ref for f in (fields or _default_list_fields))) | ||
data = await api_session.get().Admin._query(q, {"id": str(self.network_id)}) | ||
return data["images"] | ||
|
||
@api_function | ||
async def update(self, name: str) -> None: | ||
""" | ||
Updates network. | ||
""" | ||
q = ( | ||
"mutation($network: String!, $props: UpdateNetworkInput!) {" | ||
" modify_network(network: $network, props: $props) {" | ||
" ok msg" | ||
" }" | ||
"}" | ||
) | ||
variables = { | ||
"network": str(self.network_id), | ||
"props": {"name": name}, | ||
} | ||
data = await api_session.get().Admin._query(q, variables) | ||
return data["modify_network"] | ||
|
||
@api_function | ||
async def delete(self) -> None: | ||
""" | ||
Deletes network. Delete only works for networks that are not attached to active session. | ||
""" | ||
q = ( | ||
"mutation($network: String!) {" | ||
" delete_network(network: $network) {" | ||
" ok msg" | ||
" }" | ||
"}" | ||
) | ||
variables = { | ||
"network": str(self.network_id), | ||
} | ||
data = await api_session.get().Admin._query(q, variables) | ||
return data["delete_network"] |
Oops, something went wrong.