From 01085ebb6baaefd9f36c57ba54d0836e86bfb9f5 Mon Sep 17 00:00:00 2001 From: Richard Hansen <rhansen@rhansen.org> Date: Thu, 13 Oct 2022 19:44:39 -0400 Subject: [PATCH] Use `ChainMap` to improve representer/constructor inheritance Before this change, a child Dumper "inherited" its parent Dumper's representers by simply copying the parent's representers the first time a new representer was registered with the child Dumper. This approach works as users expect only if representers are never added to a parent (ancestor) Dumper after representers are added to a child (descendant) Dumper. Same goes for Loaders and their registered constructors. This commit uses a `collections.ChainMap` built from ancestor `dict` objects in method resolution order (MRO) to provide true inheritance of representers (for Dumpers) and constructors (for Loaders). This is technically a backwards-incompatible change: This change breaks any code that intentionally subverts the expected inheritance behavior by registering a function in an ancestor class *after* registering a function in a descendant class. --- lib/yaml/_utils.py | 42 ++++++++++++++++++++ lib/yaml/constructor.py | 15 +++----- lib/yaml/representer.py | 15 +++----- tests/lib/test_constructor.py | 72 +++++++++++++++++++++++++++++++++++ tests/lib/test_representer.py | 72 +++++++++++++++++++++++++++++++++++ 5 files changed, 198 insertions(+), 18 deletions(-) create mode 100644 lib/yaml/_utils.py diff --git a/lib/yaml/_utils.py b/lib/yaml/_utils.py new file mode 100644 index 00000000..26e013e1 --- /dev/null +++ b/lib/yaml/_utils.py @@ -0,0 +1,42 @@ +import collections + + +class InheritMapMixin: + """Adds :py:class:`collections.ChainMap` class attributes based on MRO. + + The added class attributes provide each subclass with its own mapping that + works just like method resolution: each ancestor type in ``__mro__`` is + visited until an entry is found in the ancestor-owned map. + + For example, for an inheritance DAG of ``InheritMapMixin`` <- ``Foo`` <- + ``Bar`` <- ``Baz`` and a desired attribute name of ``"_m"``: + + 1. ``Foo._m`` is set to ``ChainMap({})``. + 2. ``Bar._m`` is set to ``ChainMap({}, Foo._m.maps[0])``. + 3. ``Baz._m`` is set to ``ChainMap({}, Bar._m.maps[0], Foo._m.maps[0])``. + """ + + @classmethod + def __init_subclass__(cls, *, inherit_map_attrs=None, **kwargs): + """Adds :py:class:`collections.ChainMap` class attributes based on MRO. + + :param inherit_map_attrs: + Optional iterable of names of class attributes that will be set to a + :py:class:`collections.ChainMap` containing the MRO-based list of + ancestor maps. + """ + super().__init_subclass__(**kwargs) + attrs = getattr(cls, "_inherit_map_attrs", set()) + if inherit_map_attrs: + attrs = {*attrs, *inherit_map_attrs} + cls._inherit_map_attrs = attrs + for attr in attrs: + maps = [{}] # maps[0] is for cls itself. + for c in cls.__mro__[1:]: # cls.__mro__[0] is cls itself. + if ( + issubclass(c, InheritMapMixin) and + c is not InheritMapMixin and + attr in getattr(c, "_inherit_map_attrs", set()) + ): + maps.append(getattr(c, attr).maps[0]) + setattr(cls, attr, collections.ChainMap(*maps)) diff --git a/lib/yaml/constructor.py b/lib/yaml/constructor.py index 619acd30..0259baf1 100644 --- a/lib/yaml/constructor.py +++ b/lib/yaml/constructor.py @@ -13,14 +13,15 @@ import collections.abc, datetime, base64, binascii, re, sys, types -class ConstructorError(MarkedYAMLError): - pass +from . import _utils -class BaseConstructor: - yaml_constructors = {} - yaml_multi_constructors = {} +class ConstructorError(MarkedYAMLError): + pass +class BaseConstructor( + _utils.InheritMapMixin, + inherit_map_attrs={"yaml_constructors", "yaml_multi_constructors"}): def __init__(self): self.constructed_objects = {} self.recursive_objects = {} @@ -158,14 +159,10 @@ def construct_pairs(self, node, deep=False): @classmethod def add_constructor(cls, tag, constructor): - if not 'yaml_constructors' in cls.__dict__: - cls.yaml_constructors = cls.yaml_constructors.copy() cls.yaml_constructors[tag] = constructor @classmethod def add_multi_constructor(cls, tag_prefix, multi_constructor): - if not 'yaml_multi_constructors' in cls.__dict__: - cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy() cls.yaml_multi_constructors[tag_prefix] = multi_constructor class SafeConstructor(BaseConstructor): diff --git a/lib/yaml/representer.py b/lib/yaml/representer.py index 808ca06d..77934821 100644 --- a/lib/yaml/representer.py +++ b/lib/yaml/representer.py @@ -7,14 +7,15 @@ import datetime, copyreg, types, base64, collections -class RepresenterError(YAMLError): - pass +from . import _utils -class BaseRepresenter: - yaml_representers = {} - yaml_multi_representers = {} +class RepresenterError(YAMLError): + pass +class BaseRepresenter( + _utils.InheritMapMixin, + inherit_map_attrs={"yaml_representers", "yaml_multi_representers"}): def __init__(self, default_style=None, default_flow_style=False, sort_keys=True): self.default_style = default_style self.sort_keys = sort_keys @@ -64,14 +65,10 @@ def represent_data(self, data): @classmethod def add_representer(cls, data_type, representer): - if not 'yaml_representers' in cls.__dict__: - cls.yaml_representers = cls.yaml_representers.copy() cls.yaml_representers[data_type] = representer @classmethod def add_multi_representer(cls, data_type, representer): - if not 'yaml_multi_representers' in cls.__dict__: - cls.yaml_multi_representers = cls.yaml_multi_representers.copy() cls.yaml_multi_representers[data_type] = representer def represent_scalar(self, tag, value, style=None): diff --git a/tests/lib/test_constructor.py b/tests/lib/test_constructor.py index 0783a21b..2e554de6 100644 --- a/tests/lib/test_constructor.py +++ b/tests/lib/test_constructor.py @@ -296,6 +296,78 @@ def test_subclass_blacklist_types(data_filename, verbose=False): test_subclass_blacklist_types.unittest = ['.subclass_blacklist'] +def test_constructor_inheritance(verbose=False): + class Widget: + pass + + class Gizmo: + pass + + def construct_widget(loader, node): + return Widget() + + def construct_gizmo1(loader, node): + return Gizmo() + + def construct_gizmo2(loader, node): + return Gizmo() + + def construct_gizmo3(loader, node): + return Gizmo() + + class LoaderParent(yaml.Loader): + pass + + class LoaderChild(LoaderParent): + pass + + # Add a constructor to the child. Note that no constructor has been added + # to the parent yet. + LoaderChild.add_constructor("!widget", construct_widget) + if verbose: + print("After adding a constructor to the child Loader:") + print(f" {LoaderParent.yaml_constructors=}") + print(f" {LoaderChild.yaml_constructors=}") + assert LoaderChild.yaml_constructors["!widget"] is construct_widget + assert "!widget" not in LoaderParent.yaml_constructors + + # A constructor is now added to the parent. The child should be able to see + # this new constructor even though it was added after a constructor was + # added to the child above. + LoaderParent.add_constructor("!gizmo", construct_gizmo1) + if verbose: + print("After adding a constructor to the parent Loader:") + print(f" {LoaderParent.yaml_constructors=}") + print(f" {LoaderChild.yaml_constructors=}") + assert LoaderChild.yaml_constructors["!widget"] is construct_widget + assert "!widget" not in LoaderParent.yaml_constructors + assert LoaderParent.yaml_constructors["!gizmo"] is construct_gizmo1 + assert LoaderChild.yaml_constructors["!gizmo"] is construct_gizmo1 + + # Override a parent constructor in the child. + LoaderChild.add_constructor("!gizmo", construct_gizmo2) + if verbose: + print("After overriding a parent's constructor:") + print(f" {LoaderParent.yaml_constructors=}") + print(f" {LoaderChild.yaml_constructors=}") + assert LoaderChild.yaml_constructors["!widget"] is construct_widget + assert "!widget" not in LoaderParent.yaml_constructors + assert LoaderParent.yaml_constructors["!gizmo"] is construct_gizmo1 + assert LoaderChild.yaml_constructors["!gizmo"] is construct_gizmo2 + + # Changing the parent's overridden constructor should not affect the child. + LoaderParent.add_constructor("!gizmo", construct_gizmo3) + if verbose: + print("After changing the parent's overridden constructor:") + print(f" {LoaderParent.yaml_constructors=}") + print(f" {LoaderChild.yaml_constructors=}") + assert LoaderChild.yaml_constructors["!widget"] is construct_widget + assert "!widget" not in LoaderParent.yaml_constructors + assert LoaderParent.yaml_constructors["!gizmo"] is construct_gizmo3 + assert LoaderChild.yaml_constructors["!gizmo"] is construct_gizmo2 + +test_constructor_inheritance.unittest = True + if __name__ == '__main__': import sys, test_constructor sys.modules['test_constructor'] = sys.modules['__main__'] diff --git a/tests/lib/test_representer.py b/tests/lib/test_representer.py index f3095bfc..40c33e6b 100644 --- a/tests/lib/test_representer.py +++ b/tests/lib/test_representer.py @@ -38,6 +38,78 @@ def test_representer_types(code_filename, verbose=False): test_representer_types.unittest = ['.code'] +def test_representer_inheritance(verbose=False): + class Widget: + pass + + class Gizmo: + pass + + def represent_widget(representer, obj): + return representer.represent_scalar("!widget", "widget") + + def represent_gizmo1(representer, obj): + return representer.represent_scalar("!gizmo", "gizmo1") + + def represent_gizmo2(representer, obj): + return representer.represent_scalar("!gizmo", "gizmo2") + + def represent_gizmo3(representer, obj): + return representer.represent_scalar("!gizmo", "gizmo3") + + class DumperParent(yaml.Dumper): + pass + + class DumperChild(DumperParent): + pass + + # Add a representer to the child. Note that no representer has been added + # to the parent yet. + DumperChild.add_representer(Widget, represent_widget) + if verbose: + print("After adding a representer to the child Dumper:") + print(f" {DumperParent.yaml_representers=}") + print(f" {DumperChild.yaml_representers=}") + assert DumperChild.yaml_representers[Widget] is represent_widget + assert Widget not in DumperParent.yaml_representers + + # A representer is now added to the parent. The child should be able to see + # this new representer even though it was added after a representer was + # added to the child above. + DumperParent.add_representer(Gizmo, represent_gizmo1) + if verbose: + print("After adding a representer to the parent Dumper:") + print(f" {DumperParent.yaml_representers=}") + print(f" {DumperChild.yaml_representers=}") + assert DumperChild.yaml_representers[Widget] is represent_widget + assert Widget not in DumperParent.yaml_representers + assert DumperParent.yaml_representers[Gizmo] is represent_gizmo1 + assert DumperChild.yaml_representers[Gizmo] is represent_gizmo1 + + # Override a parent representer in the child. + DumperChild.add_representer(Gizmo, represent_gizmo2) + if verbose: + print("After overriding a parent's representer:") + print(f" {DumperParent.yaml_representers=}") + print(f" {DumperChild.yaml_representers=}") + assert DumperChild.yaml_representers[Widget] is represent_widget + assert Widget not in DumperParent.yaml_representers + assert DumperParent.yaml_representers[Gizmo] is represent_gizmo1 + assert DumperChild.yaml_representers[Gizmo] is represent_gizmo2 + + # Changing the parent's overridden representer should not affect the child. + DumperParent.add_representer(Gizmo, represent_gizmo3) + if verbose: + print("After changing the parent's overridden representer:") + print(f" {DumperParent.yaml_representers=}") + print(f" {DumperChild.yaml_representers=}") + assert DumperChild.yaml_representers[Widget] is represent_widget + assert Widget not in DumperParent.yaml_representers + assert DumperParent.yaml_representers[Gizmo] is represent_gizmo3 + assert DumperChild.yaml_representers[Gizmo] is represent_gizmo2 + +test_representer_inheritance.unittest = True + if __name__ == '__main__': import test_appliance test_appliance.run(globals())