Skip to content

Commit

Permalink
add Bytes scalar type
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Oct 22, 2024
1 parent 0dbb47c commit abffa81
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
23 changes: 23 additions & 0 deletions src/ai/backend/manager/models/gql_models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from typing import Any, Optional

import graphene
import graphql
from graphene.types import Scalar
from graphene.types.scalars import MAX_INT, MIN_INT
from graphql.language.ast import IntValueNode
Expand Down Expand Up @@ -62,6 +65,26 @@ def parse_literal(node):
return num


class Bytes(Scalar):
class Meta:
description = "Added in 24.09.1."

@staticmethod
def serialize(val: bytes) -> str:
return val.hex()

@staticmethod
def parse_literal(node: Any, _variables=None) -> Optional[bytes]:
if isinstance(node, graphql.language.ast.StringValueNode):
assert isinstance(node, str)
return bytes.fromhex(node)
return None

@staticmethod
def parse_value(value: str) -> bytes:
return bytes.fromhex(value)


class ImageRefType(graphene.InputObjectType):
name = graphene.String(required=True)
registry = graphene.String()
Expand Down
10 changes: 7 additions & 3 deletions src/ai/backend/manager/models/gql_models/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
GlobalIDField,
ResolvedGlobalID,
)

# from ..group import AssocGroupUserRow, GroupRow, ProjectType
from ..minilang.ordering import OrderSpecItem, QueryOrderParser
from ..minilang.queryfilter import FieldSpecItem, QueryFilterParser
from ..rbac import (
Expand All @@ -44,6 +42,7 @@
from ..scaling_group import ScalingGroupForDomainRow, get_scaling_groups
from ..user import UserRole
from ..utils import execute_with_txn_retry
from .base import Bytes
from .scaling_group import ScalinGroupConnection

if TYPE_CHECKING:
Expand Down Expand Up @@ -102,6 +101,7 @@ class Meta:
total_resource_slots = graphene.JSONString()
allowed_vfolder_hosts = graphene.JSONString()
allowed_docker_registries = graphene.List(lambda: graphene.String)
dotfiles = Bytes()
integration_id = graphene.String()

# Dynamic fields.
Expand All @@ -123,6 +123,7 @@ def from_rbac_model(
total_resource_slots=obj.total_resource_slots,
allowed_vfolder_hosts=obj.allowed_vfolder_hosts.to_json(),
allowed_docker_registries=obj.allowed_docker_registries,
dotfiles=obj.dotfiles,
integration_id=obj.integration_id,
)

Expand All @@ -142,6 +143,7 @@ def from_orm_model(
total_resource_slots=obj.total_resource_slots,
allowed_vfolder_hosts=obj.allowed_vfolder_hosts.to_json(),
allowed_docker_registries=obj.allowed_docker_registries,
dotfiles=obj.dotfiles,
integration_id=obj.integration_id,
)

Expand Down Expand Up @@ -290,7 +292,7 @@ class CreateDomainInput(graphene.InputObjectType):
lambda: graphene.String, required=False, default_value=[]
)
integration_id = graphene.String(required=False, default_value=None)
dotfiles = graphene.String(required=False, default_value=b"\x90")
dotfiles = Bytes(required=False, default_value=b"\x90")

scaling_groups = graphene.List(lambda: graphene.String, required=False)

Expand Down Expand Up @@ -366,6 +368,7 @@ class Input:
allowed_vfolder_hosts = graphene.JSONString(required=False)
allowed_docker_registries = graphene.List(lambda: graphene.String, required=False)
integration_id = graphene.String(required=False)
dotfiles = Bytes(required=False)
sgroups_to_add = graphene.List(lambda: graphene.String, required=False)
sgroups_to_remove = graphene.List(lambda: graphene.String, required=False)
client_mutation_id = graphene.String(required=False) # automatic input from relay
Expand Down Expand Up @@ -410,6 +413,7 @@ async def mutate_and_get_payload(
set_if_set(input, data, "allowed_vfolder_hosts")
set_if_set(input, data, "allowed_docker_registries")
set_if_set(input, data, "integration_id")
set_if_set(input, data, "dotfiles")

async def _update(db_session: AsyncSession) -> Optional[DomainRow]:
user = graph_ctx.user
Expand Down
2 changes: 0 additions & 2 deletions src/ai/backend/manager/models/gql_models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
AsyncNode,
Connection,
)

# from ..group import AssocGroupUserRow, GroupRow, ProjectType
from ..scaling_group import (
ScalingGroupForDomainRow,
ScalingGroupForKeypairsRow,
Expand Down
2 changes: 0 additions & 2 deletions src/ai/backend/manager/models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,12 @@ class ScalingGroupForDomainRow(Base):
id = IDColumn()
scaling_group = sa.Column(
"scaling_group",
# sa.String(length=64),
sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
)
domain = sa.Column(
"domain",
# SlugType(length=64, allow_unicode=True, allow_dot=True),
sa.ForeignKey("domains.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
Expand Down

0 comments on commit abffa81

Please sign in to comment.