diff --git a/serde/de.py b/serde/de.py index fd96f99b..bd0d78ee 100644 --- a/serde/de.py +++ b/serde/de.py @@ -16,7 +16,7 @@ from beartype.door import is_bearable from beartype.roar import BeartypeCallHintParamViolation from dataclasses import dataclass, is_dataclass -from typing import overload, TypeVar, Generic, Any, Optional, Union, Literal +from typing import overload, TypeVar, Generic, Any, Optional, Union, Literal, Iterator from typing_extensions import dataclass_transform from .compat import ( @@ -985,11 +985,31 @@ def literal(self, arg: DeField[Any]) -> str: ) def default(self, arg: DeField[Any], code: str) -> str: - if arg.alias: - aliases = (f'"{s}"' for s in [arg.name, *arg.alias]) - exists = f'_exists_by_aliases({arg.datavar}, [{",".join(aliases)}])' + """ + Renders supplying default value during deserialization. + """ + + def get_aliased_fields(arg: Field[Any]) -> Iterator[str]: + return (f'"{s}"' for s in [arg.name, *arg.alias]) + + if arg.flatten: + # When a field has the `flatten` attribute, iterate over its dataclass fields. + # This ensures that the code checks keys in the data while considering aliases. + flattened = [] + for subarg in defields(arg.type): + if subarg.alias: + aliases = get_aliased_fields(subarg) + flattened.append(f'_exists_by_aliases({arg.datavar}, [{",".join(aliases)}])') + else: + flattened.append(f'"{subarg.name}" in {arg.datavar}') + exists = " and ".join(flattened) else: - exists = f'"{arg.conv_name()}" in {arg.datavar}' + if arg.alias: + aliases = get_aliased_fields(arg) + exists = f'_exists_by_aliases({arg.datavar}, [{",".join(aliases)}])' + else: + exists = f'"{arg.conv_name()}" in {arg.datavar}' + if has_default(arg): return f'({code}) if {exists} else serde_scope.defaults["{arg.name}"]' elif has_default_factory(arg): diff --git a/tests/test_flatten.py b/tests/test_flatten.py index 85c0b390..7adf26d6 100644 --- a/tests/test_flatten.py +++ b/tests/test_flatten.py @@ -81,3 +81,37 @@ class Bar: @serde class Foo: bar: list[Bar] = field(flatten=True) + + +def test_flatten_default() -> None: + @serde + class Bar: + c: float = field(default=0.0) + d: bool = field(default=False) + + @serde + class Foo: + a: int + b: str = field(default="foo") + bar: Bar = field(flatten=True, default_factory=Bar) + + f = Foo(a=10, b="b", bar=Bar(c=100.0, d=True)) + assert from_json(Foo, to_json(f)) == f + + assert from_json(Foo, '{"a": 20}') == Foo(20, "foo", Bar()) + + +def test_flatten_default_alias() -> None: + @serde + class Bar: + a: float = field(default=0.0, alias=["aa"]) # type: ignore + b: bool = field(default=False, alias=["bb"]) # type: ignore + + @serde + class Foo: + bar: Bar = field(flatten=True, default_factory=Bar) + + f = Foo(bar=Bar(100.0, True)) + assert from_json(Foo, to_json(f)) == f + + assert from_json(Foo, '{"aa": 20.0, "bb": false}') == Foo(Bar(20.0, False))