Skip to content

Commit

Permalink
feat: field_paths_and_annotations works with generics (add doctest fo…
Browse files Browse the repository at this point in the history
…r this too)
  • Loading branch information
thorwhalen committed Sep 9, 2024
1 parent 103f31e commit 46a66e8
Showing 1 changed file with 110 additions and 15 deletions.
125 changes: 110 additions & 15 deletions ju/pydantic_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,79 @@ class Simple(BaseModel):
from typing import Any, Dict, Type, Union, get_args, get_origin, List


def _get_type_parameters(origin: Type[BaseModel]):
"""Helper function to safely retrieve the type parameters of a generic class.
Note: This is because it's not safe to call `__parameters__` directly on the origin,
so we do it anyway, but encapsulated in a function, to encapsulate and locate the
risk.
"""
return getattr(origin, '__parameters__', ())


def match_typevars_to_args(generic_model: Type[BaseModel]) -> Dict[TypeVar, Type[Any]]:
"""
Given a Pydantic generic model, returns a mapping of type variables to their
concrete types.
Args:
generic_model (Type[BaseModel]): A generic Pydantic model (e.g., Pair[int, str]).
Returns:
Dict[TypeVar, Type[Any]]: A dictionary mapping type variables (e.g., T and U)
to their concrete types (e.g., int and str).
>>> from typing import TypeVar, Generic, List
>>> T = TypeVar('T')
>>> U = TypeVar('U')
>>> class Pair(Generic[T, U]):
... first: T
... second: U
>>> X = Pair[int, str]
>>> match_typevars_to_args(X)
{~T: <class 'int'>, ~U: <class 'str'>}
"""
# Get the origin (the base class without type args, e.g., Response)
origin = get_origin(generic_model)

if origin is None:
# If the model isn't generic, return an empty dictionary
return {}

# Get the actual type arguments (e.g., (EmbeddingT,))
concrete_types = get_args(generic_model)

# Get the type variables (e.g., (DatumT,))
type_vars = _get_type_parameters(origin)

# Create a mapping from type variables to concrete types
return dict(zip(type_vars, concrete_types))


def is_a_basemodel(obj) -> bool:
"""
Check if an object is a Pydantic BaseModel.
>>> from typing import List
>>> class MyModel(BaseModel):
... '''Some model'''
>>> list(map(is_a_basemodel, [BaseModel, MyModel, 3.14, int, List[MyModel]]))
[True, True, False, False, False]
"""
if not isinstance(obj, type):
return False
else:
# Get the origin in case it's a generic type like List[str]
if get_origin(obj) is not None:
return False # right? Or do we want this to be is_a_basemodel(get_origin(obj))?
return issubclass(obj, BaseModel)


def field_paths_and_annotations(
data_model: Type[BaseModel],
) -> Dict[str, Type[Any]]:
Expand Down Expand Up @@ -418,35 +491,57 @@ def field_paths_and_annotations(
... a: A
>>> paths = field_paths_and_annotations(Model)
>>> assert paths == {
... 'a.b.*.c': int, 'a.d': str
... }
>>> expected_paths = {'a.b.*.c': int, 'a.d': str}
>>> assert paths == expected_paths, f"Expected: {expected_paths}, but got: {paths}"
See that it works with generics:
>>> from typing import TypeVar, List, Generic
>>> T = TypeVar('T')
>>> class A_with_Generic(BaseModel, Generic[T]):
... b: List[T]
... d: str
>>> class Model_with_Generic(BaseModel):
... a: A_with_Generic[BItem]
>>>
>>> field_paths_and_annotations(Model_with_Generic)
{'a.b.*.c': <class 'int'>, 'a.d': <class 'str'>}
One of the main use cases for this function is to extract values from a dictionary
that matches the structure of a Pydantic BaseModel.
This can be done using the `glom` library, for example:
"""

>>> from glom import glom # doctest: +SKIP
>>> {k: glom(data, k) for k in paths} # doctest: +SKIP
{'a.b.*.c': [1, 2], 'a.d': 'bar'}
def get_field_type(field_type, model: Type[BaseModel]):
"""Resolves the actual type of a field, replacing generics with their concrete types."""
typevar_mapping = match_typevars_to_args(model)

"""
# Replace any type variables in the field_type with their corresponding concrete types
if typevar_mapping:
if field_type in typevar_mapping:
return typevar_mapping[field_type]

# If the field_type is a generic collection (e.g., List[DatumT]), replace the args
origin_type = get_origin(field_type)
if origin_type:
args = get_args(field_type)
resolved_args = tuple(typevar_mapping.get(arg, arg) for arg in args)
return origin_type[resolved_args]

return field_type

def recurse_model(model: Type[BaseModel], prefix: str = '') -> Dict[str, Type[Any]]:
paths = {}
for field_name, field_info in model.model_fields.items():
field_type = field_info.annotation
field_type = get_field_type(field_info.annotation, model)
current_path = f"{prefix}.{field_name}" if prefix else field_name
origin = get_origin(field_type)
args = get_args(field_type)

if isinstance(field_type, type) and issubclass(field_type, BaseModel):
if is_a_basemodel(field_type) or is_a_basemodel(origin):
# the field type is a BaseModel or a generic BaseModel
paths.update(recurse_model(field_type, current_path))
elif (
origin in {list, set, tuple}
origin in {list, set, tuple, List}
and args
and isinstance(args[0], type)
and issubclass(args[0], BaseModel)
and is_a_basemodel(args[0]) # TODO: What if args[1] is a generic?
):
paths.update(recurse_model(args[0], f"{current_path}.*"))
else:
Expand Down

0 comments on commit 46a66e8

Please sign in to comment.