Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch Model #253

Draft
wants to merge 61 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
485e903
added multi backend implementation of most used ops
M-R-Schaefer Dec 29, 2023
67236ec
use ops in activvation.py
M-R-Schaefer Dec 29, 2023
e07cf05
switched to backend agnostic basis function impls
M-R-Schaefer Dec 29, 2023
55cf7ca
added multi backend einsum
M-R-Schaefer Jan 2, 2024
622724a
added todo
M-R-Schaefer Jan 2, 2024
feaf888
sketch of multi backend gm descriptor
M-R-Schaefer Jan 2, 2024
0bbdde8
Merge branch 'dev' into torch
M-R-Schaefer Feb 26, 2024
49b49e0
Merge branch 'dev' into torch
M-R-Schaefer Mar 27, 2024
d6bda34
restructured NN submodules for multi backend
M-R-Schaefer Apr 2, 2024
59246c3
implemented NTK linear
M-R-Schaefer Apr 2, 2024
8e1ff0c
added inverse softplus to activations
M-R-Schaefer Apr 2, 2024
1bdd421
switch to backend agnostic activations in empirical
M-R-Schaefer Apr 2, 2024
6fa5ce2
implemented atomistic readout
M-R-Schaefer Apr 2, 2024
51cce10
implemented scale shift layer
M-R-Schaefer Apr 2, 2024
9a0eb96
method name typos
M-R-Schaefer Apr 2, 2024
a1fa924
implemented torch descriptor
M-R-Schaefer Apr 2, 2024
8938c7f
implemented atomistic torch model
M-R-Schaefer Apr 3, 2024
5138587
linting
M-R-Schaefer Apr 3, 2024
015aee8
sketch of energy model
M-R-Schaefer Apr 3, 2024
c16c29a
Merge branch 'dev' into torch
M-R-Schaefer Apr 3, 2024
956d816
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2024
763bf14
sketch of derivative model and builder
M-R-Schaefer Apr 3, 2024
844a393
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2024
78e9ff8
Merge branch 'dev' into torch
M-R-Schaefer Apr 8, 2024
9f3f3c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
0552d4d
sketch of displacement function
M-R-Schaefer Apr 8, 2024
7fe06c1
Merge branch 'torch' of https://github.com/apax-hub/apax into torch
M-R-Schaefer Apr 8, 2024
380ac31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
8261cf2
added torch dependency
M-R-Schaefer Apr 8, 2024
a945e27
poetry update
M-R-Schaefer Apr 8, 2024
57f1b2e
renamed torch modules to avoid name clash
M-R-Schaefer Apr 8, 2024
06b965b
fixed various semantic errors
M-R-Schaefer Apr 8, 2024
3a634fa
first working torch model
M-R-Schaefer Apr 9, 2024
461b25c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2024
5c91a81
debug
M-R-Schaefer Apr 10, 2024
51dfe86
Merge branch 'torch' of https://github.com/apax-hub/apax into torch
M-R-Schaefer Apr 17, 2024
a0df48b
first working torch model export
M-R-Schaefer Apr 18, 2024
3e4315e
moved torchmodel to selfcontained impl
M-R-Schaefer Apr 20, 2024
66b414d
added back original jax layers
M-R-Schaefer Apr 25, 2024
7f2d31a
missing file
M-R-Schaefer Apr 25, 2024
cc23f3d
removed multi backend jax impl
M-R-Schaefer Apr 26, 2024
7f77443
implemented initialization from jax weights for linear and atomistic …
M-R-Schaefer Apr 26, 2024
9db6953
integration tests for torch layers
M-R-Schaefer Apr 26, 2024
185d1e4
added back old model builder
M-R-Schaefer Apr 26, 2024
2c242a2
moved integration tests
M-R-Schaefer Apr 26, 2024
dd18f3f
fixed torch layers to reproduce jax results
M-R-Schaefer Apr 26, 2024
6572016
completed integration tests
M-R-Schaefer Apr 29, 2024
e49b523
fixed ncontr train config
M-R-Schaefer Apr 29, 2024
c15ba8b
added torch ase calculator
M-R-Schaefer Apr 29, 2024
9c191b7
Merge branch 'dev' into torch
M-R-Schaefer Apr 29, 2024
07cace9
poetry update
M-R-Schaefer Apr 29, 2024
1b26df1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2024
c77cce2
implemented neighborlist reuse in ase calc
M-R-Schaefer May 15, 2024
9b16d8c
removed unused ops module
M-R-Schaefer May 15, 2024
2de9bd0
linting
M-R-Schaefer May 15, 2024
4ebd04d
removed unused torch builder
M-R-Schaefer May 15, 2024
31454da
fixed vectorized readout model. implemented pbc calculation
M-R-Schaefer May 15, 2024
d88df67
added torchscript compatible scatter add implementation
M-R-Schaefer May 15, 2024
320b03f
update import path
M-R-Schaefer May 15, 2024
7078070
fixed tests, made torch model compatible with pbc
M-R-Schaefer May 15, 2024
02c4b96
linting
M-R-Schaefer May 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class ModelConfig(BaseModel, extra="forbid"):
n_radial: PositiveInt = 5
r_min: NonNegativeFloat = 0.5
r_max: PositiveFloat = 6.0
n_contr: int = -1
n_contr: int = 8
emb_init: Optional[str] = "uniform"

nn: List[PositiveInt] = [512, 512]
Expand Down
1 change: 1 addition & 0 deletions apax/layers/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def setup(self):
scale = einops.repeat(scale, "species -> species 1")
if len(shift.shape) == 1:
shift = einops.repeat(shift, "species -> species 1")

scale_init = nn.initializers.constant(scale)
shift_init = nn.initializers.constant(shift)

Expand Down
84 changes: 84 additions & 0 deletions apax/md/torch_ase_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from pathlib import Path
from typing import Callable, Union

import numpy as np
import torch
from ase.calculators.calculator import Calculator, all_changes
from matscipy.neighbours import neighbour_list


class TorchASECalculator(Calculator):
"""
ASE Calculator for apax models.
"""

implemented_properties = [
"energy",
"forces",
]

def __init__(
self,
model_path: Union[Path, list[Path]],
dr_threshold: float = 0.5,
transformations: Callable = [],
**kwargs
):
Calculator.__init__(self, **kwargs)
self.skin = dr_threshold

self.model = torch.jit.load(model_path)
self.r_max = (
self.model.energy_model.atomistic_model.descriptor.radial_fn.basis_fn.r_max
)

self.step = None
self.neighbor_fn = None
self.neighbors = None
self.offsets = None
self.pos0 = 0
self.Z = [0, 0]
self.pbc = False

def set_neighbours_and_offsets(self, atoms, box):
condition = (
np.any(self.pbc != atoms.pbc)
or len(self.Z) != len(atoms.numbers)
or np.max(np.sum(((self.pos0 - atoms.positions) ** 2), axis=1))
> self.skin**2 / 4.0
)
if condition:
idxs_i, idxs_j, offsets = neighbour_list(
"ijS", positions=atoms.positions, pbc=atoms.pbc, cutoff=self.r_max
)

self.neighbors = np.array([idxs_i, idxs_j], dtype=np.int32)
self.offsets = np.matmul(offsets, box) # np.zeros_like(self.neighbors) #
self.pos0 = atoms.positions
self.Z = atoms.numbers
self.pbc = atoms.pbc

def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
Calculator.calculate(self, atoms, properties, system_changes)
positions = atoms.positions
box = atoms.cell.array
# if np.any(atoms.pbc):
# positions = atoms.positions @ np.linalg.inv(box)

# predict
self.set_neighbours_and_offsets(atoms, box)

inputt = (
torch.from_numpy(positions),
torch.from_numpy(atoms.numbers),
torch.from_numpy(np.asarray(self.neighbors, dtype=np.int64)),
torch.from_numpy(np.asarray(box, dtype=np.float64)),
torch.from_numpy(np.asarray(self.offsets, dtype=np.float64)),
)

results = self.model(*inputt)

self.results = {
k: np.array(v.detach().numpy(), dtype=np.float64) for k, v in results.items()
}
self.results["energy"] = self.results["energy"].item()
1 change: 0 additions & 1 deletion apax/model/gmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __call__(
gm = self.descriptor(dr_vec, Z, idx)
h = jax.vmap(self.readout)(gm)
output = self.scale_shift(h, Z)

if self.mask_atoms:
output = mask_by_atom(output, Z)
return output
Expand Down
Empty file added apax/nn/__init__.py
Empty file.
Empty file added apax/nn/common.py
Empty file.
Empty file added apax/nn/impl/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions apax/nn/impl/triangular_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np

# import numpy as np

# def tril_2d_indices(n: int):
# indices = np.zeros((int(n * (n + 1) / 2), 2), dtype=int)
# sparse_idx = 0
# for i in range(0, n):
# for j in range(i, n):
# indices[sparse_idx] = i, j
# sparse_idx += 1

# return jnp.asarray(indices)


# def tril_3d_indices(n: int):
# indices = np.zeros((int(n * (n + 1) * (n + 2) / 6), 3), dtype=int)
# sparse_idx = 0
# for i in range(0, n):
# for j in range(i, n):
# for k in range(j, n):
# indices[sparse_idx] = i, j, k
# sparse_idx += 1

# return jnp.asarray(indices)


def tril_2d_indices(n_radial):
tril_idxs = []
for i in range(n_radial):
tril_idxs.append([i, i])
for j in range(i + 1, n_radial):
tril_idxs.append([i, j])
tril_idxs = np.array(tril_idxs)
return tril_idxs


def tril_3d_indices(n_radial):
tril_idxs = []
for i in range(n_radial):
tril_idxs.append([i, i, i])
for j in range(n_radial):
if j != i:
tril_idxs.append([i, j, j])
for j in range(i + 1, n_radial):
for k in range(j + 1, n_radial):
tril_idxs.append([i, j, k])
tril_idxs = np.array(tril_idxs)
return tril_idxs
Empty file added apax/nn/torch/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions apax/nn/torch/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from apax.nn.torch.layers import descriptor, ntk_linear, scaling

__all__ = ["descriptor", "ntk_linear", "scaling"]
11 changes: 11 additions & 0 deletions apax/nn/torch/layers/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
import torch.nn as nn


class SwishT(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x):
h = 1.6765324703310907 * torch.nn.functional.silu(x)
return h
5 changes: 5 additions & 0 deletions apax/nn/torch/layers/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from apax.nn.torch.layers.descriptor.gaussian_moment_descriptor import (
GaussianMomentDescriptorT,
)

__all__ = ["GaussianMomentDescriptorT"]
114 changes: 114 additions & 0 deletions apax/nn/torch/layers/descriptor/basis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Any

import einops
import numpy as np
import torch
import torch.nn as nn


def cosine_cutoff(dr, r_max):
# shape: neighbors
dr_clipped = torch.clamp(dr, max=r_max)
cos_cutoff = 0.5 * (torch.cos(np.pi * dr_clipped / r_max) + 1.0)
cutoff = cos_cutoff[:, None]
return cutoff


class GaussianBasisT(nn.Module):
def __init__(
self,
n_basis: int = 7,
r_min: float = 0.5,
r_max: float = 6.0,
dtype: Any = torch.float32,
) -> None:
super().__init__()
self.n_basis = n_basis
self.r_min = r_min
self.r_max = r_max
self.dtype = dtype

self.betta = self.n_basis**2 / self.r_max**2
self.rad_norm = (2.0 * self.betta / np.pi) ** 0.25
shifts = self.r_min + (self.r_max - self.r_min) / self.n_basis * np.arange(
self.n_basis
)

self.betta = torch.tensor(self.betta)
self.rad_norm = torch.tensor(self.rad_norm)

# shape: 1 x n_basis
shifts = einops.repeat(shifts, "n_basis -> 1 n_basis")
self.shifts = torch.tensor(shifts, dtype=self.dtype)

def forward(self, dr: torch.Tensor) -> torch.Tensor:
# dr shape: neighbors
# neighbors -> neighbors x 1
dr = dr[:, None].type(self.dtype)
# 1 x n_basis, neighbors x 1 -> neighbors x n_basis
distances = self.shifts - dr

# shape: neighbors x n_basis
basis = torch.exp(-self.betta * (distances**2))
basis = self.rad_norm * basis
return basis


class RadialFunctionT(nn.Module):
def __init__(
self,
n_radial: int = 5,
basis_fn: nn.Module = GaussianBasisT(),
emb_init: str = "uniform",
n_species: int = 119,
params=None,
dtype: Any = torch.float32,
) -> None:
super().__init__()
self.n_radial = n_radial
self.basis_fn = basis_fn
self.n_species = n_species
self.emb_init = emb_init
self.dtype = dtype

self.r_max = torch.tensor(self.basis_fn.r_max)
norm = 1.0 / np.sqrt(self.basis_fn.n_basis)
self.embed_norm = torch.tensor(norm, dtype=self.dtype)
self.embeddings = None

if params:
emb = params["atomic_type_embedding"]
emb = torch.from_numpy(np.array(emb))
self.embeddings = nn.Parameter(emb)
elif self.emb_init is not None:
self.n_radial = n_radial
emb = torch.rand(
(self.n_species, self.n_species, self.n_radial, self.basis_fn.n_basis)
)
self.embeddings = nn.Parameter(emb)
else:
self.n_radial = self.basis_fn.n_basis

def forward(self, dr, Z_i, Z_j):
dr = dr.type(self.dtype)
# basis shape: neighbors x n_basis
basis = self.basis_fn(dr)

if self.embeddings is None:
radial_function = basis
else:
# coeffs shape: n_neighbors x n_radialx n_basis
# reverse convention to match original
species_pair_coeffs = self.embeddings[Z_j, Z_i, ...]
species_pair_coeffs = self.embed_norm * species_pair_coeffs

radial_function = torch.einsum(
"nrb, nb -> nr",
species_pair_coeffs,
basis,
)
cutoff = cosine_cutoff(dr, self.r_max)
radial_function = radial_function * cutoff

assert radial_function.dtype == self.dtype
return radial_function
Loading
Loading