-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Always use special .solve
for Kronecker linear operators
#50
base: main
Are you sure you want to change the base?
Conversation
if isinstance( | ||
linear_op, | ||
( | ||
CholLinearOperator, | ||
TriangularLinearOperator, | ||
KroneckerProductAddedDiagLinearOperator, | ||
KroneckerProductLinearOperator, | ||
KroneckerProductDiagLinearOperator, | ||
KroneckerProductTriangularLinearOperator, | ||
SumKroneckerLinearOperator, | ||
), | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm will this always apply the special solve
method? There may be situations in which we want to use Linear CG solves even for some operators with a special solve
method.
Aside: The name "fast_computations" is a bit weird; whether it's fast or not will depend on the operator structure and the data size...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gpleiss, @jacobrgardner, curious about your thoughts here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside: The name "fast_computations" is a bit weird; whether it's fast or not will depend on the operator structure and the data size...
Agreed. I regret it.
There may be situations in which we want to use Linear CG solves even for some operators with a special solve method.
@JonathanWenger and I brainstormed this a bit. One thought that we had was that a user could specify (via context manager, inline argument, etc.) when they want to go into iterative solving mode. All other solves would be performed using direct methods otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that could be useful. Another option would be to attach default rules for the decision which solves to use to the respective operators but then allow to override them (either way so a default exact solve may use an iterative instead and vice versa).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @Balandat on the default rules. That would nicely integrate with using a Kronecker, or banded specific solver. The interface Geoff and I were discussing was either via a context manager or with an optional argument that could specify the default per linear operator:
def solve(self, right_tensor: torch.Tensor, left_tensor: Optional[torch.Tensor] = None, linear_solver: LinearSolver = CG()) -> torch.Tensor:
However, there were some potential issues with this interface and the interplay with implementing torch.linalg.solve
, if I remember correctly. In a perfect world torch.linalg.solve
would dispatch on the specific kind of LinearOperator
I suppose.
As titled. These linear operators are generally much larger than their components. If
fast_computations
(in particular_fast_solves
) is turned off, then we try to compute Cholesky over huge matrices, which leads to OOMs.