Skip to content

Commit

Permalink
Use ChainMap to improve representer/constructor inheritance
Browse files Browse the repository at this point in the history
Before this change, a child Dumper "inherited" its parent Dumper's
representers by simply copying the parent's representers the first
time a new representer was registered with the child Dumper.  This
approach works as users expect only if representers are never added to
a parent (ancestor) Dumper after representers are added to a
child (descendant) Dumper.  Same goes for Loaders and their registered
constructors.

This commit uses a `collections.ChainMap` built from ancestor `dict`
objects in method resolution order (MRO) to provide true inheritance
of representers (for Dumpers) and constructors (for Loaders).

This is technically a backwards-incompatible change: This change
breaks any code that intentionally subverts the expected inheritance
behavior by registering a function in an ancestor class *after*
registering a function in a descendant class.
  • Loading branch information
rhansen committed Nov 18, 2022
1 parent 957ae4d commit 01085eb
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 18 deletions.
42 changes: 42 additions & 0 deletions lib/yaml/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import collections


class InheritMapMixin:
"""Adds :py:class:`collections.ChainMap` class attributes based on MRO.
The added class attributes provide each subclass with its own mapping that
works just like method resolution: each ancestor type in ``__mro__`` is
visited until an entry is found in the ancestor-owned map.
For example, for an inheritance DAG of ``InheritMapMixin`` <- ``Foo`` <-
``Bar`` <- ``Baz`` and a desired attribute name of ``"_m"``:
1. ``Foo._m`` is set to ``ChainMap({})``.
2. ``Bar._m`` is set to ``ChainMap({}, Foo._m.maps[0])``.
3. ``Baz._m`` is set to ``ChainMap({}, Bar._m.maps[0], Foo._m.maps[0])``.
"""

@classmethod
def __init_subclass__(cls, *, inherit_map_attrs=None, **kwargs):
"""Adds :py:class:`collections.ChainMap` class attributes based on MRO.
:param inherit_map_attrs:
Optional iterable of names of class attributes that will be set to a
:py:class:`collections.ChainMap` containing the MRO-based list of
ancestor maps.
"""
super().__init_subclass__(**kwargs)
attrs = getattr(cls, "_inherit_map_attrs", set())
if inherit_map_attrs:
attrs = {*attrs, *inherit_map_attrs}
cls._inherit_map_attrs = attrs
for attr in attrs:
maps = [{}] # maps[0] is for cls itself.
for c in cls.__mro__[1:]: # cls.__mro__[0] is cls itself.
if (
issubclass(c, InheritMapMixin) and
c is not InheritMapMixin and
attr in getattr(c, "_inherit_map_attrs", set())
):
maps.append(getattr(c, attr).maps[0])
setattr(cls, attr, collections.ChainMap(*maps))
15 changes: 6 additions & 9 deletions lib/yaml/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

import collections.abc, datetime, base64, binascii, re, sys, types

class ConstructorError(MarkedYAMLError):
pass
from . import _utils

class BaseConstructor:

yaml_constructors = {}
yaml_multi_constructors = {}
class ConstructorError(MarkedYAMLError):
pass

class BaseConstructor(
_utils.InheritMapMixin,
inherit_map_attrs={"yaml_constructors", "yaml_multi_constructors"}):
def __init__(self):
self.constructed_objects = {}
self.recursive_objects = {}
Expand Down Expand Up @@ -158,14 +159,10 @@ def construct_pairs(self, node, deep=False):

@classmethod
def add_constructor(cls, tag, constructor):
if not 'yaml_constructors' in cls.__dict__:
cls.yaml_constructors = cls.yaml_constructors.copy()
cls.yaml_constructors[tag] = constructor

@classmethod
def add_multi_constructor(cls, tag_prefix, multi_constructor):
if not 'yaml_multi_constructors' in cls.__dict__:
cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy()
cls.yaml_multi_constructors[tag_prefix] = multi_constructor

class SafeConstructor(BaseConstructor):
Expand Down
15 changes: 6 additions & 9 deletions lib/yaml/representer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

import datetime, copyreg, types, base64, collections

class RepresenterError(YAMLError):
pass
from . import _utils

class BaseRepresenter:

yaml_representers = {}
yaml_multi_representers = {}
class RepresenterError(YAMLError):
pass

class BaseRepresenter(
_utils.InheritMapMixin,
inherit_map_attrs={"yaml_representers", "yaml_multi_representers"}):
def __init__(self, default_style=None, default_flow_style=False, sort_keys=True):
self.default_style = default_style
self.sort_keys = sort_keys
Expand Down Expand Up @@ -64,14 +65,10 @@ def represent_data(self, data):

@classmethod
def add_representer(cls, data_type, representer):
if not 'yaml_representers' in cls.__dict__:
cls.yaml_representers = cls.yaml_representers.copy()
cls.yaml_representers[data_type] = representer

@classmethod
def add_multi_representer(cls, data_type, representer):
if not 'yaml_multi_representers' in cls.__dict__:
cls.yaml_multi_representers = cls.yaml_multi_representers.copy()
cls.yaml_multi_representers[data_type] = representer

def represent_scalar(self, tag, value, style=None):
Expand Down
72 changes: 72 additions & 0 deletions tests/lib/test_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,78 @@ def test_subclass_blacklist_types(data_filename, verbose=False):

test_subclass_blacklist_types.unittest = ['.subclass_blacklist']

def test_constructor_inheritance(verbose=False):
class Widget:
pass

class Gizmo:
pass

def construct_widget(loader, node):
return Widget()

def construct_gizmo1(loader, node):
return Gizmo()

def construct_gizmo2(loader, node):
return Gizmo()

def construct_gizmo3(loader, node):
return Gizmo()

class LoaderParent(yaml.Loader):
pass

class LoaderChild(LoaderParent):
pass

# Add a constructor to the child. Note that no constructor has been added
# to the parent yet.
LoaderChild.add_constructor("!widget", construct_widget)
if verbose:
print("After adding a constructor to the child Loader:")
print(f" {LoaderParent.yaml_constructors=}")
print(f" {LoaderChild.yaml_constructors=}")
assert LoaderChild.yaml_constructors["!widget"] is construct_widget
assert "!widget" not in LoaderParent.yaml_constructors

# A constructor is now added to the parent. The child should be able to see
# this new constructor even though it was added after a constructor was
# added to the child above.
LoaderParent.add_constructor("!gizmo", construct_gizmo1)
if verbose:
print("After adding a constructor to the parent Loader:")
print(f" {LoaderParent.yaml_constructors=}")
print(f" {LoaderChild.yaml_constructors=}")
assert LoaderChild.yaml_constructors["!widget"] is construct_widget
assert "!widget" not in LoaderParent.yaml_constructors
assert LoaderParent.yaml_constructors["!gizmo"] is construct_gizmo1
assert LoaderChild.yaml_constructors["!gizmo"] is construct_gizmo1

# Override a parent constructor in the child.
LoaderChild.add_constructor("!gizmo", construct_gizmo2)
if verbose:
print("After overriding a parent's constructor:")
print(f" {LoaderParent.yaml_constructors=}")
print(f" {LoaderChild.yaml_constructors=}")
assert LoaderChild.yaml_constructors["!widget"] is construct_widget
assert "!widget" not in LoaderParent.yaml_constructors
assert LoaderParent.yaml_constructors["!gizmo"] is construct_gizmo1
assert LoaderChild.yaml_constructors["!gizmo"] is construct_gizmo2

# Changing the parent's overridden constructor should not affect the child.
LoaderParent.add_constructor("!gizmo", construct_gizmo3)
if verbose:
print("After changing the parent's overridden constructor:")
print(f" {LoaderParent.yaml_constructors=}")
print(f" {LoaderChild.yaml_constructors=}")
assert LoaderChild.yaml_constructors["!widget"] is construct_widget
assert "!widget" not in LoaderParent.yaml_constructors
assert LoaderParent.yaml_constructors["!gizmo"] is construct_gizmo3
assert LoaderChild.yaml_constructors["!gizmo"] is construct_gizmo2

test_constructor_inheritance.unittest = True

if __name__ == '__main__':
import sys, test_constructor
sys.modules['test_constructor'] = sys.modules['__main__']
Expand Down
72 changes: 72 additions & 0 deletions tests/lib/test_representer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,78 @@ def test_representer_types(code_filename, verbose=False):

test_representer_types.unittest = ['.code']

def test_representer_inheritance(verbose=False):
class Widget:
pass

class Gizmo:
pass

def represent_widget(representer, obj):
return representer.represent_scalar("!widget", "widget")

def represent_gizmo1(representer, obj):
return representer.represent_scalar("!gizmo", "gizmo1")

def represent_gizmo2(representer, obj):
return representer.represent_scalar("!gizmo", "gizmo2")

def represent_gizmo3(representer, obj):
return representer.represent_scalar("!gizmo", "gizmo3")

class DumperParent(yaml.Dumper):
pass

class DumperChild(DumperParent):
pass

# Add a representer to the child. Note that no representer has been added
# to the parent yet.
DumperChild.add_representer(Widget, represent_widget)
if verbose:
print("After adding a representer to the child Dumper:")
print(f" {DumperParent.yaml_representers=}")
print(f" {DumperChild.yaml_representers=}")
assert DumperChild.yaml_representers[Widget] is represent_widget
assert Widget not in DumperParent.yaml_representers

# A representer is now added to the parent. The child should be able to see
# this new representer even though it was added after a representer was
# added to the child above.
DumperParent.add_representer(Gizmo, represent_gizmo1)
if verbose:
print("After adding a representer to the parent Dumper:")
print(f" {DumperParent.yaml_representers=}")
print(f" {DumperChild.yaml_representers=}")
assert DumperChild.yaml_representers[Widget] is represent_widget
assert Widget not in DumperParent.yaml_representers
assert DumperParent.yaml_representers[Gizmo] is represent_gizmo1
assert DumperChild.yaml_representers[Gizmo] is represent_gizmo1

# Override a parent representer in the child.
DumperChild.add_representer(Gizmo, represent_gizmo2)
if verbose:
print("After overriding a parent's representer:")
print(f" {DumperParent.yaml_representers=}")
print(f" {DumperChild.yaml_representers=}")
assert DumperChild.yaml_representers[Widget] is represent_widget
assert Widget not in DumperParent.yaml_representers
assert DumperParent.yaml_representers[Gizmo] is represent_gizmo1
assert DumperChild.yaml_representers[Gizmo] is represent_gizmo2

# Changing the parent's overridden representer should not affect the child.
DumperParent.add_representer(Gizmo, represent_gizmo3)
if verbose:
print("After changing the parent's overridden representer:")
print(f" {DumperParent.yaml_representers=}")
print(f" {DumperChild.yaml_representers=}")
assert DumperChild.yaml_representers[Widget] is represent_widget
assert Widget not in DumperParent.yaml_representers
assert DumperParent.yaml_representers[Gizmo] is represent_gizmo3
assert DumperChild.yaml_representers[Gizmo] is represent_gizmo2

test_representer_inheritance.unittest = True

if __name__ == '__main__':
import test_appliance
test_appliance.run(globals())
Expand Down

0 comments on commit 01085eb

Please sign in to comment.