diff --git a/doc/changes/dev/13866.bugfix.rst b/doc/changes/dev/13866.bugfix.rst new file mode 100644 index 00000000000..9e0eff077d1 --- /dev/null +++ b/doc/changes/dev/13866.bugfix.rst @@ -0,0 +1 @@ +Completed triaxial OPM topomap grouping by rendering separate radial and tangential maps in evoked topomap and ICA component plotting paths, by `Pragnya Khandelwal`_. diff --git a/examples/datasets/kernel_phantom.py b/examples/datasets/kernel_phantom.py index da17f708454..da2297ba88f 100644 --- a/examples/datasets/kernel_phantom.py +++ b/examples/datasets/kernel_phantom.py @@ -51,6 +51,11 @@ t_peak = 0.016 # based on visual inspection of evoked fig.axes[0].axvline(t_peak, color="k", ls=":", lw=3, zorder=2) +# %% +# Because these OPM sensors are colocated in biaxial pairs, topomaps are +# grouped into radial and tangential components. +evoked.plot_topomap(times=[t_peak], ch_type="mag", show=True) + # %% # The data covariance has an interesting structure because of densely packed sensors: @@ -106,3 +111,9 @@ ) mne.viz.plot_dipole_locations(dipoles=dip, mode="arrow", color=(0.2, 1.0, 0.5), fig=fig) mne.viz.set_3d_view(figure=fig, azimuth=30, elevation=70, distance=0.4) + +# %% +# For more information on OPM data visualization, see the OPM preprocessing +# tutorial: +# +# - :ref:`tut-opm-processing` diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index a62d2379f03..923f80f8eda 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -2059,7 +2059,7 @@ def plot_evoked_joint( zorder=1, clip_on=False, ) - fig.add_artist(con) + ts_ax.add_artist(con) # mark times in time series plot for timepoint in times_ts: diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index 46d1145e851..c1b50340d17 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -573,7 +573,8 @@ def test_plot_components_opm(): ica = ICA(max_iter=1, random_state=0, n_components=10) ica.fit(RawArray(evoked.data, evoked.info), picks="mag", verbose="error") fig = ica.plot_components() - assert len(fig.axes) == 10 + # Biaxial OPM pairs trigger grouped rendering (radial + tangential axes) + assert len(fig.axes) == 20 @pytest.mark.slowtest @@ -585,4 +586,7 @@ def test_plot_components_opm_triaxial(triaxial_raw): ica = ICA(max_iter=1, random_state=0, n_components=3) ica.fit(triaxial_raw, picks="mag", verbose="error") fig = ica.plot_components() - assert len(fig.axes) == 3 + assert len(fig.axes) == 6 + titles = [ax.get_title() for ax in fig.axes] + assert any("[radial]" in title for title in titles) + assert any("[tangential]" in title for title in titles) diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index fe6c938d244..0cb01c03a4f 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -136,18 +136,6 @@ def return_inds(d): # to test function kwarg to zorder arg of evoked.plot plt.close("all") -def test_plot_joint_opm_triaxial(triaxial_evoked): - """Test joint plot with triaxial colocated OPM channels.""" - fig = triaxial_evoked.plot_joint( - times=[0.0], - picks="mag", - show=False, - ts_args=dict(time_unit="s"), - topomap_args=dict(time_unit="s", contours=0, res=8, sensors=False), - ) - assert len(fig.axes) >= 2 - - def test_plot_topo(): """Test plotting of ERP topography.""" # Show topography diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index d364d196f18..3fbb31ce668 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -794,7 +794,8 @@ def test_plot_topomap_opm(): fig_evoked = evoked.plot_topomap( times=[-0.1, 0, 0.1, 0.2], ch_type="mag", show=False ) - assert len(fig_evoked.axes) == 5 + # Biaxial OPM pairs trigger grouped rendering (4 radial + 4 tangential + 1 colorbar) + assert len(fig_evoked.axes) == 9 def test_prepare_topomap_plot_opm_non_quspin_coils(): @@ -851,6 +852,47 @@ def test_split_opm_overlaps(triaxial_evoked): assert tangential == ["OPM002", "OPM003", "OPM005", "OPM006"] +def test_should_use_opm_orientation_groups_only_for_triaxial(): + """Test that OPM orientation grouping works for biaxial and triaxial overlaps.""" + ch_names = [f"OPM{k:03}" for k in range(1, 7)] + info = create_info(ch_names, 1000.0, ch_types="mag") + with info._unlock(): + for ch in info["chs"]: + ch["coil_type"] = FIFF.FIFFV_COIL_FIELDLINE_OPM_MAG_GEN1 + + picks = np.arange(len(ch_names)) + pair_overlaps = [ + np.array(["OPM001", "OPM002"]), + np.array(["OPM003", "OPM004"]), + ] + triax_overlaps = [ + np.array(["OPM001", "OPM002", "OPM003"]), + np.array(["OPM004", "OPM005", "OPM006"]), + ] + + # Both biaxial and triaxial overlaps should trigger grouping + assert topomap._should_use_opm_orientation_groups(info, picks, pair_overlaps, "mag") + assert topomap._should_use_opm_orientation_groups( + info, picks, triax_overlaps, "mag" + ) + + +def test_plot_evoked_topomap_opm_triaxial_groups(triaxial_evoked): + """Test grouped radial/tangential topomap rendering for triaxial OPM.""" + fig = triaxial_evoked.plot_topomap( + times=[0.0], + ch_type="mag", + contours=0, + res=8, + sensors=False, + show=False, + ) + assert len(fig.axes) == 3 + titles = [ax.get_title() for ax in fig.axes] + assert any("radial" in title for title in titles) + assert any("tangential" in title for title in titles) + + def test_plot_topomap_nirs_overlap(fnirs_epochs): """Test plotting nirs topomap with overlapping channels (gh-7414).""" fig = fnirs_epochs["A"].average(picks="hbo").plot_topomap() diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index fcbd213ce78..8d491589588 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -226,7 +226,8 @@ def _find_overlaps(info, ch_type, sphere, modality="fnirs"): channels_to_exclude = list() if len(locs3d) > 1 and np.min(dist) < 1e-10: - overlapping_mask = np.triu(squareform(dist < 1e-10)) + # Use symmetric distance matrix to find all colocated channel groups + overlapping_mask = squareform(dist < 1e-10) for chan_idx in range(overlapping_mask.shape[0]): already_overlapped = list( itertools.chain.from_iterable(overlapping_channels) @@ -329,6 +330,56 @@ def _split_opm_overlaps(overlapping_channels): return radial, tangential +def _compute_opm_orientation_topomap_data(data, ch_names, pos, overlapping_channels): + """Compute radial and tangential OPM topomap data from overlap sets.""" + from ..channels.layout import _merge_ch_data + + # Radial data matches the existing OPM merge behavior and position layout. + radial_data, radial_names = _merge_ch_data( + data.copy(), "mag", copy.copy(ch_names), modality="opm" + ) + radial_pos = pos + + name_lookup = [name.removesuffix("_MERGE-REMOVE") for name in ch_names] + tangential_data = [] + tangential_names = [] + tangential_pos = [] + for overlap_set in overlapping_channels: + idx = [name_lookup.index(ch_name) for ch_name in overlap_set[1:]] + # Collapse multiple tangential channels at one location using RMS. + tangential_data.append(np.sqrt(np.mean(data[idx] ** 2, axis=0))) + tangential_names.append(f"{overlap_set[0]}t") + tangential_pos.append(radial_pos[radial_names.index(overlap_set[0])]) + + tangential_data = np.array(tangential_data) + tangential_pos = np.array(tangential_pos) + + return [ + ("radial", radial_data, radial_pos, radial_names), + ("tangential", tangential_data, tangential_pos, tangential_names), + ] + + +def _should_use_opm_orientation_groups(info, picks, merge_channels, ch_type): + """Return whether OPM orientation grouping should be enabled. + + Grouping is used for OPM magnetometer channels with overlap sets that + include at least 2 colocated channels (biaxial or triaxial sensors). + """ + if ch_type != "mag" or not merge_channels: + return False + + pick_chs = [info["chs"][pick] for pick in picks] + if not pick_chs or not all(ch["coil_type"] in _opm_coils for ch in pick_chs): + return False + + # merge_channels should be a list of overlap sets, not a boolean + if not isinstance(merge_channels, (list, tuple)): + return False + + return any(len(overlap_set) >= 2 for overlap_set in merge_channels) + + def _plot_update_evoked_topomap(params, bools): """Update topomaps.""" from ..channels.layout import _merge_ch_data @@ -1714,9 +1765,9 @@ def plot_ica_components( axes = axes.flatten() if isinstance(axes, np.ndarray) else axes for k, picks in enumerate(pick_groups): - try: # either an iterable, 1D numpy array or others - _axes = axes[k * max_subplots : (k + 1) * max_subplots] - except TypeError: # None or Axes + if axes is None: + _axes = None + else: _axes = axes ( @@ -1729,7 +1780,6 @@ def plot_ica_components( clip_origin, ) = _prepare_topomap_plot(ica, ch_type, sphere=sphere) cmap = _setup_cmap(cmap, n_axes=len(picks)) - disp_names = _prepare_sensor_names(names, show_names) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) data = np.dot( @@ -1738,64 +1788,94 @@ def plot_ica_components( data = np.atleast_2d(data) data = data[:, data_picks] + use_opm_orientation_groups = _should_use_opm_orientation_groups( + ica.info, data_picks, merge_channels, ch_type + ) + n_group_axes = 2 if use_opm_orientation_groups else 1 + if title is None: title = "ICA components" user_passed_axes = _axes is not None if not user_passed_axes: - fig, _axes, _, _ = _prepare_trellis(len(data), ncols=ncols, nrows=nrows) + fig, _axes, _, _ = _prepare_trellis( + len(data) * n_group_axes, ncols=ncols, nrows=nrows + ) fig.suptitle(title) else: _axes = [_axes] if isinstance(_axes, Axes) else _axes + if len(_axes) != len(data) * n_group_axes: + raise RuntimeError( + "You must provide one axis per component and orientation " + "group for colocated OPM data." + ) fig = _axes[0].get_figure() subplot_titles = list() - for ii, data_, ax in zip(picks, data, _axes): + for comp_offset, (ii, data_) in enumerate(zip(picks, data)): kwargs = dict(color="gray") if ii in ica.exclude else dict() comp_title = ica._ica_names[ii] if len(set(ica.get_channel_types())) > 1: comp_title += f" ({ch_type})" - subplot_titles.append(ax.set_title(comp_title, fontsize=12, **kwargs)) - if merge_channels: - data_, names_ = _merge_ch_data(data_, ch_type, copy.copy(names)) - # ↓↓↓ NOTE: we intentionally use the default norm=False here, so that - # ↓↓↓ we get vlims that are symmetric-about-zero, even if the data for - # ↓↓↓ a given component happens to be one-sided. - _vlim = _setup_vmin_vmax(data_, *vlim) - im = plot_topomap( - data_.flatten(), - pos, - ch_type=ch_type, - sensors=sensors, - names=disp_names, - contours=contours, - outlines=outlines, - sphere=sphere, - image_interp=image_interp, - extrapolate=extrapolate, - border=border, - res=res, - size=size, - cmap=cmap[0], - vlim=_vlim, - cnorm=cnorm, - axes=ax, - show=False, - )[0] - - im.axes.set_label(ica._ica_names[ii]) - if colorbar: - cbar, cax = _add_colorbar( - ax, - im, - cmap, - title="AU", - format_=cbar_fmt, - kind="ica_comp_topomap", - ch_type=ch_type, + + if use_opm_orientation_groups: + grouped_data = _compute_opm_orientation_topomap_data( + data_[:, np.newaxis], names, pos, merge_channels ) - cbar.ax.tick_params(labelsize=12) - cbar.set_ticks(_vlim) - _hide_frame(ax) + else: + if merge_channels: + data_, names_ = _merge_ch_data(data_, ch_type, copy.copy(names)) + grouped_data = [(None, data_[:, np.newaxis], pos, names_)] + else: + grouped_data = [(None, data_[:, np.newaxis], pos, names)] + + for group_idx, ( + group_label, + group_data, + group_pos, + group_names, + ) in enumerate(grouped_data): + ax_idx = comp_offset * n_group_axes + group_idx + ax = _axes[ax_idx] + plot_title = comp_title + if group_label is not None: + plot_title += f" [{group_label}]" + subplot_titles.append(ax.set_title(plot_title, fontsize=12, **kwargs)) + _vlim = _setup_vmin_vmax(group_data[:, 0], *vlim) + im = plot_topomap( + group_data[:, 0].flatten(), + group_pos, + ch_type=ch_type, + sensors=sensors, + names=_prepare_sensor_names(group_names, show_names), + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap[0], + vlim=_vlim, + cnorm=cnorm, + axes=ax, + show=False, + )[0] + + im.axes.set_label(ica._ica_names[ii]) + if colorbar: + cbar, cax = _add_colorbar( + ax, + im, + cmap, + title="AU", + format_=cbar_fmt, + kind="ica_comp_topomap", + ch_type=ch_type, + ) + cbar.ax.tick_params(labelsize=12) + cbar.set_ticks(_vlim) + _hide_frame(ax) del pos fig.canvas.draw() @@ -2259,11 +2339,18 @@ def plot_evoked_topomap( clip_origin, ) = _prepare_topomap_plot(evoked, ch_type, sphere=sphere) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) + use_opm_orientation_groups = _should_use_opm_orientation_groups( + evoked.info, picks, merge_channels, ch_type + ) # check interactive axes_given = axes is not None interactive = isinstance(times, str) and times == "interactive" if interactive and axes_given: raise ValueError("User-provided axes not allowed when times='interactive'.") + if interactive and use_opm_orientation_groups: + raise NotImplementedError( + "times='interactive' is not supported for grouped OPM topomaps." + ) # units, scalings key = "grad" if ch_type.startswith("planar") else ch_type default_scaling = _handle_default("scalings", None)[key] @@ -2273,7 +2360,6 @@ def plot_evoked_topomap( unit = _handle_default("units", units)[key] # ch_names (required for NIRS) ch_names = names - names = _prepare_sensor_names(names, show_names) # apply projections before picking. NOTE: the `if proj is True` # anti-pattern is needed here to exclude proj='interactive' _check_option("proj", proj, (True, False, "interactive", "reconstruct")) @@ -2299,7 +2385,8 @@ def plot_evoked_topomap( f"Times should be between {evoked.times[0]:0.3} and {evoked.times[-1]:0.3}." ) # create axes - want_axes = n_times + int(colorbar) + n_groups = 2 if use_opm_orientation_groups else 1 + want_axes = n_times * n_groups + int(colorbar) if interactive: height_ratios = [5, 1] nrows = 2 @@ -2313,7 +2400,7 @@ def plot_evoked_topomap( axes.append(plt.subplot(gs[0, ax_idx])) elif axes is None: fig, axes, ncols, nrows = _prepare_trellis( - n_times, ncols=ncols, nrows=nrows, size=size + n_times * n_groups, ncols=ncols, nrows=nrows, size=size ) else: nrows, ncols = None, None # Deactivate ncols when axes were passed @@ -2386,31 +2473,38 @@ def plot_evoked_topomap( # apply scalings and merge channels data *= scaling + grouped_data = None if merge_channels: # check modality - if any(ch["coil_type"] in _opm_coils for ch in evoked.info["chs"]): + is_opm_picks = len(evoked.info["chs"]) > 0 and all( + ch["coil_type"] in _opm_coils for ch in evoked.info["chs"] + ) + if is_opm_picks: modality = "opm" elif ch_type in _fnirs_types: modality = "fnirs" else: modality = "other" - # merge data - data, ch_names = _merge_ch_data(data, ch_type, ch_names, modality=modality) - # if ch_type in _fnirs_types: - if modality != "other": - merge_channels = False - # apply mask if requested - if mask is not None: - mask = mask.astype(bool, copy=False) - if ch_type == "grad": - mask_ = ( - mask[np.ix_(picks[::2], time_idx)] | mask[np.ix_(picks[1::2], time_idx)] + if modality == "opm" and use_opm_orientation_groups: + grouped_data = _compute_opm_orientation_topomap_data( + data, ch_names, pos, merge_channels ) - else: # mag, eeg, planar1, planar2 - mask_ = mask[np.ix_(picks, time_idx)] + merge_channels = False + else: + # merge data + data, ch_names = _merge_ch_data(data, ch_type, ch_names, modality=modality) + # if ch_type in _fnirs_types: + if modality != "other": + merge_channels = False # set up colormap + if grouped_data is None: + all_data = [data] + else: + all_data = [group_[1] for group_ in grouped_data] _vlim = [ - _setup_vmin_vmax(data[:, i], *vlim, norm=merge_channels) for i in range(n_times) + _setup_vmin_vmax(group_data[:, i], *vlim, norm=merge_channels) + for group_data in all_data + for i in range(n_times) ] _vlim = [np.min(_vlim), np.max(_vlim)] cmap = _setup_cmap(cmap, n_axes=n_times, norm=_vlim[0] >= 0) @@ -2427,7 +2521,6 @@ def plot_evoked_topomap( kwargs = dict( sensors=sensors, res=res, - names=names, cmap=cmap[0], cnorm=cnorm, mask_params=mask_params, @@ -2441,33 +2534,42 @@ def plot_evoked_topomap( ch_type=ch_type, ) images, contours_ = [], [] - # loop over times - for average_idx, (time, this_average) in enumerate(zip(times, average)): - tp, cn, interp = _plot_topomap( - data[:, average_idx], - pos, - axes=axes[average_idx], - mask=mask_[:, average_idx] if mask is not None else None, - vmin=_vlim[0], - vmax=_vlim[1], - **kwargs, - ) + if grouped_data is None: + grouped_data = [(None, data, pos, ch_names)] - images.append(tp) - if cn is not None: - contours_.append(cn) - if time_format != "": - if this_average is None: - axes_title = time_format % (time * scaling_time) - else: - tmin_ = averaged_times[average_idx][0] - tmax_ = averaged_times[average_idx][-1] - from_time = time_format % (tmin_ * scaling_time) - from_time = from_time.split(" ")[0] # Remove unit - to_time = time_format % (tmax_ * scaling_time) - axes_title = f"{from_time} – {to_time}" - del from_time, to_time, tmin_, tmax_ - axes[average_idx].set_title(axes_title) + for group_idx, (group_label, group_data, group_pos, group_names) in enumerate( + grouped_data + ): + kwargs["names"] = _prepare_sensor_names(group_names, show_names) + for average_idx, (time, this_average) in enumerate(zip(times, average)): + ax_idx = group_idx * n_times + average_idx + tp, cn, interp = _plot_topomap( + group_data[:, average_idx], + group_pos, + axes=axes[ax_idx], + mask=None, + vmin=_vlim[0], + vmax=_vlim[1], + **kwargs, + ) + + images.append(tp) + if cn is not None: + contours_.append(cn) + if time_format != "": + if this_average is None: + axes_title = time_format % (time * scaling_time) + else: + tmin_ = averaged_times[average_idx][0] + tmax_ = averaged_times[average_idx][-1] + from_time = time_format % (tmin_ * scaling_time) + from_time = from_time.split(" ")[0] # Remove unit + to_time = time_format % (tmax_ * scaling_time) + axes_title = f"{from_time} – {to_time}" + del from_time, to_time, tmin_, tmax_ + if group_label is not None: + axes_title = f"{group_label}\n{axes_title}" + axes[ax_idx].set_title(axes_title) if interactive: # Add a slider to the figure and start publishing and subscribing to time_change