Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Indexing ConstantMulLinearOperator with a SumBatchLinearOperator base operator #25

Closed
j-wilson opened this issue Oct 18, 2022 · 3 comments
Labels
bug Something isn't working

Comments

@j-wilson
Copy link

🐛 Bug

To reproduce

A = ops.DenseLinearOperator(rand(4, 3, 2, 2))
B = ops.SumBatchLinearOperator(A, block_dim=-3)
C = ops.ConstantMulLinearOperator(B, rand([]))
C[:, -1:, :].to_dense()
The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-e4d937430d51> in <module>
----> 1 C[:, -1:, :].to_dense()
/mnt/xarfuse/uid-22150/d091ed77-seed-nspid4026533386_cgpid5352405-ns-4026533383/linear_operator/operators/sum_batch_linear_operator.py in to_dense(self)
     59 
     60     def to_dense(self):
---> 61         return self.base_linear_op.to_dense().sum(dim=-3)  # BlockLinearOperators always use dim3 for the block_dim
/mnt/xarfuse/uid-22150/d091ed77-seed-nspid4026533386_cgpid5352405-ns-4026533383/linear_operator/utils/memoize.py in g(self, *args, **kwargs)
     57         kwargs_pkl = pickle.dumps(kwargs)
     58         if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59             return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60         return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
     61 
/mnt/xarfuse/uid-22150/d091ed77-seed-nspid4026533386_cgpid5352405-ns-4026533383/linear_operator/operators/constant_mul_linear_operator.py in to_dense(self)
    164     def to_dense(self):
    165         res = self.base_linear_op.to_dense()
--> 166         return res * self.expanded_constant
    167 
    168     @cached(name="root_decomposition")
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1

Expected Behavior

The code is expected to behave in the same way its dense analogue would.

@j-wilson j-wilson added the bug Something isn't working label Oct 18, 2022
@j-wilson
Copy link
Author

j-wilson commented Oct 18, 2022

The problem seems to stem from ConstantMulLinearOperator._getitem. The following appears to work, but I am not sure what it's runtime profile looks like in comparison to the existing implementation. We still index into base_linear_op and constant directly. If I'm not mistaken, this new version may be faster (in the particular case considered here) since we now multiply instance: SumBatchLinearOperator by constant rather than instance.base_linear_op.

    def _getitem(self, row_index, col_index, *batch_indices):
        # NOTE TO FUTURE SELF:
        # This custom __getitem__ is actually very important!
        # It prevents constructing an InterpolatedLinearOperator when one isn't needed
        # This affects runtimes by up to 5x on simple exact GPs
        # Run __getitem__ on the base_linear_op and the constant
        base_linear_op = self.base_linear_op._getitem(row_index, col_index, *batch_indices)
        constant = self._constant.expand(self.batch_shape)[batch_indices]
        return type(self)(base_linear_op=base_linear_op, constant=constant)

@gpleiss
Copy link
Member

gpleiss commented Oct 22, 2022

Your fix seems reasonable, and I also suspect that it is faster :) Want to throw up a PR for this?

@JonathanWenger
Copy link
Collaborator

Fixed by #37.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants