diff --git a/doc/conf.py b/doc/conf.py index 49b4f028fee..fdb8ae0bd36 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -259,7 +259,7 @@ 'mapping', 'to', 'any', # unlinkable 'mayavi.mlab.pipeline.surface', - 'CoregFrame', 'Kit2FiffFrame', 'FiducialsFrame', + 'CoregFrame', 'Kit2FiffFrame', 'FiducialsFrame', 'CoregistrationUI', 'IntracranialElectrodeLocator' } numpydoc_validate = True diff --git a/environment.yml b/environment.yml index 3b716fbff67..b406d5a3390 100644 --- a/environment.yml +++ b/environment.yml @@ -23,6 +23,7 @@ dependencies: - spyder-kernels>=1.10.0 - imageio-ffmpeg>=0.4.1 - vtk>=9.0.1 +- traitlets - pyvista>=0.30 - pyvistaqt>=0.4 - qdarkstyle diff --git a/mne/gui/__init__.py b/mne/gui/__init__.py index 3e18053cdea..6203192cd5a 100644 --- a/mne/gui/__init__.py +++ b/mne/gui/__init__.py @@ -6,7 +6,7 @@ import os -from ..utils import _check_mayavi_version, verbose, get_config +from ..utils import _check_mayavi_version, verbose, get_config, warn from ._backend import _testing_mode @@ -117,7 +117,7 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, Returns ------- - frame : instance of CoregFrame + frame : instance of CoregFrame or CoregistrationUI The coregistration frame. Notes @@ -132,6 +132,32 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, subjects for which no MRI is available `_. """ + from ..viz.backends.renderer import _get_3d_backend + pyvistaqt = _get_3d_backend() == 'pyvistaqt' + if pyvistaqt: + # unsupported parameters + params = { + 'tabbed': (tabbed, False), + 'split': (split, True), + 'scrollable': (scrollable, True), + 'head_inside': (head_inside, True), + 'guess_mri_subject': guess_mri_subject, + 'head_opacity': head_opacity, + 'project_eeg': project_eeg, + 'scale_by_distance': scale_by_distance, + 'mark_inside': mark_inside, + 'interaction': interaction, + 'scale': scale, + 'advanced_rendering': advanced_rendering, + } + for key, val in params.items(): + if isinstance(val, tuple): + to_raise = val[0] != val[1] + else: + to_raise = val is not None + if to_raise: + warn(f"The parameter {key} is not supported with" + " the pyvistaqt 3d backend. It will be ignored.") config = get_config(home_dir=os.environ.get('_MNE_FAKE_HOME_DIR')) if guess_mri_subject is None: guess_mri_subject = config.get( @@ -174,20 +200,32 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, width = int(width) height = int(height) scale = float(scale) - _check_mayavi_version() - from ._backend import _check_backend - _check_backend() - from ._coreg_gui import CoregFrame, _make_view - view = _make_view(tabbed, split, width, height, scrollable) - frame = CoregFrame(inst, subject, subjects_dir, guess_mri_subject, - head_opacity, head_high_res, trans, config, - project_eeg=project_eeg, - orient_to_surface=orient_to_surface, - scale_by_distance=scale_by_distance, - mark_inside=mark_inside, interaction=interaction, - scale=scale, advanced_rendering=advanced_rendering, - head_inside=head_inside) - return _initialize_gui(frame, view) + if pyvistaqt: + from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + from ._coreg import CoregistrationUI + show = not MNE_3D_BACKEND_TESTING + standalone = not MNE_3D_BACKEND_TESTING + return CoregistrationUI( + info_file=inst, subject=subject, subjects_dir=subjects_dir, + head_resolution=head_high_res, orient_glyphs=orient_to_surface, + trans=trans, size=(width, height), show=show, standalone=standalone, + verbose=verbose + ) + else: + _check_mayavi_version() + from ._backend import _check_backend + _check_backend() + from ._coreg_gui import CoregFrame, _make_view + view = _make_view(tabbed, split, width, height, scrollable) + frame = CoregFrame(inst, subject, subjects_dir, guess_mri_subject, + head_opacity, head_high_res, trans, config, + project_eeg=project_eeg, + orient_to_surface=orient_to_surface, + scale_by_distance=scale_by_distance, + mark_inside=mark_inside, interaction=interaction, + scale=scale, advanced_rendering=advanced_rendering, + head_inside=head_inside) + return _initialize_gui(frame, view) def fiducials(subject=None, fid_file=None, subjects_dir=None): diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py new file mode 100644 index 00000000000..7637b6dd4b4 --- /dev/null +++ b/mne/gui/_coreg.py @@ -0,0 +1,976 @@ +from contextlib import contextmanager +from functools import partial +import os +import os.path as op + +import numpy as np +from traitlets import observe, HasTraits, Unicode, Bool, Float + +from ..defaults import DEFAULTS +from ..io import read_info, read_fiducials +from ..io.pick import pick_types +from ..coreg import Coregistration, _is_mri_subject +from ..viz._3d import (_plot_head_surface, _plot_head_fiducials, + _plot_head_shape_points, _plot_mri_fiducials, + _plot_hpi_coils, _plot_sensors) +from ..transforms import (read_trans, write_trans, _ensure_trans, + rotation_angles, _get_transforms_to_coord_frame) +from ..utils import get_subjects_dir, check_fname, _check_fname, fill_doc, warn + + +@fill_doc +class CoregistrationUI(HasTraits): + """Class for coregistration assisted by graphical interface. + + Parameters + ---------- + info_file : None | str + The FIFF file with digitizer data for coregistration. + %(subject)s + %(subjects_dir)s + fiducials : list | dict | str + The fiducials given in the MRI (surface RAS) coordinate + system. If a dict is provided it must be a dict with 3 entries + with keys 'lpa', 'rpa' and 'nasion' with as values coordinates in m. + If a list it must be a list of DigPoint instances as returned + by the read_fiducials function. + If set to 'estimated', the fiducials are initialized + automatically using fiducials defined in MNI space on fsaverage + template. If set to 'auto', one tries to find the fiducials + in a file with the canonical name (``bem/{subject}-fiducials.fif``) + and if abstent one falls back to 'estimated'. Defaults to 'auto'. + head_resolution : bool + If True, use a high-resolution head surface. Defaults to False. + head_transparency : bool + If True, display the head surface with transparency. Defaults to False. + hpi_coils : bool + If True, display the HPI coils. Defaults to True. + head_shape_points : bool + If True, display the head shape points. Defaults to True. + eeg_channels : bool + If True, display the EEG channels. Defaults to True. + orient_glyphs : bool + If True, orient the sensors towards the head surface. Default to False. + sensor_opacity : float + The opacity of the sensors between 0 and 1. Defaults to 1.0. + trans : str + The path to the Head<->MRI transform FIF file ("-trans.fif"). + size : tuple + The dimensions (width, height) of the rendering view. The default is + (800, 600). + bgcolor : tuple | str + The background color as a tuple (red, green, blue) of float + values between 0 and 1 or a valid color name (i.e. 'white' + or 'w'). Defaults to 'grey'. + show : bool + Display the window as soon as it is ready. Defaults to True. + standalone : bool + If True, start the Qt application event loop. Default to False. + %(verbose)s + """ + + _subject = Unicode() + _subjects_dir = Unicode() + _lock_fids = Bool() + _fiducials_file = Unicode() + _current_fiducial = Unicode() + _info_file = Unicode() + _orient_glyphs = Bool() + _hpi_coils = Bool() + _head_shape_points = Bool() + _eeg_channels = Bool() + _head_resolution = Bool() + _head_transparency = Bool() + _grow_hair = Float() + _scale_mode = Unicode() + _icp_fid_match = Unicode() + + def __init__(self, info_file, subject=None, subjects_dir=None, + fiducials='auto', head_resolution=None, + head_transparency=None, hpi_coils=None, + head_shape_points=None, eeg_channels=None, orient_glyphs=None, + sensor_opacity=None, trans=None, size=None, bgcolor=None, + show=True, standalone=False, verbose=None): + from ..viz.backends.renderer import _get_renderer + + def _get_default(var, val): + return var if var is not None else val + self._actors = dict() + self._surfaces = dict() + self._widgets = dict() + self._verbose = verbose + self._plot_locked = False + self._head_geo = None + self._coord_frame = "mri" + self._mouse_no_mvt = -1 + self._to_cf_t = None + self._omit_hsp_distance = 0.0 + self._head_opacity = 1.0 + self._fid_colors = tuple( + DEFAULTS['coreg'][f'{key}_color'] for key in + ('lpa', 'nasion', 'rpa')) + self._defaults = dict( + size=_get_default(size, (800, 600)), + bgcolor=_get_default(bgcolor, "grey"), + orient_glyphs=_get_default(orient_glyphs, False), + hpi_coils=_get_default(hpi_coils, True), + head_shape_points=_get_default(head_shape_points, True), + eeg_channels=_get_default(eeg_channels, True), + head_resolution=_get_default(head_resolution, False), + head_transparency=_get_default(head_transparency, False), + head_opacity=0.5, + sensor_opacity=_get_default(sensor_opacity, 1.0), + fiducials=("LPA", "Nasion", "RPA"), + fiducial="LPA", + lock_fids=False, + grow_hair=0.0, + scale_modes=["None", "uniform", "3-axis"], + scale_mode="None", + icp_fid_matches=('nearest', 'matched'), + icp_fid_match='nearest', + icp_n_iterations=20, + omit_hsp_distance=10.0, + lock_head_opacity=self._head_opacity < 1.0, + weights=dict( + lpa=1.0, + nasion=10.0, + rpa=1.0, + hsp=1.0, + eeg=1.0, + hpi=1.0, + ), + ) + + # process requirements + info = read_info(info_file) if info_file is not None else None + subjects_dir = get_subjects_dir( + subjects_dir=subjects_dir, raise_error=True) + subject = _get_default(subject, self._get_subjects(subjects_dir)[0]) + + # setup the window + self._renderer = _get_renderer( + size=self._defaults["size"], bgcolor=self._defaults["bgcolor"]) + self._renderer._window_close_connect(self._clean) + + # setup the model + self._info = info + self._fiducials = fiducials + self._coreg = Coregistration( + self._info, subject, subjects_dir, fiducials) + for fid in self._defaults["weights"].keys(): + setattr(self, f"_{fid}_weight", self._defaults["weights"][fid]) + + # set main traits + self._set_subjects_dir(subjects_dir) + self._set_subject(subject) + self._set_info_file(info_file) + self._set_orient_glyphs(self._defaults["orient_glyphs"]) + self._set_hpi_coils(self._defaults["hpi_coils"]) + self._set_head_shape_points(self._defaults["head_shape_points"]) + self._set_eeg_channels(self._defaults["eeg_channels"]) + self._set_head_resolution(self._defaults["head_resolution"]) + self._set_head_transparency(self._defaults["head_transparency"]) + self._set_grow_hair(self._defaults["grow_hair"]) + self._set_omit_hsp_distance(self._defaults["omit_hsp_distance"]) + self._set_icp_n_iterations(self._defaults["icp_n_iterations"]) + self._set_icp_fid_match(self._defaults["icp_fid_match"]) + + # configure UI + self._reset_fitting_parameters() + self._configure_dock() + self._configure_picking() + + # once the docks are initialized + self._set_current_fiducial(self._defaults["fiducial"]) + self._set_lock_fids(self._defaults["lock_fids"]) + self._set_scale_mode(self._defaults["scale_mode"]) + if trans is not None: + self._load_trans(trans) + + # must be done last + if show: + self._renderer.show() + if standalone: + self._renderer.figure.store["app"].exec() + + def _set_subjects_dir(self, subjects_dir): + self._subjects_dir = _check_fname( + subjects_dir, overwrite=True, must_exist=True, need_dir=True) + + def _set_subject(self, subject): + self._subject = subject + + def _set_lock_fids(self, state): + self._lock_fids = bool(state) + + def _set_fiducials_file(self, fname): + if not self._check_fif('fiducials', fname): + return + self._fiducials_file = _check_fname( + fname, overwrite=True, must_exist=True, need_dir=False) + + def _set_current_fiducial(self, fid): + self._current_fiducial = fid.lower() + + def _set_info_file(self, fname): + if fname is None: + return + if not self._check_fif('info', fname): + return + self._info_file = _check_fname( + fname, overwrite=True, must_exist=True, need_dir=False) + + def _set_omit_hsp_distance(self, distance): + self._omit_hsp_distance = distance + + def _set_orient_glyphs(self, state): + self._orient_glyphs = bool(state) + + def _set_hpi_coils(self, state): + self._hpi_coils = bool(state) + + def _set_head_shape_points(self, state): + self._head_shape_points = bool(state) + + def _set_eeg_channels(self, state): + self._eeg_channels = bool(state) + + def _set_head_resolution(self, state): + self._head_resolution = bool(state) + + def _set_head_transparency(self, state): + self._head_transparency = bool(state) + + def _set_grow_hair(self, value): + self._grow_hair = value + + def _set_scale_mode(self, mode): + self._scale_mode = mode + + def _set_fiducial(self, value, coord): + fid = self._current_fiducial.lower() + coords = ["X", "Y", "Z"] + idx = coords.index(coord) + getattr(self._coreg, f"_{fid}")[0][idx] = value / 1e3 + self._update_plot("mri_fids") + + def _set_parameter(self, value, mode_name, coord): + params = dict( + rotation=self._coreg._rotation, + translation=self._coreg._translation, + scale=self._coreg._scale, + ) + idx = ["X", "Y", "Z"].index(coord) + if mode_name == "rotation": + params[mode_name][idx] = np.deg2rad(value) + elif mode_name == "translation": + params[mode_name][idx] = value / 1e3 + else: + assert mode_name == "scale" + params[mode_name][idx] = value / 1e2 + self._coreg._update_params( + rot=params["rotation"], + tra=params["translation"], + sca=params["scale"], + ) + self._update_plot("sensors") + + def _set_icp_n_iterations(self, n_iterations): + self._icp_n_iterations = n_iterations + + def _set_icp_fid_match(self, method): + self._icp_fid_match = method + + def _set_point_weight(self, weight, point): + setattr(self, f"_{point}_weight", weight) + + @observe("_subjects_dir") + def _subjects_dir_changed(self, change=None): + # XXX: add coreg.set_subjects_dir + self._coreg._subjects_dir = self._subjects_dir + subjects = self._get_subjects() + self._subject = subjects[0] + self._reset() + + @observe("_subject") + def _subject_changed(self, changed=None): + # XXX: add coreg.set_subject() + self._coreg._subject = self._subject + self._coreg._setup_bem() + self._coreg._setup_fiducials(self._fiducials) + self._reset() + rr = (self._coreg._processed_low_res_mri_points * + self._coreg._scale) + self._head_geo = dict(rr=rr, tris=self._coreg._bem_low_res["tris"], + nn=self._coreg._bem_low_res["nn"]) + + @observe("_lock_fids") + def _lock_fids_changed(self, change=None): + view_widgets = ["orient_glyphs", "show_hpi", "show_hsp", + "show_eeg", "high_res_head"] + fid_widgets = ["fid_X", "fid_Y", "fid_Z", "fids_file", "fids"] + if self._lock_fids: + self._forward_widget_command(view_widgets, "set_enabled", True) + self._actors["msg"].SetInput("") + else: + self._forward_widget_command(view_widgets, "set_enabled", False) + self._actors["msg"].SetInput("Picking fiducials...") + self._set_sensors_visibility(self._lock_fids) + self._forward_widget_command("lock_fids", "set_value", self._lock_fids) + self._forward_widget_command(fid_widgets, "set_enabled", + not self._lock_fids) + + @observe("_fiducials_file") + def _fiducials_file_changed(self, change=None): + fids, _ = read_fiducials(self._fiducials_file) + self._coreg._setup_fiducials(fids) + self._reset() + + @observe("_current_fiducial") + def _current_fiducial_changed(self, change=None): + self._update_fiducials() + self._follow_fiducial_view() + + @observe("_info_file") + def _info_file_changed(self, change=None): + self._info = read_info(self._info_file) + # XXX: add coreg.set_info() + self._coreg._info = self._info + self._coreg._setup_digs() + self._reset() + + @observe("_orient_glyphs") + def _orient_glyphs_changed(self, change=None): + self._update_plot(["hpi", "hsp", "eeg"]) + + @observe("_hpi_coils") + def _hpi_coils_changed(self, change=None): + self._update_plot("hpi") + + @observe("_head_shape_points") + def _head_shape_point_changed(self, change=None): + self._update_plot("hsp") + + @observe("_eeg_channels") + def _eeg_channels_changed(self, change=None): + self._update_plot("eeg") + + @observe("_head_resolution") + def _head_resolution_changed(self, change=None): + self._update_plot("head") + self._grow_hair_changed() + + @observe("_head_transparency") + def _head_transparency_changed(self, change=None): + self._head_opacity = self._defaults["head_opacity"] \ + if self._head_transparency else 1.0 + self._actors["head"].GetProperty().SetOpacity(self._head_opacity) + self._renderer._update() + + @observe("_grow_hair") + def _grow_hair_changed(self, change=None): + self._coreg.set_grow_hair(self._grow_hair) + if "head" in self._surfaces: + res = "high" if self._head_resolution else "low" + self._surfaces["head"].points = \ + self._coreg._get_processed_mri_points(res) + self._renderer._update() + + @observe("_scale_mode") + def _scale_mode_changed(self, change=None): + mode = None if self._scale_mode == "None" else self._scale_mode + self._coreg.set_scale_mode(mode) + self._forward_widget_command(["sX", "sY", "sZ"], "set_enabled", + mode is not None) + + @observe("_icp_fid_match") + def _icp_fid_match_changed(self, change=None): + self._coreg.set_fid_match(self._icp_fid_match) + + def _configure_picking(self): + self._renderer._update_picking_callback( + self._on_mouse_move, + self._on_button_press, + self._on_button_release, + self._on_pick + ) + self._actors["msg"] = self._renderer.text2d(0, 0, "") + + def _on_mouse_move(self, vtk_picker, event): + if self._mouse_no_mvt: + self._mouse_no_mvt -= 1 + + def _on_button_press(self, vtk_picker, event): + self._mouse_no_mvt = 2 + + def _on_button_release(self, vtk_picker, event): + if self._mouse_no_mvt > 0: + x, y = vtk_picker.GetEventPosition() + # XXX: plotter/renderer should not be exposed if possible + plotter = self._renderer.figure.plotter + picked_renderer = self._renderer.figure.plotter.renderer + # trigger the pick + plotter.picker.Pick(x, y, 0, picked_renderer) + self._mouse_no_mvt = 0 + + def _on_pick(self, vtk_picker, event): + if self._lock_fids: + return + # XXX: taken from Brain, can be refactored + cell_id = vtk_picker.GetCellId() + mesh = vtk_picker.GetDataSet() + if mesh is None or cell_id == -1 or not self._mouse_no_mvt: + return + if not getattr(mesh, "_picking_target", False): + return + pos = np.array(vtk_picker.GetPickPosition()) + vtk_cell = mesh.GetCell(cell_id) + cell = [vtk_cell.GetPointId(point_id) for point_id + in range(vtk_cell.GetNumberOfPoints())] + vertices = mesh.points[cell] + idx = np.argmin(abs(vertices - pos), axis=0) + vertex_id = cell[idx[0]] + + fiducials = [s.lower() for s in self._defaults["fiducials"]] + idx = fiducials.index(self._current_fiducial.lower()) + # XXX: add coreg.set_fids + self._coreg._fid_points[idx] = self._surfaces["head"].points[vertex_id] + self._coreg._reset_fiducials() + self._update_fiducials() + self._update_plot("mri_fids") + + def _reset_fitting_parameters(self): + self._forward_widget_command("icp_n_iterations", "set_value", + self._defaults["icp_n_iterations"]) + self._forward_widget_command("icp_fid_match", "set_value", + self._defaults["icp_fid_match"]) + weights_widgets = [f"{w}_weight" + for w in self._defaults["weights"].keys()] + self._forward_widget_command(weights_widgets, "set_value", + list(self._defaults["weights"].values())) + + def _reset_fiducials(self): + self._set_current_fiducial(self._defaults["fiducial"]) + + def _omit_hsp(self): + self._coreg.omit_head_shape_points(self._omit_hsp_distance / 1e3) + self._update_plot("hsp") + + def _reset_omit_hsp_filter(self): + self._coreg._extra_points_filter = None + self._update_plot("hsp") + + def _update_plot(self, changes="all"): + if self._plot_locked: + return + if self._info is None: + changes = ["head", "mri_fids"] + self._to_cf_t = dict(mri=dict(trans=np.eye(4)), head=None) + else: + self._to_cf_t = _get_transforms_to_coord_frame( + self._info, self._coreg.trans, coord_frame=self._coord_frame) + if not isinstance(changes, list): + changes = [changes] + forced = "all" in changes + sensors = "sensors" in changes + if "head" in changes or forced: + self._add_head_surface() + if "hsp" in changes or forced or sensors: + self._add_head_shape_points() + if "hpi" in changes or forced or sensors: + self._add_hpi_coils() + if "eeg" in changes or forced or sensors: + self._add_eeg_channels() + if "head_fids" in changes or forced or sensors: + self._add_head_fiducials() + if "mri_fids" in changes or forced or sensors: + self._add_mri_fiducials() + + @contextmanager + def _lock_plot(self): + old_plot_locked = self._plot_locked + self._plot_locked = True + try: + yield + finally: + self._plot_locked = old_plot_locked + + @contextmanager + def _display_message(self, msg): + old_msg = self._actors["msg"].GetInput() + self._actors["msg"].SetInput(msg) + self._renderer._update() + try: + yield + finally: + self._actors["msg"].SetInput(old_msg) + self._renderer._update() + + def _follow_fiducial_view(self): + fid = self._current_fiducial.lower() + view = dict(lpa='left', rpa='right', nasion='front') + kwargs = dict(front=(90., 90.), left=(180, 90), right=(0., 90)) + kwargs = dict(zip(('azimuth', 'elevation'), kwargs[view[fid]])) + if not self._lock_fids: + self._renderer.set_camera(distance=None, **kwargs) + + def _update_fiducials(self): + fid = self._current_fiducial.lower() + val = getattr(self._coreg, f"_{fid}")[0] * 1e3 + with self._lock_plot(): + self._forward_widget_command( + ["fid_X", "fid_Y", "fid_Z"], "set_value", val) + + def _update_parameters(self): + with self._lock_plot(): + # rotation + self._forward_widget_command(["rX", "rY", "rZ"], "set_value", + np.rad2deg(self._coreg._rotation)) + # translation + self._forward_widget_command(["tX", "tY", "tZ"], "set_value", + self._coreg._translation * 1e3) + # scale + self._forward_widget_command(["sX", "sY", "sZ"], "set_value", + self._coreg._scale * 1e2) + + def _reset(self): + self._reset_fitting_parameters() + self._coreg.reset() + self._update_plot() + self._update_parameters() + + def _forward_widget_command(self, names, command, value): + names = [names] if not isinstance(names, list) else names + value = list(value) if isinstance(value, np.ndarray) else value + for idx, name in enumerate(names): + val = value[idx] if isinstance(value, list) else value + if name in self._widgets: + getattr(self._widgets[name], command)(val) + + def _set_sensors_visibility(self, state): + sensors = ["head_fiducials", "hpi_coils", "head_shape_points", + "eeg_channels"] + for sensor in sensors: + if sensor in self._actors and self._actors[sensor] is not None: + actors = self._actors[sensor] + actors = actors if isinstance(actors, list) else [actors] + for actor in actors: + actor.SetVisibility(state) + self._renderer._update() + + def _update_actor(self, actor_name, actor): + self._renderer.plotter.remove_actor(self._actors.get(actor_name)) + self._actors[actor_name] = actor + self._renderer._update() + + def _add_mri_fiducials(self): + mri_fids_actors = _plot_mri_fiducials( + self._renderer, self._coreg._fid_points, self._subjects_dir, + self._subject, self._to_cf_t, self._fid_colors) + # disable picking on the markers + for actor in mri_fids_actors: + actor.SetPickable(False) + self._update_actor("mri_fiducials", mri_fids_actors) + + def _add_head_fiducials(self): + head_fids_actors = _plot_head_fiducials( + self._renderer, self._info, self._to_cf_t, self._fid_colors) + self._update_actor("head_fiducials", head_fids_actors) + + def _add_hpi_coils(self): + if self._hpi_coils: + hpi_actors = _plot_hpi_coils( + self._renderer, self._info, self._to_cf_t, + opacity=self._defaults["sensor_opacity"], + orient_glyphs=self._orient_glyphs, surf=self._head_geo) + else: + hpi_actors = None + self._update_actor("hpi_coils", hpi_actors) + + def _add_head_shape_points(self): + if self._head_shape_points: + hsp_actors = _plot_head_shape_points( + self._renderer, self._info, self._to_cf_t, + opacity=self._defaults["sensor_opacity"], + orient_glyphs=self._orient_glyphs, surf=self._head_geo, + mask=self._coreg._extra_points_filter) + else: + hsp_actors = None + self._update_actor("head_shape_points", hsp_actors) + + def _add_eeg_channels(self): + if self._eeg_channels: + eeg = ["original"] + picks = pick_types(self._info, eeg=(len(eeg) > 0)) + eeg_actors = _plot_sensors( + self._renderer, self._info, self._to_cf_t, picks, meg=False, + eeg=eeg, fnirs=False, warn_meg=False, head_surf=self._head_geo, + units='m', sensor_opacity=self._defaults["sensor_opacity"], + orient_glyphs=self._orient_glyphs, surf=self._head_geo) + eeg_actors = eeg_actors["eeg"] + else: + eeg_actors = None + self._update_actor("eeg_channels", eeg_actors) + + def _add_head_surface(self): + bem = None + surface = "head-dense" if self._head_resolution else "head" + try: + head_actor, head_surf, _ = _plot_head_surface( + self._renderer, surface, self._subject, + self._subjects_dir, bem, self._coord_frame, self._to_cf_t, + alpha=self._head_opacity) + except IOError: + head_actor, head_surf, _ = _plot_head_surface( + self._renderer, "head", self._subject, self._subjects_dir, + bem, self._coord_frame, self._to_cf_t, + alpha=self._head_opacity) + # mark head surface mesh to restrict picking + head_surf._picking_target = True + self._update_actor("head", head_actor) + self._surfaces["head"] = head_surf + + def _fit_fiducials(self): + self._coreg.fit_fiducials( + lpa_weight=self._lpa_weight, + nasion_weight=self._nasion_weight, + rpa_weight=self._rpa_weight, + verbose=self._verbose, + ) + self._update_plot("sensors") + self._update_parameters() + + def _fit_icp(self): + with self._display_message("Fitting..."): + self._coreg.fit_icp( + n_iterations=self._icp_n_iterations, + lpa_weight=self._lpa_weight, + nasion_weight=self._nasion_weight, + rpa_weight=self._rpa_weight, + callback=lambda x, y: self._update_plot("sensors"), + verbose=self._verbose, + ) + self._update_parameters() + + def _save_trans(self, fname): + write_trans(fname, self._coreg.trans) + + def _load_trans(self, fname): + mri_head_t = _ensure_trans(read_trans(fname, return_all=True), + 'mri', 'head')['trans'] + rot_x, rot_y, rot_z = rotation_angles(mri_head_t) + x, y, z = mri_head_t[:3, 3] + self._coreg._update_params( + rot=np.array([rot_x, rot_y, rot_z]), + tra=np.array([x, y, z]), + ) + self._update_plot("sensors") + self._update_parameters() + + def _get_subjects(self, sdir=None): + # XXX: would be nice to move this function to util + sdir = sdir if sdir is not None else self._subjects_dir + is_dir = sdir and op.isdir(sdir) + if is_dir: + dir_content = os.listdir(sdir) + subjects = [s for s in dir_content if _is_mri_subject(s, sdir)] + if len(subjects) == 0: + subjects.append('') + else: + subjects = [''] + return sorted(subjects) + + def _check_fif(self, filetype, fname): + try: + check_fname(fname, filetype, ('.fif'), ('.fif')) + except IOError: + warn(f"The filename {fname} for {filetype} must end with '.fif'.") + self._widgets[f"{filetype}_file"].set_value(0, '') + return False + return True + + def _configure_dock(self): + self._renderer._dock_initialize(name="Input", area="left") + layout = self._renderer._dock_add_group_box("MRI Subject") + self._widgets["subjects_dir"] = self._renderer._dock_add_file_button( + name="subjects_dir", + desc="Load", + func=self._set_subjects_dir, + value=self._subjects_dir, + placeholder="Subjects Directory", + directory=True, + layout=layout, + ) + self._widgets["subject"] = self._renderer._dock_add_combo_box( + name="Subject", + value=self._subject, + rng=self._get_subjects(), + callback=self._set_subject, + compact=True, + layout=layout + ) + + layout = self._renderer._dock_add_group_box("MRI Fiducials") + self._widgets["lock_fids"] = self._renderer._dock_add_check_box( + name="Lock fiducials", + value=self._lock_fids, + callback=self._set_lock_fids, + layout=layout + ) + self._widgets["fiducials_file"] = self._renderer._dock_add_file_button( + name="fiducials_file", + desc="Load", + func=self._set_fiducials_file, + value=self._fiducials_file, + placeholder="Path to fiducials", + layout=layout, + ) + self._widgets["fids"] = self._renderer._dock_add_radio_buttons( + value=self._defaults["fiducial"], + rng=self._defaults["fiducials"], + callback=self._set_current_fiducial, + vertical=False, + layout=layout, + ) + hlayout = self._renderer._dock_add_layout() + for coord in ("X", "Y", "Z"): + name = f"fid_{coord}" + self._widgets[name] = self._renderer._dock_add_spin_box( + name=coord, + value=0., + rng=[-1e3, 1e3], + callback=partial( + self._set_fiducial, + coord=coord, + ), + compact=True, + double=True, + layout=hlayout + ) + self._renderer._layout_add_widget(layout, hlayout) + + layout = self._renderer._dock_add_group_box("Digitization Source") + self._widgets["info_file"] = self._renderer._dock_add_file_button( + name="info_file", + desc="Load", + func=self._set_info_file, + value=self._info_file, + placeholder="Path to info", + layout=layout, + ) + self._widgets["grow_hair"] = self._renderer._dock_add_spin_box( + name="Grow Hair", + value=self._grow_hair, + rng=[0.0, 10.0], + callback=self._set_grow_hair, + layout=layout, + ) + hlayout = self._renderer._dock_add_layout(vertical=False) + self._widgets["omit_distance"] = self._renderer._dock_add_spin_box( + name="Omit Distance", + value=self._omit_hsp_distance, + rng=[0.0, 100.0], + callback=self._set_omit_hsp_distance, + layout=hlayout, + ) + self._widgets["omit"] = self._renderer._dock_add_button( + name="Omit", + callback=self._omit_hsp, + layout=hlayout, + ) + self._widgets["reset_omit"] = self._renderer._dock_add_button( + name="Reset", + callback=self._reset_omit_hsp_filter, + layout=hlayout, + ) + self._renderer._layout_add_widget(layout, hlayout) + + layout = self._renderer._dock_add_group_box("View") + self._widgets["orient_glyphs"] = self._renderer._dock_add_check_box( + name="Orient glyphs", + value=self._orient_glyphs, + callback=self._set_orient_glyphs, + layout=layout + ) + self._widgets["show_hpi"] = self._renderer._dock_add_check_box( + name="Show HPI Coils", + value=self._hpi_coils, + callback=self._set_hpi_coils, + layout=layout + ) + self._widgets["show_hsp"] = self._renderer._dock_add_check_box( + name="Show Head Shape Points", + value=self._head_shape_points, + callback=self._set_head_shape_points, + layout=layout + ) + self._widgets["show_eeg"] = self._renderer._dock_add_check_box( + name="Show EEG Channels", + value=self._eeg_channels, + callback=self._set_eeg_channels, + layout=layout + ) + self._widgets["high_res_head"] = self._renderer._dock_add_check_box( + name="Show High Resolution Head", + value=self._head_resolution, + callback=self._set_head_resolution, + layout=layout + ) + self._widgets["make_transparent"] = self._renderer._dock_add_check_box( + name="Make skin surface transparent", + value=self._head_transparency, + callback=self._set_head_transparency, + layout=layout + ) + self._renderer._dock_add_stretch() + + self._renderer._dock_initialize(name="Parameters", area="right") + self._widgets["scaling_mode"] = self._renderer._dock_add_combo_box( + name="Scaling Mode", + value=self._defaults["scale_mode"], + rng=self._defaults["scale_modes"], + callback=self._set_scale_mode, + compact=True, + ) + hlayout = self._renderer._dock_add_group_box( + name="Scaling Parameters", + ) + for coord in ("X", "Y", "Z"): + name = f"s{coord}" + self._widgets[name] = self._renderer._dock_add_spin_box( + name=name, + value=0., + rng=[-1e3, 1e3], + callback=partial( + self._set_parameter, + mode_name="scale", + coord=coord, + ), + compact=True, + double=True, + layout=hlayout + ) + + for mode, mode_name in (("t", "Translation"), ("r", "Rotation")): + hlayout = self._renderer._dock_add_group_box( + f"{mode_name} ({mode})") + for coord in ("X", "Y", "Z"): + name = f"{mode}{coord}" + self._widgets[name] = self._renderer._dock_add_spin_box( + name=name, + value=0., + rng=[-1e3, 1e3], + callback=partial( + self._set_parameter, + mode_name=mode_name.lower(), + coord=coord, + ), + compact=True, + double=True, + step=1, + layout=hlayout + ) + + layout = self._renderer._dock_add_group_box("Fitting") + hlayout = self._renderer._dock_add_layout(vertical=False) + self._renderer._dock_add_button( + name="Fit Fiducials", + callback=self._fit_fiducials, + layout=hlayout, + ) + self._renderer._dock_add_button( + name="Fit ICP", + callback=self._fit_icp, + layout=hlayout, + ) + self._renderer._layout_add_widget(layout, hlayout) + self._widgets["icp_n_iterations"] = self._renderer._dock_add_spin_box( + name="Number Of ICP Iterations", + value=self._defaults["icp_n_iterations"], + rng=[1, 100], + callback=self._set_icp_n_iterations, + compact=True, + double=False, + layout=layout, + ) + self._widgets["icp_fid_match"] = self._renderer._dock_add_combo_box( + name="Fiducial point matching", + value=self._defaults["icp_fid_match"], + rng=self._defaults["icp_fid_matches"], + callback=self._set_icp_fid_match, + compact=True, + layout=layout + ) + layout = self._renderer._dock_add_group_box( + name="Weights", + layout=layout, + ) + for point, fid in zip(("HSP", "EEG", "HPI"), + self._defaults["fiducials"]): + hlayout = self._renderer._dock_add_layout(vertical=False) + point_lower = point.lower() + name = f"{point_lower}_weight" + self._widgets[name] = self._renderer._dock_add_spin_box( + name=point, + value=getattr(self, f"_{point_lower}_weight"), + rng=[1., 100.], + callback=partial(self._set_point_weight, point=point_lower), + compact=True, + double=True, + layout=hlayout + ) + + fid_lower = fid.lower() + name = f"{fid_lower}_weight" + self._widgets[name] = self._renderer._dock_add_spin_box( + name=fid, + value=getattr(self, f"_{fid_lower}_weight"), + rng=[1., 100.], + callback=partial(self._set_point_weight, point=fid_lower), + compact=True, + double=True, + layout=hlayout + ) + self._renderer._layout_add_widget(layout, hlayout) + self._renderer._dock_add_button( + name="Reset Fitting Options", + callback=self._reset_fitting_parameters, + layout=layout, + ) + layout = self._renderer._dock_layout + hlayout = self._renderer._dock_add_layout(vertical=False) + self._renderer._dock_add_button( + name="Reset", + callback=self._reset, + layout=hlayout, + ) + self._widgets["save_trans"] = self._renderer._dock_add_file_button( + name="save_trans", + desc="Save...", + save=True, + func=self._save_trans, + input_text_widget=False, + layout=hlayout, + ) + self._widgets["load_trans"] = self._renderer._dock_add_file_button( + name="load_trans", + desc="Load...", + func=self._load_trans, + input_text_widget=False, + layout=hlayout, + ) + self._renderer._layout_add_widget(layout, hlayout) + self._renderer._dock_add_stretch() + + def _clean(self): + self._renderer = None + self._coreg = None + self._widgets.clear() + self._actors.clear() + self._surfaces.clear() + self._defaults.clear() + self._head_geo = None + + def close(self): + """Close interface and cleanup data structure.""" + self._renderer.close() diff --git a/mne/gui/tests/test_coreg_gui.py b/mne/gui/tests/test_coreg_gui.py index 3ae34adc5a1..b36f2c5e860 100644 --- a/mne/gui/tests/test_coreg_gui.py +++ b/mne/gui/tests/test_coreg_gui.py @@ -16,7 +16,7 @@ from mne.io.kit.tests import data_dir as kit_data_dir from mne.surface import dig_mri_distances from mne.transforms import invert_transform -from mne.utils import requires_mayavi, traits_test, modified_env +from mne.utils import requires_mayavi, traits_test, modified_env, get_config data_path = testing.data_path(download=False) raw_path = op.join(data_path, 'MEG', 'sample', 'sample_audvis_trunc_raw.fif') @@ -24,6 +24,7 @@ 'sample_audvis_trunc-trans.fif') kit_raw_path = op.join(kit_data_dir, 'test_bin_raw.fif') subjects_dir = op.join(data_path, 'subjects') +fid_fname = op.join(subjects_dir, 'sample', 'bem', 'sample-fiducials.fif') @testing.requires_testing_data @@ -336,3 +337,92 @@ def test_coreg_gui_automation(): errs_nearest = np.median( dig_mri_distances(info, fname_trans, subject, subjects_dir)) assert 1e-3 < errs_nearest < 2e-3 + + +class TstVTKPicker(object): + """Class to test cell picking.""" + + def __init__(self, mesh, cell_id, event_pos): + self.mesh = mesh + self.cell_id = cell_id + self.point_id = None + self.event_pos = event_pos + + def GetCellId(self): + """Return the picked cell.""" + return self.cell_id + + def GetDataSet(self): + """Return the picked mesh.""" + return self.mesh + + def GetPickPosition(self): + """Return the picked position.""" + vtk_cell = self.mesh.GetCell(self.cell_id) + cell = [vtk_cell.GetPointId(point_id) for point_id + in range(vtk_cell.GetNumberOfPoints())] + self.point_id = cell[0] + return self.mesh.points[self.point_id] + + def GetEventPosition(self): + """Return event position.""" + return self.event_pos + + +@pytest.mark.slowtest +@testing.requires_testing_data +def test_coreg_gui_pyvista(tmpdir, renderer_interactive_pyvistaqt): + """Test that using CoregistrationUI matches mne coreg.""" + from mne.gui import coregistration + tempdir = str(tmpdir) + config = get_config(home_dir=os.environ.get('_MNE_FAKE_HOME_DIR')) + tmp_trans = op.join(tempdir, 'tmp-trans.fif') + coreg = coregistration(subject='sample', subjects_dir=subjects_dir, + trans=fname_trans) + coreg._reset_fiducials() + coreg.close() + coreg = coregistration(inst=raw_path, subject='sample', + subjects_dir=subjects_dir) + coreg._set_fiducials_file(fid_fname) + assert coreg._fiducials_file == fid_fname + # picking + vtk_picker = TstVTKPicker(coreg._surfaces['head'], 0, (0, 0)) + coreg._on_mouse_move(vtk_picker, None) + coreg._on_button_press(vtk_picker, None) + coreg._on_pick(vtk_picker, None) + coreg._on_button_release(vtk_picker, None) + coreg._set_lock_fids(True) + assert coreg._lock_fids + coreg._on_pick(vtk_picker, None) # also pick when locked + coreg._set_lock_fids(False) + assert not coreg._lock_fids + coreg._set_lock_fids(True) + assert coreg._lock_fids + assert coreg._nasion_weight == 10. + coreg._set_point_weight(11., 'nasion') + assert coreg._nasion_weight == 11. + coreg._fit_fiducials() + coreg._fit_icp() + assert coreg._coreg._extra_points_filter is None + coreg._omit_hsp() + assert coreg._coreg._extra_points_filter is not None + coreg._reset_omit_hsp_filter() + assert coreg._coreg._extra_points_filter is None + assert coreg._grow_hair == 0 + coreg._set_grow_hair(0.1) + assert coreg._grow_hair == 0.1 + assert coreg._orient_glyphs == \ + (config.get('MNE_COREG_ORIENT_TO_SURFACE', '') == 'true') + assert coreg._hpi_coils + assert coreg._eeg_channels + assert coreg._head_shape_points + assert coreg._scale_mode == 'None' + assert coreg._icp_fid_match == 'nearest' + assert coreg._head_resolution == \ + (config.get('MNE_COREG_HEAD_HIGH_RES', 'true') == 'true') + assert not coreg._head_transparency + coreg._set_head_transparency(True) + assert coreg._head_transparency + coreg._save_trans(tmp_trans) + assert op.isfile(tmp_trans) + coreg.close() diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 70058c4bc4f..3201b626203 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -465,7 +465,7 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, glyph_height=None, glyph_center=None, glyph_resolution=None, opacity=1.0, scale_mode='none', scalars=None, backface_culling=False, line_width=2., name=None, - glyph_width=None, glyph_depth=None, + glyph_width=None, glyph_depth=None, glyph_radius=0.15, solid_transform=None): _check_option('mode', mode, ALLOWED_QUIVER_MODES) with warnings.catch_warnings(): @@ -484,6 +484,8 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, if scale_mode == 'scalar': _point_data(grid)['mag'] = np.array(scalars) scale = 'mag' + elif scale_mode == 'vector': + scale = True else: scale = False if mode == '2darrow': @@ -501,10 +503,12 @@ def quiver3d(self, x, y, z, u, v, w, color, scale, mode, resolution=8, if mode == 'cone': glyph = vtk.vtkConeSource() glyph.SetCenter(0.5, 0, 0) - glyph.SetRadius(0.15) + if glyph_radius is not None: + glyph.SetRadius(glyph_radius) elif mode == 'cylinder': glyph = vtk.vtkCylinderSource() - glyph.SetRadius(0.15) + if glyph_radius is not None: + glyph.SetRadius(glyph_radius) elif mode == 'oct': glyph = vtk.vtkPlatonicSolidSource() glyph.SetSolidTypeToOctahedron() diff --git a/requirements.txt b/requirements.txt index c95d84741f5..5b308ff407d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,10 +29,11 @@ nilearn xlrd imageio>=2.6.1 imageio-ffmpeg>=0.4.1 +traitlets pyvista>=0.30 pyvistaqt>=0.4 tqdm mffpy>=0.5.7 ipywidgets ipyvtklink -pooch \ No newline at end of file +pooch