Skip to content

Commit

Permalink
refactor: Shuffling things around
Browse files Browse the repository at this point in the history
  • Loading branch information
thorwhalen committed Mar 22, 2024
1 parent bf6265e commit a22a154
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 84 deletions.
25 changes: 23 additions & 2 deletions ju/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
"""

from typing import Mapping, Sequence, Callable
from inspect import Parameter
from functools import partial

from pydantic import BaseModel

from typing import Mapping, Sequence, Callable
from ju.util import is_type


class _BasicPythonTypes(BaseModel):
Expand Down Expand Up @@ -73,6 +77,23 @@ def gen():

DFLT_JSON_TYPE = 'string' # TODO: 'string' or 'object'?


def parametrized_param_to_type(
param: Parameter,
*,
type_mapping=DFLT_PY_JSON_TYPE_PAIRS,
default=DFLT_JSON_TYPE,
):
for python_type, json_type in type_mapping:
if is_type(param, python_type):
return json_type
return default


DFLT_PARAM_TO_TYPE = partial(
parametrized_param_to_type, type_mapping=DFLT_PY_JSON_TYPE_PAIRS
)

# -------------------------------------------------------------------------------------
# util

Expand Down Expand Up @@ -104,7 +125,6 @@ def wrap_schema_in_opus_spec(schema: dict):

from typing import Mapping, Sequence
import inspect
from inspect import Parameter
from operator import attrgetter
from i2 import Sig

Expand All @@ -129,6 +149,7 @@ def wrap_schema_in_opus_spec(schema: dict):
)


# TODO: See
def function_to_json_schema(
func: Callable,
*,
Expand Down
86 changes: 4 additions & 82 deletions ju/rjsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,86 +77,7 @@ def func_to_form_spec(func: Callable):
}


# def is_type(param: Parameter, type_: type):
# return param.annotation is type_ or isinstance(param.default, type_)

from typing import get_args, get_origin, Any, Union, GenericAlias, Type
from types import GenericAlias

SomeType = Union[Type, GenericAlias, Any]
SomeType.__doc__ = "A type or a GenericAlias, but also Any, just in case"


def is_type(param: Parameter, type_: SomeType):
"""
Checks if the type of a parameter's default value or its annotation matches a
given type.
This function handles both regular types and subscripted generics.
Args:
param (Parameter): The parameter to check.
type_ (type): The type to check against.
Returns:
bool: True if the parameter's type matches the given type, False otherwise.
Doctests:
>>> from inspect import Parameter
>>> param = Parameter('p', Parameter.KEYWORD_ONLY, default=3.14)
>>> is_type(param, float)
True
>>> is_type(param, int)
False
>>> param = Parameter('p', Parameter.KEYWORD_ONLY, default=[1, 2, 3])
>>> is_type(param, list)
True
>>> from typing import List, Union
>>> is_type(param, List[int])
True
>>> is_type(param, List[str])
False
>>> is_type(param, Union[int, List[int]])
True
"""
if param.annotation is type_:
return True
if isinstance(type_, type):
return isinstance(param.default, type_)
if hasattr(type_, '__origin__'):
origin = get_origin(type_)
if origin is Union:
args = get_args(type_)
return any(is_type(param, arg) for arg in args)
else:
args = get_args(type_)
if isinstance(param.default, origin):
if all(
any(isinstance(element, arg) for element in param.default)
for arg in args
):
return True
return False


from ju.json_schema import DFLT_PY_JSON_TYPE_PAIRS, DFLT_JSON_TYPE


def parametrized_param_to_type(
param: Parameter,
*,
type_mapping=DFLT_PY_JSON_TYPE_PAIRS,
default=DFLT_JSON_TYPE,
):
for python_type, json_type in type_mapping:
if is_type(param, python_type):
return json_type
return default


_dflt_param_to_type = partial(
parametrized_param_to_type, type_mapping=DFLT_PY_JSON_TYPE_PAIRS
)
from ju.json_schema import DFLT_PARAM_TO_TYPE


# TODO: The loop body could be factored out
Expand All @@ -173,9 +94,10 @@ def get_properties(parameters, *, param_to_prop_type):
... ):
... '''A Foo function'''
>>>
>>> from ju.json_schema import DFLT_PARAM_TO_TYPE
>>> parameters = inspect.signature(foo).parameters
>>> assert (
... get_properties(parameters, param_to_prop_type=_dflt_param_to_type)
... get_properties(parameters, param_to_prop_type=DFLT_PARAM_TO_TYPE)
... == {
... 'a_bool': {'type': 'boolean'},
... 'a_float': {'type': 'number', 'default': 3.14},
Expand Down Expand Up @@ -208,7 +130,7 @@ def get_required(properties: dict):


# TODO: This all should really use meshed instead, to be easily composable.
def _func_to_rjsf_schemas(func, *, param_to_prop_type: Callable = _dflt_param_to_type):
def _func_to_rjsf_schemas(func, *, param_to_prop_type: Callable = DFLT_PARAM_TO_TYPE):
"""
Returns the JSON schema and the UI schema for a function.
Expand Down
66 changes: 66 additions & 0 deletions ju/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@
Tuple,
runtime_checkable,
Protocol,
get_args,
get_origin,
Any,
Type,
)
from collections import defaultdict
from dataclasses import dataclass
from types import GenericAlias
from inspect import Parameter

from i2 import mk_sentinel

try:
Expand All @@ -37,6 +44,65 @@ def is_jsonable(x):
return False


SomeType = Union[Type, GenericAlias, Any]
SomeType.__doc__ = "A type or a GenericAlias, but also Any, just in case"


def is_type(param: Parameter, type_: SomeType):
"""
Checks if the type of a parameter's default value or its annotation matches a
given type.
This function handles both regular types and subscripted generics.
Args:
param (Parameter): The parameter to check.
type_ (type): The type to check against.
Returns:
bool: True if the parameter's type matches the given type, False otherwise.
Doctests:
>>> from inspect import Parameter
>>> param = Parameter('p', Parameter.KEYWORD_ONLY, default=3.14)
>>> is_type(param, float)
True
>>> is_type(param, int)
False
>>> param = Parameter('p', Parameter.KEYWORD_ONLY, default=[1, 2, 3])
>>> is_type(param, list)
True
>>> from typing import List, Union
>>> is_type(param, List[int])
True
>>> is_type(param, List[str])
False
>>> is_type(param, Union[int, List[int]])
True
"""
if param.annotation is type_:
return True
if isinstance(type_, type):
return isinstance(param.default, type_)
if hasattr(type_, '__origin__'):
origin = get_origin(type_)
if origin is Union:
args = get_args(type_)
return any(is_type(param, arg) for arg in args)
else:
args = get_args(type_)
if isinstance(param.default, origin):
if all(
any(isinstance(element, arg) for element in param.default)
for arg in args
):
return True
return False


# -------------------------------------------------------------------------------------
# Mappers

CallableMapper = Callable[[KT], VT]
MappingMapper = Mapping[KT, VT]
PairsMapper = Sequence[Tuple[KT, VT]]
Expand Down

0 comments on commit a22a154

Please sign in to comment.