Skip to content

Commit

Permalink
refactor: move field_paths_and_annotations in pydantic_model_to_code
Browse files Browse the repository at this point in the history
  • Loading branch information
thorwhalen committed Sep 3, 2024
1 parent 1d78912 commit cdc1550
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 84 deletions.
7 changes: 4 additions & 3 deletions ju/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from ju.pydantic_util import (
is_valid_wrt_model,
valid_models,
data_to_pydantic_model,
pydantic_model_to_code,
data_to_pydantic_model, # data to pydantic model
pydantic_model_to_code, # pydantic model to code
field_paths_and_annotations, # flattened field paths & annotations from model
)
from ju.viz import model_digraph
from ju.viz import model_digraph # visualize pydantic models
83 changes: 83 additions & 0 deletions ju/pydantic_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,86 @@ class Simple(BaseModel):
**extra_json_schema_parser_kwargs,
)
return parser.parse()


# -------------------------------------------------------------------------------------
# utils for extraction
from typing import Any, Dict, Type, Union, get_args, get_origin, List


def field_paths_and_annotations(
data_model: Type[BaseModel],
) -> Dict[str, Type[Any]]:
"""
Get flattened field paths and their corresponding annotations from a Pydantic model.
Generates a dictionary of dot-separated paths and their corresponding types
from the fields of a given Pydantic BaseModel and any nested BaseModels within it.
The function recursively traverses the fields of the BaseModel and its nested models,
including fields that are lists, sets, tuples, or iterables containing BaseModels.
If the field is a collection containing a BaseModel, the path is marked with a '*'.
This structure is compatible with the `glom` library, allowing extraction of values
from a dictionary that matches the BaseModel structure.
Args:
data_model (Type[BaseModel]): The Pydantic BaseModel to extract field paths and annotations from.
Returns:
Dict[str, Type[Any]]: A dictionary where the keys are the dot-separated paths to fields
and the values are their corresponding types.
Example:
>>> from pydantic import BaseModel
>>> from typing import List
>>> class BItem(BaseModel):
... c: int
>>> class A(BaseModel):
... b: List[BItem]
... d: str
>>> class Model(BaseModel):
... a: A
>>> paths = field_paths_and_annotations(Model)
>>> assert paths == {
... 'a.b.*.c': int, 'a.d': str
... }
One of the main use cases for this function is to extract values from a dictionary
that matches the structure of a Pydantic BaseModel.
This can be done using the `glom` library, for example:
>>> from glom import glom # doctest: +SKIP
>>> {k: glom(data, k) for k in paths} # doctest: +SKIP
{'a.b.*.c': [1, 2], 'a.d': 'bar'}
"""

def recurse_model(model: Type[BaseModel], prefix: str = '') -> Dict[str, Type[Any]]:
paths = {}
for field_name, field_info in model.model_fields.items():
field_type = field_info.annotation
current_path = f"{prefix}.{field_name}" if prefix else field_name
origin = get_origin(field_type)
args = get_args(field_type)

if isinstance(field_type, type) and issubclass(field_type, BaseModel):
paths.update(recurse_model(field_type, current_path))
elif (
origin in {list, set, tuple}
and args
and isinstance(args[0], type)
and issubclass(args[0], BaseModel)
):
paths.update(recurse_model(args[0], f"{current_path}.*"))
else:
paths[current_path] = field_type

return paths

return recurse_model(data_model)
81 changes: 0 additions & 81 deletions ju/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,84 +427,3 @@ def feature_similarity_search(

FeatureSwitch = FuncFactory(feature_switch)


# -------------------------------------------------------------------------------------
# utils for extraction
from typing import Any, Dict, Type, Union, get_args, get_origin, List
from pydantic import BaseModel


def field_paths_and_annotations(
data_model: Type[BaseModel],
) -> Dict[str, Type[Any]]:
"""
Generates a dictionary of dot-separated paths and their corresponding types
from the fields of a given Pydantic BaseModel and any nested BaseModels within it.
The function recursively traverses the fields of the BaseModel and its nested models,
including fields that are lists, sets, tuples, or iterables containing BaseModels.
If the field is a collection containing a BaseModel, the path is marked with a '*'.
This structure is compatible with the `glom` library, allowing extraction of values
from a dictionary that matches the BaseModel structure.
Args:
data_model (Type[BaseModel]): The Pydantic BaseModel to extract field paths and annotations from.
Returns:
Dict[str, Type[Any]]: A dictionary where the keys are the dot-separated paths to fields
and the values are their corresponding types.
Example:
>>> from pydantic import BaseModel
>>> from typing import List
>>> class BItem(BaseModel):
... c: int
>>> class A(BaseModel):
... b: List[BItem]
... d: str
>>> class Model(BaseModel):
... a: A
>>> paths = field_paths_and_annotations(Model)
>>> assert paths == {
... 'a.b.*.c': int, 'a.d': str
... }
One of the main use cases for this function is to extract values from a dictionary
that matches the structure of a Pydantic BaseModel.
This can be done using the `glom` library, for example:
>>> from glom import glom # doctest: +SKIP
>>> {k: glom(data, k) for k in paths} # doctest: +SKIP
{'a.b.*.c': [1, 2], 'a.d': 'bar'}
"""

def recurse_model(model: Type[BaseModel], prefix: str = '') -> Dict[str, Type[Any]]:
paths = {}
for field_name, field_info in model.model_fields.items():
field_type = field_info.annotation
current_path = f"{prefix}.{field_name}" if prefix else field_name
origin = get_origin(field_type)
args = get_args(field_type)

if isinstance(field_type, type) and issubclass(field_type, BaseModel):
paths.update(recurse_model(field_type, current_path))
elif (
origin in {list, set, tuple}
and args
and isinstance(args[0], type)
and issubclass(args[0], BaseModel)
):
paths.update(recurse_model(args[0], f"{current_path}.*"))
else:
paths[current_path] = field_type

return paths

return recurse_model(data_model)
2 changes: 2 additions & 0 deletions ju/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def model_digraph(
>>> dot.render(f"User_model_graph", format="png", cleanup=True) # doctest: +SKIP
'User_model_graph.png'
See also the online tool: https://navneethg.github.io/jsonschemaviewer/
"""

# pylint: disable=import-outside-toplevel
Expand Down

0 comments on commit cdc1550

Please sign in to comment.