Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions compressai/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Any, Tuple
import math

from typing import Any

import torch
import torch.nn as nn

Expand All @@ -46,14 +45,74 @@
"ResidualBlock",
"ResidualBlockUpsample",
"ResidualBlockWithStride",
"conv1x1",
"SpectralConv2d",
"SpectralConvTranspose2d",
"conv3x3",
"subpel_conv3x3",
"QReLU",
"sequential_channel_ramp",
]


class _SpectralConvNdMixin:
def __init__(self, dim: Tuple[int, ...]):
self.dim = dim
self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight))
del self._parameters["weight"] # Unregister weight, and fallback to property.

@property
def weight(self) -> Tensor:
return self._from_transform_domain(self.weight_transformed)

def _to_transform_domain(self, x: Tensor) -> Tensor:
return torch.fft.rfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho")

def _from_transform_domain(self, x: Tensor) -> Tensor:
return torch.fft.irfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho")


class SpectralConv2d(nn.Conv2d, _SpectralConvNdMixin):
r"""Spectral 2D convolution.

Introduced in [Balle2018efficient].
Reparameterizes the weights to be derived from weights stored in the
frequency domain.
In the original paper, this is referred to as "spectral Adam" or
"Sadam" due to its effect on the Adam optimizer update rule.
The motivation behind representing the weights in the frequency
domain is that optimizer updates/steps may now affect all
frequencies to an equal amount.
This improves the gradient conditioning, thus leading to faster
convergence and increased stability at larger learning rates.

For comparison, see the TensorFlow Compression implementations of
`SignalConv2D
<https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61>`_
and
`RDFTParameter
<https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71>`_.

[Balle2018efficient]: `"Efficient Nonlinear Transforms for Lossy
Image Compression" <https://arxiv.org/abs/1802.00847>`_,
by Johannes Ballé, PCS 2018.
"""

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
_SpectralConvNdMixin.__init__(self, dim=(-2, -1))


class SpectralConvTranspose2d(nn.ConvTranspose2d, _SpectralConvNdMixin):
r"""Spectral 2D transposed convolution.

Transposed version of :class:`SpectralConv2d`.
"""

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
_SpectralConvNdMixin.__init__(self, dim=(-2, -1))


class MaskedConv2d(nn.Conv2d):
r"""Masked 2D convolution implementation, mask future "unseen" pixels.
Useful for building auto-regressive network components.
Expand Down