-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
ENH: Support concatenate_epochs() for EpochsSpectrum
#13748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
258b94a
a016371
e2cfc81
b2ec20e
1c1f5e9
61b6c94
7d6be81
c1fb318
7199a84
c1c75bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Add support for :func:`mne.concatenate_epochs` with :class:`~mne.time_frequency.EpochsSpectrum` instances, by ``aman-coder03``. (:gh:`13747`) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4693,6 +4693,65 @@ def _concatenate_epochs( | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def _concatenate_epochs_spectrum(epochs_list, add_offset=True): | ||||||
| """Concatenate a list of EpochsSpectrum instances.""" | ||||||
| for ii, ep in enumerate(epochs_list): | ||||||
| if type(ep).__name__ != "EpochsSpectrum": | ||||||
| raise TypeError( | ||||||
| f"epochs_list[{ii}] must be an instance of EpochsSpectrum, " | ||||||
| f"got {type(ep)}" | ||||||
| ) | ||||||
| ref = epochs_list[0] | ||||||
| for ii, ep in enumerate(epochs_list[1:], 1): | ||||||
| if not np.array_equal(ep.freqs, ref.freqs): | ||||||
| raise ValueError(f"epochs_list[{ii}] freqs do not match epochs_list[0]") | ||||||
| _ensure_infos_match(ep.info, ref.info, f"epochs_list[{ii}]") | ||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'd better add checking here for equivalent |
||||||
| data = np.concatenate([ep.data for ep in epochs_list], axis=0) | ||||||
|
|
||||||
| shift = np.int64(10 * ref.info["sfreq"]) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| events_offset = int(np.max(epochs_list[0].events[:, 0])) + shift | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| all_events = [epochs_list[0].events.copy()] | ||||||
| for ep in epochs_list[1:]: | ||||||
| evs = ep.events.copy() | ||||||
| if add_offset: | ||||||
| evs[:, 0] += events_offset | ||||||
| events_offset += int(np.max(ep.events[:, 0])) + shift | ||||||
| all_events.append(evs) | ||||||
| events = np.concatenate(all_events, axis=0) | ||||||
|
|
||||||
| event_id = deepcopy(ref.event_id) | ||||||
| for ep in epochs_list[1:]: | ||||||
| event_id.update(ep.event_id) | ||||||
|
|
||||||
| selection = np.concatenate([ep.selection for ep in epochs_list]) | ||||||
| drop_log = sum([ep.drop_log for ep in epochs_list], ()) | ||||||
|
|
||||||
| metadatas = [ep.metadata for ep in epochs_list] | ||||||
| n_have = sum(m is not None for m in metadatas) | ||||||
| if n_have == 0: | ||||||
| metadata = None | ||||||
| elif n_have != len(metadatas): | ||||||
| raise ValueError( | ||||||
| f"{n_have} of {len(metadatas)} EpochsSpectrum instances have metadata, " | ||||||
| "all or none must have metadata" | ||||||
| ) | ||||||
| else: | ||||||
| pd = _check_pandas_installed(strict=False) | ||||||
| metadata = pd.concat(metadatas) if pd is not False else sum(metadatas, list()) | ||||||
|
|
||||||
| state = ref.__getstate__() | ||||||
| state["data"] = data | ||||||
| state["events"] = events | ||||||
| state["event_id"] = event_id | ||||||
| state["selection"] = selection | ||||||
| state["drop_log"] = drop_log | ||||||
| state["metadata"] = metadata | ||||||
| out = type(epochs_list[0]).__new__(type(epochs_list[0])) | ||||||
| out.__setstate__(state) | ||||||
| return out | ||||||
|
|
||||||
|
|
||||||
| @verbose | ||||||
| def concatenate_epochs( | ||||||
| epochs_list, add_offset=True, *, on_mismatch="raise", verbose=None | ||||||
|
|
@@ -4725,6 +4784,8 @@ def concatenate_epochs( | |||||
| ----- | ||||||
| .. versionadded:: 0.9.0 | ||||||
| """ | ||||||
| if epochs_list and type(epochs_list[0]).__name__ == "EpochsSpectrum": | ||||||
| return _concatenate_epochs_spectrum(epochs_list, add_offset=add_offset) | ||||||
| ( | ||||||
| info, | ||||||
| data, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR number gets added automatically