Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong gradients when inputs are dynamically broadcasted #1089

Open
ricardoV94 opened this issue Aug 1, 2022 · 80 comments
Open

Wrong gradients when inputs are dynamically broadcasted #1089

ricardoV94 opened this issue Aug 1, 2022 · 80 comments
Labels
bug Something isn't working gradient implementations help wanted Extra attention is needed important

Comments

@ricardoV94
Copy link
Contributor

ricardoV94 commented Aug 1, 2022

This bug is an unexpected consequence of #928 and rewrites that make certain assumptions: #1089 (comment)

import aesara
import aesara.tensor as at
import numpy as np

x_row = at.row("x_row")
x_matrix = at.matrix("x_matrix")
y = at.matrix("y")

x_row_grad = at.grad(at.sum(x_row + y), wrt=x_row)
x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)

f_row = aesara.function([x_row, y], x_row_grad)
print(f_row(np.ones((1, 5)), np.ones((5, 5))))
# [[5. 5. 5. 5. 5.]]

f_matrix = aesara.function([x_matrix, y], x_matrix_grad)
print(f_matrix(np.ones((1, 5)), np.ones((5, 5))))
# [[1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]]

The faulty logic is found here:

# sum out the broadcasted dimensions
for i, ipt in enumerate(inputs):
if isinstance(rval[i].type, (NullType, DisconnectedType)):
continue
# List of all the dimensions that are broadcastable for input[i] so
# we can sum over them
# TODO: only count dimensions that were effectively broadcasted
to_sum = [
j
for j, bcast in enumerate(ipt.type.broadcastable)
if bcast and not outs[0].broadcastable[j]
]
if to_sum:
sr = at_sum(rval[i], axis=to_sum, keepdims=True)
rval[i] = sr
return rval

This is also likely a problem in the grad of BroadcastTo which calls infer_broadcastable and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.

_, shape_bcast = at.infer_broadcastable(shape)

def infer_broadcastable(shape):

And also GEMM since #986

I am not sure if there's a good solution to this problem, as we would need an expression with different output shapes depending on whether the runtime inputs are broadcasted or not.

Solution might look something like: #1089 (comment)

@ricardoV94 ricardoV94 added bug Something isn't working gradient implementations labels Aug 1, 2022
@brandonwillard brandonwillard added help wanted Extra attention is needed important labels Aug 1, 2022
@aseyboldt

This comment was marked as off-topic.

@brandonwillard

This comment was marked as off-topic.

@aseyboldt

This comment was marked as off-topic.

@brandonwillard

This comment was marked as off-topic.

@aseyboldt

This comment was marked as off-topic.

@brandonwillard

This comment was marked as off-topic.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Aug 2, 2022

It seems like we might need a new Op that unbroadcasts (reduces) arrays to a given shape or leaves the input unchanged.

Check https://mostafa-samir.github.io/auto-diff-pt2/#unbroadcasting-adjoints

https://github.com/Mostafa-Samir/Hands-on-Intro-to-Auto-Diff/blob/29a9e5157421e2603846a15bceff21d3b2104f3d/autodiff/grads.py#L103

Something like:

x = at.matrix("x")

# at.reduce_to is probably a better name
# maybe sum is all we will ever need
y1 = at.unbroadcast_to(x, shape=(1, 5), reduce_op="sum”) 
y1.eval({x: np.ones((5, 5))})  # [[5, 5, 5, 5, 5]]

# This won't do anything, but shape may be only 
# known at runtime, as in the example in this issue!
y2 = at.unbroadcast_to(x, shape=(5, 5))
y2.eval({x: np.ones((5, 5))})  # np.ones((5, 5))

# If the shape is not compatible with something that could 
# have been broadcasted to the input shape, an error is raised
y3 = at.unbroadcast_to(x, shape=(2, 5))
y3.eval({x: np.ones((5, 5))})  # ValueError

This was also brought up by @Sayam753 and @purna135 in Slack in relation to their work on batched solve where dynamic unbroadcasting gradients also crops up. It was that discussion that led me to suspect of this bug!

Edit: This may be possible already without a specialized Op, if sum allows for symbolic axis? Does it?

In that case we could cook a helper pretty quickly, and perhaps add some rewrite in case the axis are constant folded during compilation/ and a sum with constant axis is more efficient.

Edit: Sum does not allow for variable axis

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Aug 2, 2022

Edit: mentioned other related issues in the top comment.

@ricardoV94 ricardoV94 changed the title Wrong Elemwise gradient when inputs are dynamically broadcasted Wrong gradients when inputs are dynamically broadcasted Aug 2, 2022
@brandonwillard
Copy link
Member

Let's try to move to the Blockwise form of this problem/situation. That way, we can attempt to make some progress on #695 (e.g. finish implementing Blockwise.Lop in #757) and address these issues (via inheritance from the Elemwise case).

@brandonwillard
Copy link
Member

brandonwillard commented Aug 2, 2022

This is also likely a problem in the grad of BroadcastTo which calls infer_broadcastable and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.

Yeah, it looks like we might need to add symbolic conditions for those dimensions and let them be simplified later via shape inference and/or constant folding. This is similar to what we do when constructing broadcasted shape graphs.

@aseyboldt

This comment was marked as off-topic.

@brandonwillard

This comment was marked as off-topic.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Aug 3, 2022

For example if we broadcast arrays of shape (), (2, 5, 0) and (1, 5, 1) we'd get

What's up with that shape of (2, 5, 0)?

I think that combo is invalid as per numpy broadcasting rules due to the zero?

Maybe it would help if we added an op that precisely computes what actually happens when we broadcast several arrays?

That sounds valid, but it seems a more convoluted answer if you ONLY want to fix this problem.

In these cases we have the input and broadcasted gradient output, so the only thing we need is to reduce that gradient along the dimensions that where of size 1 in the input.

Actually, in the Elemwise case we don't even have to worry about new dims, because make_node adds Dimshuffles to the inputs to align the number of dims (but we will perhaps remove that one day)

So what we need is just:

# perform method of new Op boils down to this
def unbroadcast_to(x, shape):
  axis_to_sum = [
    i
    for i, (s, xs) in enumerate(zip(shape, x.shape))
    if s==1 and xs !=1
  ]

  if not axis_to_sum:
    return x
  return np.sum(x, axis=axis_to_sum, keepdims=True)

In the grad where this issue would crop up we would do something like

grad = unbroadcast_to(bcast_grad, shape=input.shape)

And then we could have some rewrites to try to get rid of this Op during compilation. For instance:

  1. if the target shape and the input shapes are known to be equivalent, we can remove the Op

  2. if the target shape has some (constant-folded) 1s we can replace the original input by one where we summed the known 1s dimensions already.

  3. Hopefully there's already a rewrite to get rid of useless sum along dimensions of size 1. If not, we can add one.

  4. if the target shape has no (constant-folded) 1s we can remove the Op (or replace by some Asserts related to the shape if we want to)

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Aug 3, 2022

Yeah, it looks like we might need to add symbolic conditions for those dimensions and let them be simplified later via shape inference and/or constant folding. This is similar to what we do when constructing broadcasted shape graphs.

The problem is that without an Op like the one I sketched, the only safe thing to do when you can't be sure ahead of time if something will have had a shape of 1 (or certainly not 1) is to raise in the grad method.

If IfElse allowed for different shapes in the two branches we could also write a symbolic graph that applies the needed logic, but from one of the open issues it seems that both branches must have the same shape.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Aug 3, 2022

Allowing sum to have symbolic axis (as long as keepdims is used, this should be fine for Aesara) would also allow for a simple solution without new Ops. But maybe that would raise a whole new set of problems

@brandonwillard
Copy link
Member

The problem is that without an Op like the one I sketched, the only safe thing to do when you can't be sure ahead of time if something will have had a shape of 1 (or certainly not 1) is to raise in the grad method.

Simply put, if we don't have the information at compile time, then it needs to be handled at run-time.
The latter is the only scenario in which an Op would be helpful; however, I have yet to see why a new Op is necessary for anything in this scenario. The reason(s) why a new Op is completely necessary also need to be very clear in order to merge anything that takes such an approach.

@brandonwillard
Copy link
Member

Additionally, the description of this issue needs to clarify which result is correct and why.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Aug 3, 2022

Additionally, the description of this issue needs to clarify which result is correct and why.

The gradients should have the same shape of the inputs, so the case where row is used is correct.

This issue will arise for any Op that may or not broadcast its inputs at runtime. If broadcast occurs you need to sum the gradient across the broadcasted dimensions, otherwise you should not. However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Aug 3, 2022

Note that explicitly broadcasting all the inputs (like the explicit Dimshuffles introduced by Elemwise) wouldn't fix this either. The gradient of BroadcastTo shares the same limitations of Elemwise.

@brandonwillard
Copy link
Member

However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.

It does.

@ricardoV94
Copy link
Contributor Author

However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.

It does.

Via what?

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Aug 3, 2022

To be clear, we need something that can do the following.

def foo(x, y):
  ...


x = at.matrix("x")
y = np.random.normal(size=(5, 5))
f = aesara.function([x], foo(x, y)

assert f(np.ones((1, 5))) == np.sum(y, axis=0, keepdims=True)
assert f(np.ones((5, 1))) == np.sum(y, axis=1, keepdims=True)
assert f(np.ones((1, 1))) == np.sum(y, axis=(0, 1), keepdims=True)
assert f(np.ones((5, 5))) == y

@aseyboldt

This comment was marked as off-topic.

@ricardoV94

This comment was marked as off-topic.

@aseyboldt

This comment was marked as off-topic.

@aseyboldt

This comment was marked as off-topic.

@brandonwillard

This comment was marked as off-topic.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Aug 22, 2022

This helper will probably also have too be updated when/if the ops that use in in the grad allow for runtime broadcasting:

def _sum_grad_over_bcasted_dims(x, gx):
"""
Sum of gx over dimensions to reproduce x.broadcastable.
This is useful to sum gradients over certain dimensions when
x has been broadcasted, and we need to sum the gradient contributions
over all duplications.
"""
if gx.broadcastable != x.broadcastable:
x_dim_added = gx.ndim - x.ndim
x_broad = (True,) * x_dim_added + x.broadcastable
assert sum(gx.broadcastable) < sum(x_broad)
axis_to_sum = []
for i in range(gx.ndim):
if gx.broadcastable[i] is False and x_broad[i] is True:
axis_to_sum.append(i)
elif gx.broadcastable[i] is True and x_broad[i] is False:
# This means that Aesara was able to infer that
# gx.shape[i] is 1, so x.shape[i] is 1, but we
# didn't know it. It is fine.
pass
else:
assert gx.broadcastable[i] == x_broad[i]
gx = gx.sum(axis=axis_to_sum, keepdims=True)
if gx.ndim != x.ndim:
assert gx.ndim > x.ndim
for i in range(x_dim_added):
assert gx.broadcastable[i]
gx = gx.dimshuffle(*list(range(x_dim_added, gx.ndim)))
assert gx.broadcastable == x.broadcastable
return gx

@ricardoV94

This comment was marked as off-topic.

@brandonwillard

This comment was marked as off-topic.

@ricardoV94

This comment was marked as off-topic.

@brandonwillard
Copy link
Member

brandonwillard commented Oct 17, 2022

To clarify the relationship between this Aesara issue and Theano's old broadcasting assumptions/TensorType.broadcastable interpretation, consider this issue's example in Theano:

import theano
import theano.tensor as tt
import numpy as np


theano.__version__
# '1.0.5'

X_row = tt.row("X_row")
X_matrix = tt.matrix("X_matrix")


def X_grad_fn_constructor(X):
    Y = tt.matrix("Y")
    X_sum = tt.sum(X + Y)
    X_grad = tt.grad(X_sum, wrt=X)
    X_grad_fn = theano.function([X, Y], X_grad)
    return X_grad_fn


X_grad_fn_row = X_grad_fn_constructor(X_row)
X_grad_fn_matrix = X_grad_fn_constructor(X_matrix)

# This input is broadcastable in the first dimension, but the `Type`-level
# representation of that fact is lacking in the `X_matrix` case.  Let's see how
# Theano handles this broadcast information disparity.
# To be clear, *both* cases should theoretically return the same values for the
# same inputs.
X_val = np.ones((1, 5))
Y_val = np.ones((5, 5))

# The "row"-`Type` case (i.e. we tell Theano that the first dimension of `X` is
# broadcastable)
row_res = X_grad_fn_row(X_val, Y_val)
row_res
# array([[5., 5., 5., 5., 5.]])

# The "matrix"-`Type` case (i.e. we *don't* tell Theano that the first
# dimension of `X` is actually broadcastable)
matrix_res = X_grad_fn_matrix(X_val, Y_val)
matrix_res
# array([[1., 1., 1., 1., 1.]])

assert np.array_equal(matrix_res, row_res)
# AssertionError:

In other words, Theano's assumptions were not capable of solving the issue raised here. Instead, the same shape inference problem(s) simply took different forms. (N.B. This also means that no amount "reverting" will fix this issue.)

The thing we're trying to fix in this issue has always been an issue; only now (e.g. with TensorType.shape and the requisite logic clarifications) are we becoming equipped to sufficiently address it.

@aseyboldt

This comment was marked as off-topic.

@brandonwillard

This comment was marked as off-topic.

@aseyboldt

This comment was marked as off-topic.

@brandonwillard
Copy link
Member

brandonwillard commented Oct 18, 2022

To reiterate, this issue can be closed by

  1. raising an error at compile/construction-time stating that Elemwise.grad does not support cases with insufficient broadcast information, or
  2. making Elemwise.grad support cases with incomplete broadcast information.

This issue concerns itself with the relevant details of the above two approaches, and any new ones we may not have considered yet.

Conversations about TensorType.broadcastable, its relationship to the new TensorType.shape, etc., belong in #1170 or a new Discussion altogether.

@ricardoV94

This comment was marked as off-topic.

@rlouf
Copy link
Member

rlouf commented Oct 19, 2022

To be fair, from an outsider perspective, the discussion here is impossible to follow. The concerns that the issue raises are valid, but the discussion has since diverged to questions that are more fundamental and/or historical in nature. To move forward I suggest

  1. We open a discussion where we lay out exactly where we want to to go with shape inference and how we want to move forward generally;
  2. We can refer to the above discussion here to talk about the original issue specifically.

@rlouf
Copy link
Member

rlouf commented Oct 25, 2022

I just ran @ricardoV94's example both on the current HEAD and on b60cf7240 (the commit before the changes in #928 were merged):

def test_ambiguous_broadcast():
    import aesara
    import aesara.tensor as at
    import numpy as np

    x_row = at.row("x_row")
    x_matrix = at.matrix("x_matrix")
    y = at.matrix("y")

    x_row_grad = at.grad(at.sum(x_row + y), wrt=x_row)
    x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)

    f_row = aesara.function([x_row, y], x_row_grad)
    row_res = f_row(np.ones((1, 5)), np.ones((5, 5)))

    f_matrix = aesara.function([x_matrix, y], x_matrix_grad)
    assert np.array_equal(f_matrix(np.ones((1, 5)), np.ones((5, 5))), row_res)

test_ambiguous_broadcast()

and it fails in both situations. This means that this issue is not a consequence of #928. In other words, reverting this change is not going to fix this bug.

Try this for yourself (don't forget to clear the cache), so we can all agree on this point moving forward. Comment "I see it" (and only that for now) if you can reproduce it, comment something else only if you can't reproduce it.

@brandonwillard
Copy link
Member

I see it

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Oct 26, 2022

@rlouf This is clearly explained in one of the "off-topic" comments as arising from an "inconsistent" rewrite: #1089 (comment)

Running on b60cf7240

import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode
import numpy as np

x_matrix = at.matrix("x_matrix")
y = at.matrix("y")

x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)

f_matrix = aesara.function(
    [x_matrix, y], 
    x_matrix_grad, 
    mode=Mode().excluding("local_fill_to_alloc"),
)
matrix_res = f_matrix(np.ones((1, 5)), np.ones((5, 5)))  # ValueError

It's one consequence of "rewrites assume original graphs were valid" mindset: https://theano-pymc.readthedocs.io/en/latest/tutorial/shape_info.html#shape-inference-problem (they mention the same applies to Elemwise after the example).


If you want an example that clear fails before that commit and passes after, make it just a bit more complex:

import aesara
import aesara.tensor as at
import numpy as np

x_matrix = at.matrix("x_matrix")
y = at.matrix("y")

x_matrix_grad = at.grad((at.sum(at.exp(x_matrix + y))), wrt=x_matrix)

f_matrix = aesara.function(
    [x_matrix, y], 
    x_matrix_grad, 
)
# ValueError before `b60cf7240` and not after
matrix_res = f_matrix(np.ones((1, 5)), np.ones((5, 5)))

@rlouf
Copy link
Member

rlouf commented Oct 26, 2022

The reason I am writing this is newcomers to the issue will take your first comment as face value and assume #928 is the problem, and get confused if they try to run it with the commit before that. We're not writing these just for ourselves.

If the example you gave me indeed fails after #928 and not before could you please edit your original comment?

@ricardoV94
Copy link
Contributor Author

Thanks. I updated the original message to link to that comment.

@brandonwillard
Copy link
Member

brandonwillard commented Oct 26, 2022

If the example you gave me indeed fails after #928 and not before could you please edit your original comment?

Doing that changes the entire nature of this issue, which was never originally about that ValueError, so let's not add that confusion to the mix.

It would make more sense to describe the errant rewrites in the opening comment.

@rlouf
Copy link
Member

rlouf commented Oct 26, 2022

Thanks. I updated the original message to link to that comment.

Unless I'm mistaken you left the original code snippet that fails both after and before #928?

The reason I'm asking this is that issues / PRs in the repository are not only read by those who interacted with them. We should always keep this in mind and strive to make them understandable to anyone vaguely familiar with Aesara. I spent hours trying to understand what was going on in this issue before realizing yesterday that assertions in the original comment did not hold (namely that the particular example you provided fails because of #928). Most people will not do this legwork and take your original comment at face value. Ensues general confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working gradient implementations help wanted Extra attention is needed important
Projects
None yet
Development

No branches or pull requests

4 participants