diff --git a/data/transforms.py b/data/transforms.py index 412b47d..bb22bb6 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -70,6 +70,41 @@ def apply_mask(data, mask_func, seed=None): return data * mask, mask +def fft2c(x, dim=(-2, -1)): + """ Centered 2D Fast Fourier Transform + + Args: + x (torch.Tensor): Complex valued input data containing at least 3 + dimensions: dimensions -2 & -1 are spatial dimensions. All other + dimensions are assumed to be batch dimensions. + + dim (tuple): Dimensions to apply the FFT along. Default is (-2, -1) + + Returns: + torch.Tensor: The FFT of the input. + """ + x = torch.fft.ifftshift(x, dim=dim) + x = torch.fft.fft2(x, dim=dim) + return torch.fft.fftshift(x, dim=dim) + + +def ifft2c(x, dim=(-2, -1)): + """ Centered 2D Inverse Fast Fourier Transform + + Args: + x (torch.Tensor): Complex valued input data containing at least 3 + dimensions: dimensions -2 & -1 are spatial dimensions. All other + dimensions are assumed to be batch dimensions. + dim (tuple): Dimensions to apply the IFFT along. Default is (-2, -1) + + Returns: + torch.Tensor: The IFFT of the input. + """ + x = torch.fft.ifftshift(x, dim=dim) + x = torch.fft.ifft2(x, dim=dim) + return torch.fft.fftshift(x, dim=dim) + + def fft2(data, normalized=True): """ Apply centered 2 dimensional Fast Fourier Transform. diff --git a/models/Recurrent_Transformer.py b/models/Recurrent_Transformer.py index 9d87819..780f9ef 100644 --- a/models/Recurrent_Transformer.py +++ b/models/Recurrent_Transformer.py @@ -10,22 +10,17 @@ from torch.nn import functional as F import numpy as np -class DataConsistencyInKspace(nn.Module): - """ Create data consistency operator - - Warning: note that FFT2 (by the default of torch.fft) is applied to the last 2 axes of the input. - This method detects if the input tensor is 4-dim (2D data) or 5-dim (3D data) - and applies FFT2 to the (nx, ny) axis. - """ +class DataConsistencyInKspace(nn.Module): + """ Data consistency layer in k-space. """ def __init__(self): super(DataConsistencyInKspace, self).__init__() def forward(self, *input, **kwargs): return self.perform(*input) - - def data_consistency(self,k, k0, mask): + + def data_consistency(self, k, k0, mask): """ k - input in k-space k0 - initially sampled elements in k-space @@ -36,23 +31,33 @@ def data_consistency(self,k, k0, mask): return out def perform(self, x, k0, mask): + """ Forward pass to enforce data consistency in k-space. + + Args: + x (torch.Tensor): Input image in spatial domain (batch_size, 2, height, width). + k0 (torch.Tensor): Measured k-space data (batch_size, 2, height, width). + mask (torch.Tensor): Binary mask indicating sampled k-space locations (batch_size, 1, height, width). + + Returns: + torch.Tensor: Corrected image with the same shape as input. """ - x - input in image domain, of shape (n, 2, nx, ny[, nt]) - k0 - initially sampled elements in k-space - mask - corresponding nonzero location - """ - x = x.permute(0, 2, 3, 1) - k0 = k0.permute(0, 2, 3, 1) - mask = mask.permute(0, 2, 3, 1) + x_cx = torch.complex(x[:, 0], x[:, 1]).unsqueeze(1) + k0_cx = torch.complex(k0[:, 0], k0[:, 1]).unsqueeze(1) - k = transforms.fft2(x) + # Fourier transform + x_kspace = transforms.fft2c(x_cx) - out = self.data_consistency(k, k0, mask) - x_res = transforms.ifft2(out) + # Fill in k-space + x_kspace = self.data_consistency(x_kspace, k0_cx, mask) - x_res = x_res.permute(0, 3, 1, 2) + # Inverse Fourier transform + out = transforms.ifft2c(x_kspace) + + # Stack real and imaginary parts + out = torch.cat((out.real, out.imag), dim=1) + + return out - return x_res class RFB(nn.Module): """