Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute dimension sums in Elemwise.grad at run-time #1260

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions aesara/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from aesara.graph.op import _NoPythonOp
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from aesara.graph.type import HasDataType, HasShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast, specify_shape


if TYPE_CHECKING:
Expand Down Expand Up @@ -254,33 +254,42 @@ def grad(self, ins, grads):
# Since input true/false entries must have the same dtypes, we need to
# cast the zeros to the corresponding `grads` dtypes and not the input
# dtypes.
inputs_true_grad = (
[condition]
+ grads
+ [
at.basic.zeros_like(t, dtype=grads[i].dtype)
for i, t in enumerate(inputs_true_branch)
]
# The `grads` can also have different shapes than the `inputs`, so we
# effectively assert that the shapes are preserved in each branch.
# TODO FIXME: This doesn't seem like a sufficient solution to the
# problem.
inputs_true_grads = if_true_op(
*(
[condition]
+ [specify_shape(g, i.shape) for g, i in zip(grads, inputs_true_branch)]
+ [
at.basic.zeros_like(t, dtype=grads[i].dtype)
for i, t in enumerate(inputs_true_branch)
]
),
return_list=True,
)
inputs_false_grad = (
[condition]
+ [
at.basic.zeros_like(f, dtype=grads[i].dtype)
for i, f in enumerate(inputs_false_branch)
]
+ grads
inputs_false_grads = if_false_op(
*(
[condition]
+ [
at.basic.zeros_like(f, dtype=grads[i].dtype)
for i, f in enumerate(inputs_false_branch)
]
+ [
specify_shape(g, i.shape)
for g, i in zip(grads, inputs_false_branch)
]
),
return_list=True,
)

# `condition` does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = condition.zeros_like().astype(config.floatX)

return (
[condition_grad]
+ if_true_op(*inputs_true_grad, return_list=True)
+ if_false_op(*inputs_false_grad, return_list=True)
)
return [condition_grad] + inputs_true_grads + inputs_false_grads

def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
cond = node.inputs[0]
Expand Down
6 changes: 3 additions & 3 deletions aesara/sparse/sandbox/sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True):

# taking into account multiple
# input features
col = (
col = int(
iy * inshp[2] + ix + fmapi * np.prod(inshp[1:])
)

Expand All @@ -196,13 +196,13 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True):

# convert to row index of sparse matrix
if ws:
row = (
row = int(
(y * outshp[1] + x) * inshp[0] * ksize
+ l
+ fmapi * ksize
)
else:
row = y * outshp[1] + x
row = int(y * outshp[1] + x)

# Store something at that location
# in sparse matrix. The written
Expand Down
23 changes: 7 additions & 16 deletions aesara/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,16 @@ def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
gz = as_tensor_variable(gz)
grad_order = ["x"] * len(x.type.broadcastable)
grad_order = ["x"] * x.type.ndim
for i, v in enumerate(self.new_order):
if v != "x":
grad_order[v] = i

# Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph.
if inp[0].dtype in discrete_dtypes:
return [inp[0].zeros_like(dtype=config.floatX)]
if x.dtype in discrete_dtypes:
return [x.zeros_like(dtype=config.floatX)]
else:
return [
DimShuffle(gz.type.broadcastable, grad_order)(
Expand Down Expand Up @@ -542,7 +543,6 @@ def connection_pattern(self, node):
return [[True for output in node.outputs] for ipt in node.inputs]

def L_op(self, inputs, outs, ograds):
from aesara.tensor.math import sum as at_sum

# Compute grad with respect to broadcasted input
rval = self._bgrad(inputs, outs, ograds)
Expand Down Expand Up @@ -573,18 +573,9 @@ def L_op(self, inputs, outs, ograds):
if isinstance(rval[i].type, (NullType, DisconnectedType)):
continue

# List of all the dimensions that are broadcastable for input[i] so
# we can sum over them
# TODO: only count dimensions that were effectively broadcasted
to_sum = [
j
for j, bcast in enumerate(ipt.type.broadcastable)
if bcast and not outs[0].broadcastable[j]
]

if to_sum:
sr = at_sum(rval[i], axis=to_sum, keepdims=True)
rval[i] = sr
rval[i] = aesara.tensor.extra_ops.sum_broadcasted_dims(
rval[i], ipt, outs[0].type.shape
)

return rval

Expand Down
60 changes: 46 additions & 14 deletions aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Collection
from functools import reduce
from typing import Iterable, Set, Tuple, Union
from typing import Iterable, Optional, Sequence, Set, Tuple, Union

import numpy as np
import numpy.core.numeric
Expand Down Expand Up @@ -1669,19 +1669,11 @@ def grad(self, inputs, outputs_gradients):

d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)

# Determine the dimensions that were broadcast
_, static_shape = at.infer_static_shape(shape)

# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums = [
i
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
if a_s == 1 and s_s != 1
]

if bcast_sums:
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True)
# Determine the dimensions that were broadcast and sum them
static_out_shape = tuple(
s.data if isinstance(s, Constant) else None for s in shape[-a.ndim :]
)
d_wrt_a = sum_broadcasted_dims(d_wrt_a, a, static_out_shape)

return [d_wrt_a] + [
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1)
Expand Down Expand Up @@ -1808,6 +1800,46 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)


def sum_broadcasted_dims(
value: TensorVariable,
inp: TensorVariable,
out_shape: Sequence[Optional[int]],
) -> TensorVariable:
"""Sum dimensions in `value` that are broadcasted between `inp`'s shape and `out_shape`.

For ambiguous cases, this builds a graph that determine whether or not
dimensions are to be summed at run-time.

"""
dims_to_sum = ()
ambiguous_dim_conds = ()

in_shape = inp.type.shape

for i, (s1, s2) in enumerate(zip(in_shape, out_shape)):
if s1 == 1 and s2 != 1:
dims_to_sum += (i,)
elif s1 is None and s2 != 1:
ambiguous_dim_conds += (
(i, aes.eq(at.scalar_from_tensor(inp.shape[i]), 1)),
)

if dims_to_sum:
value = at_sum(value, axis=dims_to_sum, keepdims=True)

if ambiguous_dim_conds:
from aesara.ifelse import ifelse

for i, cond in ambiguous_dim_conds:
value = ifelse(
cond,
at_sum(value, axis=i, keepdims=True),
value,
)

return value


__all__ = [
"searchsorted",
"cumsum",
Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def makeKeepDims(x, y, axis):
newaxis.append(a)
i = 0
new_dims = []
for j, _ in enumerate(x.type.broadcastable):
for j in range(x.type.ndim):
if j in newaxis:
new_dims.append("x")
else:
Expand Down
36 changes: 35 additions & 1 deletion tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import aesara
import aesara.scalar as aes
import aesara.tensor as at
import tests.unittest_tools as utt
from aesara.compile.mode import Mode
from aesara.compile.mode import Mode, get_default_mode
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
Expand Down Expand Up @@ -889,6 +890,39 @@ def test_invalid_static_shape(self):
):
x + y

def test_grad_sum_bcast_input_dims(self):
"""Make sure broadcasted dimensions in the gradients are summed when static shape information isn't available."""
Y = matrix("Y")
X = matrix("X")
X_grad = aesara.grad((X + Y).sum(), wrt=X)

mode = get_default_mode().including("fast_run")

X_grad_fn = aesara.function([X, Y], X_grad, mode=mode)
res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5)))
assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]]))

# When the shapes are known at compile-time, the compiled graph should
# simplify
Y = tensor(np.float64, shape=(5, None), name="Y")
X = tensor(np.float64, shape=(1, 5), name="X")
X_grad = aesara.grad((X + Y).sum(), wrt=X)

X_grad_fn = aesara.function([X, Y], X_grad, mode=mode)
res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5)))
assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]]))

assert X_grad_fn.maker.fgraph.apply_nodes

def test_grad_of_grad(self):
"""This tests a special case in which the static shapes of a `DimShuffle` and its gradient don't match."""
a = at.vector("a")

out = aesara.grad((a * a).sum(), a).sum()
out = aesara.grad(out, a)

assert out.type.shape == (None,)


def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
Expand Down
1 change: 1 addition & 0 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,7 @@ def test_memory_leak(self):
[
[lambda x: broadcast_to(x, (1,)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)],
[lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)],
],
Expand Down
53 changes: 53 additions & 0 deletions tests/test_ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,56 @@ def test_nested():
linker = aesara.link.vm.VMLinker(lazy=True)
f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run"))
assert f(1, 0, np.array(10, dtype=x1.dtype), 0) == 20.5


def test_DimShuffle_drop():
c = scalar("c")
x = scalar("x")
y = vector("y")

cost = ifelse(c, x.dimshuffle("x"), y).sum()

# Sum{acc_dtype=float64} [id A] <TensorType(float64, ())>
# |if{} [id B] <TensorType(float64, (None,))>
# |c [id C] <TensorType(float64, ())>
# |InplaceDimShuffle{x} [id D] <TensorType(float64, (1,))>
# | |x [id E] <TensorType(float64, ())>
# |y [id F] <TensorType(float64, (None,))>

out = aesara.grad(cost, y)
assert out.type.shape == (None,)

out = aesara.grad(cost, x)

#
# `DimShuffle.L_op` `inputs`
#
# x [id A] <TensorType(float64, ())>

#
# `DimShuffle.L_op` `outputs`
#
# InplaceDimShuffle{x} [id B] <TensorType(float64, (1,))>
# |x [id A] <TensorType(float64, ())>

#
# `DimShuffle.L_op` `output_grads`
#
# if{} [id C] <TensorType(float64, (None,))>
# |c [id D] <TensorType(float64, ())>
# |Elemwise{second} [id E] <TensorType(float64, (None,))>
# | |if{} [id F] <TensorType(float64, (None,))>
# | | |c [id D] <TensorType(float64, ())>
# | | |InplaceDimShuffle{x} [id B] <TensorType(float64, (1,))>
# | | |y [id G] <TensorType(float64, (None,))>
# | |InplaceDimShuffle{x} [id H] <TensorType(float64, (1,))>
# | |Elemwise{second,no_inplace} [id I] <TensorType(float64, ())>
# | |Sum{acc_dtype=float64} [id J] <TensorType(float64, ())>
# | | |if{} [id F] <TensorType(float64, (None,))>
# | |TensorConstant{1.0} [id K] <TensorType(float64, ())>
# |Elemwise{second,no_inplace} [id L] <TensorType(float64, (1,))>
# |InplaceDimShuffle{x} [id B] <TensorType(float64, (1,))>
# |InplaceDimShuffle{x} [id M] <TensorType(float64, (1,))>
# |TensorConstant{0.0} [id N] <TensorType(float64, ())>

assert out.type.shape == ()