Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MBAR error as uncertainty with a single protocol repeat in RBFE #883

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
43 changes: 29 additions & 14 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def _generate_bad_legs_error_message(set_vals, ligpair):
def _parse_raw_units(results: dict) -> list[tuple]:
# grab individual unit results from master results dict
# returns list of (estimate, uncertainty) tuples
list_of_pur = list(results['protocol_result']['data'].values())[0]
list_of_pur = list(results['protocol_result']['data'].values())

return [(pu['outputs']['unit_estimate'],
pu['outputs']['unit_estimate_error'])
# could add to each tuple pu[0]["source_key"] for ID
return [(pu[0]['outputs']['unit_estimate'],
pu[0]['outputs']['unit_estimate_error'])
for pu in list_of_pur]


Expand Down Expand Up @@ -178,10 +179,10 @@ def _get_ddgs(legs, error_on_missing=True):
return DDGs


def _write_ddg(legs, writer, allow_partial):
def _write_ddg(legs, writer, allow_partial): # unc
DDGs = _get_ddgs(legs, error_on_missing=not allow_partial)
writer.writerow(["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)",
"uncertainty (kcal/mol)"])
"uncertainty (kcal/mol)"]) # unc])
for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs:
if DDGbind is not None:
DDGbind, bind_unc = format_estimate_uncertainty(DDGbind, bind_unc)
Expand All @@ -191,19 +192,19 @@ def _write_ddg(legs, writer, allow_partial):
writer.writerow([ligA, ligB, DDGhyd, hyd_unc])


def _write_raw(legs, writer, allow_partial=True):
writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)",
"MBAR uncertainty (kcal/mol)"])
def _write_raw(legs, writer, allow_partial=True): # *args?
writer.writerow(["leg", "repeat", "ligand_i", "ligand_j",
"DG(i->j) (kcal/mol)", "MBAR uncertainty (kcal/mol)"])

for ligpair, vals in sorted(legs.items()):
for simtype, repeats in sorted(vals.items()):
for m, u in repeats:
for rep, (m, u) in enumerate(repeats, 1):
if m is None:
m, u = 'NaN', 'NaN'
else:
m, u = format_estimate_uncertainty(m.m, u.m)

writer.writerow([simtype, *ligpair, m, u])
writer.writerow([simtype, rep, *ligpair, m, u])


def _write_dg_raw(legs, writer, allow_partial): # pragma: no-cover
Expand All @@ -218,7 +219,7 @@ def _write_dg_raw(legs, writer, allow_partial): # pragma: no-cover
writer.writerow([simtype, *ligpair, m, u])


def _write_dg_mle(legs, writer, allow_partial):
def _write_dg_mle(legs, writer, allow_partial): # unc
import networkx as nx
import numpy as np
from cinnabar.stats import mle
Expand Down Expand Up @@ -264,7 +265,7 @@ def _write_dg_mle(legs, writer, allow_partial):
MLEs.append((ligname, f, df))

writer.writerow(["ligand", "DG(MLE) (kcal/mol)",
"uncertainty (kcal/mol)"])
"uncertainty (kcal/mol)"]) # unc])
for ligA, DG, unc_DG in MLEs:
DG, unc_DG = format_estimate_uncertainty(DG, unc_DG)
writer.writerow([ligA, DG, unc_DG])
Expand Down Expand Up @@ -336,6 +337,9 @@ def gather(rootdir, output, report, allow_partial):
# 3) pair legs of simulations together into dict of dicts
legs = defaultdict(dict)

######## CHECK IF ALL RESULTS HAVE SAME # OF PROTOCOLUNITS?
# MBAR_errors = True

for result_fn in result_fns:
result = load_results(result_fn)
if result is None:
Expand All @@ -344,6 +348,8 @@ def gather(rootdir, output, report, allow_partial):
click.echo(f"WARNING: Calculations for {result_fn} did not finish successfully!",
err=True)



try:
names = get_names(result)
except KeyError:
Expand All @@ -353,8 +359,15 @@ def gather(rootdir, output, report, allow_partial):
except KeyError:
simtype = legacy_get_type(result_fn)

raw_units = _parse_raw_units(result)
######## CHECK IF ALL RESULTS HAVE SAME # OF PROTOCOLUNITS?
# if MBAR_errors and len(raw_units) > 1:
# MBAR_errors = False

if report.lower() == 'raw':
legs[names][simtype] = _parse_raw_units(result)
legs[names][simtype] = raw_units
elif len(raw_units) == 1:
legs[names][simtype] = raw_units[0]
else:
legs[names][simtype] = result['estimate'], result['uncertainty']

Expand All @@ -364,6 +377,8 @@ def gather(rootdir, output, report, allow_partial):
lineterminator="\n", # to exactly reproduce previous, prefer "\r\n"
)

# unc = "MBAR uncertainty (kcal/mol)" if MBAR_errors else "uncertainty (kcal/mol)"

# 5a) write out MLE values
# 5b) write out DDG values
# 5c) write out each leg
Expand All @@ -373,7 +388,7 @@ def gather(rootdir, output, report, allow_partial):
# 'dg-raw': _write_dg_raw,
'raw': _write_raw,
}[report.lower()]
writing_func(legs, writer, allow_partial)
writing_func(legs, writer, allow_partial) # , unc)


PLUGIN = OFECommandPlugin(
Expand Down
Loading