From 40243c06493214a0b4846b4c87554649d56959ba Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 5 Aug 2025 23:23:10 +0000 Subject: [PATCH 1/3] Remove generic lfilter loop --- src/libtorchaudio/lfilter.cpp | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/src/libtorchaudio/lfilter.cpp b/src/libtorchaudio/lfilter.cpp index 9d9b05c7d8..acb5f9607a 100644 --- a/src/libtorchaudio/lfilter.cpp +++ b/src/libtorchaudio/lfilter.cpp @@ -73,27 +73,6 @@ void cpu_lfilter_core_loop( }); } -void lfilter_core_generic_loop( - const torch::Tensor& input_signal_windows, - const torch::Tensor& a_coeff_flipped, - torch::Tensor& padded_output_waveform) { - int64_t n_samples_input = input_signal_windows.size(2); - int64_t n_order = a_coeff_flipped.size(1); - auto coeff = a_coeff_flipped.unsqueeze(2); - for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { - auto windowed_output_signal = - torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order) - .transpose(0, 1); - auto o0 = torch::select(input_signal_windows, 2, i_sample) - - at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1); - padded_output_waveform.index_put_( - {torch::indexing::Slice(), - torch::indexing::Slice(), - i_sample + n_order - 1}, - o0); - } -} - } // namespace TORCH_LIBRARY(torchaudio, m) { @@ -110,7 +89,3 @@ TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop); } #endif - -TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) { - m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop); -} From d923531fcb8542edad2ccf46bc5a43fa30164831 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 5 Sep 2025 14:48:23 +0000 Subject: [PATCH 2/3] Add dispatcher for mac core loop --- src/torchaudio/functional/filtering.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 1a7aa3e37e..c37b21445a 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -931,9 +931,14 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t() padded_output_waveform[:, :, i_sample + n_order - 1] = o0 +def _lfilter_core_loop_dispatcher(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor): + if input_signal_windows.is_cuda or input_signal_windows.is_cpu: + return torch.ops.torchaudio._lfilter_core_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) + else: + return _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) if _IS_TORCHAUDIO_EXT_AVAILABLE: - _lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop + _lfilter_core_loop = _lfilter_core_loop_dispatcher else: _lfilter_core_loop = _lfilter_core_generic_loop From 48f077ae13e89dc39f660ac52776582e04beace5 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 11 Sep 2025 14:49:33 +0300 Subject: [PATCH 3/3] Fix lint --- src/torchaudio/functional/filtering.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index c37b21445a..f302e602bd 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -931,12 +931,16 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t() padded_output_waveform[:, :, i_sample + n_order - 1] = o0 -def _lfilter_core_loop_dispatcher(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor): + +def _lfilter_core_loop_dispatcher( + input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor +): if input_signal_windows.is_cuda or input_signal_windows.is_cpu: return torch.ops.torchaudio._lfilter_core_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) else: return _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) + if _IS_TORCHAUDIO_EXT_AVAILABLE: _lfilter_core_loop = _lfilter_core_loop_dispatcher else: