Skip to content

Commit

Permalink
Fix #268, decrease the pmap size by the number of items removed rathe…
Browse files Browse the repository at this point in the history
…r than one to avoid inconsistencies in case of bad equality implementations in contained instances
  • Loading branch information
tobgu committed Oct 22, 2023
1 parent 4ee2c66 commit 0be88dd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
5 changes: 3 additions & 2 deletions pyrsistent/_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,10 @@ def remove(self, key):

if bucket:
new_bucket = [(k, v) for (k, v) in bucket if k != key]
if len(bucket) > len(new_bucket):
size_diff = len(bucket) - len(new_bucket)
if size_diff > 0:
self._buckets_evolver[index] = new_bucket if new_bucket else None
self._size -= 1
self._size -= size_diff
return self

raise KeyError('{0}'.format(key))
Expand Down
43 changes: 36 additions & 7 deletions tests/map_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import namedtuple
from collections.abc import Mapping, Hashable
from operator import add
import pytest
from pyrsistent import pmap, m, PVector
from pyrsistent import pmap, m
import pickle


Expand Down Expand Up @@ -64,15 +65,15 @@ def test_various_iterations():
assert {('a', 1), ('b', 2)} == set(m(a=1, b=2).iteritems())
assert {('a', 1), ('b', 2)} == set(m(a=1, b=2).items())

pm = pmap({k:k for k in range(100)})
pm = pmap({k: k for k in range(100)})
assert len(pm) == len(pm.keys())
assert len(pm) == len(pm.values())
assert len(pm) == len(pm.items())
ks = pm.keys()
assert all(k in pm for k in ks)
assert all(k in ks for k in ks)
us = pm.items()
assert all(pm[k] == v for (k,v) in us)
assert all(pm[k] == v for (k, v) in us)
vs = pm.values()
assert all(v in vs for v in vs)

Expand Down Expand Up @@ -147,12 +148,11 @@ def test_same_hash_when_content_the_same_but_underlying_vector_size_differs():


class HashabilityControlled(object):

hashable = True

def __hash__(self):
if self.hashable:
return 4 # Proven random
return 4 # Proven random
raise ValueError("I am not currently hashable.")


Expand Down Expand Up @@ -286,7 +286,6 @@ def __eq__(self, other):


def test_hash_collision_is_correctly_resolved():

dummy1 = HashDummy()
dummy2 = HashDummy()
dummy3 = HashDummy()
Expand Down Expand Up @@ -422,7 +421,7 @@ def test_evolver_simple_update():


def test_evolver_update_with_relocation():
x = pmap({'a':1000}, pre_size=1)
x = pmap({'a': 1000}, pre_size=1)
e = x.evolver()
e['b'] = 3000
e['c'] = 4000
Expand Down Expand Up @@ -520,3 +519,33 @@ def test_iterable():
"""

assert pmap(iter([("a", "b")])) == pmap([("a", "b")])


class BrokenPerson(namedtuple('Person', 'name')):
def __eq__(self, other):
return self.__class__ == other.__class__ and self.name == other.name

def __hash__(self):
return hash(self.name)


class BrokenItem(namedtuple('Item', 'name')):
def __eq__(self, other):
return self.__class__ == other.__class__ and self.name == other.name

def __hash__(self):
return hash(self.name)


def test_pmap_removal_with_broken_classes_deriving_from_namedtuple():
"""
The two classes above implement __eq__ but also would need to implement __ne__ to compare
consistently. See issue https://github.com/tobgu/pyrsistent/issues/268 for details.
"""
s = pmap({BrokenPerson('X'): 2, BrokenItem('X'): 3})
s = s.remove(BrokenPerson('X'))

# Both items are removed due to how they are compared for inequality
assert BrokenPerson('X') not in s
assert BrokenItem('X') not in s
assert len(s) == 0

0 comments on commit 0be88dd

Please sign in to comment.