Skip to content

Commit

Permalink
52 save velocities (#54)
Browse files Browse the repository at this point in the history
* save momenta

* improve tests
  • Loading branch information
PythonFZ authored Apr 17, 2023
1 parent 7b25fb0 commit d5313c4
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 3 deletions.
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def atoms_list(request) -> list[ase.Atoms]:
"""
if getattr(request, "param", "").startswith("vary_size"):
atoms = [ase.build.molecule(x) for x in ase.collections.g2.names]
for atom in atoms:
atom.set_velocities(np.random.rand(len(atom), 3))

else:
random.seed(1234)
atoms = [
Expand Down
22 changes: 22 additions & 0 deletions tests/test_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def test_AtomsReader(tmp_path, reader, atoms_list, use_add):
)


@pytest.mark.parametrize("atoms_list", ["vary_size"], indirect=True)
def test_momenta(tmp_path, atoms_list):
os.chdir(tmp_path)
print(tmp_path)

db = znh5md.io.DataWriter(filename="db.h5")
db.initialize_database_groups()
db.add(znh5md.io.AtomsReader(atoms_list))

data = znh5md.ASEH5MD("db.h5")
assert "momenta" in data[0].arrays

with pytest.raises(AssertionError):
# Assert array not equal zero array
npt.assert_array_equal(
data[0].get_velocities, np.zeros_like(data[0].get_positions())
)


@pytest.mark.parametrize("reader", [znh5md.io.ASEFileReader, znh5md.io.AtomsReader])
@pytest.mark.parametrize("atoms_list", ["vary_size", "vary_pbc_vary_pbc"], indirect=True)
def test_AtomsReader_with_pbc_group(tmp_path, reader, atoms_list):
Expand Down Expand Up @@ -96,6 +115,7 @@ def test_AtomsReader_with_pbc_group(tmp_path, reader, atoms_list):
npt.assert_array_almost_equal(a.get_potential_energy(), b.get_potential_energy())
npt.assert_array_equal(a.get_pbc(), b.get_pbc())
npt.assert_array_almost_equal(a.get_stress(), b.get_stress())
npt.assert_array_almost_equal(a.get_velocities(), b.get_velocities())

# now test with Dask
traj = znh5md.DaskH5MD("db.h5")
Expand Down Expand Up @@ -130,6 +150,8 @@ def test_AtomsReader_observables(tmp_path, atoms_list, save_atoms_results):
for a, b in zip(atoms, atoms_list):
for key in b.calc.results:
npt.assert_array_almost_equal(a.calc.results[key], b.calc.results[key])
assert "momenta" not in a.arrays
assert "momenta" not in b.arrays
else:
assert "predicted_energy" not in atoms[0].calc.results
assert "predicted_forces" not in atoms[0].calc.results
Expand Down
2 changes: 2 additions & 0 deletions znh5md/format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class GRP:
forces: str = "forces"
stress: str = "stress"
velocity: str = "velocity"
momentum: str = "momentum"
pbc: str = "pbc"
dimension: str = "dimension"

Expand Down Expand Up @@ -56,6 +57,7 @@ def decode_boundary(value) -> np.ndarray:
GRP.species,
GRP.forces,
GRP.velocity,
GRP.momentum,
GRP.pbc,
GRP.dimension,
GRP.edges,
Expand Down
10 changes: 9 additions & 1 deletion znh5md/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def _get_forces(self, atoms: list[ase.Atoms]) -> np.ndarray:
except ValueError:
return self._fill_with_nan(data).astype(float)

def _get_momenta(self, atoms: list[ase.Atoms]) -> np.ndarray:
data = [x.arrays["momenta"] for x in atoms]
try:
return np.array(data).astype(float)
except ValueError:
return self._fill_with_nan(data).astype(float)

def _get_stress(self, atoms: list[ase.Atoms]) -> np.ndarray:
return np.array([x.get_stress() for x in atoms]).astype(float)

Expand Down Expand Up @@ -113,6 +120,7 @@ def yield_chunks(
GRP.stress: self._get_stress,
GRP.edges: self._get_edges,
GRP.boundary: self._get_boundary,
GRP.momentum: self._get_momenta,
}
if self.use_pbc_group:
functions[GRP.pbc] = self._get_pbc
Expand All @@ -128,7 +136,7 @@ def yield_chunks(
step=self.step,
time=self.time,
)
except (PropertyNotImplementedError, RuntimeError) as err:
except (PropertyNotImplementedError, RuntimeError, KeyError) as err:
if group_names is not None:
# if the property was specifically selected, raise the error
raise err
Expand Down
5 changes: 3 additions & 2 deletions znh5md/znh5md/h5ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def _gather_value(particles_data, key, idx):
Returns None if the key is not present in the data.
"""
if key in particles_data:
if key in [GRP.species, GRP.position, GRP.velocity, GRP.forces]:
if key in [GRP.species, GRP.position, GRP.velocity, GRP.forces, GRP.momentum]:
# use PARTICLES_GRP
return rm_nan(particles_data[key][idx])
return particles_data[key][idx]
return None
Expand Down Expand Up @@ -108,7 +109,7 @@ def get_atoms_list(self, item=None) -> typing.List[ase.Atoms]:
obj = ase.Atoms(
symbols=_gather_value(particles_data, GRP.species, idx),
positions=_gather_value(particles_data, GRP.position, idx),
velocities=_gather_value(particles_data, GRP.velocity, idx),
momenta=_gather_value(particles_data, GRP.momentum, idx),
cell=_gather_value(particles_data, GRP.edges, idx),
pbc=_gather_value(particles_data, GRP.pbc, idx),
)
Expand Down

0 comments on commit d5313c4

Please sign in to comment.