Skip to content

Commit

Permalink
Use locking/unlocking for Info.update() (#9914)
Browse files Browse the repository at this point in the history
* Add update method using __setitem__.

* FIX: MNE functions using info.update().

* FIX: Tests using info.update() [circle full]

* FIX: Consider argument other as iterable of key/value pairs instead of Mapping. [circle full]

* FIX: failing tutorial. [circle full]

* Dont run all [skip azp] [skip actions]

* FIX: tutorial. [skip azp] [skip actions]

Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
mscheltienne and larsoner authored Oct 28, 2021
1 parent f10502b commit dac4f16
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 12 deletions.
3 changes: 2 additions & 1 deletion mne/beamformer/tests/test_dics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def _simulate_data(fwd, idx): # Somewhere on the frontal lobe by default

# Create an info object that holds information about the sensors
info = mne.create_info(fwd['info']['ch_names'], sfreq, ch_types='grad')
info.update(fwd['info']) # Merge in sensor position information
with info._unlock():
info.update(fwd['info']) # Merge in sensor position information
# heavily decimate sensors to make it much faster
info = mne.pick_info(info, np.arange(info['nchan'])[::5])
fwd = mne.pick_channels_forward(fwd, info['ch_names'])
Expand Down
3 changes: 2 additions & 1 deletion mne/beamformer/tests/test_lcmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def test_lcmv_vector():
# For speed and for rank-deficiency calculation simplicity,
# just use grads
info = mne.pick_info(info, mne.pick_types(info, meg='grad', exclude=()))
info.update(bads=[], projs=[])
with info._unlock():
info.update(bads=[], projs=[])

forward = mne.read_forward_solution(fname_fwd)
forward = mne.pick_channels_forward(forward, info['ch_names'])
Expand Down
4 changes: 2 additions & 2 deletions mne/io/fieldtrip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def _create_info(ft_struct, raw_info):
else:
info = create_info(ch_names, sfreq)
chs, dig = _create_info_chs_dig(ft_struct)
info.update(chs=chs, dig=dig)
info._update_redundant()
with info._unlock(update_redundant=True):
info.update(chs=chs, dig=dig)

return info

Expand Down
2 changes: 1 addition & 1 deletion mne/io/hitachi/hitachi.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ def __init__(self, fname, preload=False, *, verbose=None):

# Create mne structure
info = create_info(ch_names, sfreq, ch_types=ch_types)
info.update(info_extra)
with info._unlock():
info.update(info_extra)
info['meas_date'] = meas_date
for li, loc in enumerate(locs):
info['chs'][li]['loc'][:] = loc
Expand Down
10 changes: 10 additions & 0 deletions mne/io/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# License: BSD-3-Clause

from collections import Counter, OrderedDict
from collections.abc import Mapping
import contextlib
from copy import deepcopy
import datetime
Expand Down Expand Up @@ -726,6 +727,15 @@ def __setitem__(self, key, val):
DeprecationWarning)
super().__setitem__(key, val)

def update(self, other=None, **kwargs):
"""Update method using __setitem__()."""
iterable = other.items() if isinstance(other, Mapping) else other
if other is not None:
for key, val in iterable:
self[key] = val
for key, val in kwargs.items():
self[key] = val

@contextlib.contextmanager
def _unlock(self, *, update_redundant=False, check_after=False):
"""Context manager unlocking access to attributes."""
Expand Down
2 changes: 1 addition & 1 deletion mne/io/nirx/nirx.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def __init__(self, fname, saturated, preload=False, verbose=None):
info = create_info(chnames,
samplingrate,
ch_types='fnirs_cw_amplitude')
info.update(subject_info=subject_info, dig=dig)
with info._unlock():
info.update(subject_info=subject_info, dig=dig)
info['meas_date'] = meas_date

# Store channel, source, and detector locations
Expand Down
3 changes: 2 additions & 1 deletion mne/io/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None,
# Set other info-keys from original instance.
pick_info = {k: v for k, v in inst.info.items() if k not in
['chs', 'ch_names', 'bads', 'nchan', 'sfreq']}
ref_info.update(pick_info)
with ref_info._unlock():
ref_info.update(pick_info)

# Rereferencing of data.
ref_data = multiplier @ inst._data
Expand Down
6 changes: 4 additions & 2 deletions mne/simulation/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,8 @@ def add_chpi(raw, head_pos=None, interp='cos2', n_jobs=1, verbose=None):
sinusoids = 70e-9 * np.sin(2 * np.pi * hpi_freqs[:, np.newaxis] *
(np.arange(len(times)) / info['sfreq']))
info = pick_info(info, meg_picks)
info.update(projs=[], bads=[]) # Ensure no 'projs' or 'bads'
with info._unlock():
info.update(projs=[], bads=[]) # Ensure no 'projs' or 'bads'
megcoils, _, _, _ = _prep_meg_channels(info, ignore_ref=False)
used = np.zeros(len(raw.times), bool)
dev_head_ts.append(dev_head_ts[-1]) # ZOH after time ends
Expand Down Expand Up @@ -689,7 +690,8 @@ def _iter_forward_solutions(info, trans, src, bem, dev_head_ts, mindist,
"""Calculate a forward solution for a subject."""
logger.info('Setting up forward solutions')
info = pick_info(info, picks)
info.update(projs=[], bads=[]) # Ensure no 'projs' or 'bads'
with info._unlock():
info.update(projs=[], bads=[]) # Ensure no 'projs' or 'bads'
mri_head_t, trans = _get_trans(trans)
megcoils, meg_info, compcoils, megnames, eegels, eegnames, rr, info, \
update_kwargs, bem = _prepare_for_forward(
Expand Down
3 changes: 2 additions & 1 deletion mne/tests/test_cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,8 @@ def test_low_rank_cov(raw_epochs_events):
# test that rank=306 is same as rank='full'
epochs_meg = epochs.copy().pick_types(meg=True)
assert len(epochs_meg.ch_names) == 306
epochs_meg.info.update(bads=[], projs=[])
with epochs_meg.info._unlock():
epochs_meg.info.update(bads=[], projs=[])
cov_full = compute_covariance(epochs_meg, method='oas',
rank='full', verbose='error')
assert _cov_rank(cov_full, epochs_meg.info) == 306
Expand Down
3 changes: 1 addition & 2 deletions tutorials/simulation/80_dics.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ def coh_signal_gen():

# Read the info from the sample dataset. This defines the location of the
# sensors and such.
info = mne.io.read_info(raw_fname)
info.update(sfreq=sfreq, bads=[])
info = mne.io.read_raw(raw_fname).crop(0, 1).resample(50).info

# Only use gradiometers
picks = mne.pick_types(info, meg='grad', stim=True, exclude=())
Expand Down

0 comments on commit dac4f16

Please sign in to comment.