diff --git a/ju/pydantic_util.py b/ju/pydantic_util.py index a0d9b44..2ec40bb 100644 --- a/ju/pydantic_util.py +++ b/ju/pydantic_util.py @@ -379,6 +379,79 @@ class Simple(BaseModel): from typing import Any, Dict, Type, Union, get_args, get_origin, List +def _get_type_parameters(origin: Type[BaseModel]): + """Helper function to safely retrieve the type parameters of a generic class. + + Note: This is because it's not safe to call `__parameters__` directly on the origin, + so we do it anyway, but encapsulated in a function, to encapsulate and locate the + risk. + """ + return getattr(origin, '__parameters__', ()) + + +def match_typevars_to_args(generic_model: Type[BaseModel]) -> Dict[TypeVar, Type[Any]]: + """ + Given a Pydantic generic model, returns a mapping of type variables to their + concrete types. + + Args: + generic_model (Type[BaseModel]): A generic Pydantic model (e.g., Pair[int, str]). + + Returns: + Dict[TypeVar, Type[Any]]: A dictionary mapping type variables (e.g., T and U) + to their concrete types (e.g., int and str). + + >>> from typing import TypeVar, Generic, List + + >>> T = TypeVar('T') + >>> U = TypeVar('U') + + >>> class Pair(Generic[T, U]): + ... first: T + ... second: U + + >>> X = Pair[int, str] + + >>> match_typevars_to_args(X) + {~T: , ~U: } + """ + # Get the origin (the base class without type args, e.g., Response) + origin = get_origin(generic_model) + + if origin is None: + # If the model isn't generic, return an empty dictionary + return {} + + # Get the actual type arguments (e.g., (EmbeddingT,)) + concrete_types = get_args(generic_model) + + # Get the type variables (e.g., (DatumT,)) + type_vars = _get_type_parameters(origin) + + # Create a mapping from type variables to concrete types + return dict(zip(type_vars, concrete_types)) + + +def is_a_basemodel(obj) -> bool: + """ + Check if an object is a Pydantic BaseModel. + + >>> from typing import List + >>> class MyModel(BaseModel): + ... '''Some model''' + >>> list(map(is_a_basemodel, [BaseModel, MyModel, 3.14, int, List[MyModel]])) + [True, True, False, False, False] + + """ + if not isinstance(obj, type): + return False + else: + # Get the origin in case it's a generic type like List[str] + if get_origin(obj) is not None: + return False # right? Or do we want this to be is_a_basemodel(get_origin(obj))? + return issubclass(obj, BaseModel) + + def field_paths_and_annotations( data_model: Type[BaseModel], ) -> Dict[str, Type[Any]]: @@ -418,35 +491,57 @@ def field_paths_and_annotations( ... a: A >>> paths = field_paths_and_annotations(Model) - >>> assert paths == { - ... 'a.b.*.c': int, 'a.d': str - ... } + >>> expected_paths = {'a.b.*.c': int, 'a.d': str} + >>> assert paths == expected_paths, f"Expected: {expected_paths}, but got: {paths}" + + See that it works with generics: + + >>> from typing import TypeVar, List, Generic + >>> T = TypeVar('T') + >>> class A_with_Generic(BaseModel, Generic[T]): + ... b: List[T] + ... d: str + >>> class Model_with_Generic(BaseModel): + ... a: A_with_Generic[BItem] + >>> + >>> field_paths_and_annotations(Model_with_Generic) + {'a.b.*.c': , 'a.d': } - One of the main use cases for this function is to extract values from a dictionary - that matches the structure of a Pydantic BaseModel. - This can be done using the `glom` library, for example: + """ - >>> from glom import glom # doctest: +SKIP - >>> {k: glom(data, k) for k in paths} # doctest: +SKIP - {'a.b.*.c': [1, 2], 'a.d': 'bar'} + def get_field_type(field_type, model: Type[BaseModel]): + """Resolves the actual type of a field, replacing generics with their concrete types.""" + typevar_mapping = match_typevars_to_args(model) - """ + # Replace any type variables in the field_type with their corresponding concrete types + if typevar_mapping: + if field_type in typevar_mapping: + return typevar_mapping[field_type] + + # If the field_type is a generic collection (e.g., List[DatumT]), replace the args + origin_type = get_origin(field_type) + if origin_type: + args = get_args(field_type) + resolved_args = tuple(typevar_mapping.get(arg, arg) for arg in args) + return origin_type[resolved_args] + + return field_type def recurse_model(model: Type[BaseModel], prefix: str = '') -> Dict[str, Type[Any]]: paths = {} for field_name, field_info in model.model_fields.items(): - field_type = field_info.annotation + field_type = get_field_type(field_info.annotation, model) current_path = f"{prefix}.{field_name}" if prefix else field_name origin = get_origin(field_type) args = get_args(field_type) - if isinstance(field_type, type) and issubclass(field_type, BaseModel): + if is_a_basemodel(field_type) or is_a_basemodel(origin): + # the field type is a BaseModel or a generic BaseModel paths.update(recurse_model(field_type, current_path)) elif ( - origin in {list, set, tuple} + origin in {list, set, tuple, List} and args - and isinstance(args[0], type) - and issubclass(args[0], BaseModel) + and is_a_basemodel(args[0]) # TODO: What if args[1] is a generic? ): paths.update(recurse_model(args[0], f"{current_path}.*")) else: