diff --git a/compressai/layers/layers.py b/compressai/layers/layers.py index 73fcbce1..9a75c0ca 100644 --- a/compressai/layers/layers.py +++ b/compressai/layers/layers.py @@ -27,10 +27,9 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from typing import Any, Tuple import math -from typing import Any - import torch import torch.nn as nn @@ -46,7 +45,8 @@ "ResidualBlock", "ResidualBlockUpsample", "ResidualBlockWithStride", - "conv1x1", + "SpectralConv2d", + "SpectralConvTranspose2d", "conv3x3", "subpel_conv3x3", "QReLU", @@ -54,6 +54,65 @@ ] +class _SpectralConvNdMixin: + def __init__(self, dim: Tuple[int, ...]): + self.dim = dim + self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight)) + del self._parameters["weight"] # Unregister weight, and fallback to property. + + @property + def weight(self) -> Tensor: + return self._from_transform_domain(self.weight_transformed) + + def _to_transform_domain(self, x: Tensor) -> Tensor: + return torch.fft.rfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho") + + def _from_transform_domain(self, x: Tensor) -> Tensor: + return torch.fft.irfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho") + + +class SpectralConv2d(nn.Conv2d, _SpectralConvNdMixin): + r"""Spectral 2D convolution. + + Introduced in [Balle2018efficient]. + Reparameterizes the weights to be derived from weights stored in the + frequency domain. + In the original paper, this is referred to as "spectral Adam" or + "Sadam" due to its effect on the Adam optimizer update rule. + The motivation behind representing the weights in the frequency + domain is that optimizer updates/steps may now affect all + frequencies to an equal amount. + This improves the gradient conditioning, thus leading to faster + convergence and increased stability at larger learning rates. + + For comparison, see the TensorFlow Compression implementations of + `SignalConv2D + `_ + and + `RDFTParameter + `_. + + [Balle2018efficient]: `"Efficient Nonlinear Transforms for Lossy + Image Compression" `_, + by Johannes Ballé, PCS 2018. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + _SpectralConvNdMixin.__init__(self, dim=(-2, -1)) + + +class SpectralConvTranspose2d(nn.ConvTranspose2d, _SpectralConvNdMixin): + r"""Spectral 2D transposed convolution. + + Transposed version of :class:`SpectralConv2d`. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + _SpectralConvNdMixin.__init__(self, dim=(-2, -1)) + + class MaskedConv2d(nn.Conv2d): r"""Masked 2D convolution implementation, mask future "unseen" pixels. Useful for building auto-regressive network components.