Skip to content
This repository has been archived by the owner on Aug 20, 2023. It is now read-only.
/ ferminet_tum Public archive

A library which trains the Fermionic Neural Network to find the ground state wave functions of an atom or a molecule using neural network quantum states.

License

Notifications You must be signed in to change notification settings

cevenkadir/ferminet_tum

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

40 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ferminet_tum - Fermionic Neural Network

I wrote this library for my internship at TUM. It reproduces the results of Pfau et al. (2020), which trains so-called Fermionic Neural Network to find the ground state wave functions of an atom or a molecule using neural network quantum states, see Carleo and Troyer (2017).

Installation

ferminet_tum uses JAX for the automatic differentiation, Chex for the initialization of the network parameters, optax for the optimization of the network parameters, h5py for storing the variational energy values, and mendeleev for the chemical elements used in the library.

To install it, you can run

git clone https://github.com/cevenkadir/ferminet_tum
cd ferminet_tum
python setup.py install --user

Usage

Here is an example for a Lithium atom.

import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt

from ferminet_tum.atom import Atom
from ferminet_tum.network import FermiNet
from ferminet_tum.hamiltonian import Hamiltonian
from ferminet_tum.sampler import Sampler
from ferminet_tum.driver import NNQS
li_atom = [Atom(id="Li", pos=jnp.array([0.0, 0.0, 0.0]))]
n_electrons = [2, 1]  # number of up- and down-spin electrons

n_1 = [8, 8]  # numbers of hidden units for the one-electron stream
n_2 = [4, 4]  # numbers of hidden units for the two-electron stream
n_k = 16  # number of many-electron determinants
L = len(n_1)  # number of layers

# initialize the FermiNet
fn = FermiNet(
    li_atom,
    n_electrons,
    L,
    n_k,
    n_1,
    n_2,
)
params = fn.init_params()  # initialize the parameters
# initialize the Hamiltonian
hamiltonian = Hamiltonian(fn)
n_step = 10  # number of steps in each sampling
step_std_per_dim = (
    0.02  # standard deviation of the Gaussian noise in each step for each dimension
)

# initalize the sampler
sampler = Sampler(fn, n_step, step_std_per_dim)
n_samples = 4096  # number of samples

# initialize the driver
nnqs = NNQS(sampler, hamiltonian, fn, n_samples)
n_iters = 7000  # number of iterations

# training the network
params, energies = nnqs.train(n_iters, params, optax.adam(learning_rate=1e-3))
  0%|          | 0/7000 [00:00<?, ?it/s]


/opt/homebrew/Caskroom/miniforge/base/envs/jax-env/lib/python3.9/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=bool. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
/opt/homebrew/Caskroom/miniforge/base/envs/jax-env/lib/python3.9/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=bool. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
expected_energy = -7.47798  # in Hartree energy
fig, axs = plt.subplots(1,2,dpi=150)

axs[0].plot(energies, linewidth=0.25, label="FermiNet")
axs[0].set_xlabel("iteration")
axs[0].set_ylabel("<E> (Hartree)")
axs[0].axhline(expected_energy, color="r", linestyle="--", label="true energy")
axs[0].legend()
axs[0].patch.set_facecolor("white")

per_err = jnp.abs((energies - expected_energy)/expected_energy)*100
per_err = per_err[~jnp.isnan(per_err)]
axs[1].plot(per_err, linewidth=0.25, label="FermiNet")
axs[1].set_xlabel("iteration")
axs[1].set_ylabel("correlation energy error (%)")
axs[1].axhline(jnp.abs(1.59e-3/expected_energy)*100, color="y", linestyle="--", label="chemical accuracy")
axs[1].legend()
axs[1].set_yscale("log")
axs[1].patch.set_facecolor("white")

fig.suptitle("a Li atom ({:.3f} ± {:.3f} % error from last 400 iteration)".format(per_err[-400:].mean(), per_err[-400:].std()))

fig.tight_layout()
fig.patch.set_facecolor("white")

png

Here is an example for a LiH molecule.

lih_mol = [
    Atom(id="Li", pos=jnp.array([0.0, 0.0, 3.015 / 2])),
    Atom(id="H", pos=jnp.array([0.0, 0.0, -3.015 / 2])),
]
n_electrons = [2, 2]  # number of up- and down-spin electrons

n_1 = [16, 16, 16]  # numbers of hidden units for the one-electron stream
n_2 = [8, 8, 8]  # numbers of hidden units for the two-electron stream
n_k = 16  # number of many-electron determinants
L = len(n_1)  # number of layers

# initialize the FermiNet
fn_mol = FermiNet(
    lih_mol,
    n_electrons,
    L,
    n_k,
    n_1,
    n_2,
)
params = fn_mol.init_params()  # initialize the parameters
# initialize the Hamiltonian
hamiltonian = Hamiltonian(fn_mol)
n_step = 10  # number of steps in each sampling
step_std_per_dim = (
    0.02  # standard deviation of the Gaussian noise in each step for each dimension
)

# initalize the sampler
sampler = Sampler(fn_mol, n_step, step_std_per_dim)
n_samples = 4096  # number of samples

# initialize the driver
nnqs = NNQS(sampler, hamiltonian, fn_mol, n_samples)
n_iters = 7000  # number of iterations

# training the network
params, energies = nnqs.train(n_iters, params, optax.chain(optax.adam(learning_rate=0.01), optax.add_noise(0.01,0.55,0))
)
  0%|          | 0/7000 [00:00<?, ?it/s]


/opt/homebrew/Caskroom/miniforge/base/envs/jax-env/lib/python3.9/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=bool. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
/opt/homebrew/Caskroom/miniforge/base/envs/jax-env/lib/python3.9/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=bool. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
expected_energy = -8.07050
fig, axs = plt.subplots(1,2,dpi=150)

axs[0].plot(energies, linewidth=0.25, label="FermiNet")
axs[0].set_xlabel("iteration")
axs[0].set_ylabel("<E> (Hartree)")
axs[0].axhline(expected_energy, color="r", linestyle="--", label="true energy")
axs[0].legend()
axs[0].patch.set_facecolor("white")

per_err = jnp.abs((energies - expected_energy)/expected_energy)*100
per_err = per_err[~jnp.isnan(per_err)]
axs[1].plot(per_err, linewidth=0.25, label="FermiNet")
axs[1].set_xlabel("iteration")
axs[1].set_ylabel("correlation energy error (%)")
axs[1].axhline(jnp.abs(1.59e-3/expected_energy)*100, color="y", linestyle="--", label="chemical accuracy")
axs[1].legend()
axs[1].set_yscale("log")
axs[1].patch.set_facecolor("white")

fig.suptitle("a LiH molecule ({:.3f} ± {:.3f} % error from last 400 iteration)".format(per_err[-400:].mean(), per_err[-400:].std()))

fig.tight_layout()
fig.patch.set_facecolor("white")

png

About

A library which trains the Fermionic Neural Network to find the ground state wave functions of an atom or a molecule using neural network quantum states.

Topics

Resources

License

Stars

Watchers

Forks

Languages