Skip to content

Commit

Permalink
Merge pull request #332 from apax-hub/erbs_fixes
Browse files Browse the repository at this point in the history
full ensemble fix, feature model fix
  • Loading branch information
M-R-Schaefer authored Sep 9, 2024
2 parents b97acd1 + 137676e commit b601606
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 27 deletions.
14 changes: 12 additions & 2 deletions apax/config/md_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os

# from types import UnionType
from typing import Literal, Union

import yaml
from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt
from typing_extensions import Annotated

from apax.utils.helpers import APAX_PROPERTIES


class ConstantTempSchedule(BaseModel, extra="forbid"):
"""Constant temperature schedule.
Expand Down Expand Up @@ -234,6 +234,14 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
extra_capacity : int, default = 0
| JaxMD allocates a maximal number of neighbors. This argument lets you add
| additional capacity to avoid recompilation. The default is usually fine.
dynamics_checks: list[DynamicsCheck]
| List of termination criteria. Currently energy and force uncertainty
| are available
properties: list[str]
| Whitelist of properties to be saved in the trajectory.
| This does not effect what the model will calculate, e.g..
| an ensemble will still calculate uncertainties.
initial_structure : str, required
| Path to the starting structure of the simulation.
sim_dir : str, default = "."
Expand Down Expand Up @@ -266,6 +274,8 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):

dynamics_checks: list[DynamicsCheck] = []

properties: list[str] = APAX_PROPERTIES

initial_structure: str
load_momenta: bool = False
sim_dir: str = "."
Expand Down
10 changes: 10 additions & 0 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,18 @@ def make_ensemble(model):
def ensemble(positions, Z, idx, box, offsets):
results = model(positions, Z, idx, box, offsets)
uncertainty = {k + "_uncertainty": jnp.std(v, axis=0) for k, v in results.items()}
ensemble = {k + "_ensemble": v for k, v in results.items()}
results = {k: jnp.mean(v, axis=0) for k, v in results.items()}
if "forces_ensemble" in ensemble.keys():
ensemble["forces_ensemble"] = jnp.transpose(
ensemble["forces_ensemble"], (1, 2, 0)
)
if "forces_ensemble" in ensemble.keys():
ensemble["stress_ensemble"] = jnp.transpose(
ensemble["forces_ensemble"], (1, 2, 0)
)
results.update(uncertainty)
results.update(ensemble)

return results

Expand Down
26 changes: 20 additions & 6 deletions apax/md/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,29 @@
from ase.calculators.singlepoint import SinglePointCalculator

from apax.md.sim_utils import System
from apax.utils.helpers import APAX_PROPERTIES
from apax.utils.jax_md_reduced import space

log = logging.getLogger(__name__)


class TrajHandler:
def __init__(self) -> None:
self.system: System
self.sampling_rate: int
self.buffer_size: int
self.traj_path: Path
self.time_step: float
def __init__(
self,
system: System,
sampling_rate: int,
buffer_size: int,
traj_path: Path,
time_step: float = 0.5,
properties: list[str] = APAX_PROPERTIES,
) -> None:
self.atomic_numbers = system.atomic_numbers
self.box = system.box
self.fractional = np.any(self.box > 1e-6)
self.sampling_rate = sampling_rate
self.traj_path = traj_path
self.time_step = time_step
self.properties = properties

def step(self, state_and_energy, transform=None):
pass
Expand Down Expand Up @@ -53,6 +64,7 @@ def atoms_from_state(self, state, predictions, nbr_kwargs):
atoms.pbc = np.diag(atoms.cell.array) > 1e-6
predictions = {k: np.array(v) for k, v in predictions.items()}
predictions["energy"] = predictions["energy"].item()
predictions = {k: v for k, v in predictions.items() if k in self.properties}
atoms.calc = SinglePointCalculator(atoms, **predictions)
return atoms

Expand All @@ -65,13 +77,15 @@ def __init__(
buffer_size: int,
traj_path: Path,
time_step: float = 0.5,
properties: list[str] = [],
) -> None:
self.atomic_numbers = system.atomic_numbers
self.box = system.box
self.fractional = np.any(self.box > 1e-6)
self.sampling_rate = sampling_rate
self.traj_path = traj_path
self.time_step = time_step
self.properties = properties
self.db = znh5md.IO(
self.traj_path, timestep=self.time_step, store="time", save_units=False
)
Expand Down
5 changes: 3 additions & 2 deletions apax/md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ def run_sim(
n_inner: int,
extra_capacity: int,
rng_key: int,
traj_handler: TrajHandler,
load_momenta: bool = False,
restart: bool = True,
checkpoint_interval: int = 50_000,
traj_handler: TrajHandler = TrajHandler(),
dynamics_checks: list[DynamicsCheckBase] = [],
disable_pbar: bool = False,
):
Expand Down Expand Up @@ -520,6 +520,7 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"):
md_config.buffer_size,
traj_path,
md_config.ensemble.dt,
properties=md_config.properties,
)
# TODO implement correct chunking

Expand All @@ -531,10 +532,10 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"):
n_inner=md_config.n_inner,
extra_capacity=md_config.extra_capacity,
load_momenta=md_config.load_momenta,
traj_handler=traj_handler,
rng_key=jax.random.PRNGKey(md_config.seed),
restart=md_config.restart,
checkpoint_interval=md_config.checkpoint_interval,
sim_dir=sim_dir,
traj_handler=traj_handler,
dynamics_checks=dynamics_checks,
)
10 changes: 7 additions & 3 deletions apax/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def __call__(
perturbation,
)

gm = self.descriptor(dr_vec, Z, idx)
features = jax.vmap(self.readout)(gm)
features = self.descriptor(dr_vec, Z, idx)

if self.readout:
features = jax.vmap(self.readout)(features)

if self.mask_atoms:
features = mask_by_atom(features, Z)
Expand Down Expand Up @@ -268,7 +270,9 @@ def __call__(

prediction["forces"] = forces_mean
prediction["forces_uncertainty"] = jnp.sqrt(forces_variance)
prediction["forces_ensemble"] = forces_ens

forces_ens = jnp.transpose(forces_ens, (1, 2, 0))
prediction["forces_ensemble"] = forces_ens # n_atoms x 3 x n_members

else:
forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets)
Expand Down
22 changes: 13 additions & 9 deletions apax/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,26 @@

import yaml

APAX_PROPERTIES = [
"energy",
"forces",
"stress",
"forces_uncertainty",
"energy_uncertainty",
"stress_uncertainty",
"energy_ensemble",
"forces_ensemble",
"stress_ensemble",
]


def setup_ase():
"""Add uncertainty keys to ASE all properties.
from https://github.com/zincware/IPSuite/blob/main/ipsuite/utils/helpers.py#L10
"""
from ase.calculators.calculator import all_properties

additional_keys = [
"forces_uncertainty",
"energy_uncertainty",
"stress_uncertainty",
"energy_ensemble",
"forces_ensemble",
]

for val in additional_keys:
for val in APAX_PROPERTIES:
if val not in all_properties:
all_properties.append(val)

Expand Down
13 changes: 9 additions & 4 deletions tests/integration_tests/md/md_config_threshold.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
ensemble:
name: nvt
dt: 0.1 # fs time step
dt: 0.2 # fs time step
temperature_schedule:
name: piecewise
T0: 5 # K
T0: 50 # K
values: [100, 200, 1000]
steps: [10, 20, 30]

duration: 100 # fs
duration: 500 # fs
n_inner: 1
sampling_rate: 1
checkpoint_interval: 2
restart: True
dynamics_checks:
- name: forces_uncertainty
threshold: 1.0
threshold: 0.01
properties:
- energy
- forces
- energy_uncertainty
- forces_ensemble
12 changes: 11 additions & 1 deletion tests/integration_tests/md/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset):
}
model_config_dict = load_config_and_run_training(model_confg_path, data_config_mods)

md_confg_path = TEST_PATH / "md_config.yaml"
md_confg_path = TEST_PATH / "md_config_threshold.yaml"

with open(md_confg_path.as_posix(), "r") as stream:
md_config_dict = yaml.safe_load(stream)
Expand All @@ -214,3 +214,13 @@ def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset):

traj = znh5md.IO(md_config.sim_dir + "/" + md_config.traj_name)[:]
assert len(traj) < 1000 # num steps

results_keys = list(traj[0].calc.results.keys())

assert "energy" in results_keys
assert "forces" in results_keys
assert "energy_uncertainty" in results_keys
assert "forces_ensemble" in results_keys

assert "energy_ensemble" not in results_keys
assert "forces_uncertainty" not in results_keys

0 comments on commit b601606

Please sign in to comment.