From 01085ebb6baaefd9f36c57ba54d0836e86bfb9f5 Mon Sep 17 00:00:00 2001
From: Richard Hansen <rhansen@rhansen.org>
Date: Thu, 13 Oct 2022 19:44:39 -0400
Subject: [PATCH] Use `ChainMap` to improve representer/constructor inheritance

Before this change, a child Dumper "inherited" its parent Dumper's
representers by simply copying the parent's representers the first
time a new representer was registered with the child Dumper.  This
approach works as users expect only if representers are never added to
a parent (ancestor) Dumper after representers are added to a
child (descendant) Dumper.  Same goes for Loaders and their registered
constructors.

This commit uses a `collections.ChainMap` built from ancestor `dict`
objects in method resolution order (MRO) to provide true inheritance
of representers (for Dumpers) and constructors (for Loaders).

This is technically a backwards-incompatible change: This change
breaks any code that intentionally subverts the expected inheritance
behavior by registering a function in an ancestor class *after*
registering a function in a descendant class.
---
 lib/yaml/_utils.py            | 42 ++++++++++++++++++++
 lib/yaml/constructor.py       | 15 +++-----
 lib/yaml/representer.py       | 15 +++-----
 tests/lib/test_constructor.py | 72 +++++++++++++++++++++++++++++++++++
 tests/lib/test_representer.py | 72 +++++++++++++++++++++++++++++++++++
 5 files changed, 198 insertions(+), 18 deletions(-)
 create mode 100644 lib/yaml/_utils.py

diff --git a/lib/yaml/_utils.py b/lib/yaml/_utils.py
new file mode 100644
index 00000000..26e013e1
--- /dev/null
+++ b/lib/yaml/_utils.py
@@ -0,0 +1,42 @@
+import collections
+
+
+class InheritMapMixin:
+    """Adds :py:class:`collections.ChainMap` class attributes based on MRO.
+
+    The added class attributes provide each subclass with its own mapping that
+    works just like method resolution: each ancestor type in ``__mro__`` is
+    visited until an entry is found in the ancestor-owned map.
+
+    For example, for an inheritance DAG of ``InheritMapMixin`` <- ``Foo`` <-
+    ``Bar`` <- ``Baz`` and a desired attribute name of ``"_m"``:
+
+    1. ``Foo._m`` is set to ``ChainMap({})``.
+    2. ``Bar._m`` is set to ``ChainMap({}, Foo._m.maps[0])``.
+    3. ``Baz._m`` is set to ``ChainMap({}, Bar._m.maps[0], Foo._m.maps[0])``.
+    """
+
+    @classmethod
+    def __init_subclass__(cls, *, inherit_map_attrs=None, **kwargs):
+        """Adds :py:class:`collections.ChainMap` class attributes based on MRO.
+
+        :param inherit_map_attrs:
+            Optional iterable of names of class attributes that will be set to a
+            :py:class:`collections.ChainMap` containing the MRO-based list of
+            ancestor maps.
+        """
+        super().__init_subclass__(**kwargs)
+        attrs = getattr(cls, "_inherit_map_attrs", set())
+        if inherit_map_attrs:
+            attrs = {*attrs, *inherit_map_attrs}
+            cls._inherit_map_attrs = attrs
+        for attr in attrs:
+            maps = [{}]  # maps[0] is for cls itself.
+            for c in cls.__mro__[1:]:  # cls.__mro__[0] is cls itself.
+                if (
+                        issubclass(c, InheritMapMixin) and
+                        c is not InheritMapMixin and
+                        attr in getattr(c, "_inherit_map_attrs", set())
+                ):
+                    maps.append(getattr(c, attr).maps[0])
+            setattr(cls, attr, collections.ChainMap(*maps))
diff --git a/lib/yaml/constructor.py b/lib/yaml/constructor.py
index 619acd30..0259baf1 100644
--- a/lib/yaml/constructor.py
+++ b/lib/yaml/constructor.py
@@ -13,14 +13,15 @@
 
 import collections.abc, datetime, base64, binascii, re, sys, types
 
-class ConstructorError(MarkedYAMLError):
-    pass
+from . import _utils
 
-class BaseConstructor:
 
-    yaml_constructors = {}
-    yaml_multi_constructors = {}
+class ConstructorError(MarkedYAMLError):
+    pass
 
+class BaseConstructor(
+        _utils.InheritMapMixin,
+        inherit_map_attrs={"yaml_constructors", "yaml_multi_constructors"}):
     def __init__(self):
         self.constructed_objects = {}
         self.recursive_objects = {}
@@ -158,14 +159,10 @@ def construct_pairs(self, node, deep=False):
 
     @classmethod
     def add_constructor(cls, tag, constructor):
-        if not 'yaml_constructors' in cls.__dict__:
-            cls.yaml_constructors = cls.yaml_constructors.copy()
         cls.yaml_constructors[tag] = constructor
 
     @classmethod
     def add_multi_constructor(cls, tag_prefix, multi_constructor):
-        if not 'yaml_multi_constructors' in cls.__dict__:
-            cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy()
         cls.yaml_multi_constructors[tag_prefix] = multi_constructor
 
 class SafeConstructor(BaseConstructor):
diff --git a/lib/yaml/representer.py b/lib/yaml/representer.py
index 808ca06d..77934821 100644
--- a/lib/yaml/representer.py
+++ b/lib/yaml/representer.py
@@ -7,14 +7,15 @@
 
 import datetime, copyreg, types, base64, collections
 
-class RepresenterError(YAMLError):
-    pass
+from . import _utils
 
-class BaseRepresenter:
 
-    yaml_representers = {}
-    yaml_multi_representers = {}
+class RepresenterError(YAMLError):
+    pass
 
+class BaseRepresenter(
+        _utils.InheritMapMixin,
+        inherit_map_attrs={"yaml_representers", "yaml_multi_representers"}):
     def __init__(self, default_style=None, default_flow_style=False, sort_keys=True):
         self.default_style = default_style
         self.sort_keys = sort_keys
@@ -64,14 +65,10 @@ def represent_data(self, data):
 
     @classmethod
     def add_representer(cls, data_type, representer):
-        if not 'yaml_representers' in cls.__dict__:
-            cls.yaml_representers = cls.yaml_representers.copy()
         cls.yaml_representers[data_type] = representer
 
     @classmethod
     def add_multi_representer(cls, data_type, representer):
-        if not 'yaml_multi_representers' in cls.__dict__:
-            cls.yaml_multi_representers = cls.yaml_multi_representers.copy()
         cls.yaml_multi_representers[data_type] = representer
 
     def represent_scalar(self, tag, value, style=None):
diff --git a/tests/lib/test_constructor.py b/tests/lib/test_constructor.py
index 0783a21b..2e554de6 100644
--- a/tests/lib/test_constructor.py
+++ b/tests/lib/test_constructor.py
@@ -296,6 +296,78 @@ def test_subclass_blacklist_types(data_filename, verbose=False):
 
 test_subclass_blacklist_types.unittest = ['.subclass_blacklist']
 
+def test_constructor_inheritance(verbose=False):
+    class Widget:
+        pass
+
+    class Gizmo:
+        pass
+
+    def construct_widget(loader, node):
+        return Widget()
+
+    def construct_gizmo1(loader, node):
+        return Gizmo()
+
+    def construct_gizmo2(loader, node):
+        return Gizmo()
+
+    def construct_gizmo3(loader, node):
+        return Gizmo()
+
+    class LoaderParent(yaml.Loader):
+        pass
+
+    class LoaderChild(LoaderParent):
+        pass
+
+    # Add a constructor to the child.  Note that no constructor has been added
+    # to the parent yet.
+    LoaderChild.add_constructor("!widget", construct_widget)
+    if verbose:
+        print("After adding a constructor to the child Loader:")
+        print(f"  {LoaderParent.yaml_constructors=}")
+        print(f"  {LoaderChild.yaml_constructors=}")
+    assert LoaderChild.yaml_constructors["!widget"] is construct_widget
+    assert "!widget" not in LoaderParent.yaml_constructors
+
+    # A constructor is now added to the parent.  The child should be able to see
+    # this new constructor even though it was added after a constructor was
+    # added to the child above.
+    LoaderParent.add_constructor("!gizmo", construct_gizmo1)
+    if verbose:
+        print("After adding a constructor to the parent Loader:")
+        print(f"  {LoaderParent.yaml_constructors=}")
+        print(f"  {LoaderChild.yaml_constructors=}")
+    assert LoaderChild.yaml_constructors["!widget"] is construct_widget
+    assert "!widget" not in LoaderParent.yaml_constructors
+    assert LoaderParent.yaml_constructors["!gizmo"] is construct_gizmo1
+    assert LoaderChild.yaml_constructors["!gizmo"] is construct_gizmo1
+
+    # Override a parent constructor in the child.
+    LoaderChild.add_constructor("!gizmo", construct_gizmo2)
+    if verbose:
+        print("After overriding a parent's constructor:")
+        print(f"  {LoaderParent.yaml_constructors=}")
+        print(f"  {LoaderChild.yaml_constructors=}")
+    assert LoaderChild.yaml_constructors["!widget"] is construct_widget
+    assert "!widget" not in LoaderParent.yaml_constructors
+    assert LoaderParent.yaml_constructors["!gizmo"] is construct_gizmo1
+    assert LoaderChild.yaml_constructors["!gizmo"] is construct_gizmo2
+
+    # Changing the parent's overridden constructor should not affect the child.
+    LoaderParent.add_constructor("!gizmo", construct_gizmo3)
+    if verbose:
+        print("After changing the parent's overridden constructor:")
+        print(f"  {LoaderParent.yaml_constructors=}")
+        print(f"  {LoaderChild.yaml_constructors=}")
+    assert LoaderChild.yaml_constructors["!widget"] is construct_widget
+    assert "!widget" not in LoaderParent.yaml_constructors
+    assert LoaderParent.yaml_constructors["!gizmo"] is construct_gizmo3
+    assert LoaderChild.yaml_constructors["!gizmo"] is construct_gizmo2
+
+test_constructor_inheritance.unittest = True
+
 if __name__ == '__main__':
     import sys, test_constructor
     sys.modules['test_constructor'] = sys.modules['__main__']
diff --git a/tests/lib/test_representer.py b/tests/lib/test_representer.py
index f3095bfc..40c33e6b 100644
--- a/tests/lib/test_representer.py
+++ b/tests/lib/test_representer.py
@@ -38,6 +38,78 @@ def test_representer_types(code_filename, verbose=False):
 
 test_representer_types.unittest = ['.code']
 
+def test_representer_inheritance(verbose=False):
+    class Widget:
+        pass
+
+    class Gizmo:
+        pass
+
+    def represent_widget(representer, obj):
+        return representer.represent_scalar("!widget", "widget")
+
+    def represent_gizmo1(representer, obj):
+        return representer.represent_scalar("!gizmo", "gizmo1")
+
+    def represent_gizmo2(representer, obj):
+        return representer.represent_scalar("!gizmo", "gizmo2")
+
+    def represent_gizmo3(representer, obj):
+        return representer.represent_scalar("!gizmo", "gizmo3")
+
+    class DumperParent(yaml.Dumper):
+        pass
+
+    class DumperChild(DumperParent):
+        pass
+
+    # Add a representer to the child.  Note that no representer has been added
+    # to the parent yet.
+    DumperChild.add_representer(Widget, represent_widget)
+    if verbose:
+        print("After adding a representer to the child Dumper:")
+        print(f"  {DumperParent.yaml_representers=}")
+        print(f"  {DumperChild.yaml_representers=}")
+    assert DumperChild.yaml_representers[Widget] is represent_widget
+    assert Widget not in DumperParent.yaml_representers
+
+    # A representer is now added to the parent.  The child should be able to see
+    # this new representer even though it was added after a representer was
+    # added to the child above.
+    DumperParent.add_representer(Gizmo, represent_gizmo1)
+    if verbose:
+        print("After adding a representer to the parent Dumper:")
+        print(f"  {DumperParent.yaml_representers=}")
+        print(f"  {DumperChild.yaml_representers=}")
+    assert DumperChild.yaml_representers[Widget] is represent_widget
+    assert Widget not in DumperParent.yaml_representers
+    assert DumperParent.yaml_representers[Gizmo] is represent_gizmo1
+    assert DumperChild.yaml_representers[Gizmo] is represent_gizmo1
+
+    # Override a parent representer in the child.
+    DumperChild.add_representer(Gizmo, represent_gizmo2)
+    if verbose:
+        print("After overriding a parent's representer:")
+        print(f"  {DumperParent.yaml_representers=}")
+        print(f"  {DumperChild.yaml_representers=}")
+    assert DumperChild.yaml_representers[Widget] is represent_widget
+    assert Widget not in DumperParent.yaml_representers
+    assert DumperParent.yaml_representers[Gizmo] is represent_gizmo1
+    assert DumperChild.yaml_representers[Gizmo] is represent_gizmo2
+
+    # Changing the parent's overridden representer should not affect the child.
+    DumperParent.add_representer(Gizmo, represent_gizmo3)
+    if verbose:
+        print("After changing the parent's overridden representer:")
+        print(f"  {DumperParent.yaml_representers=}")
+        print(f"  {DumperChild.yaml_representers=}")
+    assert DumperChild.yaml_representers[Widget] is represent_widget
+    assert Widget not in DumperParent.yaml_representers
+    assert DumperParent.yaml_representers[Gizmo] is represent_gizmo3
+    assert DumperChild.yaml_representers[Gizmo] is represent_gizmo2
+
+test_representer_inheritance.unittest = True
+
 if __name__ == '__main__':
     import test_appliance
     test_appliance.run(globals())