-
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
ENH: Add EpochsTFR.from_raw_tfr() to epoch RawTFR objects
#13750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Add new classmethod :meth:`mne.time_frequency.EpochsTFR.from_raw_tfr` to create an :class:`~mne.time_frequency.EpochsTFR` from a :class:`~mne.time_frequency.RawTFR` by slicing the already-computed TFR data at event times (similar to how :class:`~mne.Epochs` works with :class:`~mne.io.Raw`), by `Aman Srivastava`_. (:gh:`13750`) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,7 @@ | |
| _freq_mask, | ||
| _import_h5io_funcs, | ||
| _is_numeric, | ||
| _on_missing, | ||
| _pl, | ||
| _prepare_read_metadata, | ||
| _prepare_write_metadata, | ||
|
|
@@ -3238,6 +3239,133 @@ def _update_epoch_attributes(self): | |
| # we need this for compatibility with equalize_event_counts() | ||
| self._bad_dropped = True | ||
|
|
||
| @classmethod | ||
| @verbose | ||
| def from_raw_tfr( | ||
| cls, | ||
| raw_tfr, | ||
| events, | ||
| event_id=None, | ||
| tmin=-0.2, | ||
| tmax=0.5, | ||
| picks=None, | ||
| reject_by_annotation=True, | ||
| on_missing="raise", | ||
| event_repeated="error", | ||
| metadata=None, | ||
| verbose=None, | ||
| ): | ||
| """Create an EpochsTFR from a RawTFR object. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| raw_tfr : instance of RawTFR | ||
| The continuous TFR data to epoch. | ||
| %(events_epochs)s | ||
| %(event_id)s | ||
| %(epochs_tmin_tmax)s | ||
| %(picks_good_data_noref)s | ||
| %(reject_by_annotation_epochs)s | ||
| %(on_missing_epochs)s | ||
| %(event_repeated_epochs)s | ||
| %(metadata_epochs)s | ||
| %(verbose)s | ||
|
|
||
| Returns | ||
| ------- | ||
| epochs_tfr : instance of EpochsTFR | ||
| The epoched TFR data. | ||
| """ | ||
| from ..epochs import _handle_event_repeated | ||
| from .tfr import RawTFR | ||
|
|
||
| _validate_type(raw_tfr, RawTFR, "raw_tfr") | ||
| events = _ensure_events(events) | ||
| event_id = _check_event_id(event_id, events) | ||
|
|
||
| sfreq = raw_tfr.sfreq | ||
| raw_times = raw_tfr.times | ||
|
|
||
| # figure out which time samples correspond to tmin/tmax | ||
| tmin_idx = int(round(tmin * sfreq)) | ||
| tmax_idx = int(round(tmax * sfreq)) | ||
| n_times = tmax_idx - tmin_idx + 1 | ||
| epoch_times = np.arange(tmin_idx, tmax_idx + 1) / sfreq | ||
|
|
||
| # picks | ||
| if picks is not None: | ||
| pick_idx = _picks_to_idx(raw_tfr.info, picks, "data", with_ref_meg=False) | ||
| else: | ||
| pick_idx = slice(None) | ||
|
|
||
| # get the raw TFR data: shape (n_channels, n_freqs, n_raw_times) | ||
| raw_data = raw_tfr.data[pick_idx] | ||
|
|
||
| # find valid events (those where the epoch window fits in the data) | ||
| values = list(event_id.values()) | ||
| selected = np.where(np.isin(events[:, 2], values))[0] | ||
| selection = selected.copy() | ||
|
|
||
| drop_log = [() for _ in range(len(events))] | ||
| good_data = [] | ||
| good_event_indices = [] | ||
|
|
||
| for idx in selected: | ||
| event_samp = events[idx, 0] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Samples don't necessarily start from 0 with the start of the recording. Often, the first time in the time list for the raw object corresponds with a high valued sample number, e.g. 30000. To match times to event samples, the first sample of the raw object must be subtracted from the event sample, then the result divided by sfreq. To my knowledge, this isn't stored in the RawTFR object. @drammock Is the original raw first sample stored somewhere that I've missed? Basically, the code here needs to add a way to account for the starting sample offset value. One way to address this might be to add a parameter in the function call where the user must specify the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think first_samp is preserved in the |
||
| # convert event sample to index in raw_tfr.times | ||
| start = np.searchsorted(raw_times * sfreq, event_samp + tmin_idx) | ||
| stop = start + n_times | ||
|
|
||
| # check bounds before clamping via searchsorted | ||
| first_samp = int(round(raw_times[0] * sfreq)) | ||
| if (event_samp + tmin_idx) < first_samp or stop > raw_data.shape[-1]: | ||
| drop_log[idx] = ("TOO_SHORT",) | ||
| continue | ||
|
|
||
| epoch_data = raw_data[:, :, start:stop] | ||
|
|
||
| if ( | ||
| reject_by_annotation | ||
| and hasattr(raw_tfr, "inst") | ||
| and raw_tfr.inst is not None | ||
| ): | ||
| # skip annotation rejection for now (RawTFR doesn't carry annotations) | ||
| pass | ||
|
|
||
| good_data.append(epoch_data) | ||
| good_event_indices.append(idx) | ||
|
|
||
| good_events = events[good_event_indices] | ||
| selection = np.array(good_event_indices) | ||
| drop_log = tuple(tuple(d) for d in drop_log) | ||
|
|
||
| # handle event_repeated | ||
| good_events, event_id, selection, drop_log = _handle_event_repeated( | ||
| good_events, event_id, event_repeated, selection, drop_log | ||
| ) | ||
|
|
||
| # handle on_missing | ||
| for key, val in event_id.items(): | ||
| if val not in good_events[:, 2]: | ||
| msg = f"No matching events found for {key} (event id {val})" | ||
| _on_missing(on_missing, msg) | ||
|
|
||
| data = np.stack(good_data, axis=0) # (n_epochs, n_channels, n_freqs, n_times) | ||
|
|
||
| state = raw_tfr.__getstate__() | ||
| state["data"] = data | ||
| state["times"] = epoch_times | ||
| state["events"] = good_events | ||
| state["event_id"] = event_id | ||
| state["selection"] = selection | ||
| state["drop_log"] = drop_log | ||
| state["metadata"] = metadata | ||
| state["raw_times"] = epoch_times | ||
|
|
||
| out = cls.__new__(cls) | ||
| out.__setstate__(state) | ||
| return out | ||
|
|
||
| def average(self, method="mean", *, dim="epochs", copy=False): | ||
| """Aggregate the EpochsTFR across epochs, frequencies, or times. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Baseline attribute probably should go here. I'd use
(None, 0)as the default value to use the start of the epoch to stim onset.