How to expose tagged union and alternative constructors in GraphQL #56
-
I have a module called scanspec which allows the definition of a trajectory of points that will be moved through to collect experimental data. It is currently written with pydantic, but the Union of subclasses pattern introduces some messiness, and I would prefer to use plan dataclasses: Here is the core of what I would like to achieve (this doesn't currently work with apischema) from typing import AsyncIterable, TypeVar
from dataclasses import dataclass
from typing_extensions import Annotated as A
from apischema import schema, deserialize, serialize
def desc(description: str):
return schema(description=description)
class ScanSpec:
async def points(self) -> AsyncIterable[float]:
"""Iterate through the points of the scan"""
raise NotImplementedError(self)
T = TypeVar("T")
@dataclass
class Line(ScanSpec):
"""A straight line"""
# Only declare descriptions once
_start = desc("The first point")
start: A[float, _start]
_stop = desc("The last point")
stop: A[float, _stop]
step: A[float, desc("The step between points")] = 1
async def points(self) -> AsyncIterable[float]:
for point in range(self.start, self.stop, self.step):
yield point
@classmethod
def sized(
cls: T,
start: A[float, _start],
stop: A[float, _stop],
size: A[int, desc("Number of points")],
) -> T:
"""Alternative constructor with size instead of step"""
return cls(start, stop, stop - start / size)
@dataclass
class Concat(ScanSpec):
left: A[ScanSpec, desc("First spec to produce")]
right: A[ScanSpec, desc("Second spec to produce")]
async def points(self) -> AsyncIterable[float]:
async for point in self.left.points():
yield point
async for point in self.right.points():
yield point
async def get_points(spec: ScanSpec) -> AsyncIterable[float]:
async for point in spec.points():
yield point
def test_de_serialize():
line = Line(1, 4)
serialized = {"Line": {"start": 1, "stop": 4, "step": 1}}
assert serialize(line) == serialized
assert deserialize(ScanSpec, serialized) == line I would like to expose a GraphQL API that exposes a subscription As per discussion on graphql-python/graphene#729 (comment) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 9 replies
-
Ok, what you want is in fact quite more complex than the [class as union of its subclasses example] of apischema documentation. More detailed, if you consider you have a created a tagged union type (let's name it This is not a real issue, as apischema provides the material to do it, but the solution will require a little bit more code, and actually, there is two solutions:
Personally, I prefer the first solution as I consider that There is again two possibilities of implementation:
I will arbitrary choose the second one, so it will give something like: from dataclasses import dataclass, fields, make_dataclass
from typing import Any, Union
from graphql.utilities import print_schema
from apischema import (
Undefined,
UndefinedType,
deserialize,
deserializer,
serialize,
validator,
)
from apischema.conversions import Conversion
from apischema.graphql import Query, graphql_schema
from apischema.json_schema import deserialization_schema, serialization_schema
# Waiting for the tagged union feature, let's define something that come close to it
@dataclass
class _TaggedUnion:
@validator
def _get_only_field(self) -> Any:
defined_fields = [
f.name for f in fields(self) if getattr(self, f.name) is not Undefined
]
if len(defined_fields) != 1:
raise ValueError("Tagged union must have one and only one field set")
return getattr(self, defined_fields[0])
class Base:
pass
@dataclass
class Foo(Base):
foo: int
@dataclass
class Bar(Base):
bar: str
# Define the BaseTaggedUnion class
tagged_fields = [
(sub_cls.__name__, Union[sub_cls, UndefinedType], Undefined)
for sub_cls in Base.__subclasses__() # no need to recurse in your example
]
tagged_union = make_dataclass(f"BaseTaggedUnion", tagged_fields, bases=(_TaggedUnion,))
# Conversions
deserializer(Conversion(_TaggedUnion._get_only_field, source=tagged_union, target=Base))
tagged_union_conversion = Conversion(
lambda obj: tagged_union(**{type(obj).__name__: obj}),
source=Base,
target=tagged_union,
)
assert serialize(Foo(0), conversions=tagged_union_conversion) == {"Foo": {"foo": 0}}
assert deserialize(Base, {"Bar": {"bar": "hello"}}) == Bar("hello")
assert deserialization_schema(Base) == {
"type": "object",
"properties": {
"Foo": {
"type": "object",
"properties": {"foo": {"type": "integer"}},
"required": ["foo"],
"additionalProperties": False,
},
"Bar": {
"type": "object",
"properties": {"bar": {"type": "string"}},
"required": ["bar"],
"additionalProperties": False,
},
},
"additionalProperties": False,
"$schema": "http://json-schema.org/draft/2019-09/schema#",
}
assert serialization_schema(Base, conversions=tagged_union_conversion) == {
"type": "object",
"properties": {
"Foo": {
"type": "object",
"properties": {"foo": {"type": "integer"}},
"required": ["foo"],
"additionalProperties": False,
},
"Bar": {
"type": "object",
"properties": {"bar": {"type": "string"}},
"required": ["bar"],
"additionalProperties": False,
},
},
"additionalProperties": False,
"$schema": "http://json-schema.org/draft/2019-09/schema#",
}
def query(base: Base) -> Base:
return base
schema = graphql_schema(query=[Query(query, conversions=tagged_union_conversion)])
schema_str = """\
type Query {
query(base: BaseTaggedUnionInput!): BaseTaggedUnion!
}
type BaseTaggedUnion {
Foo: Foo
Bar: Bar
}
type Foo {
foo: Int!
}
type Bar {
bar: String!
}
input BaseTaggedUnionInput {
Foo: FooInput
Bar: BarInput
}
input FooInput {
foo: Int!
}
input BarInput {
bar: String!
}
"""
assert print_schema(schema) == schema_str
query_str = """
{
query(base: {Foo: {foo: 0}}){
Foo {
foo
}
Bar {
bar
}
}
}
"""
assert graphql.graphql_sync(schema, query_str).data == {
"query": {
"Foo": {"foo": 0},
"Bar": None,
}
} Of course, By the way, I invite you to update your apischema version to the last one (0.14.1), because I have done some refactoring thanks to your use case (especially concerning subscription where conversions were not supported yet). |
Beta Was this translation helpful? Give feedback.
-
EDIT Actually, I've realized I've missed the whole point on alternative constructors … By the way, I've built my previous example on a symmetry between I will try to be more in line with your issue with this second response. For that, I will use the new release of apischema which brings the following elements : Here is your original example reworked with apischema: import asyncio
from collections.abc import Iterator
from dataclasses import dataclass, field
from types import new_class
from typing import (
Annotated,
AsyncIterable,
Callable,
ClassVar,
Collection,
TypeVar,
)
import graphql
from apischema.conversions import (
Conversion,
dataclass_input_wrapper,
reset_deserializers,
)
from apischema.graphql import graphql_schema
from apischema.json_schema import deserialization_schema
from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged
T = TypeVar("T")
# Recursive implementation of type.__subclasses__
def rec_subclasses(cls: type[T]) -> Iterator[type[T]]:
for sub_cls in cls.__subclasses__():
yield sub_cls
yield from rec_subclasses(sub_cls)
# Shortcut
def desc(description: str):
return schema(description=description)
class ScanSpec:
async def points(self) -> AsyncIterable[float]:
"""Iterate through the points of the scan"""
raise NotImplementedError
# Will be used in in the input tagged union
_additional_constructors: ClassVar[Collection[Callable[..., "ScanSpec"]]] = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
namespace, annotations = {}, {}
for spec_cls in rec_subclasses(ScanSpec):
# Add tagged field for the ScanSpec subclass
annotations[spec_cls.__name__] = Tagged[spec_cls]
# Add tagged fields for all its additional constructors
# (use class __dict__ in order to avoid inheritances of this constructors)
for constructor in spec_cls.__dict__.get("_additional_constructors", ()):
# Deref the constructor function if it is a classmethod/staticmethod
constructor = constructor.__get__(None, spec_cls)
# Build the alias of the field
alias = "".join(map(str.capitalize, constructor.__name__.split("_")))
# dataclass_input_wrapper uses get_type_hints, but the constructor
# return type is stringified and the class not defined yet,
# so it must be assigned manually
constructor.__annotations__["return"] = spec_cls
# Wraps the constructor and rename its input class
wrapper, wrapper_cls = dataclass_input_wrapper(constructor)
wrapper_cls.__name__ = alias
# Add constructor tagged field with its conversion
annotations[alias] = Tagged[spec_cls]
namespace[alias] = Tagged(deserialization=wrapper)
# Create the tagged union class
namespace |= {"__annotations__": annotations}
tagged_union = new_class(
"ScanSpec",
(TaggedUnion,),
exec_body=lambda ns: ns.update(namespace),
)
# Because deserializers stack, they must be reset before being reassigned
reset_deserializers(ScanSpec)
# Register the deserializer using get_tagged
deserializer(Conversion(lambda obj: get_tagged(obj)[1], tagged_union, ScanSpec))
@dataclass
class Line(ScanSpec):
"""A straight line"""
start: float = field(metadata=desc("The first point"))
stop: float = field(metadata=desc("The last point"))
step: float = field(
default=1, metadata=schema(description="The step between points", exc_min=0)
)
async def points(self) -> AsyncIterable[float]:
point = self.start
while point <= self.stop:
yield point
point += self.step
@staticmethod
def sized_line(
start: Annotated[float, desc("The first point")],
stop: Annotated[float, desc("The last point")],
size: Annotated[float, schema(description="Number of points", min=1)],
) -> "Line":
"""Alternative constructor with size instead of step"""
return Line(start=start, stop=stop, step=(stop - start) / (size - 1))
_additional_constructors = [sized_line]
@dataclass
class Concat(ScanSpec):
left: ScanSpec = field(metadata=desc("First spec to produce"))
right: ScanSpec = field(metadata=desc("Second spec to produce"))
async def points(self) -> AsyncIterable[float]:
async for point in self.left.points():
yield point
async for point in self.right.points():
yield point
async def get_points(spec: ScanSpec) -> AsyncIterable[float]:
async for point in spec.points():
yield point
def hello() -> str:
return "world!"
scan_spec_schema = graphql_schema(query=[hello], subscription=[get_points])
assert (
graphql.utilities.print_schema(scan_spec_schema)
== '''\
type Query {
hello: String!
}
type Subscription {
getPoints(spec: ScanSpecInput!): Float!
}
input ScanSpecInput {
Line: LineInput
SizedLine: SizedLineInput
Concat: ConcatInput
}
input LineInput {
"""The first point"""
start: Float!
"""The last point"""
stop: Float!
"""The step between points"""
step: Float! = 1
}
input SizedLineInput {
"""The first point"""
start: Float!
"""The last point"""
stop: Float!
"""Number of points"""
size: Float!
}
input ConcatInput {
"""First spec to produce"""
left: ScanSpecInput!
"""Second spec to produce"""
right: ScanSpecInput!
}
'''
)
subscription_query = """
subscription {
getPoints(spec: {SizedLine: {start: 0, stop: 12, size: 3}})
}
"""
async def main():
subscription = await graphql.subscribe(
scan_spec_schema, graphql.parse(subscription_query)
)
assert [res.data["getPoints"] async for res in subscription] == [0.0, 6.0, 12.0]
asyncio.run(main()) This time, I've used |
Beta Was this translation helpful? Give feedback.
EDIT
An example of this use case is now presented in the documentation.
Actually, I've realized I've missed the whole point on alternative constructors … By the way, I've built my previous example on a symmetry between
ScanSpec
input and output, but you seemed to be only interested in the input part, so it's another misunderstanding on my side.I will try to be more in line with your issue with this second response. For that, I will use the new release of apischema which brings the following elements :
dataclass_input_wrapper
andTaggedUnion
. With both of them, things become a lot easier.Here is your original example reworked with apischema: