Skip to content

Commit

Permalink
Merge pull request #612 from yukinarit/fix-flatten-default
Browse files Browse the repository at this point in the history
Fix flatten with default
  • Loading branch information
yukinarit authored Nov 17, 2024
2 parents f2d048c + d0f924c commit 3a77152
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
30 changes: 25 additions & 5 deletions serde/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from beartype.door import is_bearable
from beartype.roar import BeartypeCallHintParamViolation
from dataclasses import dataclass, is_dataclass
from typing import overload, TypeVar, Generic, Any, Optional, Union, Literal
from typing import overload, TypeVar, Generic, Any, Optional, Union, Literal, Iterator
from typing_extensions import dataclass_transform

from .compat import (
Expand Down Expand Up @@ -985,11 +985,31 @@ def literal(self, arg: DeField[Any]) -> str:
)

def default(self, arg: DeField[Any], code: str) -> str:
if arg.alias:
aliases = (f'"{s}"' for s in [arg.name, *arg.alias])
exists = f'_exists_by_aliases({arg.datavar}, [{",".join(aliases)}])'
"""
Renders supplying default value during deserialization.
"""

def get_aliased_fields(arg: Field[Any]) -> Iterator[str]:
return (f'"{s}"' for s in [arg.name, *arg.alias])

if arg.flatten:
# When a field has the `flatten` attribute, iterate over its dataclass fields.
# This ensures that the code checks keys in the data while considering aliases.
flattened = []
for subarg in defields(arg.type):
if subarg.alias:
aliases = get_aliased_fields(subarg)
flattened.append(f'_exists_by_aliases({arg.datavar}, [{",".join(aliases)}])')
else:
flattened.append(f'"{subarg.name}" in {arg.datavar}')
exists = " and ".join(flattened)
else:
exists = f'"{arg.conv_name()}" in {arg.datavar}'
if arg.alias:
aliases = get_aliased_fields(arg)
exists = f'_exists_by_aliases({arg.datavar}, [{",".join(aliases)}])'
else:
exists = f'"{arg.conv_name()}" in {arg.datavar}'

if has_default(arg):
return f'({code}) if {exists} else serde_scope.defaults["{arg.name}"]'
elif has_default_factory(arg):
Expand Down
34 changes: 34 additions & 0 deletions tests/test_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,37 @@ class Bar:
@serde
class Foo:
bar: list[Bar] = field(flatten=True)


def test_flatten_default() -> None:
@serde
class Bar:
c: float = field(default=0.0)
d: bool = field(default=False)

@serde
class Foo:
a: int
b: str = field(default="foo")
bar: Bar = field(flatten=True, default_factory=Bar)

f = Foo(a=10, b="b", bar=Bar(c=100.0, d=True))
assert from_json(Foo, to_json(f)) == f

assert from_json(Foo, '{"a": 20}') == Foo(20, "foo", Bar())


def test_flatten_default_alias() -> None:
@serde
class Bar:
a: float = field(default=0.0, alias=["aa"]) # type: ignore
b: bool = field(default=False, alias=["bb"]) # type: ignore

@serde
class Foo:
bar: Bar = field(flatten=True, default_factory=Bar)

f = Foo(bar=Bar(100.0, True))
assert from_json(Foo, to_json(f)) == f

assert from_json(Foo, '{"aa": 20.0, "bb": false}') == Foo(Bar(20.0, False))

0 comments on commit 3a77152

Please sign in to comment.