Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical error with JAX #60

Open
shawnwwimer opened this issue Nov 20, 2024 · 2 comments
Open

Numerical error with JAX #60

shawnwwimer opened this issue Nov 20, 2024 · 2 comments

Comments

@shawnwwimer
Copy link

I was using JAX for the backend and couldn't get some simulations to agree with an analytical form. I noticed that changing the backend to the CPU fixed this problem and found that JAX uses a low precision by default for some operations:

I don't have the time to test right now, but from that second issue: "Try setting jax.default_matmul_precision to float32". If anybody runs into a similar problem this may be the cause. If so, it may be good to note it in the readme.

@rafael-fuente
Copy link
Owner

There is a significant accuracy loss for diffraction calculations when using JAX float32 (which is the default when importing JAX)

Can you check if enabling JAX x64 solves your problem?

You need to add the following lines before importing diffractsim:

import jax
jax.config.update("jax_enable_x64", True)

@shawnwwimer
Copy link
Author

Yes, that does fix it. The difference is definitely the precision. I brought it up here simply because I didn't see a related issue here and I was surprised to find out it was due to default JAX behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants