From a1054e8cbd302309b929801f683235e68e33ea17 Mon Sep 17 00:00:00 2001 From: pierre-ubuntu Date: Sun, 5 May 2024 13:49:09 -0700 Subject: [PATCH] use consisten rounding operation to retrieve all samples from StreamResampler --- pedalboard/io/ResampledReadableAudioFile.h | 2 +- pedalboard/io/StreamResampler.h | 11 +++++++---- tests/test_stream_resampler.py | 14 ++++++++++---- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/pedalboard/io/ResampledReadableAudioFile.h b/pedalboard/io/ResampledReadableAudioFile.h index ed0f2efbc..b5fbc1884 100644 --- a/pedalboard/io/ResampledReadableAudioFile.h +++ b/pedalboard/io/ResampledReadableAudioFile.h @@ -111,7 +111,7 @@ class ResampledReadableAudioFile length -= (std::round(resampler.getOutputLatency()) - resampler.getOutputLatency()); } - return (long)length; + return (long)std::round(length); } double getDuration() const { diff --git a/pedalboard/io/StreamResampler.h b/pedalboard/io/StreamResampler.h index 10e8c7386..333b41f7d 100644 --- a/pedalboard/io/StreamResampler.h +++ b/pedalboard/io/StreamResampler.h @@ -78,21 +78,24 @@ template class StreamResampler { targetSampleRate / sourceSampleRate) - totalSamplesOutput); - // TODO: Don't copy the entire input buffer multiple times here! + int roundedExpectedResampledSamples = + (int)std::round(expectedResampledSamples); + + // TODO: Dont copy the entire input buffer multiple times here! juce::AudioBuffer output(input.getNumChannels(), - (int)expectedResampledSamples); + roundedExpectedResampledSamples); for (size_t c = 0; c < input.getNumChannels(); c++) { if (input.getNumSamples() > 0) { long long inputSamplesConsumed = resamplers[c].process( resamplerRatio, input.getReadPointer(c), output.getWritePointer(c), - (int)expectedResampledSamples); + roundedExpectedResampledSamples); if (c == 0) { if (!isFlushing) { totalSamplesInput += inputSamplesConsumed; } - totalSamplesOutput += (int)expectedResampledSamples; + totalSamplesOutput += roundedExpectedResampledSamples; } if (!isFlushing) { diff --git a/tests/test_stream_resampler.py b/tests/test_stream_resampler.py index c1c42a239..6c0bf8f65 100644 --- a/tests/test_stream_resampler.py +++ b/tests/test_stream_resampler.py @@ -40,6 +40,7 @@ @pytest.mark.parametrize( "quality", TOLERANCE_PER_QUALITY.keys(), ids=[q.name for q in TOLERANCE_PER_QUALITY.keys()] ) +@pytest.mark.parametrize("num_seconds", [1.0, 1.23]) def test_stream_resample( fundamental_hz: float, sample_rate: float, @@ -47,18 +48,19 @@ def test_stream_resample( buffer_size: int, num_channels: int, quality: Resample.Quality, + num_seconds: float ): sine_wave = generate_sine_at( sample_rate, fundamental_hz, num_channels=num_channels, - num_seconds=1, + num_seconds=num_seconds, ).astype(np.float32) expected_sine_wave = generate_sine_at( target_sample_rate, fundamental_hz, num_channels=num_channels, - num_seconds=1, + num_seconds=num_seconds, ).astype(np.float32) if num_channels == 1: sine_wave = np.expand_dims(sine_wave, 0) @@ -73,8 +75,12 @@ def test_stream_resample( outputs.append(resampler.process(None)) output = np.concatenate(outputs, axis=1) - num_samples = min(output.shape[1], expected_sine_wave.shape[1]) + # In case we have a round number of input and output samples, + # we check that the number of output samples is as expected + if (num_seconds * sample_rate).is_integer() and (num_seconds * target_sample_rate).is_integer(): + assert output.shape[1] == expected_sine_wave.shape[1] + num_samples = min(output.shape[1], expected_sine_wave.shape[1]) np.testing.assert_allclose( expected_sine_wave[:, :num_samples], output[:, :num_samples], @@ -178,7 +184,7 @@ def test_flush(sample_rate: float, target_sample_rate: float, quality: Resample. @pytest.mark.parametrize( "quality", TOLERANCE_PER_QUALITY.keys(), ids=[q.name for q in TOLERANCE_PER_QUALITY.keys()] ) -def test_returned_sample_count( +def test_returned_sample_count_from_chunks( sample_rate: float, target_sample_rate: float, chunk_size: int, quality ): input_signal = np.linspace(0, 3, num=int(sample_rate), dtype=np.float32)