Skip to content

Commit

Permalink
handle float point error in sys_probs (#1919)
Browse files Browse the repository at this point in the history
Fix #1917.
  • Loading branch information
njzjz authored and root committed Mar 8, 2023
1 parent 5aa3adf commit 9e244fe
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
9 changes: 5 additions & 4 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,17 +496,18 @@ def _process_sys_probs(self, sys_probs) :
sys_probs = np.array(sys_probs)
type_filter = sys_probs >= 0
assigned_sum_prob = np.sum(type_filter * sys_probs)
assert assigned_sum_prob <= 1, "the sum of assigned probability should be less than 1"
# 1e-8 is to handle floating point error; See #1917
assert assigned_sum_prob <= 1. + 1e-8, "the sum of assigned probability should be less than 1"
rest_sum_prob = 1. - assigned_sum_prob
if rest_sum_prob != 0 :
if not np.isclose(rest_sum_prob, 0):
rest_nbatch = (1 - type_filter) * self.nbatches
rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch)
ret_prob = rest_prob + type_filter * sys_probs
else :
ret_prob = sys_probs
assert np.sum(ret_prob) == 1, "sum of probs should be 1"
assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1"
return ret_prob

def _prob_sys_size_ext(self, keywords):
block_str = keywords.split(';')[1:]
block_stt = []
Expand Down
37 changes: 36 additions & 1 deletion source/tests/test_deepmd_data_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,4 +289,39 @@ def _in_array(self, target, idx_map, ndof, array):
for idx,ii in enumerate(all_find) :
self.assertTrue(ii, msg = 'does not find frame %d in array' % idx)


def test_sys_prob_floating_point_error(self):
# test floating point error; See #1917
sys_probs = [
0.010,
0.010,
0.010,
0.010,
0.010,
0.010,
0.010,
0.010,
0.010,
0.150,
0.100,
0.100,
0.050,
0.050,
0.020,
0.015,
0.015,
0.050,
0.020,
0.015,
0.040,
0.055,
0.025,
0.025,
0.015,
0.025,
0.055,
0.040,
0.040,
0.005,
]
ds = DeepmdDataSystem(self.sys_name, 3, 2, 2.0, sys_probs=sys_probs)
self.assertEqual(ds.sys_probs.size, len(sys_probs))

0 comments on commit 9e244fe

Please sign in to comment.