Skip to content

Commit

Permalink
upgrade packages
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremie Coullon committed Aug 7, 2023
1 parent a2a2ae8 commit c629e15
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 82 deletions.
6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ nbsphinx==0.8.1
nb-black==1.0.7
matplotlib==3.3.3
sphinx
mypy==0.910
mypy-extensions==0.4.3
mypy==1.4.1
mypy-extensions==1.0.0
pytest==6.2.4
black==22.1.0
black==23.7.0
isort==5.10.1
pytest-cov==3.0.0
14 changes: 7 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
absl-py==0.13.0
absl-py==1.4.0
flatbuffers==2.0
jax==0.2.14
jaxlib==0.1.67
numpy==1.21.0
jax==0.4.13
jaxlib==0.4.13
numpy==1.24.4
opt-einsum==3.3.0
scipy==1.7.0
scipy==1.10.1
six==1.16.0
tqdm==4.61.1
optax==0.1.1
tqdm==4.65.0
optax==0.1.7
3 changes: 1 addition & 2 deletions sgmcmcjax/examples/gauss_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import optax
from jax import random

from sgmcmcjax.optimizer import build_adam_optimizer, build_optax_optimizer
from sgmcmcjax.optimizer import build_optax_optimizer


def loglikelihood(theta, x):
Expand All @@ -26,7 +26,6 @@ def logprior(theta):
dt = 1e-2
opt = optax.adam(learning_rate=dt)
optimizer = build_optax_optimizer(opt, loglikelihood, logprior, (X_data,), batch_size)
# optimizer = build_adam_optimizer(dt, loglikelihood, logprior, (X_data,), batch_size)

Nsamples = 10_000
params, log_post_list = optimizer(key, Nsamples, jnp.zeros(D))
Expand Down
6 changes: 3 additions & 3 deletions sgmcmcjax/gradient_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp
from jax import jit, lax, random
from jax.tree_util import tree_flatten, tree_map, tree_multimap, tree_unflatten
from jax.tree_util import tree_flatten, tree_map, tree_unflatten

from .types import PRNGKey, PyTree, SamplerState, SVRGState

Expand Down Expand Up @@ -79,7 +79,7 @@ def init_gradient(key: PRNGKey, param: PyTree):
grad_center = grad_log_post(centering_value, *minibatch_data)
flat_param_grad, tree_param_grad = tree_flatten(param_grad)
flat_grad_center, tree_grad_center = tree_flatten(grad_center)
new_flat_param_grad = tree_multimap(
new_flat_param_grad = tree_map(
update_fn, flat_fb_grad_center, flat_param_grad, flat_grad_center
)
param_grad = tree_unflatten(tree_param_grad, new_flat_param_grad)
Expand Down Expand Up @@ -144,7 +144,7 @@ def estimate_gradient(
grad_center = grad_log_post(svrg_state.centering_value, *minibatch_data)
flat_param_grad, tree_param_grad = tree_flatten(param_grad)
flat_grad_center, tree_grad_center = tree_flatten(grad_center)
new_flat_param_grad = tree_multimap(
new_flat_param_grad = tree_map(
update_fn, svrg_state.fb_grad_center, flat_param_grad, flat_grad_center
)
param_grad = tree_unflatten(tree_param_grad, new_flat_param_grad)
Expand Down
2 changes: 1 addition & 1 deletion sgmcmcjax/ksd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def k_0_fun(


@jit
def imq_KSD(samples: Array, grads: Array) -> float:
def imq_KSD(samples: Array, grads: Array) -> Array:
"""Kernel Stein Discrepancy with IMQ kernel
Args:
Expand Down
1 change: 1 addition & 0 deletions sgmcmcjax/models/bayesian_NN/NN_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

N_data = X_train.shape[0]


# ==========
# Functions to initialise parameters
# initialise params: list of tuples (W, b) for each layer
Expand Down
51 changes: 0 additions & 51 deletions sgmcmcjax/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,12 @@
import jax.numpy as jnp
import optax
from jax import jit, lax, random
from jax.experimental.optimizers import adam
from jax.tree_util import tree_map

from .gradient_estimation import build_gradient_estimation_fn
from .util import build_grad_log_post, progress_bar_scan


def build_adam_optimizer(
dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int
) -> Callable:
"""build adam optimizer using JAX `optimizers` module
Args:
dt (float): step size
loglikelihood (Callable): log-likelihood for a single data point
logprior (Callable): log-prior for a single data point
data (Tuple): tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)
batch_size (int): batch size
Returns:
Callable: optimizer function with signature:
Args:
key (PRNGKey): random key
Niters (int): number of iterations
params_IC (PyTree): initial parameters
Returns:
PyTree: final parameters
jnp.ndarray: array of log-posterior values during the optimization
"""
grad_log_post = build_grad_log_post(loglikelihood, logprior, data, with_val=True)
estimate_gradient, _ = build_gradient_estimation_fn(grad_log_post, data, batch_size)

init_fn, update, get_params = adam(dt)

@jit
def body(carry, i):
key, state = carry
key, subkey = random.split(key)
(lp_val, param_grad), _ = estimate_gradient(i, subkey, get_params(state))
neg_param_grad = tree_map(lambda x: -x, param_grad)
state = update(i, neg_param_grad, state)
return (key, state), lp_val

def run_adam(key, Niters, params_IC):
state = init_fn(params_IC)
body_pbar = progress_bar_scan(Niters)(body)
(_, state_opt), logpost_array = lax.scan(
body_pbar, (key, state), jnp.arange(Niters)
)
return get_params(state_opt), logpost_array

return run_adam


def build_optax_optimizer(
optimizer: optax.GradientTransformation,
loglikelihood: Callable,
Expand Down
16 changes: 1 addition & 15 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import optax
from jax import random

from sgmcmcjax.optimizer import build_adam_optimizer, build_optax_optimizer
from sgmcmcjax.optimizer import build_optax_optimizer


def loglikelihood(theta, x):
Expand Down Expand Up @@ -34,17 +34,3 @@ def test_optax_optimizer():
print(log_post_list.shape)
print(params.shape)
assert jnp.allclose(params, mu_true, atol=1e-1)


def test_jax_Adam_optimizer():
# Adam
batch_size = int(0.1 * N)
dt = 1e-2
opt = optax.adam(learning_rate=dt)
optimizer = build_adam_optimizer(dt, loglikelihood, logprior, (X_data,), batch_size)

Nsamples = 10_000
params, log_post_list = optimizer(key, Nsamples, jnp.zeros(D))
print(log_post_list.shape)
print(params.shape)
assert jnp.allclose(params, mu_true, atol=1e-1)

0 comments on commit c629e15

Please sign in to comment.