Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow einsum to support naive contraction strategy #24915

Open
ryan112358 opened this issue Nov 15, 2024 · 0 comments
Open

Allow einsum to support naive contraction strategy #24915

ryan112358 opened this issue Nov 15, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@ryan112358
Copy link

ryan112358 commented Nov 15, 2024

I would like to compute an einsum according to the following formula:

n = 8192
arrays = [jax.random.normal(key=jax.random.PRNGKey(0), shape=(n, n)) for _ in range(6)]
formula = 'ij,ik,il,jk,jl,kl->ij'

I want to express the computation as 4 nested for loops over indices i, j, k, l without creating any intermediate arrays. As far as einsum_path is concerned, I can do this by passing the einsum path directly as [(0, 1, 2, 3, 4, 5)] via the optimize kwarg).

>>> jax.numpy.einsum_path(formula,` *arrays, optimize=[(0,1,2,3,4,5)])
Complete contraction:  ij,ik,il,jk,jl,kl->ij
          Naive scaling:  4
      Optimized scaling:  4
       Naive FLOP count:  2.702e+16
   Optimized FLOP count:  2.702e+16
    Theoretical speedup:  1.000e+0
   Largest intermediate:  6.711e+7 elements
 --------------------------------------------------------------------------------
 scaling        BLAS                current                             remaining
 --------------------------------------------------------------------------------
    4              0  kl,jl,jk,il,ik,ij->ij                                ij->ij)

However, when I try to do the einsum, I get this NotImplementedError with a comment that says "# if this is actually reachable, open an issue!"

https://github.com/jax-ml/jax/blob/main/jax/_src/numpy/lax_numpy.py#L9775

>>> ans = jnp.einsum(formula, *arrays, optimize=[(0,1,2,3,4,5)])
>>> ans.block_until_ready()
@ryan112358 ryan112358 added the enhancement New feature or request label Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant