You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was using JAX for the backend and couldn't get some simulations to agree with an analytical form. I noticed that changing the backend to the CPU fixed this problem and found that JAX uses a low precision by default for some operations:
I don't have the time to test right now, but from that second issue: "Try setting jax.default_matmul_precision to float32". If anybody runs into a similar problem this may be the cause. If so, it may be good to note it in the readme.
The text was updated successfully, but these errors were encountered:
Yes, that does fix it. The difference is definitely the precision. I brought it up here simply because I didn't see a related issue here and I was surprised to find out it was due to default JAX behavior.
I was using JAX for the backend and couldn't get some simulations to agree with an analytical form. I noticed that changing the backend to the CPU fixed this problem and found that JAX uses a low precision by default for some operations:
I don't have the time to test right now, but from that second issue: "Try setting jax.default_matmul_precision to float32". If anybody runs into a similar problem this may be the cause. If so, it may be good to note it in the readme.
The text was updated successfully, but these errors were encountered: