Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return (inv) root of KroneckerProductAddedDiagLinearOperator as lazy #14

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@
import torch

from .added_diag_linear_operator import AddedDiagLinearOperator
from .diag_linear_operator import DiagLinearOperator
from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from .matmul_linear_operator import MatmulLinearOperator


class KroneckerProductAddedDiagLinearOperator(AddedDiagLinearOperator):
def __init__(self, *linear_operators, preconditioner_override=None):
# TODO: implement the woodbury formula for diagonal tensors that are non constants.

super(KroneckerProductAddedDiagLinearOperator, self).__init__(
*linear_operators, preconditioner_override=preconditioner_override
)
super().__init__(*linear_operators, preconditioner_override=preconditioner_override)
if len(linear_operators) > 2:
raise RuntimeError("An AddedDiagLinearOperator can only have two components")
elif isinstance(linear_operators[0], DiagLinearOperator):
Expand All @@ -27,59 +24,67 @@ def __init__(self, *linear_operators, preconditioner_override=None):
raise RuntimeError(
"One of the LinearOperators input to AddedDiagLinearOperator must be a DiagLinearOperator!"
)
self._diag_is_constant = isinstance(self.diag_tensor, ConstantDiagLinearOperator)

def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
# we want to call the standard InvQuadLogDet to easily get the probe vectors and do the
# solve but we only want to cache the probe vectors for the backwards
inv_quad_term, _ = super().inv_quad_logdet(
inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad
)

if logdet is not False:
logdet_term = self._logdet()
else:
logdet_term = None

return inv_quad_term, logdet_term
if self._diag_is_constant:
# we want to call the standard InvQuadLogDet to easily get the probe vectors and do the
# solve but we only want to cache the probe vectors for the backwards
inv_quad_term, _ = super().inv_quad_logdet(
inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad
)
logdet_term = self._logdet() if logdet else None
return inv_quad_term, logdet_term
return super().inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad)

def _logdet(self):
# symeig requires computing the eigenvectors so that it's differentiable
evals, _ = self.linear_operator.symeig(eigenvectors=True)
evals_plus_diag = evals + self.diag_tensor.diag()
return torch.log(evals_plus_diag).sum(dim=-1)
if self._diag_is_constant:
# symeig requires computing the eigenvectors so that it's differentiable
evals, _ = self.linear_operator.symeig(eigenvectors=True)
evals_plus_diag = evals + self.diag_tensor.diag()
return torch.log(evals_plus_diag).sum(dim=-1)
return super()._logdet()

def _preconditioner(self):
# solves don't use CG so don't waste time computing it
return None, None, None

def _solve(self, rhs, preconditioner=None, num_tridiag=0):
# we do the solve in double for numerical stability issues
# TODO: Use fp64 registry once #1213 is addressed
if self._diag_is_constant:
# we do the solve in double for numerical stability issues
# TODO: Use fp64 registry once #1213 is addressed

rhs_dtype = rhs.dtype
rhs = rhs.double()

evals, q_matrix = self.linear_operator.symeig(eigenvectors=True)
evals, q_matrix = evals.double(), q_matrix.double()

rhs_dtype = rhs.dtype
rhs = rhs.double()
evals_plus_diagonal = evals + self.diag_tensor.diag()
evals_root = evals_plus_diagonal.pow(0.5)
inv_mat_sqrt = DiagLinearOperator(evals_root.reciprocal())

evals, q_matrix = self.linear_operator.symeig(eigenvectors=True)
evals, q_matrix = evals.double(), q_matrix.double()
res = q_matrix.transpose(-2, -1).matmul(rhs)
res2 = inv_mat_sqrt.matmul(res)

evals_plus_diagonal = evals + self.diag_tensor.diag()
evals_root = evals_plus_diagonal.pow(0.5)
inv_mat_sqrt = DiagLinearOperator(evals_root.reciprocal())
lhs = q_matrix.matmul(inv_mat_sqrt)
return lhs.matmul(res2).type(rhs_dtype)

res = q_matrix.transpose(-2, -1).matmul(rhs)
res2 = inv_mat_sqrt.matmul(res)
# TODO: implement woodbury formula for non-constant Kronecker-structured diagonal operators

lhs = q_matrix.matmul(inv_mat_sqrt)
return lhs.matmul(res2).type(rhs_dtype)
return super()._solve(rhs, preconditioner=preconditioner, num_tridiag=num_tridiag)

def _root_decomposition(self):
evals, q_matrix = self.linear_operator.symeig(eigenvectors=True)
updated_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(0.5))
matrix_root = q_matrix.matmul(updated_evals)
return matrix_root
if self._diag_is_constant:
# we can be use eigendecomposition and shift the eigenvalues
evals, q_matrix = self.linear_operator.symeig(eigenvectors=True)
updated_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(0.5))
return MatmulLinearOperator(q_matrix, updated_evals)
return super()._root_decomposition()

def _root_inv_decomposition(self, initial_vectors=None):
evals, q_matrix = self.linear_operator.symeig(eigenvectors=True)
inv_sqrt_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(-0.5))
matrix_inv_root = q_matrix.matmul(inv_sqrt_evals)
return matrix_inv_root
if self._diag_is_constant:
evals, q_matrix = self.linear_operator.symeig(eigenvectors=True)
inv_sqrt_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(-0.5))
return MatmulLinearOperator(q_matrix, inv_sqrt_evals)
return super()._root_inv_decomposition(initial_vectors=initial_vectors)
2 changes: 1 addition & 1 deletion linear_operator/operators/linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,7 +1925,7 @@ def __getitem__(self, index):
# Alternatively, if we're using tensor indices and losing dimensions, use self._get_indices
if row_col_are_absorbed:
# Convert all indices into tensor indices
*batch_indices, row_index, col_index, = _convert_indices_to_tensors(
(*batch_indices, row_index, col_index,) = _convert_indices_to_tensors(
self, (*batch_indices, row_index, col_index)
)
res = self._get_indices(row_index, col_index, *batch_indices)
Expand Down
3 changes: 1 addition & 2 deletions linear_operator/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@


def prod(items):
"""
"""
""""""
if len(items):
res = items[0]
for item in items[1:]:
Expand Down
6 changes: 2 additions & 4 deletions linear_operator/utils/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def interpolate(self, x_grid: List[torch.Tensor], x_target: torch.Tensor, interp


def left_interp(interp_indices, interp_values, rhs):
"""
"""
""""""
is_vector = rhs.ndimension() == 1

if is_vector:
Expand All @@ -181,8 +180,7 @@ def left_interp(interp_indices, interp_values, rhs):


def left_t_interp(interp_indices, interp_values, rhs, output_dim):
"""
"""
""""""
from .. import dsmm

is_vector = rhs.ndimension() == 1
Expand Down
3 changes: 1 addition & 2 deletions linear_operator/utils/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def lanczos_tridiag(
num_init_vecs=1,
tol=1e-5,
):
"""
"""
""""""
# Determine batch mode
multiple_init_vecs = False

Expand Down
9 changes: 3 additions & 6 deletions linear_operator/utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def sparse_eye(size):


def sparse_getitem(sparse, idxs):
"""
"""
""""""
if not isinstance(idxs, tuple):
idxs = (idxs,)

Expand Down Expand Up @@ -201,8 +200,7 @@ def sparse_getitem(sparse, idxs):


def sparse_repeat(sparse, *repeat_sizes):
"""
"""
""""""
if len(repeat_sizes) == 1 and isinstance(repeat_sizes, tuple):
repeat_sizes = repeat_sizes[0]

Expand Down Expand Up @@ -243,8 +241,7 @@ def sparse_repeat(sparse, *repeat_sizes):


def to_sparse(dense):
"""
"""
""""""
mask = dense.ne(0)
indices = mask.nonzero(as_tuple=False)
if indices.storage():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from linear_operator import settings
from linear_operator.operators import (
ConstantDiagLinearOperator,
DenseLinearOperator,
DiagLinearOperator,
KroneckerProductAddedDiagLinearOperator,
Expand All @@ -23,7 +24,7 @@ class TestKroneckerProductAddedDiagLinearOperator(unittest.TestCase, LinearOpera
should_call_lanczos = False
should_call_cg = False

def create_linear_operator(self):
def create_linear_operator(self, constant_diag=True):
a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float)
b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float)
c = torch.tensor([[4, 0.5, 1, 0], [0.5, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float)
Expand All @@ -33,10 +34,13 @@ def create_linear_operator(self):
kp_linear_operator = KroneckerProductLinearOperator(
DenseLinearOperator(a), DenseLinearOperator(b), DenseLinearOperator(c)
)

return KroneckerProductAddedDiagLinearOperator(
kp_linear_operator, DiagLinearOperator(0.1 * torch.ones(kp_linear_operator.shape[-1]))
)
if constant_diag:
diag_linear_operator = ConstantDiagLinearOperator(
torch.tensor([0.25], dtype=torch.float), kp_linear_operator.shape[-1],
)
else:
diag_linear_operator = DiagLinearOperator(0.5 * torch.rand(kp_linear_operator.shape[-1], dtype=torch.float))
return KroneckerProductAddedDiagLinearOperator(kp_linear_operator, diag_linear_operator)

def evaluate_linear_operator(self, linear_operator):
tensor = linear_operator._linear_operator.to_dense()
Expand All @@ -59,7 +63,7 @@ def test_root_inv_decomposition_no_cholesky(self):
with mock.patch.object(linear_operator, "cholesky") as chol_mock:
root_approx = linear_operator.root_inv_decomposition()
res = root_approx.matmul(test_mat)
actual = linear_operator.inv_matmul(test_mat)
actual = torch.solve(test_mat, linear_operator.to_dense()).solution
self.assertAllClose(res, actual, rtol=0.05, atol=0.02)
chol_mock.assert_not_called()

Expand Down