Skip to content

Commit

Permalink
Changed upper_tri.py to support versions <3.11.
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitSolomonPrinceton committed Sep 16, 2024
1 parent 20fce79 commit 4ea58cb
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions cvxtorch/torch_numerics/affine/upper_tri.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import torch
from cvxpy.expressions.expression import Expression


def torch_numeric(expr: Expression, values: list[torch.Tensor]) -> torch.Tensor:
def torch_numeric(expr: Expression, values: list[torch.Tensor]) -> torch.Tensor:
inds = torch.triu_indices(row=values[0].shape[0], col=values[0].shape[1], offset=1)
return values[0][*inds]
# This can be simplified as `return values[0][*inds]`. However, this doesn't work on <3.11.
# I ended up doing the following solution to support both versions (checking the version
# during runtime causes issues during testing).
# In future Python versions >=3.11, this should be replaced.
if len(inds) != 2:
raise ValueError(f"Expected inds to be of length 2, but got {len(inds)} instead.")
return values[0][inds[0], inds[1]]

0 comments on commit 4ea58cb

Please sign in to comment.