From fd1b8d796faa7928ef1ca1100ce1e2a82d6009b7 Mon Sep 17 00:00:00 2001 From: a-gardner1 Date: Mon, 14 Oct 2024 19:22:10 +0000 Subject: [PATCH] Respect TypedDict.__required_keys__ --- jsonargparse/_typehints.py | 13 ++++++++++- jsonargparse_tests/test_typehints.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index aab043d2..42f3735d 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -931,7 +931,18 @@ def adapt_typehints( kwargs["prev_val"] = None val[k] = adapt_typehints(v, subtypehints[1], **kwargs) if type(typehint) in typed_dict_meta_types: - if typehint.__total__: + if hasattr(typehint, "__required_keys__"): + required_keys = set(typehint.__required_keys__) + # The standard library TypedDict below Python 3.11 does not store runtime + # information about optional and required keys when using Required or NotRequired. + # Thus, capture explicitly Required keys + required_keys.update( + {k for k, v in typehint.__annotations__.items() if get_typehint_origin(v) in required_types} + ) + # The standard library TypedDict in Python 3.8 does not store runtime information + # about which (if any) keys are optional. See https://bugs.python.org/issue38834. + # Thus, fall back to totality and explicitly Required keys + elif typehint.__total__: required_keys = { k for k, v in typehint.__annotations__.items() if get_typehint_origin(v) not in not_required_types } diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index 332e35eb..b5dfb063 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -679,6 +679,38 @@ def test_invalid_inherited_unpack_typeddict(parser, init_args): parser.parse_args([f"--testclass={json.dumps(test_config)}"]) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Python 3.8 lacked runtime inspection of TypedDict required keys") +def test_typeddict_totality_inheritance(parser): + + class BottomDict(TypedDict, total=True): + a: int + + class MiddleDict(BottomDict, total=False): + b: int + + class TopDict(MiddleDict, total=True): + c: int + + parser.add_argument("--middledict", type=MiddleDict, required=False) + parser.add_argument("--topdict", type=TopDict, required=False) + assert {"a": 1} == parser.parse_args(["--middledict={'a': 1}"])["middledict"] + assert {"a": 1, "b": 2} == parser.parse_args(["--middledict={'a': 1, 'b': 2}"])["middledict"] + with pytest.raises(ArgumentError) as ctx: + parser.parse_args(["--middledict={}"]) + ctx.match("Missing required keys") + with pytest.raises(ArgumentError) as ctx: + parser.parse_args(['--middledict={"b": 2}']) + ctx.match("Missing required keys") + assert {"a": 1, "c": 2} == parser.parse_args(["--topdict={'a': 1, 'c': 2}"])["topdict"] + assert {"a": 1, "b": 2, "c": 3} == parser.parse_args(["--topdict={'a': 1, 'b': 2, 'c': 3}"])["topdict"] + with pytest.raises(ArgumentError) as ctx: + parser.parse_args(['--topdict={"a": 1, "b": 2}']) + ctx.match("Missing required keys") + with pytest.raises(ArgumentError) as ctx: + parser.parse_args(['--topdict={"b":2, "c": 3}']) + ctx.match("Missing required keys") + + def test_mapping_proxy_type(parser): parser.add_argument("--mapping", type=MappingProxyType) cfg = parser.parse_args(['--mapping={"x":1}'])