diff --git a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py index a59d822..c821389 100644 --- a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py +++ b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py @@ -5,16 +5,13 @@ import torch from .added_diag_linear_operator import AddedDiagLinearOperator -from .diag_linear_operator import DiagLinearOperator +from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator +from .matmul_linear_operator import MatmulLinearOperator class KroneckerProductAddedDiagLinearOperator(AddedDiagLinearOperator): def __init__(self, *linear_operators, preconditioner_override=None): - # TODO: implement the woodbury formula for diagonal tensors that are non constants. - - super(KroneckerProductAddedDiagLinearOperator, self).__init__( - *linear_operators, preconditioner_override=preconditioner_override - ) + super().__init__(*linear_operators, preconditioner_override=preconditioner_override) if len(linear_operators) > 2: raise RuntimeError("An AddedDiagLinearOperator can only have two components") elif isinstance(linear_operators[0], DiagLinearOperator): @@ -27,59 +24,67 @@ def __init__(self, *linear_operators, preconditioner_override=None): raise RuntimeError( "One of the LinearOperators input to AddedDiagLinearOperator must be a DiagLinearOperator!" ) + self._diag_is_constant = isinstance(self.diag_tensor, ConstantDiagLinearOperator) def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True): - # we want to call the standard InvQuadLogDet to easily get the probe vectors and do the - # solve but we only want to cache the probe vectors for the backwards - inv_quad_term, _ = super().inv_quad_logdet( - inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad - ) - - if logdet is not False: - logdet_term = self._logdet() - else: - logdet_term = None - - return inv_quad_term, logdet_term + if self._diag_is_constant: + # we want to call the standard InvQuadLogDet to easily get the probe vectors and do the + # solve but we only want to cache the probe vectors for the backwards + inv_quad_term, _ = super().inv_quad_logdet( + inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad + ) + logdet_term = self._logdet() if logdet else None + return inv_quad_term, logdet_term + return super().inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad) def _logdet(self): - # symeig requires computing the eigenvectors so that it's differentiable - evals, _ = self.linear_operator.symeig(eigenvectors=True) - evals_plus_diag = evals + self.diag_tensor.diag() - return torch.log(evals_plus_diag).sum(dim=-1) + if self._diag_is_constant: + # symeig requires computing the eigenvectors so that it's differentiable + evals, _ = self.linear_operator.symeig(eigenvectors=True) + evals_plus_diag = evals + self.diag_tensor.diag() + return torch.log(evals_plus_diag).sum(dim=-1) + return super()._logdet() def _preconditioner(self): # solves don't use CG so don't waste time computing it return None, None, None def _solve(self, rhs, preconditioner=None, num_tridiag=0): - # we do the solve in double for numerical stability issues - # TODO: Use fp64 registry once #1213 is addressed + if self._diag_is_constant: + # we do the solve in double for numerical stability issues + # TODO: Use fp64 registry once #1213 is addressed + + rhs_dtype = rhs.dtype + rhs = rhs.double() + + evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) + evals, q_matrix = evals.double(), q_matrix.double() - rhs_dtype = rhs.dtype - rhs = rhs.double() + evals_plus_diagonal = evals + self.diag_tensor.diag() + evals_root = evals_plus_diagonal.pow(0.5) + inv_mat_sqrt = DiagLinearOperator(evals_root.reciprocal()) - evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) - evals, q_matrix = evals.double(), q_matrix.double() + res = q_matrix.transpose(-2, -1).matmul(rhs) + res2 = inv_mat_sqrt.matmul(res) - evals_plus_diagonal = evals + self.diag_tensor.diag() - evals_root = evals_plus_diagonal.pow(0.5) - inv_mat_sqrt = DiagLinearOperator(evals_root.reciprocal()) + lhs = q_matrix.matmul(inv_mat_sqrt) + return lhs.matmul(res2).type(rhs_dtype) - res = q_matrix.transpose(-2, -1).matmul(rhs) - res2 = inv_mat_sqrt.matmul(res) + # TODO: implement woodbury formula for non-constant Kronecker-structured diagonal operators - lhs = q_matrix.matmul(inv_mat_sqrt) - return lhs.matmul(res2).type(rhs_dtype) + return super()._solve(rhs, preconditioner=preconditioner, num_tridiag=num_tridiag) def _root_decomposition(self): - evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) - updated_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(0.5)) - matrix_root = q_matrix.matmul(updated_evals) - return matrix_root + if self._diag_is_constant: + # we can be use eigendecomposition and shift the eigenvalues + evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) + updated_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(0.5)) + return MatmulLinearOperator(q_matrix, updated_evals) + return super()._root_decomposition() def _root_inv_decomposition(self, initial_vectors=None): - evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) - inv_sqrt_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(-0.5)) - matrix_inv_root = q_matrix.matmul(inv_sqrt_evals) - return matrix_inv_root + if self._diag_is_constant: + evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) + inv_sqrt_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(-0.5)) + return MatmulLinearOperator(q_matrix, inv_sqrt_evals) + return super()._root_inv_decomposition(initial_vectors=initial_vectors) diff --git a/linear_operator/operators/linear_operator.py b/linear_operator/operators/linear_operator.py index 5a83eee..790ddd6 100644 --- a/linear_operator/operators/linear_operator.py +++ b/linear_operator/operators/linear_operator.py @@ -1925,7 +1925,7 @@ def __getitem__(self, index): # Alternatively, if we're using tensor indices and losing dimensions, use self._get_indices if row_col_are_absorbed: # Convert all indices into tensor indices - *batch_indices, row_index, col_index, = _convert_indices_to_tensors( + (*batch_indices, row_index, col_index,) = _convert_indices_to_tensors( self, (*batch_indices, row_index, col_index) ) res = self._get_indices(row_index, col_index, *batch_indices) diff --git a/linear_operator/utils/__init__.py b/linear_operator/utils/__init__.py index 2ee2eb7..a753cfd 100644 --- a/linear_operator/utils/__init__.py +++ b/linear_operator/utils/__init__.py @@ -9,8 +9,7 @@ def prod(items): - """ - """ + """""" if len(items): res = items[0] for item in items[1:]: diff --git a/linear_operator/utils/interpolation.py b/linear_operator/utils/interpolation.py index dad521b..c7c39a2 100644 --- a/linear_operator/utils/interpolation.py +++ b/linear_operator/utils/interpolation.py @@ -156,8 +156,7 @@ def interpolate(self, x_grid: List[torch.Tensor], x_target: torch.Tensor, interp def left_interp(interp_indices, interp_values, rhs): - """ - """ + """""" is_vector = rhs.ndimension() == 1 if is_vector: @@ -181,8 +180,7 @@ def left_interp(interp_indices, interp_values, rhs): def left_t_interp(interp_indices, interp_values, rhs, output_dim): - """ - """ + """""" from .. import dsmm is_vector = rhs.ndimension() == 1 diff --git a/linear_operator/utils/lanczos.py b/linear_operator/utils/lanczos.py index ad3ebf5..999cc12 100644 --- a/linear_operator/utils/lanczos.py +++ b/linear_operator/utils/lanczos.py @@ -18,8 +18,7 @@ def lanczos_tridiag( num_init_vecs=1, tol=1e-5, ): - """ - """ + """""" # Determine batch mode multiple_init_vecs = False diff --git a/linear_operator/utils/sparse.py b/linear_operator/utils/sparse.py index d3729eb..cc0d332 100644 --- a/linear_operator/utils/sparse.py +++ b/linear_operator/utils/sparse.py @@ -140,8 +140,7 @@ def sparse_eye(size): def sparse_getitem(sparse, idxs): - """ - """ + """""" if not isinstance(idxs, tuple): idxs = (idxs,) @@ -201,8 +200,7 @@ def sparse_getitem(sparse, idxs): def sparse_repeat(sparse, *repeat_sizes): - """ - """ + """""" if len(repeat_sizes) == 1 and isinstance(repeat_sizes, tuple): repeat_sizes = repeat_sizes[0] @@ -243,8 +241,7 @@ def sparse_repeat(sparse, *repeat_sizes): def to_sparse(dense): - """ - """ + """""" mask = dense.ne(0) indices = mask.nonzero(as_tuple=False) if indices.storage(): diff --git a/test/operators/test_kronecker_product_added_diag_linear_operator.py b/test/operators/test_kronecker_product_added_diag_linear_operator.py index e2584ba..f482367 100644 --- a/test/operators/test_kronecker_product_added_diag_linear_operator.py +++ b/test/operators/test_kronecker_product_added_diag_linear_operator.py @@ -9,6 +9,7 @@ from linear_operator import settings from linear_operator.operators import ( + ConstantDiagLinearOperator, DenseLinearOperator, DiagLinearOperator, KroneckerProductAddedDiagLinearOperator, @@ -23,7 +24,7 @@ class TestKroneckerProductAddedDiagLinearOperator(unittest.TestCase, LinearOpera should_call_lanczos = False should_call_cg = False - def create_linear_operator(self): + def create_linear_operator(self, constant_diag=True): a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float) b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float) c = torch.tensor([[4, 0.5, 1, 0], [0.5, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float) @@ -33,10 +34,13 @@ def create_linear_operator(self): kp_linear_operator = KroneckerProductLinearOperator( DenseLinearOperator(a), DenseLinearOperator(b), DenseLinearOperator(c) ) - - return KroneckerProductAddedDiagLinearOperator( - kp_linear_operator, DiagLinearOperator(0.1 * torch.ones(kp_linear_operator.shape[-1])) - ) + if constant_diag: + diag_linear_operator = ConstantDiagLinearOperator( + torch.tensor([0.25], dtype=torch.float), kp_linear_operator.shape[-1], + ) + else: + diag_linear_operator = DiagLinearOperator(0.5 * torch.rand(kp_linear_operator.shape[-1], dtype=torch.float)) + return KroneckerProductAddedDiagLinearOperator(kp_linear_operator, diag_linear_operator) def evaluate_linear_operator(self, linear_operator): tensor = linear_operator._linear_operator.to_dense() @@ -59,7 +63,7 @@ def test_root_inv_decomposition_no_cholesky(self): with mock.patch.object(linear_operator, "cholesky") as chol_mock: root_approx = linear_operator.root_inv_decomposition() res = root_approx.matmul(test_mat) - actual = linear_operator.inv_matmul(test_mat) + actual = torch.solve(test_mat, linear_operator.to_dense()).solution self.assertAllClose(res, actual, rtol=0.05, atol=0.02) chol_mock.assert_not_called()