diff --git a/changelog/167.feature.rst b/changelog/167.feature.rst new file mode 100644 index 0000000..1d9215f --- /dev/null +++ b/changelog/167.feature.rst @@ -0,0 +1 @@ +Added :meth:`radiospectra.spectrogram.GenericSpectrogram.slice` which allows slicing of a spectrogram by time and frequency ranges. diff --git a/radiospectra/spectrogram/spectrogrambase.py b/radiospectra/spectrogram/spectrogrambase.py index f1cb2b3..9044bb1 100644 --- a/radiospectra/spectrogram/spectrogrambase.py +++ b/radiospectra/spectrogram/spectrogrambase.py @@ -1,3 +1,10 @@ +import numpy as np + +from astropy.time import Time + +from sunpy.net import attrs as a +from sunpy.time import parse_time + from radiospectra.exceptions import SpectraMetaValidationError from radiospectra.mixins import NonUniformImagePlotMixin, PcolormeshPlotMixin @@ -84,6 +91,55 @@ def frequencies(self): """ return self.meta["freqs"] + def slice(self, time=None, freq=None): + """ + times = [t0, t1, t2, t3, t4, t5] + freqs = [f0, f1, f2, f3, f4] + + Before slice method (manual slicing): + sliced_times = times[1:5] + sliced_freqs = freqs[1:4] + sliced_data = data[1:5, 1:4] + + After slice method: + sliced_data = slice(time=(t1, t4), freq=(f1, f3)) + """ + times = self.times + freqs = self.frequencies + + if time is not None: + t_start, t_end = time + if not isinstance(t_start, Time): + t_start = parse_time(t_start) + if not isinstance(t_end, Time): + t_end = parse_time(t_end) + time_mask = (times >= t_start) & (times <= t_end) + else: + time_mask = np.ones(len(times), dtype=bool) + + if freq is not None: + f_min, f_max = freq + if hasattr(f_min, "unit"): + f_min = f_min.to(freqs.unit) + if hasattr(f_max, "unit"): + f_max = f_max.to(freqs.unit) + freq_mask = (freqs >= f_min) & (freqs <= f_max) + else: + freq_mask = np.ones(len(freqs), dtype=bool) + + sliced_data = self.data[np.ix_(time_mask, freq_mask)] + sliced_times = times[time_mask] + sliced_freqs = freqs[freq_mask] + + new_meta = dict(self.meta) + new_meta["times"] = sliced_times + new_meta["freqs"] = sliced_freqs + new_meta["start_time"] = sliced_times[0] + new_meta["end_time"] = sliced_times[-1] + new_meta["wavelength"] = a.Wavelength(sliced_freqs.min(), sliced_freqs.max()) + + return self.__class__(sliced_data, new_meta) + def _validate_meta(self): """ Validates the meta-information associated with a Spectrogram. diff --git a/radiospectra/spectrogram/tests/test_spectrogrambase.py b/radiospectra/spectrogram/tests/test_spectrogrambase.py index ddd3a36..8966cbd 100644 --- a/radiospectra/spectrogram/tests/test_spectrogrambase.py +++ b/radiospectra/spectrogram/tests/test_spectrogrambase.py @@ -135,3 +135,95 @@ def test_plotim_uses_time_support_for_datetime_conversion(make_spectrogram): np.testing.assert_allclose(x_values, expected_tt) np.testing.assert_allclose(y_values, spec.frequencies.value) np.testing.assert_allclose(image, spec.data) + + +# --------- Tests for GenericSpectrogram.slice() --------- + + +def test_slice_by_time_only(make_spectrogram): + """Slicing by time should keep only matching rows.""" + spec = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + t0 = spec.times[1] + t1 = spec.times[2] + + sliced = spec.slice(time=(t0, t1)) + + assert sliced.data.shape == (2, 4) + np.testing.assert_array_equal(sliced.times, spec.times[1:3]) + np.testing.assert_array_equal(sliced.frequencies, spec.frequencies) + np.testing.assert_array_equal(sliced.data, spec.data[1:3, :]) + assert sliced.start_time == t0 + assert sliced.end_time == t1 + + +def test_slice_by_freq_only(make_spectrogram): + """Slicing by frequency should keep only matching columns.""" + spec = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + sliced = spec.slice(freq=(20 * u.kHz, 30 * u.kHz)) + + assert sliced.data.shape == (4, 2) + np.testing.assert_array_equal(sliced.frequencies, np.array([20, 30]) * u.kHz) + np.testing.assert_array_equal(sliced.data, spec.data[:, 1:3]) + + +def test_slice_by_time_and_freq(make_spectrogram): + """Slicing by both axes simultaneously.""" + spec = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + t0, t1 = spec.times[1], spec.times[2] + sliced = spec.slice(time=(t0, t1), freq=(20 * u.kHz, 30 * u.kHz)) + + assert sliced.data.shape == (2, 2) + np.testing.assert_array_equal(sliced.data, spec.data[1:3, 1:3]) + + +def test_slice_no_arguments_returns_copy(make_spectrogram): + """Calling slice() with no arguments returns equivalent spectrogram.""" + spec = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + sliced = spec.slice() + + assert sliced is not spec + assert sliced.data.shape == spec.data.shape + np.testing.assert_array_equal(sliced.data, spec.data) + np.testing.assert_array_equal(sliced.times, spec.times) + np.testing.assert_array_equal(sliced.frequencies, spec.frequencies) + + +def test_slice_with_string_times(make_spectrogram): + """Time range can be given as ISO-format strings.""" + spec = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + sliced = spec.slice(time=("2020-01-01 00:01", "2020-01-01 00:02")) + + assert sliced.data.shape[0] == 2 + assert sliced.start_time == spec.times[1] + assert sliced.end_time == spec.times[2] + + +def test_slice_freq_with_unit_conversion(make_spectrogram): + """Frequency limits in a different unit should be converted automatically.""" + spec = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + sliced = spec.slice(freq=(0.015 * u.MHz, 0.035 * u.MHz)) + + assert sliced.data.shape == (4, 2) + np.testing.assert_array_equal(sliced.frequencies.value, [20, 30]) + + +def test_slice_preserves_class(make_spectrogram): + """Sliced result should be the same class as the original.""" + spec = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + sliced = spec.slice(freq=(10 * u.kHz, 30 * u.kHz)) + + assert type(sliced) is type(spec) + + +def test_slice_does_not_modify_original(make_spectrogram): + """The original spectrogram must remain unchanged after slicing.""" + spec = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + original_shape = spec.data.shape + original_times_len = len(spec.times) + original_freqs_len = len(spec.frequencies) + + spec.slice(time=(spec.times[1], spec.times[2]), freq=(20 * u.kHz, 30 * u.kHz)) + + assert spec.data.shape == original_shape + assert len(spec.times) == original_times_len + assert len(spec.frequencies) == original_freqs_len