Skip to content

Commit

Permalink
rename "inner" expressions to "children"
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Dec 25, 2024
1 parent f3c80ba commit 66b097d
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 146 deletions.
3 changes: 0 additions & 3 deletions src/icepool/evaluator/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(self,
self._truth_value = truth_value

def next_state(self, state, outcome, *counts):
"""Adjusts the counts, then forwards to inner."""
if state is None:
expression_states = (None, ) * len(self._expressions)
evaluator_state = None
Expand Down Expand Up @@ -67,13 +66,11 @@ def final_outcome(
return self._evaluator.final_outcome(evaluator_state)

def order(self) -> Order:
"""Forwards to inner."""
expression_order = Order.merge(*(expression.order()
for expression in self._expressions))
return Order.merge(expression_order, self._evaluator.order())

def extra_outcomes(self, *generators) -> Collection[T_contra]:
"""Forwards to inner."""
return self._evaluator.extra_outcomes(*generators)

@cached_property
Expand Down
27 changes: 14 additions & 13 deletions src/icepool/evaluator/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
class JointEvaluator(MultisetEvaluator[T_contra, tuple]):
"""A `MultisetEvaluator` that jointly evaluates sub-evaluators on the same set of input generators."""

def __init__(self, *inners: MultisetEvaluator) -> None:
self._inners = inners
def __init__(self, *children: MultisetEvaluator) -> None:
self._children = children

def next_state(self, state, outcome, *counts):
"""Runs `next_state` for all sub-evaluator.
Expand All @@ -35,7 +35,7 @@ def next_state(self, state, outcome, *counts):
*evaluator_prefix_counts,
*counts,
) for evaluator, evaluator_prefix_counts in zip(
self._inners, self._split_prefix_counts(*prefix_counts)))
self._children, self._split_prefix_counts(*prefix_counts)))
else:
result = tuple(
evaluator.next_state(
Expand All @@ -44,7 +44,7 @@ def next_state(self, state, outcome, *counts):
*evaluator_prefix_counts,
*counts,
) for evaluator, substate, evaluator_prefix_counts in zip(
self._inners, state,
self._children, state,
self._split_prefix_counts(*prefix_counts)))
if icepool.Reroll in result:
return icepool.Reroll
Expand All @@ -59,11 +59,12 @@ def final_outcome(self, final_state) -> 'tuple | icepool.RerollType':
If any sub-evaluator returns `Reroll`, the result as a whole is `Reroll`.
"""
if final_state is None:
result = tuple(inner.final_outcome(None) for inner in self._inners)
result = tuple(
child.final_outcome(None) for child in self._children)
else:
result = tuple(
inner.final_outcome(final_substate)
for inner, final_substate in zip(self._inners, final_state))
child.final_outcome(final_substate)
for child, final_substate in zip(self._children, final_state))
if icepool.Reroll in result:
return icepool.Reroll
else:
Expand All @@ -76,23 +77,23 @@ def order(self) -> Order:
ValueError: If sub-evaluators have conflicting orders, i.e. some are
ascending and others are descending.
"""
return Order.merge(*(inner.order() for inner in self._inners))
return Order.merge(*(child.order() for child in self._children))

def extra_outcomes(self, outcomes) -> Collection[T_contra]:
return icepool.sorted_union(*(evaluator.extra_outcomes(outcomes)
for evaluator in self._inners))
for evaluator in self._children))

@cached_property
def _prefix_generators(self) -> 'tuple[icepool.MultisetGenerator, ...]':
return tuple(
itertools.chain.from_iterable(expression.prefix_generators()
for expression in self._inners))
for expression in self._children))

def prefix_generators(self) -> 'tuple[icepool.MultisetGenerator, ...]':
return self._prefix_generators

def validate_arity(self, arity: int) -> None:
for evaluator in self._inners:
for evaluator in self._children:
evaluator.validate_arity(arity)

@cached_property
Expand All @@ -105,7 +106,7 @@ def _prefix_slices(self) -> tuple[slice, ...]:
"""Precomputed slices for determining which prefix counts go with which sub-evaluator."""
result = []
index = 0
for expression in self._inners:
for expression in self._children:
counts_length = sum(
generator.output_arity()
for generator in expression.prefix_generators())
Expand All @@ -120,4 +121,4 @@ def _split_prefix_counts(self, *extra_counts:

def __str__(self) -> str:
return 'JointEvaluator(\n' + ''.join(
f' {evaluator},\n' for evaluator in self._inners) + ')'
f' {evaluator},\n' for evaluator in self._children) + ')'
65 changes: 32 additions & 33 deletions src/icepool/expression/adjust_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,72 +18,71 @@ class MapCountsExpression(MultisetExpression[T_contra]):

_function: Callable[..., int]

def __init__(self, *inners: MultisetExpression[T_contra],
def __init__(self, *children: MultisetExpression[T_contra],
function: Callable[..., int]) -> None:
"""Constructor.
Args:
inner: The inner expression.
children: The children expression(s).
function: A function that takes `outcome, *counts` and produces a
combined count.
"""
for inner in inners:
self._validate_output_arity(inner)
self._inners = inners
for child in children:
self._validate_output_arity(child)
self._children = children
self._function = function

def _make_unbound(self, *unbound_inners) -> 'icepool.MultisetExpression':
return MapCountsExpression(*unbound_inners, function=self._function)
def _make_unbound(self, *unbound_children) -> 'icepool.MultisetExpression':
return MapCountsExpression(*unbound_children, function=self._function)

def _next_state(self, state, outcome: T_contra, *counts:
int) -> tuple[Hashable, int]:

inner_states = state or (None, ) * len(self._inners)
inner_states, inner_counts = zip(
*(inner._next_state(inner_state, outcome, *counts)
for inner, inner_state in zip(self._inners, inner_states)))
child_states = state or (None, ) * len(self._children)
child_states, child_counts = zip(
*(child._next_state(child_state, outcome, *counts)
for child, child_state in zip(self._children, child_states)))

count = self._function(outcome, *inner_counts)
count = self._function(outcome, *child_counts)
return state, count

def order(self) -> Order:
return Order.merge(*(inner.order() for inner in self._inners))
return Order.merge(*(child.order() for child in self._children))

@cached_property
def _cached_arity(self) -> int:
return max(inner._free_arity() for inner in self._inners)
return max(child._free_arity() for child in self._children)

def _free_arity(self) -> int:
return self._cached_arity


class AdjustCountsExpression(MultisetExpression[T_contra]):

def __init__(self, inner: MultisetExpression[T_contra], /, *,
def __init__(self, child: MultisetExpression[T_contra], /, *,
constant: int) -> None:
self._validate_output_arity(inner)
self._inner = inner
self._inners = (inner, )
self._validate_output_arity(child)
self._children = (child, )
self._constant = constant

def _make_unbound(self, *unbound_inners) -> 'icepool.MultisetExpression':
return type(self)(*unbound_inners, constant=self._constant)
def _make_unbound(self, *unbound_children) -> 'icepool.MultisetExpression':
return type(self)(*unbound_children, constant=self._constant)

@abstractmethod
def adjust_count(self, count: int, constant: int) -> int:
"""Adjusts the count."""

def _next_state(self, state, outcome: T_contra, *counts:
int) -> tuple[Hashable, int]:
state, count = self._inner._next_state(state, outcome, *counts)
state, count = self._children[0]._next_state(state, outcome, *counts)
count = self.adjust_count(count, self._constant)
return state, count

def order(self) -> Order:
return self._inner.order()
return self._children[0].order()

def _free_arity(self) -> int:
return self._inner._free_arity()
return self._children[0]._free_arity()


class MultiplyCountsExpression(AdjustCountsExpression):
Expand All @@ -93,7 +92,7 @@ def adjust_count(self, count: int, constant: int) -> int:
return count * constant

def __str__(self) -> str:
return f'({self._inner} * {self._constant})'
return f'({self._children[0]} * {self._constant})'


class FloorDivCountsExpression(AdjustCountsExpression):
Expand All @@ -103,7 +102,7 @@ def adjust_count(self, count: int, constant: int) -> int:
return count // constant

def __str__(self) -> str:
return f'({self._inner} // {self._constant})'
return f'({self._children[0]} // {self._constant})'


class ModuloCountsExpression(AdjustCountsExpression):
Expand All @@ -113,15 +112,15 @@ def adjust_count(self, count: int, constant: int) -> int:
return count % constant

def __str__(self) -> str:
return f'({self._inner} % {self._constant})'
return f'({self._children[0]} % {self._constant})'


class KeepCountsExpression(AdjustCountsExpression):

def __init__(self, inner: MultisetExpression[T_contra], /, *,
def __init__(self, child: MultisetExpression[T_contra], /, *,
comparison: Literal['==', '!=', '<=', '<', '>=',
'>'], constant: int):
super().__init__(inner, constant=constant)
super().__init__(child, constant=constant)
operators = {
'==': operator.eq,
'!=': operator.ne,
Expand All @@ -135,8 +134,8 @@ def __init__(self, inner: MultisetExpression[T_contra], /, *,
self._comparison = comparison
self._op = operators[comparison]

def _make_unbound(self, *unbound_inners) -> 'icepool.MultisetExpression':
return KeepCountsExpression(*unbound_inners,
def _make_unbound(self, *unbound_children) -> 'icepool.MultisetExpression':
return KeepCountsExpression(*unbound_children,
comparison=self._comparison,
constant=self._constant)

Expand All @@ -147,7 +146,7 @@ def adjust_count(self, count: int, constant: int) -> int:
return 0

def __str__(self) -> str:
return f"{self._inner}.keep_counts('{self._comparison}', {self._constant})"
return f"{self._children[0]}.keep_counts('{self._comparison}', {self._constant})"


class UniqueExpression(AdjustCountsExpression):
Expand All @@ -158,6 +157,6 @@ def adjust_count(self, count: int, constant: int) -> int:

def __str__(self) -> str:
if self._constant == 1:
return f'{self._inner}.unique()'
return f'{self._children[0]}.unique()'
else:
return f'{self._inner}.unique({self._constant})'
return f'{self._children[0]}.unique({self._constant})'
34 changes: 17 additions & 17 deletions src/icepool/expression/binary_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@

class BinaryOperatorExpression(MultisetExpression[T_contra]):

def __init__(self, *inners: MultisetExpression[T_contra]) -> None:
def __init__(self, *children: MultisetExpression[T_contra]) -> None:
"""Constructor.
Args:
*inners: Any number of expressions to feed into the operator.
*children: Any number of expressions to feed into the operator.
If zero expressions are provided, the result will have all zero
counts.
If more than two expressions are provided, the counts will be
`reduce`d.
"""
for inner in inners:
self._validate_output_arity(inner)
self._inners = inners
for child in children:
self._validate_output_arity(child)
self._children = children

def _make_unbound(self, *unbound_inners) -> 'icepool.MultisetExpression':
return type(self)(*unbound_inners)
def _make_unbound(self, *unbound_children) -> 'icepool.MultisetExpression':
return type(self)(*unbound_children)

@staticmethod
@abstractmethod
Expand All @@ -43,30 +43,30 @@ def symbol() -> str:

def _next_state(self, state, outcome: T_contra, *counts:
int) -> tuple[Hashable, int]:
if len(self._inners) == 0:
if len(self._children) == 0:
return (), 0
inner_states = state or (None, ) * len(self._inners)
child_states = state or (None, ) * len(self._children)

inner_states, inner_counts = zip(
*(inner._next_state(inner_state, outcome, *counts)
for inner, inner_state in zip(self._inners, inner_states)))
child_states, child_counts = zip(
*(child._next_state(child_state, outcome, *counts)
for child, child_state in zip(self._children, child_states)))

count = reduce(self.merge_counts, inner_counts)
return inner_states, count
count = reduce(self.merge_counts, child_counts)
return child_states, count

def order(self) -> Order:
return Order.merge(*(inner.order() for inner in self._inners))
return Order.merge(*(child.order() for child in self._children))

@cached_property
def _cached_arity(self) -> int:
return max(inner._free_arity() for inner in self._inners)
return max(child._free_arity() for child in self._children)

def _free_arity(self) -> int:
return self._cached_arity

def __str__(self) -> str:
return '(' + (' ' + self.symbol() + ' ').join(
str(inner) for inner in self._inners) + ')'
str(child) for child in self._children) + ')'


class IntersectionExpression(BinaryOperatorExpression):
Expand Down
Loading

0 comments on commit 66b097d

Please sign in to comment.