Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

full ensemble fix, feature model fix #332

Merged
merged 14 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ def make_ensemble(model):
def ensemble(positions, Z, idx, box, offsets):
results = model(positions, Z, idx, box, offsets)
uncertainty = {k + "_uncertainty": jnp.std(v, axis=0) for k, v in results.items()}
ensemble = {k + "_ensemble": v for k, v in results.items()}
results = {k: jnp.mean(v, axis=0) for k, v in results.items()}
results.update(uncertainty)
results.update(ensemble)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean, when using any kind of ensemble, we will have:

{energy: float, energy_ensemble: list[float], forces: ndarray[N, 3], forces_ensemble: ndarray[m, N, 3]}

and so on? I see where this can be very useful, but for the sake of performance it might be a good idea to make this toggleable? For e.g. saving to h5 the amount of data will increase noticeably?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also opened zincware/ZnH5MD#136 but I think this option should be available on both sides?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes these are now available for all types of ensemble. passing arguments to the ASE calculator in IPS is a bit cumbersome at the moment. You always need a second node in addition to the training node (otherwise you have to retrain when changing the ASE options).
The inference speed should actually not be affected, but yes the storage size will increase significantly. Storing the ensemble predictions is relevant for some UQ metrics and reweighting.
I guess I can add an optional white list to the ASE calculator which filters the results by keys. same for jaxMD.
it seems like a cleaner solution to me to specify this in the h5 writer, but I can also implement it on the apax side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it doesn't hurt to implement it here as well, but I agree that it is more useful on the znh5md side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, I forgot about ASE. Yeah I can do that for ASE as well, but it's not going to be very ergonomic for IPSuite.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to say that this would be really unergonomical when using the ASECalculator in IPSuite. I think models should return what they return and which properties are logged should be part of the trajectory writer. If I implement this as part of the ASECalc, it would mean that we train a model with a specified list of outputs and changing these properties would require a retraining or using a separate Node.
So I think this was a good addition to jaxMD (as it is implemented in the trajectory writer), but for the ASECalculator it makes everything more complicated.


return results

Expand Down
6 changes: 4 additions & 2 deletions apax/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def __call__(
perturbation,
)

gm = self.descriptor(dr_vec, Z, idx)
features = jax.vmap(self.readout)(gm)
features = self.descriptor(dr_vec, Z, idx)

if self.readout:
features = jax.vmap(self.readout)(features)
PythonFZ marked this conversation as resolved.
Show resolved Hide resolved

if self.mask_atoms:
features = mask_by_atom(features, Z)
Expand Down
Loading