Skip to content

Commit

Permalink
Merge pull request #22 from datacamp/fix/get_cls
Browse files Browse the repository at this point in the history
[LE-1175] Fix get_cls in BaseNodeRegistry
  • Loading branch information
TimSangster authored Sep 27, 2019
2 parents 41110e5 + f304d19 commit 81a5b52
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

All notable changes to this project will be documented in this file.

## v0.8.1

- Fix get_cls in BaseNodeRegistry, now updates fields of classes already in the registry

## v0.8.0

Expand Down
2 changes: 1 addition & 1 deletion antlr_ast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.8.0"
__version__ = "0.8.1"

from . import ast
6 changes: 6 additions & 0 deletions antlr_ast/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ def get_cls(self, cls_name: str, field_names: tuple) -> Type[BaseNode]:
self.dynamic_node_classes[cls_name] = BaseNode.create_cls(
cls_name, field_names
)
else:
existing_cls = self.dynamic_node_classes[cls_name]
all_fields = tuple(set(existing_cls._fields) | set(field_names))
if len(all_fields) > len(existing_cls._fields):
existing_cls._fields = all_fields

return self.dynamic_node_classes[cls_name]

def isinstance(self, instance: BaseNode, class_name: str) -> bool:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_base_node_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest

from antlr_ast.ast import BaseNodeRegistry


def test_base_node_registry_get_cls():
# Given
base_node_registry = BaseNodeRegistry()

# When
cls1 = base_node_registry.get_cls("cls1", ("field1",))
cls1_2 = base_node_registry.get_cls("cls1", ("field1", "field2"))

# Then
assert set(cls1._fields) == {"field1", "field2"}
assert set(cls1_2._fields) == {"field1", "field2"}


def test_base_node_registry_isinstance():
# Given
base_node_registry = BaseNodeRegistry()

# When
Cls1 = base_node_registry.get_cls("cls1", ("field1",))
Cls1_2 = base_node_registry.get_cls("cls1", ("field1", "field2"))
Cls2 = base_node_registry.get_cls("cls2", ("field_a", "field_b"))

cls1_obj = Cls1([], [], [])
cls1_2_obj = Cls1_2([], [], [])
cls2_obj = Cls2([], [], [])

# Then
assert isinstance(cls1_obj, type(cls1_2_obj))
assert base_node_registry.isinstance(cls1_obj, "cls1")
assert base_node_registry.isinstance(cls1_2_obj, "cls1")
assert base_node_registry.isinstance(cls2_obj, "cls2")
assert not base_node_registry.isinstance(cls1_obj, "cls2")
assert not base_node_registry.isinstance(cls1_2_obj, "cls2")
assert not base_node_registry.isinstance(cls2_obj, "cls1")

assert not base_node_registry.isinstance(cls2_obj, "cls3")

with pytest.raises(
TypeError, match="This function can only be used for BaseNode objects"
):
base_node_registry.isinstance([], "cls1")

0 comments on commit 81a5b52

Please sign in to comment.