Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jax implementation of DII #146

Merged
merged 22 commits into from
Nov 13, 2024
Merged

Add jax implementation of DII #146

merged 22 commits into from
Nov 13, 2024

Conversation

vdeltatto
Copy link
Collaborator

@vdeltatto vdeltatto commented Nov 11, 2024

Proposed changes

  1. Added new class DiffImbalance, containing the JAX implementation of the DII
  2. Added tutorial notebook for new class, tests and updated documentation

Comment 1 - JAX version

The code works with jax=0.4.30, which unfortunately is not available for Python<3.9. Since Python 3.8 is still widely employed I found a workaround by installing this JAX version only for Python>=3.9. For Python<3.9 then the DiffImbalance class is not available (this is specified in the README and in the documentation), and the corresponding tests are not run during the installation.

Comment 2 - JAX on GPU

To avoid installation problems, the JAX library is installed by default on CPU (as from PR #140). The GPU installation depends on several factors, so I would leave it completely to the user (I added some tips in the README and a reference to the installation tutorial by Google).

TODO

  • The tutorial can be expanded.

  • The code is missing of a method for the forward / backward greedy selection that is present in the Cython implementation

@vdeltatto vdeltatto marked this pull request as ready for review November 12, 2024 22:55
@wildromi wildromi merged commit 4c6a41f into main Nov 13, 2024
8 checks passed
@vdeltatto vdeltatto deleted the jax_DII_pr branch November 14, 2024 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants