Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return (inv) root of KroneckerProductAddedDiagLinearOperator as lazy #14

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Balandat
Copy link
Collaborator

This is a clone of cornellius-gp/gpytorch#1430

Previously, _root_decomposition and _inv_root_decomposition were returning the (inv) root from the eigendecomposition as a dense tensor, which can be inefficient. This now returns the root as MatmulLazyTensor instead. E.g. a matrix vector product of the root with some vector v is now implicitly computed as q_matrix @ (evals \dot v) rather than (q_matrix @ diag(evals)) @ v, which can make a big difference since q_matrix is a Kronecker product.

This can help runtime, but more importantly it significantly reduces memory footprint, since we don't need to instantiate the (inv) root, but only the constituent components.

This also fixes an issue with KroneckerProductAddedDiagLinearOperator implicitly assuming the diagonal to be constant, and returning incorrect results if that was not the case. The changes here make the tensor fall back to the superclass implementation in case of non-constant diagonals.

Previously, `_root_decomposition` and `_inv_root_decomposition` were returning the (inv) root from the eigendecomposition as a dense tensor, which can be inefficient. This now returns the root as `MatmulLazyTensor` instead. E.g. a matrix vector product of the root with some vector `v` is now implicitly computed as `q_matrix @ (evals \dot v)` rather than `(q_matrix @ diag(evals)) @ v`, which can make a big difference since `q_matrix` is a Kronecker product.

This can help runtime, but more importantly it significantly reduces memory footprint, since we don't need to instantiate the (inv) root, but only the constitutent components.

This is a clone of cornellius-gp/gpytorch#1430
@Balandat
Copy link
Collaborator Author

I am not sure why in the case of using a ConstantDiagLinearOperator gradients aren't being returned properly (hence the test failure).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant