From 81bd50f80dfa9cb209c5514ef7a9ee42773874c4 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Fri, 29 Aug 2025 11:35:01 +0300 Subject: [PATCH 1/3] Fix torchscript related test failures. --- src/torchaudio/functional/filtering.py | 50 ++++++++++++++++-------- src/torchaudio/functional/functional.py | 37 ++++++++++++++++-- src/torchaudio/transforms/_transforms.py | 3 +- 3 files changed, 69 insertions(+), 21 deletions(-) diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 1a7aa3e37e..9aec37a5f8 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs): b_coeff_flipped = b_coeffs.flip(1).contiguous() padded_waveform = F.pad(waveform, (n_order - 1, 0)) output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel) - ctx.save_for_backward(waveform, b_coeffs, output) + if not torch.jit.is_scripting(): + ctx.save_for_backward(waveform, b_coeffs, output) return output @staticmethod @@ -955,21 +956,28 @@ def backward(ctx, dy): n_batch = x.size(0) n_channel = x.size(1) n_order = b_coeffs.size(1) - db = ( - F.conv1d( - F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), - dy.view(n_batch * n_channel, 1, -1), - groups=n_batch * n_channel, - ) - .view(n_batch, n_channel, -1) - .sum(0) - .flip(1) - if b_coeffs.requires_grad - else None - ) - dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None + + db = F.conv1d( + F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), + dy.view(n_batch * n_channel, 1, -1), + groups=n_batch * n_channel + ).view( + n_batch, n_channel, -1 + ).sum(0).flip(1) if b_coeffs.requires_grad else None + dx = F.conv1d( + F.pad(dy, (0, n_order - 1)), + b_coeffs.unsqueeze(1), + groups=n_channel + ) if x.requires_grad else None return (dx, db) + @staticmethod + def ts_apply(waveform, b_coeffs): + if torch.jit.is_scripting(): + return DifferentiableFIR.forward(torch.empty(0), waveform, b_coeffs) + else: + return DifferentiableFIR.apply(waveform, b_coeffs) + class DifferentiableIIR(torch.autograd.Function): @staticmethod @@ -984,7 +992,8 @@ def forward(ctx, waveform, a_coeffs_normalized): ) _lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform) output = padded_output_waveform[:, :, n_order - 1 :] - ctx.save_for_backward(waveform, a_coeffs_normalized, output) + if not torch.jit.is_scripting(): + ctx.save_for_backward(waveform, a_coeffs_normalized, output) return output @staticmethod @@ -1006,10 +1015,17 @@ def backward(ctx, dy): ) return (dx, da) + @staticmethod + def ts_apply(waveform, a_coeffs_normalized): + if torch.jit.is_scripting(): + return DifferentiableIIR.forward(torch.empty(0), waveform, a_coeffs_normalized) + else: + return DifferentiableIIR.apply(waveform, a_coeffs_normalized) + def _lfilter(waveform, a_coeffs, b_coeffs): - filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1]) - return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1]) + filtered_waveform = DifferentiableFIR.ts_apply(waveform, b_coeffs / a_coeffs[:, 0:1]) + return DifferentiableIIR.ts_apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1]) def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor: diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 4070141958..e9abb84c61 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -848,7 +848,8 @@ def mask_along_axis_iid( if axis not in [dim - 2, dim - 1]: raise ValueError( - f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)." + "Only Frequency and Time masking are supported" + f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)." ) if not 0.0 <= p <= 1.0: @@ -920,7 +921,8 @@ def mask_along_axis( if axis not in [dim - 2, dim - 1]: raise ValueError( - f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)." + "Only Frequency and Time masking are supported" + f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)." ) if not 0.0 <= p <= 1.0: @@ -1732,6 +1734,35 @@ def backward(ctx, dy): result = grad * grad_out return (result, None, None, None, None, None, None, None) + @staticmethod + def ts_apply( + logits, + targets, + logit_lengths, + target_lengths, + blank: int, + clamp: float, + fused_log_softmax: bool): + if torch.jit.is_scripting(): + output, saved = torch.ops.torchaudio.rnnt_loss_forward( + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + fused_log_softmax) + return output + else: + return RnntLoss.apply( + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + fused_log_softmax) + def _rnnt_loss( logits: Tensor, @@ -1775,7 +1806,7 @@ def _rnnt_loss( if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank - costs = RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax) + costs = RnntLoss.ts_apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax) if reduction == "mean": return costs.mean() diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 08d2dcef11..7eb50da3f8 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -1202,7 +1202,8 @@ def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0 specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p ) else: - return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p) + mask_value_ = float(mask_value) if isinstance(mask_value, Tensor) else mask_value + return F.mask_along_axis(specgram, self.mask_param, mask_value_, self.axis + specgram.dim() - 3, p=self.p) class FrequencyMasking(_AxisMasking): From 8115d76149770b2073d1b0048cfa5e860723696f Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 10 Sep 2025 13:32:45 +0300 Subject: [PATCH 2/3] Rebase against main --- src/torchaudio/functional/filtering.py | 25 ++++++++++++----------- src/torchaudio/functional/functional.py | 27 ++++--------------------- 2 files changed, 17 insertions(+), 35 deletions(-) diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 9aec37a5f8..8f18b35de2 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -957,18 +957,19 @@ def backward(ctx, dy): n_channel = x.size(1) n_order = b_coeffs.size(1) - db = F.conv1d( - F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), - dy.view(n_batch * n_channel, 1, -1), - groups=n_batch * n_channel - ).view( - n_batch, n_channel, -1 - ).sum(0).flip(1) if b_coeffs.requires_grad else None - dx = F.conv1d( - F.pad(dy, (0, n_order - 1)), - b_coeffs.unsqueeze(1), - groups=n_channel - ) if x.requires_grad else None + db = ( + F.conv1d( + F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), + dy.view(n_batch * n_channel, 1, -1), + groups=n_batch * n_channel, + ) + .view(n_batch, n_channel, -1) + .sum(0) + .flip(1) + if b_coeffs.requires_grad + else None + ) + dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None return (dx, db) @staticmethod diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index e9abb84c61..884beec1f7 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1735,33 +1735,14 @@ def backward(ctx, dy): return (result, None, None, None, None, None, None, None) @staticmethod - def ts_apply( - logits, - targets, - logit_lengths, - target_lengths, - blank: int, - clamp: float, - fused_log_softmax: bool): + def ts_apply(logits, targets, logit_lengths, target_lengths, blank: int, clamp: float, fused_log_softmax: bool): if torch.jit.is_scripting(): output, saved = torch.ops.torchaudio.rnnt_loss_forward( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax) + logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax + ) return output else: - return RnntLoss.apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax) + return RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax) def _rnnt_loss( From e30a9b7a863bd1f13e9eebdf6a1249066e06e971 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 10 Sep 2025 13:50:01 +0300 Subject: [PATCH 3/3] Enable torchscript tests in CI workflow --- .github/scripts/unittest-linux/run_test.sh | 2 +- .github/scripts/unittest-windows/run_test.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index 6cc935b444..5b235c772c 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -34,5 +34,5 @@ fi export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_inflect=true export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_pytorch_lightning=true cd test - pytest torchaudio_unittest -k "not torchscript and not fairseq and not demucs ${PYTEST_K_EXTRA}" + pytest torchaudio_unittest -k "not fairseq and not demucs ${PYTEST_K_EXTRA}" ) diff --git a/.github/scripts/unittest-windows/run_test.sh b/.github/scripts/unittest-windows/run_test.sh index 25d8e14196..9f6ffb1375 100644 --- a/.github/scripts/unittest-windows/run_test.sh +++ b/.github/scripts/unittest-windows/run_test.sh @@ -12,5 +12,5 @@ python -m torch.utils.collect_env env | grep TORCHAUDIO || true cd test -pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not torchscript and not fairseq and not demucs and not librosa" +pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not fairseq and not demucs and not librosa" coverage html