You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to solve multiple QPs with vmap in JAXopt by using matvec API, and as far as I understood, the use of of matvec allows faster computation speed as it uses matrix-vector multiplication.
But when I compared two methods of solving a QP (1. qp with vmap and JIT, 2. qp with matvec, vmap and JIT), the computation time is pretty same.
Am I using matvec in a wrong way? I've been struggling with JAX this week, and I would really appreciate if anyone could help me with increasing the computation speed.
import jax
import jax.numpy as jnp
from jaxopt import BoxOSQP
import math
import time
import torch
from torch import Tensor
# Define the matrix-vector product for Q
def matvec_Q(params_Q, x):
return params_Q @ x
# Define the matrix-vector product for A
def matvec_A(params_A, x):
return params_A @ x
class QP:
def __init__(self):
# Initialize BoxOSQP solver
self.qp = BoxOSQP(tol=1e-3)
self.qp_matvec = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=1e-3)
def runQP(self, A_input):
a1 = A_input[0]
a2 = A_input[1]
# Define problem data in JAX arrays
Q = jnp.array([[4, 0], [0, 2]], dtype=jnp.float32)
c = jnp.array([1, 1], dtype=jnp.float32)
A = jnp.array([[a1, a2], [1, 0], [0, 1]], dtype=jnp.float32)
l = jnp.array([1, 0, 0], dtype=jnp.float32)
u = jnp.array([1, 0.7, 0.7], dtype=jnp.float32)
# Run the solver without initial parameters
hyper_params = dict(params_obj=(Q, c), params_eq=A, params_ineq=(l, u))
sol, state = self.qp.run(None, **hyper_params)
def runQP_matvec(self, A_input):
a1 = A_input[0]
a2 = A_input[1]
# Define problem data in JAX arrays
Q = jnp.array([[4, 0], [0, 2]], dtype=jnp.float32)
c = jnp.array([1, 1], dtype=jnp.float32)
A = jnp.array([[a1, a2], [1, 0], [0, 1]], dtype=jnp.float32)
l = jnp.array([1, 0, 0], dtype=jnp.float32)
u = jnp.array([1, 0.7, 0.7], dtype=jnp.float32)
# Run the solver without initial parameters
hyper_params = dict(params_obj=(Q, c), params_eq=A, params_ineq=(l, u))
sol, state = self.qp_matvec.run(None, **hyper_params)
my_qp = QP()
input = jnp.array([1.0, 1.0])
input_vector = jnp.tile(input, (num_qp, 1))
# 1. Apply vmap & jit to QP
for i in range(10):
start_time = time.time()
auto_batch_runQP = jax.vmap(my_qp.runQP)
jitted_runQP_vectorized = jax.jit(auto_batch_runQP)
jitted_runQP_vectorized(input_vector)
end_time = time.time()
elapsed_time = end_time - start_time
print("Function 'runQP_vmap_jit' execution time: {}s".format(elapsed_time))
# 2. Apply vmap & jit to QP_matvec
for i in range(10):
start_time = time.time()
auto_batch_runQP_matvec = jax.vmap(my_qp.runQP_matvec)
jitted_runQP_matvec_vectorized = jax.jit(auto_batch_runQP_matvec)
jitted_runQP_matvec_vectorized(input_vector)
end_time = time.time()
elapsed_time = end_time - start_time
print("Function 'runQP_matvec_vmap_jit' execution time: {}s".format(elapsed_time))
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi.
I'm trying to solve multiple QPs with vmap in JAXopt by using matvec API, and as far as I understood, the use of of matvec allows faster computation speed as it uses matrix-vector multiplication.
But when I compared two methods of solving a QP (1. qp with vmap and JIT, 2. qp with matvec, vmap and JIT), the computation time is pretty same.
Am I using matvec in a wrong way? I've been struggling with JAX this week, and I would really appreciate if anyone could help me with increasing the computation speed.
Beta Was this translation helpful? Give feedback.
All reactions