diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index 64daa530..01717354 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -874,7 +874,8 @@ def __init__(self, domain, layout, broadcast_dims): self.layout = layout self.broadcast_dims = broadcast_dims # Determine deployment dimensions - deploy_dims_ext = np.array(broadcast_dims) & np.array(domain.constant) + constant_input_dims = np.array(domain.global_shape(layout, domain.dealias)) == 1 # includes None and size=1 bases + deploy_dims_ext = np.array(broadcast_dims) & constant_input_dims deploy_dims = deploy_dims_ext[~layout.local] # Build subcomm or skip casting if any(deploy_dims): diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index e7705f40..bd157583 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -9,10 +9,10 @@ from . import operators from ..libraries import spin_recombination from ..tools.array import kron, axslice, apply_matrix, permute_axis -from ..tools.cache import CachedAttribute, CachedMethod, CachedClass +from ..tools.cache import CachedAttribute, CachedMethod, CachedClass, CachedFunction from ..tools import jacobi from ..tools import clenshaw -from ..tools.array import reshape_vector, axindex, axslice, interleave_matrices +from ..tools.array import reshape_vector, axindex, axslice, interleave_matrices, sparse_block_diag from ..tools.dispatch import MultiClass, SkipDispatchException from ..tools.general import unify, DeferredTuple from .coords import Coordinate, CartesianCoordinates, S2Coordinates, SphericalCoordinates, PolarCoordinates, AzimuthalCoordinate, DirectProduct @@ -36,6 +36,9 @@ 'RealFourier', 'ComplexFourier', 'Fourier', + 'EvenParity', + 'OddParity', + 'Parity', 'DiskBasis', 'AnnulusBasis', 'SphereBasis', @@ -694,8 +697,8 @@ def _group_matrix(group, input_basis, output_basis): unit_amplitude = 1 / output_basis.constant_mode_value return np.array([[unit_amplitude]]) else: - # Constructor should only loop over group 0. - raise ValueError("This should never happen.") + # Constructor should only loop over group 0 + raise ValueError(f"This should never happen: group = {group}") class DifferentiateJacobi(operators.Differentiate, operators.SpectralOperator1D): @@ -705,16 +708,26 @@ class DifferentiateJacobi(operators.Differentiate, operators.SpectralOperator1D) subaxis_dependence = [True] subaxis_coupling = [True] + @classmethod + def _check_args(cls, operand, coord, order=1, out=None): + # Only integer derivatives implemented + if float(order).is_integer() and order > 0: + return super()._check_args(operand, coord, order=order, out=out) + return False + @staticmethod - def _output_basis(input_basis): - return input_basis.derivative_basis(order=1) + def _output_basis(input_basis, order): + return input_basis.derivative_basis(order=int(order)) @staticmethod @CachedMethod - def _full_matrix(input_basis, output_basis): + def _full_matrix(input_basis, output_basis, order): N = input_basis.size a, b = input_basis.a, input_basis.b - matrix = jacobi.differentiation_matrix(N, a, b) / input_basis.COV.stretch + matrix = jacobi.differentiation_matrix(N, a, b) + for i in range(1, int(order)): + matrix = jacobi.differentiation_matrix(N, a+i, b+i) @ matrix + matrix *= input_basis.COV.stretch ** (-order) return matrix.tocsr() @@ -855,6 +868,7 @@ def __init__(self, coord, size, bounds, dealias=(1,), library=None): self.dealias = dealias self.library = library # Other attributes + self.grid_params = (coord, bounds, dealias, library) self.constant_mode_value = 1 # No permutations by default self.forward_coeff_permutation = None @@ -881,6 +895,18 @@ def __rmatmul__(self, other): def __pow__(self, other): return self + def derivative_basis(self, order=1): + return self + + def gradient_basis(self): + return self + + def component_basis(self): + return self + + def skew_basis(self): + return self + def elements_to_groups(self, grid_space, elements): if grid_space[0]: groups = elements @@ -1005,13 +1031,12 @@ class ConvertConstantComplexFourier(operators.ConvertConstant, operators.Spectra @staticmethod def _group_matrix(group, input_basis, output_basis): - # Rescale group (native wavenumber) to get physical wavenumber - k = group / output_basis.COV.stretch # 1 = exp(1j*0*x) - if k == 0: + if group == 0: unit_amplitude = 1 / output_basis.constant_mode_value return np.array([[unit_amplitude]]) else: + # Return zero-width column for subproblem construction return np.zeros(shape=(1, 0)) @@ -1022,6 +1047,46 @@ class DifferentiateComplexFourier(operators.Differentiate, operators.SpectralOpe subaxis_dependence = [True] subaxis_coupling = [False] + @staticmethod + def _output_basis(input_basis, order): + return input_basis + + @staticmethod + def _group_matrix(group, input_basis, output_basis, order): + # Rescale group (native wavenumber) to get physical wavenumber + k = group / input_basis.COV.stretch + # dx**n exp(1j*k*x) = (1j*k)**n * exp(1j*k*x) + S = (1j * k) ** order + return np.array([[S]]) + + +class RieszDerivativeComplexFourier(operators.RieszDerivative, operators.SpectralOperator1D): + """ComplexFourier Riesz derivative.""" + + input_basis_type = ComplexFourier + subaxis_dependence = [True] + subaxis_coupling = [False] + + @staticmethod + def _output_basis(input_basis, order): + return input_basis + + @staticmethod + def _group_matrix(group, input_basis, output_basis, order): + # Rescale group (native wavenumber) to get physical wavenumber + k = group / input_basis.COV.stretch + # R_a exp(1j*k*x) = - |k|**a * exp(1j*k*x) + S = - abs(k) ** order + return np.array([[S]]) + + +class HilbertTransformComplexFourier(operators.HilbertTransform, operators.SpectralOperator1D): + """ComplexFourier Hilbert transform.""" + + input_basis_type = ComplexFourier + subaxis_dependence = [True] + subaxis_coupling = [False] + @staticmethod def _output_basis(input_basis): return input_basis @@ -1030,8 +1095,9 @@ def _output_basis(input_basis): def _group_matrix(group, input_basis, output_basis): # Rescale group (native wavenumber) to get physical wavenumber k = group / input_basis.COV.stretch - # dx exp(1j*k*x) = 1j * k * exp(1j*k*x) - return np.array([[1j*k]]) + # Hx exp(1j*k*x) = -1j * sgn(k) * exp(1j*k*x) + S = -1j * np.sign(k) + return np.array([[S]]) class InterpolateComplexFourier(operators.Interpolate, operators.SpectralOperator1D): @@ -1077,8 +1143,8 @@ def _group_matrix(group, input_basis, output_basis): L = input_basis.COV.problem_length return np.array([[L]]) else: - # Constructor should only loop over group 0. - raise ValueError("This should never happen.") + # Constructor should only loop over group 0 + raise ValueError(f"This should never happen: group = {group}") class AverageComplexFourier(operators.Average, operators.SpectralOperator1D): @@ -1101,8 +1167,8 @@ def _group_matrix(group, input_basis, output_basis): if k == 0: return np.array([[1]]) else: - # Constructor should only loop over group 0. - raise ValueError("This should never happen.") + # Constructor should only loop over group 0 + raise ValueError(f"This should never happen: group = {group}") class RealFourier(FourierBase, metaclass=CachedClass): @@ -1192,14 +1258,12 @@ class ConvertConstantRealFourier(operators.ConvertConstant, operators.SpectralOp @staticmethod def _group_matrix(group, input_basis, output_basis): - # Rescale group (native wavenumber) to get physical wavenumber - k = group / output_basis.COV.stretch # 1 = cos(0*x) - if k == 0: + if group == 0: unit_amplitude = 1 / output_basis.constant_mode_value - return np.array([[unit_amplitude], - [0]]) + return np.array([[unit_amplitude], [0]]) else: + # Return zero-width column for subproblem construction return np.zeros(shape=(2, 0)) @@ -1211,17 +1275,61 @@ class DifferentiateRealFourier(operators.Differentiate, operators.SpectralOperat subaxis_coupling = [False] @staticmethod - def _output_basis(input_basis): + def _output_basis(input_basis, order): return input_basis @staticmethod - def _group_matrix(group, input_basis, output_basis): + def _group_matrix(group, input_basis, output_basis, order): # Rescale group (native wavenumber) to get physical wavenumber k = group / input_basis.COV.stretch - # dx cos(k*x) = k * -sin(k*x) - # dx -sin(k*x) = -k * cos(k*x) - return np.array([[0, -k], - [k, 0]]) + # dx^n exp(ikx) = (ik)^n exp(ikx) = S exp(ikx) + # dx^n cos(kx) = R(S) cos(kx) - I(S) sin(kx) + # dx^n sin(kx) = R(S) sin(kx) + I(S) cos(kx) + S = (1j * k) ** order + return np.array([[S.real, -S.imag], + [S.imag, S.real]]) + + +class RieszDerivativeRealFourier(operators.RieszDerivative, operators.SpectralOperator1D): + """RealFourier Riesz derivative.""" + + input_basis_type = RealFourier + subaxis_dependence = [True] + subaxis_coupling = [False] + + @staticmethod + def _output_basis(input_basis, order): + return input_basis + + @staticmethod + def _group_matrix(group, input_basis, output_basis, order): + # Rescale group (native wavenumber) to get physical wavenumber + k = group / input_basis.COV.stretch + # R_a exp(ikx) = - |k|^n exp(ikx) = S exp(ikx) + # R_a cos(kx) = S cos(kx) + # R_a sin(kx) = S sin(kx) + S = - abs(k) ** order + return np.array([[S, 0], + [0, S]]) + + +class HilbertTransformRealFourier(operators.HilbertTransform, operators.SpectralOperator1D): + """RealFourier Hilbert transform.""" + + input_basis_type = RealFourier + subaxis_dependence = [True] + subaxis_coupling = [False] + + @staticmethod + def _output_basis(input_basis): + return input_basis + + @staticmethod + def _group_matrix(group, input_basis, output_basis): + # Hx cos(kx) = sin(kx) + # Hx sin(kx) = -cos(kx) + return np.array([[ 0, 1], + [-1, 0]]) class InterpolateRealFourier(operators.Interpolate, operators.SpectralOperator1D): @@ -1239,7 +1347,6 @@ def _output_basis(input_basis, position): @staticmethod def _full_matrix(input_basis, output_basis, position): # Build native interpolation vector - # Interleaved cos(k*x), -sin(k*x) x = input_basis.COV.native_coord(position) k = input_basis.native_wavenumbers interp_vector = np.zeros(k.size) @@ -1265,14 +1372,14 @@ def _output_basis(input_basis): def _group_matrix(group, input_basis, output_basis): # Rescale group (native wavenumber) to get physical wavenumber k = group / input_basis.COV.stretch - # integ cos(k*x) = L * δ(k, 0) - # integ -sin(k*x) = 0 + # integ cos(kx) dx = L (k==0) + # integ sin(kx) dx = 0 if k == 0: L = input_basis.COV.problem_length return np.array([[L, 0]]) else: - # Constructor should only loop over group 0. - raise ValueError("This should never happen.") + # Constructor should only loop over group 0 + raise ValueError(f"This should never happen: group = {group}") class AverageRealFourier(operators.Average, operators.SpectralOperator1D): @@ -1291,250 +1398,533 @@ def _output_basis(input_basis): def _group_matrix(group, input_basis, output_basis): # Rescale group (native wavenumber) to get physical wavenumber k = group / input_basis.COV.stretch - # integ cos(k*x) / L = δ(k, 0) - # integ -sin(k*x) / L = 0 + # integ cos(kx) dx / L = (k==0) + # integ sin(kx) dx / L = 0 if k == 0: return np.array([[1, 0]]) else: - # Constructor should only loop over group 0. - raise ValueError("This should never happen.") + # Constructor should only loop over group 0 + raise ValueError(f"This should never happen: group = {group}") + + +class ParityBase(FourierBase): + """Base class for parity-based bases.""" + + native_bounds = (0, np.pi) + default_library = "fftw" + group_shape = (1,) + transforms = {} + @CachedAttribute + def _native_wavenumbers(self): + # Excludes Nyquist mode + kmax = self.size - 1 + return np.arange(0, kmax+1) -# class HilbertTransformFourier(operators.HilbertTransform): -# """Fourier series Hilbert transform.""" - -# input_basis_type = Fourier -# bands = [-1, 1] -# separable = True - -# @staticmethod -# def output_basis(space, input_basis): -# return space.Fourier - -# @staticmethod -# def _build_subspace_entry(i, j, space, input_basis): -# # Hx(cos(n*x)) = sin(n*x) -# # Hx(sin(n*x)) = -cos(n*x) -# n = j // 2 -# if n == 0: -# return 0 -# elif (j % 2) == 0: -# # Hx(cos(n*x)) = sin(n*x) -# if i == (j + 1): -# return 1 -# else: -# return 0 -# else: -# # Hx(sin(n*x)) = -cos(n*x) -# if i == (j - 1): -# return (-1) -# else: -# return 0 - - -# class Sine(Basis, metaclass=CachedClass): -# """Sine series basis.""" -# space_type = ParityInterval -# const = None -# supported_dtypes = {np.float64, np.complex128} - -# def __add__(self, other): -# space = self.space -# if other is space.Sine: -# return space.Sine -# else: -# return NotImplemented - -# def __mul__(self, other): -# space = self.space -# if other is None: -# return space.Sine -# elif other is space.Sine: -# return space.Cosine -# elif other is space.Cosine: -# return space.Sine -# else: -# return NotImplemented - -# def __pow__(self, other): -# space = self.space -# if (other % 2) == 0: -# return space.Cosine -# elif (other % 2) == 1: -# return space.Sine -# else: -# return NotImplemented - -# def include_mode(self, mode): -# # Drop k=0 and Nyquist mode -# k = mode -# return (1 <= k <= self.space.kmax) - - -# class Cosine(Basis, metaclass=CachedClass): -# """Cosine series basis.""" -# space_type = ParityInterval -# const = 1 - -# def __add__(self, other): -# space = self.space -# if other is None: -# return space.Cosine -# elif other is space.Cosine: -# return space.Cosine -# else: -# return NotImplemented - -# def __mul__(self, other): -# space = self.space -# if other is None: -# return space.Cosine -# elif other is space.Sine: -# return space.Sine -# elif other is space.Cosine: -# return space.Cosine -# else: -# return NotImplemented - -# def __pow__(self, other): -# return self.space.Cosine - -# def include_mode(self, mode): -# # Drop Nyquist mode -# k = mode -# return (0 <= k <= self.space.kmax) - - -# class InterpolateSine(operators.Interpolate): -# """Sine series interpolation.""" - -# input_basis_type = Sine - -# @staticmethod -# def _build_subspace_entry(j, space, input_basis, position): -# # sin(n*x) -# x = space.COV.native_coord(position) -# return math.sin(j*x) - - -# class InterpolateCosine(operators.Interpolate): -# """Cosine series interpolation.""" - -# input_basis_type = Cosine - -# @staticmethod -# def _build_subspace_entry(j, space, input_basis, position): -# # cos(n*x) -# x = space.COV.native_coord(position) -# return math.cos(j*x) - - -# class IntegrateSine(operators.Integrate): -# """Sine series integration.""" - -# input_basis_type = Sine - -# @staticmethod -# def _build_subspace_entry(j, space, input_basis): -# # integral(sin(n*x), 0, pi) = (2 / n) * (n % 2) -# if (j % 2): -# return 0 -# else: -# return (2 / j) * space.COV.stretch - - -# class IntegrateCosine(operators.Integrate): -# """Cosine series integration.""" - -# input_basis_type = Cosine - -# @staticmethod -# def _build_subspace_entry(j, space, input_basis): -# # integral(cos(n*x), 0, pi) = pi * δ(n, 0) -# if j == 0: -# return np.pi * space.COV.stretch -# else: -# return 0 + def _native_grid(self, scale): + """Native flat global grid.""" + N, = self.grid_shape((scale,)) + return (np.pi / N) * (1/2 + np.arange(N)) + @CachedMethod + def _transform_plan(self, grid_size, coeff_size, parity, library): + """Caching layer to share plans across even/odd parity bases.""" + # Use matrix transforms for trivial cases + if (grid_size == 1 or coeff_size == 1) and (library != "matrix"): + return self._transform_plan(grid_size, coeff_size, parity, "matrix") + parity_name = {1: "cos", -1: "sin"}[parity] + return ParityBase.transforms[f"{library}-{parity_name}"](grid_size, coeff_size) + + def transform_plan(self, grid_size, parity): + """Build transform plan.""" + return self._transform_plan(grid_size, self.size, parity, self.library) -# class DifferentiateSine(operators.Differentiate): -# """Sine series differentiation.""" + def forward_transform(self, field, axis, gdata, cdata): + # Get plans + data_axis = len(field.tensorsig) + axis + grid_size = gdata.shape[data_axis] + P = self.component_parity(field.tensorsig) + parity_plans = {p: self.transform_plan(grid_size, p) for p in np.unique(P)} + # Transform component-by-component + gdata = gdata.copy() # Copy to avoid overwrite errors with dealiasing + for i, p in np.ndenumerate(P): + parity_plans[p].forward(gdata[i], cdata[i], axis) + # Permute coefficients + if self.forward_coeff_permutation is not None: + permute_axis(cdata, axis+len(field.tensorsig), self.forward_coeff_permutation, out=cdata) -# input_basis_type = Sine -# bands = [0] -# separable = True + def backward_transform(self, field, axis, cdata, gdata): + # Get plans + data_axis = len(field.tensorsig) + axis + grid_size = gdata.shape[data_axis] + P = self.component_parity(field.tensorsig) + parity_plans = {p: self.transform_plan(grid_size, p) for p in np.unique(P)} + # Permute coefficients + if self.backward_coeff_permutation is not None: + permute_axis(cdata, axis+len(field.tensorsig), self.backward_coeff_permutation, out=cdata) + # Transform component-by-component + cdata = cdata.copy() # Copy to avoid overwrite errors with dealiasing + for i, p in np.ndenumerate(P): + parity_plans[p].backward(cdata[i], gdata[i], axis) -# @staticmethod -# def output_basis(space, input_basis): -# return space.Cosine + def derivative_basis(self, order=1): + if order % 2 == 0: + return self + elif order % 2 == 1: + return self.opposite_parity() + else: + raise ValueError(f"Invalid derivative order: {order}") -# @staticmethod -# def _build_subspace_entry(i, j, space, input_basis): -# # dx(sin(n*x)) = n*cos(n*x) -# if i == j: -# return j / space.COV.stretch -# else: -# return 0 + def gradient_basis(self): + return self + def component_basis(self): + return self.opposite_parity() -# class DifferentiateCosine(operators.Differentiate): -# """Cosine series differentiation.""" + def skew_basis(self): + return self.opposite_parity() -# input_basis_type = Cosine -# bands = [0] -# separable = True -# @staticmethod -# def output_basis(space, input_basis): -# return space.Sine +def Parity(*args, parity=None, **kw): + """Factory function dispatching to EvenParity and OddParity based on provided parity.""" + if parity is None: + raise ValueError("parity must be specified") + elif parity == 1: + return EvenParity(*args, **kw) + elif parity == -1: + return OddParity(*args, **kw) + else: + raise ValueError(f"Unrecognized parity: {parity}") -# @staticmethod -# def _build_subspace_entry(i, j, space, input_basis): -# # dx(cos(n*x)) = -n*sin(n*x) -# if i == j: -# return (-j) / space.COV.stretch -# else: -# return 0 +class EvenParity(ParityBase, metaclass=CachedClass): + """Even parity basis (cosine series for scalars).""" -# class HilbertTransformSine(operators.HilbertTransform): -# """Sine series Hilbert transform.""" + def __add__(self, other): + if other is None: + return self + if other is self: + return self + if isinstance(other, EvenParity): + if self.grid_params == other.grid_params: + size = max(self.size, other.size) + return self.clone_with(size=size) + return NotImplemented -# input_basis_type = Sine -# bands = [0] -# separable = True + def __mul__(self, other): + if other is None: + return self + if other is self: + return self + if isinstance(other, OddParity): + if self.grid_params == other.grid_params: + size = max(self.size, other.size) + return OddParity(self.coord, size, self.bounds, self.dealias, self.library) + if isinstance(other, EvenParity): + if self.grid_params == other.grid_params: + size = max(self.size, other.size) + return EvenParity(self.coord, size, self.bounds, self.dealias, self.library) + return NotImplemented -# @staticmethod -# def output_basis(space, input_basis): -# return space.Cosine + def __pow__(self, other): + return NotImplemented -# @staticmethod -# def _build_subspace_entry(i, j, space, input_basis): -# # Hx(sin(n*x)) = -cos(n*x) -# if i == j: -# return (-1) -# else: -# return 0 + def valid_elements(self, tensorsig, grid_space, elements): + vshape = tuple(cs.dim for cs in tensorsig) + elements[0].shape + valid = np.ones(shape=vshape, dtype=bool) + return valid + def opposite_parity(self): + return OddParity(self.coord, self.size, self.bounds, self.dealias, self.library) -# class HilbertTransformCosine(operators.HilbertTransform): -# """Cosine series Hilbert transform.""" + @CachedMethod + def component_parity(self, tensorsig): + P = np.ones([cs.dim for cs in tensorsig], dtype=int) + coord = self.coord + for i, cs in enumerate(tensorsig): + if coord in cs.coords: + start = cs.coords.index(coord) + P[axslice(i, start, start+1)] *= -1 + return P -# input_basis_type = Cosine -# bands = [0] -# separable = True -# @staticmethod -# def output_basis(space, input_basis): -# return space.Sine +class OddParity(ParityBase, metaclass=CachedClass): + """Odd parity basis (sine series for scalars).""" -# @staticmethod -# def _build_subspace_entry(i, j, space, input_basis): -# # Hx(cos(n*x)) = sin(n*x) -# if i == j: -# return 1 -# else: -# return 0 + def __add__(self, other): + if other is self: + return self + if isinstance(other, OddParity): + if self.grid_params == other.grid_params: + size = max(self.size, other.size) + return self.clone_with(size=size) + return NotImplemented + + def __mul__(self, other): + if other is None: + return self + if other is self: + return EvenParity(self.coord, self.size, self.bounds, self.dealias, self.library) + if isinstance(other, OddParity): + if self.grid_params == other.grid_params: + size = max(self.size, other.size) + return EvenParity(self.coord, size, self.bounds, self.dealias, self.library) + if isinstance(other, EvenParity): + if self.grid_params == other.grid_params: + size = max(self.size, other.size) + return OddParity(self.coord, size, self.bounds, self.dealias, self.library) + return NotImplemented + + def __pow__(self, other): + return NotImplemented + + def valid_elements(self, tensorsig, grid_space, elements): + vshape = tuple(cs.dim for cs in tensorsig) + elements[0].shape + valid = np.ones(shape=vshape, dtype=bool) + if not grid_space[0]: + # Drop sine part of k=0 + groups = self.elements_to_groups(grid_space, elements) + allcomps = tuple(slice(None) for cs in tensorsig) + selection = (groups[0] == 0) + valid[allcomps + (selection,)] = False + return valid + + def opposite_parity(self): + return EvenParity(self.coord, self.size, self.bounds, self.dealias, self.library) + + @CachedMethod + def component_parity(self, tensorsig): + P = np.ones([cs.dim for cs in tensorsig], dtype=int) + coord = self.coord + for i, cs in enumerate(tensorsig): + if coord in cs.coords: + start = cs.coords.index(coord) + P[axslice(i, start, start+1)] *= -1 + return -P + + +class SpectralOperatorParity(operators.SpectralOperator1D): + """Base class for spectral operators that operate on parity-based bases.""" + + @CachedAttribute + def operand_component_parity(self): + return self.input_basis.component_parity(self.operand.tensorsig) + + def subproblem_matrix(self, subproblem): + """Build operator matrix for a specific subproblem.""" + axis = self.last_axis + group = subproblem.group[axis] + # Build matrices for each parity + P = self.operand_component_parity + if group is None: + parity_matrices = {p: self.subspace_matrix(self.dist.coeff_layout, p) for p in np.unique(P)} + else: + parity_matrices = {p: self.group_matrix(group, p) for p in np.unique(P)} + # Kronecker up to proper size + shape = subproblem.coeff_shape(self.domain) + N_before = prod(shape[:axis]) + if N_before > 1: + I_before = sparse.identity(N_before, format='coo') # COO faster for kron + for p in parity_matrices: + parity_matrices[p] = sparse.kron(I_before, parity_matrices[p]) + N_after = prod(shape[axis+1:]) + if N_after > 1: + I_after = sparse.identity(N_after, format='coo') # COO faster for kron + for p in parity_matrices: + parity_matrices[p] = sparse.kron(parity_matrices[p], I_after) + # Block diagonalize over components + blocks = [parity_matrices[p] for p in P.ravel()] + return sparse_block_diag(blocks) + + def group_matrix(self, group, parity): + return self._group_matrix(group, self.input_basis, self.output_basis, parity) + + def subspace_matrix(self, layout, parity): + """Build matrix operating on local subspace data.""" + # Caching layer to allow insertion of other arguments + return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis, parity) + + def operate(self, out): + operand = self.args[0] + input_basis = self.input_basis + data_axis = self.last_axis + # Set output layout + out.preset_layout(operand.layout) + # Apply operator to each component + P = self.operand_component_parity + parity_matrices = {p: self.subspace_matrix(operand.layout, p) for p in np.unique(P)} + for i, p in np.ndenumerate(P): + # TODO: check shapes on first try + apply_matrix(parity_matrices[p], operand.data[i], axis=data_axis, out=out.data[i]) + + +class ConvertConstantEvenParity(operators.ConvertConstant, SpectralOperatorParity): + """ + Upcast constants to EvenParity. + Note: this is a single implementation because tensor fields have both even and odd components. + """ + + output_basis_type = EvenParity + subaxis_dependence = [True] + subaxis_coupling = [False] + + def __init__(self, operand, output_basis, out=None): + super().__init__(operand, output_basis, out=out) + for cs in self.operand.tensorsig: + if self.coords in cs.coords: + raise NotImplementedError("Converting constant tensor fields to EvenParity is not supported.") + + @CachedAttribute + def operand_component_parity(self): + return self.output_basis.component_parity(self.operand.tensorsig) + + @staticmethod + def _group_matrix(group, input_basis, output_basis, parity): + if parity != 1: + raise ValueError(f"This should never happen: parity = {parity}") + # 1 = cos(0*x) + if group == 0: + unit_amplitude = 1 / output_basis.constant_mode_value + return np.array([[unit_amplitude]]) + else: + # Return zero-width column for subproblem construction + return np.zeros(shape=(1, 0)) + + +class InterpolateParity(operators.Interpolate, SpectralOperatorParity): + """ + Parity basis interpolation. + Note: this is a single implementation because tensor fields have both even and odd components. + """ + + input_basis_type = ParityBase + basis_subaxis = 0 + subaxis_dependence = [True] + subaxis_coupling = [True] + + @staticmethod + def _output_basis(input_basis, position): + return None + + def subspace_matrix(self, layout, parity): + """Build matrix operating on global subspace data.""" + return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis, self.position, parity) + + @staticmethod + def _full_matrix(input_basis, output_basis, position, parity): + # Build native interpolation vector + x = input_basis.COV.native_coord(position) + k = input_basis.native_wavenumbers + if parity == 1: + interp_vector = np.cos(k * x) + elif parity == -1: + interp_vector = np.sin(k * x) + else: + raise ValueError(f"This should never happen: parity = {parity}") + # Return with shape (1, N) + return interp_vector[None, :] + + +class DifferentiateParity(operators.Differentiate, SpectralOperatorParity): + """ + Parity basis differentiation. + Note: this is a single implementation because tensor fields have both even and odd components. + """ + + input_basis_type = ParityBase + subaxis_dependence = [True] + subaxis_coupling = [False] + + @staticmethod + def _output_basis(input_basis, order): + return input_basis.derivative_basis(order) + + def subspace_matrix(self, layout, parity): + """Build matrix operating on local subspace data.""" + # Caching layer to allow insertion of other arguments + return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis, self.order, parity) + + def group_matrix(self, group, parity): + return self._group_matrix(group, self.input_basis, self.output_basis, self.order, parity) + + @staticmethod + def _group_matrix(group, input_basis, output_basis, order, parity): + # Rescale group (native wavenumber) to get physical wavenumber + k = group / input_basis.COV.stretch + if parity == 1: + # dx^1 cos(kx) = -k^1 sin(kx) + # dx^2 cos(kx) = -k^2 cos(kx) + # dx^3 cos(kx) = k^3 sin(kx) + # dx^4 cos(kx) = k^4 cos(kx) + S = k**order * (-1)**((order+1)//2) + return np.array([[S]]) + elif parity == -1: + # dx^1 sin(kx) = k^1 cos(kx) + # dx^2 sin(kx) = -k^2 sin(kx) + # dx^3 sin(kx) = -k^3 cos(kx) + # dx^4 sin(kx) = k^4 sin(kx) + S = k**order * (-1)**(order//2) + return np.array([[S]]) + else: + raise ValueError(f"This should never happen: parity = {parity}") + + +class RieszDerivativeParity(operators.RieszDerivative, SpectralOperatorParity): + """ + Parity basis Riesz derivative. + Note: this is a single implementation because tensor fields have both even and odd components. + """ + + input_basis_type = ParityBase + subaxis_dependence = [True] + subaxis_coupling = [False] + + @staticmethod + def _output_basis(input_basis, order): + return input_basis + + def subspace_matrix(self, layout, parity): + """Build matrix operating on local subspace data.""" + # Caching layer to allow insertion of other arguments + return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis, self.order, parity) + + def group_matrix(self, group, parity): + return self._group_matrix(group, self.input_basis, self.output_basis, self.order, parity) + + @staticmethod + def _group_matrix(group, input_basis, output_basis, order, parity): + # Rescale group (native wavenumber) to get physical wavenumber + k = group / input_basis.COV.stretch + # R_a exp(ikx) = - |k|^n exp(ikx) = S exp(ikx) + S = - abs(k) ** order + if parity == 1: + # R_a cos(kx) = S cos(kx) + return np.array([[S]]) + elif parity == -1: + # R_a sin(kx) = S sin(kx) + return np.array([[S]]) + else: + raise ValueError(f"This should never happen: parity = {parity}") + + +class HilbertTransformParity(operators.HilbertTransform, SpectralOperatorParity): + """ + Parity basis Hilbert transform. + Note: this is a single implementation because tensor fields have both even and odd components. + """ + + input_basis_type = ParityBase + subaxis_dependence = [True] + subaxis_coupling = [False] + + @staticmethod + def _output_basis(input_basis): + return input_basis.derivative_basis() + + @staticmethod + def _group_matrix(group, input_basis, output_basis, parity): + if parity == 1: + # Hx cos(kx) = sin(kx) + return np.array([[1]]) + elif parity == -1: + # Hx sin(kx) = -cos(kx) + return np.array([[-1]]) + else: + raise ValueError(f"This should never happen: parity = {parity}") + + +class IntegrateEvenParity(operators.Integrate, operators.SpectralOperator1D): + """EvenParity basis integration.""" + + input_coord_type = Coordinate + input_basis_type = EvenParity + subaxis_dependence = [True] + subaxis_coupling = [False] + + @staticmethod + def _output_basis(input_basis): + return None + + @staticmethod + def _group_matrix(group, input_basis, output_basis): + # \int_{0}^{\pi} cos(kx) dx = \pi (k=0) + if group == 0: + L = input_basis.COV.problem_length + return np.array([[L]]) + else: + # Constructor should only loop over group 0 + raise ValueError(f"This should never happen: group = {group}") + + +class IntegrateOddParity(operators.Integrate, operators.SpectralOperator1D): + """OddParity basis integration.""" + + input_coord_type = Coordinate + input_basis_type = OddParity + subaxis_dependence = [True] + subaxis_coupling = [True] + + @staticmethod + def _output_basis(input_basis): + return None + + @staticmethod + def _full_matrix(input_basis, output_basis): + # Build native integration vector + # \int_{0}^{\pi} sin(kx) dx = (1 - (-1)^k) / k + k = input_basis.native_wavenumbers + k_odd = (k % 2 == 1) + integ_vector = np.zeros(k.size) + integ_vector[k_odd] = 2 / k[k_odd] * input_basis.COV.stretch + # Return with shape (1, N) + return integ_vector[None, :] + + +class AverageEvenParity(operators.Average, operators.SpectralOperator1D): + """EvenParity basis averaging.""" + + input_coord_type = Coordinate + input_basis_type = EvenParity + subaxis_dependence = [True] + subaxis_coupling = [False] + + @staticmethod + def _output_basis(input_basis): + return None + + @staticmethod + def _group_matrix(group, input_basis, output_basis): + # \int_{0}^{\pi} cos(kx) dx = \pi (k=0) + if group == 0: + return np.array([[1]]) + else: + # Constructor should only loop over group 0 + raise ValueError(f"This should never happen: group = {group}") + + +class AverageOddParity(operators.Average, operators.SpectralOperator1D): + """OddParity basis averaging.""" + + input_coord_type = Coordinate + input_basis_type = OddParity + subaxis_dependence = [True] + subaxis_coupling = [True] + + @staticmethod + def _output_basis(input_basis): + return None + + @staticmethod + def _full_matrix(input_basis, output_basis): + # Build native integration vector + # \int_{0}^{\pi} sin(kx) dx = (1 - (-1)^k) / k + k = input_basis.native_wavenumbers + k_odd = (k % 2 == 1) + integ_vector = np.zeros(k.size) + integ_vector[k_odd] = 2 / k[k_odd] + ave_vector = integ_vector / np.pi + # Return with shape (1, N) + return ave_vector[None, :] class MultidimensionalBasis(Basis): @@ -2402,6 +2792,9 @@ def __rmatmul__(self, other): return self.clone_with(shape=shape, k=k) return NotImplemented + def skew_basis(self): + return self + def global_grid_radius(self, dist, scale): r = self.radial_COV.problem_coord(self._native_radius_grid(scale)) return reshape_vector(r, dim=dist.dim, axis=dist.get_basis_axis(self)+1) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index 9c5875a9..ef1205b8 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -39,6 +39,8 @@ 'Integrate', 'Average', 'Differentiate', + 'RieszDerivative', + 'HilbertTransform', 'Convert', 'TransposeComponents', 'RadialComponent', @@ -990,6 +992,7 @@ def operate(self, out): # Apply matrix if arg.data.size and out.data.size: data_axis = self.last_axis + len(arg.tensorsig) + # TODO: special case to vector multiply when group size is 1 apply_matrix(self.subspace_matrix(layout), arg.data, data_axis, out=out.data) else: out.data.fill(0) @@ -1196,7 +1199,7 @@ def __init__(self, operand, coord): SpectralOperator.__init__(self, operand) # Require integrand is a scalar if coord in operand.tensorsig: - raise ValueError("Can only integrate scalars.") + raise ValueError("Can only integrate scalar fields.") # SpectralOperator requirements self.coord = coord self.input_basis = operand.domain.get_basis(coord) @@ -1269,7 +1272,7 @@ def __init__(self, operand, coord): SpectralOperator.__init__(self, operand) # Require integrand is a scalar if coord in operand.tensorsig: - raise ValueError("Can only average scalars.") + raise ValueError("Can only average scalar fields.") # SpectralOperator requirements self.coord = coord self.input_basis = operand.domain.get_basis(coord) @@ -1341,6 +1344,7 @@ def __init__(self, operand, coord): # return arg +@alias("diff", "D") class Differentiate(SpectralOperator1D, metaclass=MultiClass): """ Differentiation along one dimension. @@ -1354,12 +1358,13 @@ class Differentiate(SpectralOperator1D, metaclass=MultiClass): name = "Diff" - def __init__(self, operand, coord, out=None): + def __init__(self, operand, coord, order=1, out=None): super().__init__(operand, out=out) + self.order = order # SpectralOperator requirements self.coord = coord self.input_basis = operand.domain.get_basis(coord) - self.output_basis = self._output_basis(self.input_basis) + self.output_basis = self._output_basis(self.input_basis, self.order) self.first_axis = self.dist.get_axis(coord) self.last_axis = self.first_axis self.axis = self.first_axis @@ -1371,7 +1376,7 @@ def __init__(self, operand, coord, out=None): self.dtype = operand.dtype @classmethod - def _check_args(cls, operand, coord, out=None): + def _check_args(cls, operand, coord, order=1, out=None): # Dispatch by operand basis if isinstance(operand, Operand): basis = operand.domain.get_basis(coord) @@ -1380,29 +1385,38 @@ def _check_args(cls, operand, coord, out=None): return False def new_operand(self, operand, **kw): - return Differentiate(operand, self.coord, **kw) + return Differentiate(operand, self.coord, self.order, **kw) + + def subspace_matrix(self, layout): + return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis, self.order) + + def group_matrix(self, group): + return self._group_matrix(group, self.input_basis, self.output_basis, self.order) @staticmethod - def _output_basis(input_basis): + def _output_basis(input_basis, order): # Subclasses must implement raise NotImplementedError() def __str__(self): - return 'd{!s}({!s})'.format(self.coord.name, self.operand) + if self.order == 1: + return 'd{!s}({!s})'.format(self.coord.name, self.operand) + else: + return 'd{!s}({!s},{!s})'.format(self.coord.name, self.operand, self.order) - def _expand_multiply(self, operand, vars): - """Expand over multiplication.""" - args = operand.args - # Apply product rule to factors - partial_diff = lambda i: prod([self.new_operand(arg) if i==j else arg for j,arg in enumerate(args)]) - return sum((partial_diff(i) for i in range(len(args)))) + # def _expand_multiply(self, operand, vars): + # """Expand over multiplication.""" + # args = operand.args + # # Apply product rule to factors + # partial_diff = lambda i: prod([self.new_operand(arg) if i==j else arg for j,arg in enumerate(args)]) + # return sum((partial_diff(i) for i in range(len(args)))) class DifferentiateConstant(Differentiate): """Constant differentiation.""" @classmethod - def _check_args(cls, operand, coord, out=None): + def _check_args(cls, operand, coord, order=1, out=None): # Dispatch for numbers of constant bases if isinstance(operand, Number): return True @@ -1411,25 +1425,11 @@ def _check_args(cls, operand, coord, out=None): return True return False - def __new__(cls, operand, coord, out=None): + def __new__(cls, operand, coord, order=1, out=None): return 0 -# @prefix('H') -# @parseable('hilbert_transform', 'hilbert', 'H') -# def hilbert_transform(arg, *spaces, **space_kw): -# # Parse space/order keywords into space list -# for space, order in space_kw.items(): -# spaces += (space,) * order -# # Identify domain -# domain = unify_attributes((arg,)+spaces, 'domain', require=False) -# # Apply iteratively -# for space in spaces: -# space = domain.get_space_object(space) -# arg = HilbertTransform(arg, space) -# return arg - - +@alias("hilbert", "H") class HilbertTransform(SpectralOperator1D, metaclass=MultiClass): """ Hilbert transform along one dimension. @@ -1441,36 +1441,122 @@ class HilbertTransform(SpectralOperator1D, metaclass=MultiClass): """ + name = "Hilbert" + + def __init__(self, operand, coord, out=None): + super().__init__(operand, out=out) + # SpectralOperator requirements + self.coord = coord + self.input_basis = operand.domain.get_basis(coord) + self.output_basis = self._output_basis(self.input_basis) + self.first_axis = self.dist.get_axis(coord) + self.last_axis = self.first_axis + self.axis = self.first_axis + # LinearOperator requirements + self.operand = operand + # FutureField requirements + self.domain = operand.domain.substitute_basis(self.input_basis, self.output_basis) + self.tensorsig = operand.tensorsig + self.dtype = operand.dtype + @classmethod - def _check_args(cls, operand, space, out=None): + def _check_args(cls, operand, coord, out=None): # Dispatch by operand basis if isinstance(operand, Operand): - if isinstance(operand.get_basis(space), cls.input_basis_type): + basis = operand.domain.get_basis(coord) + if isinstance(basis, cls.input_basis_type): return True return False - @property - def base(self): - return HilbertTransform + def new_operand(self, operand, **kw): + return HilbertTransform(operand, self.coord, **kw) + + @staticmethod + def _output_basis(input_basis): + # Subclasses must implement + raise NotImplementedError() + + def __str__(self): + return 'H{!s}({!s})'.format(self.coord.name, self.operand) class HilbertTransformConstant(HilbertTransform): """Constant Hilbert transform.""" @classmethod - def _check_args(cls, operand, space, out=None): + def _check_args(cls, operand, coord, out=None): # Dispatch for numbers of constant bases if isinstance(operand, Number): return True if isinstance(operand, Operand): - if operand.get_basis(space) is None: + if operand.domain.get_basis(coord) is None: return True return False - def __new__(cls, operand, space, out=None): + def __new__(cls, operand, coord, out=None): return 0 +@alias("riesz", "R") +class RieszDerivative(SpectralOperator1D, metaclass=MultiClass): + """ + Riesz derivative along one dimension. + + Parameters + ---------- + operand : number or Operand object + space : Space object + + Notes + ----- + R_a exp(1j*k*x) = - |k|**a * exp(1j*k*x). + + """ + + name = "Riesz" + + def __init__(self, operand, coord, order=1, out=None): + super().__init__(operand, out=out) + self.order = order + # SpectralOperator requirements + self.coord = coord + self.input_basis = operand.domain.get_basis(coord) + self.output_basis = self._output_basis(self.input_basis, order) + self.first_axis = self.dist.get_axis(coord) + self.last_axis = self.first_axis + self.axis = self.first_axis + # LinearOperator requirements + self.operand = operand + # FutureField requirements + self.domain = operand.domain.substitute_basis(self.input_basis, self.output_basis) + self.tensorsig = operand.tensorsig + self.dtype = operand.dtype + + @classmethod + def _check_args(cls, operand, coord, order=1, out=None): + # Dispatch by operand basis + if isinstance(operand, Operand): + basis = operand.domain.get_basis(coord) + if isinstance(basis, cls.input_basis_type): + return True + return False + + def new_operand(self, operand, **kw): + return RieszDerivative(operand, self.coord, self.order, **kw) + + def subspace_matrix(self, layout): + return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis, self.order) + + def group_matrix(self, group): + return self._group_matrix(group, self.input_basis, self.output_basis, self.order) + + def __str__(self): + if self.order == 1: + return 'R{!s}({!s})'.format(self.coord.name, self.operand) + else: + return 'R{!s}({!s},{!s})'.format(self.coord.name, self.operand, self.order) + + def convert(arg, bases): # Skip for numbers if isinstance(arg, Number): @@ -1621,10 +1707,6 @@ def replace(self, old, new): # # Simplify operand, skipping conversion # return self.operand.simplify(*vars) - def subspace_matrix(self, layout): - """Build matrix operating on global subspace data.""" - return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) - def operate(self, out): """Perform operation.""" arg = self.args[0] @@ -2073,10 +2155,22 @@ def __init__(self, operand, index=0, out=None): # LinearOperator requirements self.operand = operand # FutureField requirements - self.domain = operand.domain + bases = self._build_bases(operand, index) + self.domain = Domain(operand.dist, bases) self.tensorsig = operand.tensorsig self.dtype = operand.dtype + def _build_bases(self, operand, index): + input_bases = operand.domain.bases + output_bases = [] + coordsys = operand.tensorsig[index] + for basis in input_bases: + if (basis.coordsys is coordsys) or (basis.coord in coordsys.coords): + output_bases.append(basis.skew_basis()) + else: + output_bases.append(basis) + return tuple(output_bases) + def new_operand(self, operand, **kw): return Skew(operand, index=self.index, **kw) @@ -2143,7 +2237,7 @@ def subproblem_matrix(self, subproblem): factors[self.index] = 1j * factors[self.index] else: azimuth_index = len(self.tensorsig) + self.azimuth_axis - id_m = sparse.identity(shape[self.azimuth_axis]//2, format='csr') + id_m = sparse.identity(shape[azimuth_index]//2, format='csr') mul_1j = np.array([[0, -1], [1, 0]]) factors[azimuth_index] = sparse.kron(id_m, mul_1j) return reduce(sparse.kron, factors, 1).tocsr() @@ -2352,11 +2446,10 @@ def __init__(self, operand, coordsys, out=None): if args[i] == 0: args[i] = 2*operand args[i].args[0] = 0 - original_args = list(args[i].original_args) - original_args[0] = 0 - args[i].original_args = tuple(original_args) - bases = self._build_bases(*args) - args = [convert(arg, bases) for arg in args] + args[i].original_args = (0, args[i].original_args[1]) + # Convert along orthogonal bases + bases = self._build_bases(operand, coordsys) + args = self._convert_args(coordsys, args, bases) LinearOperator.__init__(self, *args, out=out) self.coordsys = coordsys # LinearOperator requirements @@ -2366,8 +2459,22 @@ def __init__(self, operand, coordsys, out=None): self.tensorsig = (coordsys,) + operand.tensorsig self.dtype = operand.dtype - def _build_bases(self, *args): - return sum(args).domain.bases + def _build_bases(self, operand, coordsys): + input_bases = operand.domain.bases + output_bases = [] + for basis in input_bases: + if basis.coord is coordsys or basis.coord in coordsys.coords: + output_bases.append(basis.gradient_basis()) + else: + output_bases.append(basis) + return tuple(output_bases) + + def _convert_args(self, coordsys, args, bases): + converted_args = [] + for coord, arg in zip(coordsys.coords, args): + arg_bases = [b for b in bases if b.coord != coord] + converted_args.append(convert(arg, arg_bases)) + return converted_args def matrix_dependence(self, *vars): arg_vals = [arg.matrix_dependence(self, *vars) for arg in self.args] @@ -3323,10 +3430,21 @@ def __init__(self, operand, index, comp, out=None): # LinearOperator requirements self.operand = operand # FutureField requirements - self.domain = operand.domain + bases = self._build_bases(operand, index, comp) + self.domain = Domain(operand.dist, bases) self.tensorsig = operand.tensorsig[:index] + operand.tensorsig[index+1:] self.dtype = operand.dtype + def _build_bases(self, operand, index, comp): + input_bases = operand.domain.bases + output_bases = [] + for basis in input_bases: + if basis.coord is comp: + output_bases.append(basis.component_basis()) + else: + output_bases.append(basis) + return tuple(output_bases) + def check_conditions(self): """Check that operands are in a proper layout.""" # Any layout @@ -4021,7 +4139,7 @@ def __init__(self, operand, coordsys, out=None): # Wrap to handle gradient wrt single coordinate if isinstance(coordsys, coords.Coordinate): coordsys = coords.CartesianCoordinates(coordsys.name) - parts = [Differentiate(Differentiate(operand, c), c) for c in coordsys.coords] + parts = [Differentiate(operand, c, order=2) for c in coordsys.coords] arg = sum(parts) LinearOperator.__init__(self, arg, out=out) self.coordsys = coordsys diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 00758fb2..8c1376b3 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -629,6 +629,177 @@ def backward(self, cdata, gdata, axis): plan.backward(temp, gdata) +class SineTransform(SeparableTransform): + """ + Abstract base class for sine transforms. + + Parameters + ---------- + grid_size : int + Grid size (N) along transform dimension. + coeff_size : int + Coefficient size (M) along transform dimension. + + Notes + ----- + Let KN = (N - 1) be the maximum fully resolved (non-Nyquist) mode on the grid. + Let KM = (M - 1) be the maximum retained mode in coeff space. + Then K = min(KN, KM) is the maximum wavenumber used in the transforms. + A unit-amplitude normalization is used. + + Grid: + x(n) = \pi (1/2 + n) / N, n = 0 .. (N-1) + + Forward transform: + if k == 0: + a(k) = 0 + elif k <= K: + a(k) = (2/N) \sum_{n=0}^{N-1} f(n) \sin(k x(n)) + else: + a(k) = 0 + + Backward transform: + f(n) = \sum_{k=1}^{K} a(k) \sin(k x(n)) + + Coefficient ordering: + The sine coefficients are ordered simply as + [0, a(1), a(2), ..., a(KM)] + """ + + def __init__(self, grid_size, coeff_size): + self.N = grid_size + self.M = coeff_size + self.KN = (self.N - 1) + self.KM = (self.M - 1) + self.Kmax = min(self.KN, self.KM) + + @property + def wavenumbers(self): + """One-dimensional global wavenumber array.""" + return np.arange(self.M) + + +@register_transform(basis.ParityBase, 'matrix-sin') +class SineMMT(SineTransform, SeparableMatrixTransform): + """Sine MMT.""" + + @CachedAttribute + def forward_matrix(self): + """Build forward transform matrix.""" + N = self.N + K = self.wavenumbers[:, None] + X = np.arange(N)[None, :] + 1/2 + dX = N / np.pi + quadrature = (2 / N) * np.sin(K*X/dX) + # Zero higher modes for transforms with grid_size < coeff_size + quadrature *= (K <= self.Kmax) + # Ensure C ordering for fast dot products + return np.asarray(quadrature, order='C') + + @CachedAttribute + def backward_matrix(self): + """Build backward transform matrix.""" + N = self.N + K = self.wavenumbers[None, :] + X = np.arange(N)[:, None] + 1/2 + dX = N / np.pi + functions = np.sin(K*X/dX) + # Zero higher modes for transforms with grid_size < coeff_size + functions *= (K <= self.Kmax) + # Ensure C ordering for fast dot products + return np.asarray(functions, order='C') + + +class FastSineTransform(SineTransform): + """Abstract base class for fast sine transforms.""" + + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + # Standard scaling factors for unit-amplitude normalization from DST-II + self.forward_rescale = 1 / self.N + self.backward_rescale = 1 / 2 + + def resize_rescale_forward(self, data_in, data_out, axis, Kmax): + """Resize by padding/trunction and rescale to unit amplitude.""" + zerofreq = axslice(axis, 0, 1) + data_out[zerofreq] = 0 + if Kmax > 0: + # Shift to account for zero frequency in output + posfreq_in = axslice(axis, 0, Kmax) + posfreq_out = axslice(axis, 1, Kmax+1) + np.multiply(data_in[posfreq_in], self.forward_rescale, data_out[posfreq_out]) + if self.KM > Kmax: + badfreq = axslice(axis, Kmax+1, None) + data_out[badfreq] = 0 + + def resize_rescale_backward(self, data_in, data_out, axis, Kmax): + """Resize by padding/trunction and rescale to unit amplitude.""" + if Kmax == 0: + zerofreq = axslice(axis, 0, 1) + data_out[zerofreq] = 0 + else: + # Shift to account for zero frequency in input + posfreq_in = axslice(axis, 1, Kmax+1) + posfreq_out = axslice(axis, 0, Kmax) + np.multiply(data_in[posfreq_in], self.backward_rescale, data_out[posfreq_out]) + if self.KN >= Kmax: + badfreq = axslice(axis, Kmax, None) + data_out[badfreq] = 0 + + +@register_transform(basis.ParityBase, 'scipy-sin') +class ScipyDST(FastSineTransform): + """Fast sine transform using scipy.fft.""" + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call DST + # Avoid overwrite_x to prevent overwriting problems + temp = scipy.fft.dst(gdata, type=2, axis=axis) # Creates temporary + # Resize and rescale for unit-ampltidue normalization + self.resize_rescale_forward(temp, cdata, axis, self.Kmax) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = np.empty_like(gdata) # Creates temporary + self.resize_rescale_backward(cdata, temp, axis, self.Kmax) + # Call IDST + temp = scipy.fft.dst(temp, type=3, axis=axis, overwrite_x=True) + np.copyto(gdata, temp) + + +@register_transform(basis.ParityBase, 'fftw-sin') +class FFTWDST(FFTWBase, FastSineTransform): + """Fast sine transform using FFTW.""" + + @CachedMethod + def _build_fftw_plan(self, dtype, gshape, axis): + """Build FFTW plans and temporary arrays.""" + logger.debug("Building FFTW DST plan for (dtype, gshape, axis) = (%s, %s, %s)" %(dtype, gshape, axis)) + flags = ['FFTW_'+self.rigor.upper()] + plan = fftw.DiscreteSineTransform(dtype, gshape, axis, flags=flags) + temp = fftw.create_array(gshape, dtype) + return plan, temp + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + plan, temp = self._build_fftw_plan(gdata.dtype, gdata.shape, axis) + # Execute FFTW plan + plan.forward(gdata, temp) + # Resize and rescale for unit-ampltidue normalization + self.resize_rescale_forward(temp, cdata, axis, self.Kmax) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + plan, temp = self._build_fftw_plan(gdata.dtype, gdata.shape, axis) + # Resize and rescale for unit-amplitude normalization + self.resize_rescale_backward(cdata, temp, axis, self.Kmax) + # Execute FFTW plan + plan.backward(temp, gdata) + + class CosineTransform(SeparableTransform): """ Abstract base class for cosine transforms. @@ -642,21 +813,24 @@ class CosineTransform(SeparableTransform): Notes ----- - Let KN = (N - 1) be the maximum (Nyquist) mode on the grid. + Let KN = (N - 1) be the maximum fully resolved (non-Nyquist) mode on the grid. Let KM = (M - 1) be the maximum retained mode in coeff space. Then K = min(KN, KM) is the maximum wavenumber used in the transforms. A unit-amplitude normalization is used. + Grid: + x(n) = \pi (1/2 + n) / N, n = 0 .. (N-1) + Forward transform: if k == 0: - a(k) = (1/N) \sum_{x=0}^{N-1} f(x) + a(k) = (1/N) \sum_{n=0}^{N-1} f(n) elif k <= K: - a(k) = (2/N) \sum_{x=0}^{N-1} f(x) \cos(\pi k x / N) + a(k) = (2/N) \sum_{n=0}^{N-1} f(n) \cos(k x(n)) else: a(k) = 0 Backward transform: - f(x) = \sum_{k=0}^{K} a(k) \cos(\pi k x / N) + f(n) = \sum_{k=0}^{K} a(k) \cos(k x(n)) Coefficient ordering: The cosine coefficients are ordered simply as @@ -673,10 +847,10 @@ def __init__(self, grid_size, coeff_size): @property def wavenumbers(self): """One-dimensional global wavenumber array.""" - return np.arange(self.KM + 1) + return np.arange(self.M) -#@register_transform(basis.Cosine, 'matrix') +@register_transform(basis.ParityBase, 'matrix-cos') class CosineMMT(CosineTransform, SeparableMatrixTransform): """Cosine MMT.""" @@ -684,10 +858,8 @@ class CosineMMT(CosineTransform, SeparableMatrixTransform): def forward_matrix(self): """Build forward transform matrix.""" N = self.N - M = self.M - Kmax = self.Kmax K = self.wavenumbers[:, None] - X = np.arange(N)[None, :] + X = np.arange(N)[None, :] + 1/2 dX = N / np.pi quadrature = (2 / N) * np.cos(K*X/dX) quadrature[0] = 1 / N @@ -700,8 +872,6 @@ def forward_matrix(self): def backward_matrix(self): """Build backward transform matrix.""" N = self.N - M = self.M - Kmax = self.Kmax K = self.wavenumbers[None, :] X = np.arange(N)[:, None] + 1/2 dX = N / np.pi @@ -730,9 +900,9 @@ def resize_rescale_forward(self, data_in, data_out, axis, Kmax): if Kmax > 0: posfreq = axslice(axis, 1, Kmax+1) np.multiply(data_in[posfreq], self.forward_rescale_pos, data_out[posfreq]) - if self.KM > Kmax: - badfreq = axslice(axis, Kmax+1, None) - data_out[badfreq] = 0 + if self.KM > Kmax: + badfreq = axslice(axis, Kmax+1, None) + data_out[badfreq] = 0 def resize_rescale_backward(self, data_in, data_out, axis, Kmax): """Resize by padding/trunction and rescale to unit amplitude.""" @@ -741,20 +911,21 @@ def resize_rescale_backward(self, data_in, data_out, axis, Kmax): if Kmax > 0: posfreq = axslice(axis, 1, Kmax+1) np.multiply(data_in[posfreq], self.backward_rescale_pos, data_out[posfreq]) - if self.KN > Kmax: - badfreq = axslice(axis, Kmax+1, None) - data_out[badfreq] = 0 + if self.KN > Kmax: + badfreq = axslice(axis, Kmax+1, None) + data_out[badfreq] = 0 -#@register_transform(basis.Cosine, 'scipy') +@register_transform(basis.ParityBase, 'scipy-cos') class ScipyDCT(FastCosineTransform): """Fast cosine transform using scipy.fft.""" def forward(self, gdata, cdata, axis): """Apply forward transform along specified axis.""" # Call DCT + # Avoid overwrite_x to prevent overwriting problems temp = scipy.fft.dct(gdata, type=2, axis=axis) # Creates temporary - # Resize and rescale for unit-ampltidue normalization + # Resize and rescale for unit-amplitude normalization self.resize_rescale_forward(temp, cdata, axis, self.Kmax) def backward(self, cdata, gdata, axis): @@ -764,11 +935,11 @@ def backward(self, cdata, gdata, axis): temp = np.empty_like(gdata) # Creates temporary self.resize_rescale_backward(cdata, temp, axis, self.Kmax) # Call IDCT - temp = scipy.fft.dct(temp, type=3, axis=axis, overwrite_x=True) # Creates temporary + temp = scipy.fft.dct(temp, type=3, axis=axis, overwrite_x=True) np.copyto(gdata, temp) -#@register_transform(basis.Cosine, 'fftw') +@register_transform(basis.ParityBase, 'fftw-cos') class FFTWDCT(FFTWBase, FastCosineTransform): """Fast cosine transform using FFTW.""" diff --git a/dedalus/tests/test_fourier_operators.py b/dedalus/tests/test_fourier_operators.py index 018861e3..dd596276 100644 --- a/dedalus/tests/test_fourier_operators.py +++ b/dedalus/tests/test_fourier_operators.py @@ -21,6 +21,15 @@ def build_fourier(N, bounds, dealias, dtype): return c, d, b, x +@CachedMethod +def build_parity(N, bounds, dealias, parity, dtype): + c = d3.Coordinate('x') + d = d3.Distributor(c, dtype=dtype) + b = d3.Parity(c, size=N, bounds=bounds, dealias=dealias, parity=parity) + x = d.local_grid(b, scale=1) + return c, d, b, x + + @pytest.mark.parametrize('N', N_range) @pytest.mark.parametrize('bounds', bounds_range) @pytest.mark.parametrize('dealias', dealias_range) @@ -40,14 +49,145 @@ def test_fourier_convert_constant(N, bounds, dealias, dtype, layout): @pytest.mark.parametrize('bounds', bounds_range) @pytest.mark.parametrize('dealias', dealias_range) @pytest.mark.parametrize('dtype', dtype_range) -def test_fourier_differentiate(N, bounds, dealias, dtype): +@pytest.mark.parametrize('layout', ['g', 'c']) +def test_parity_convert_constant(N, bounds, dealias, dtype, layout): + """Test conversion from constant to EvenParity basis.""" + c, d, b, x = build_parity(N, bounds, dealias, 1, dtype) + f = d.Field() + f['g'] = 1 + f.change_layout(layout) + g = d3.Convert(f, b).evaluate() + assert np.allclose(g['g'], f['g']) + + +@pytest.mark.parametrize('N', N_range) +@pytest.mark.parametrize('bounds', bounds_range) +@pytest.mark.parametrize('dealias', dealias_range) +@pytest.mark.parametrize('dtype', dtype_range) +@pytest.mark.parametrize('order', [1, 2, 3]) +def test_fourier_differentiate(N, bounds, dealias, dtype, order): """Test differentiation in Fourier basis.""" c, d, b, x = build_fourier(N, bounds, dealias, dtype) f = d.Field(bases=b) - k = 4 * np.pi / (bounds[1] - bounds[0]) + L = bounds[1] - bounds[0] + k = 4 * np.pi / L + f['g'] = 1 + np.sin(k*x+0.1) + g = d3.Differentiate(f, c, order=order).evaluate() + if order == 1: + assert np.allclose(g['g'], k*np.cos(k*x+0.1)) + elif order == 2: + assert np.allclose(g['g'], -k**2*np.sin(k*x+0.1)) + elif order == 3: + assert np.allclose(g['g'], -k**3*np.cos(k*x+0.1)) + + +@pytest.mark.parametrize('N', N_range) +@pytest.mark.parametrize('bounds', bounds_range) +@pytest.mark.parametrize('dealias', dealias_range) +@pytest.mark.parametrize('dtype', dtype_range) +@pytest.mark.parametrize('parity', [1, -1]) +@pytest.mark.parametrize('order', [1, 2, 3]) +def test_parity_differentiate(N, bounds, dealias, dtype, parity, order): + """Test differentiation in Parity bases.""" + c, d, b, x = build_parity(N, bounds, dealias, parity, dtype) + f = d.Field(bases=b) + x0 = bounds[0] + L = bounds[1] - bounds[0] + k = 3 * np.pi / L + if parity == 1: + f['g'] = 1 + np.cos(k*(x-x0)) + g = d3.Differentiate(f, c, order=order).evaluate() + if order == 1: + assert np.allclose(g['g'], -k*np.sin(k*(x-x0))) + elif order == 2: + assert np.allclose(g['g'], -k**2*np.cos(k*(x-x0))) + elif order == 3: + assert np.allclose(g['g'], k**3*np.sin(k*(x-x0))) + elif parity == -1: + f['g'] = np.sin(k*(x-x0)) + g = d3.Differentiate(f, c, order=order).evaluate() + if order == 1: + assert np.allclose(g['g'], k*np.cos(k*(x-x0))) + elif order == 2: + assert np.allclose(g['g'], -k**2*np.sin(k*(x-x0))) + elif order == 3: + assert np.allclose(g['g'], -k**3*np.cos(k*(x-x0))) + + +@pytest.mark.parametrize('N', N_range) +@pytest.mark.parametrize('bounds', bounds_range) +@pytest.mark.parametrize('dealias', dealias_range) +@pytest.mark.parametrize('dtype', dtype_range) +@pytest.mark.parametrize('order', [1, 1.5, 2]) +def test_fourier_riesz(N, bounds, dealias, dtype, order): + """Test Riesz derivative in Fourier basis.""" + c, d, b, x = build_fourier(N, bounds, dealias, dtype) + f = d.Field(bases=b) + L = bounds[1] - bounds[0] + k = 4 * np.pi / L + f['g'] = 1 + np.sin(k*x+0.1) + g = d3.RieszDerivative(f, c, order=order).evaluate() + assert np.allclose(g['g'], -abs(k)**order*np.sin(k*x+0.1)) + + +@pytest.mark.parametrize('N', N_range) +@pytest.mark.parametrize('bounds', bounds_range) +@pytest.mark.parametrize('dealias', dealias_range) +@pytest.mark.parametrize('dtype', dtype_range) +@pytest.mark.parametrize('parity', [1, -1]) +@pytest.mark.parametrize('order', [1, 1.5, 2]) +def test_parity_riesz(N, bounds, dealias, dtype, parity, order): + """Test Riesz derivative in Fourier basis.""" + c, d, b, x = build_parity(N, bounds, dealias, parity, dtype) + f = d.Field(bases=b) + x0 = bounds[0] + L = bounds[1] - bounds[0] + k = 4 * np.pi / L + if parity == 1: + f['g'] = 1 + np.cos(k*(x-x0)) + g = d3.RieszDerivative(f, c, order=order).evaluate() + assert np.allclose(g['g'], -abs(k)**order*np.cos(k*(x-x0))) + elif parity == -1: + f['g'] = np.sin(k*(x-x0)) + g = d3.RieszDerivative(f, c, order=order).evaluate() + assert np.allclose(g['g'], -abs(k)**order*np.sin(k*(x-x0))) + + +@pytest.mark.parametrize('N', N_range) +@pytest.mark.parametrize('bounds', bounds_range) +@pytest.mark.parametrize('dealias', dealias_range) +@pytest.mark.parametrize('dtype', dtype_range) +def test_fourier_hilbert(N, bounds, dealias, dtype): + """Test Hilbert transform in Fourier basis.""" + c, d, b, x = build_fourier(N, bounds, dealias, dtype) + f = d.Field(bases=b) + L = bounds[1] - bounds[0] + k = 4 * np.pi / L f['g'] = 1 + np.sin(k*x+0.1) - g = d3.Differentiate(f, c).evaluate() - assert np.allclose(g['g'], k*np.cos(k*x+0.1)) + g = d3.HilbertTransform(f, c).evaluate() + assert np.allclose(g['g'], -np.cos(k*x+0.1)) + + +@pytest.mark.parametrize('N', N_range) +@pytest.mark.parametrize('bounds', bounds_range) +@pytest.mark.parametrize('dealias', dealias_range) +@pytest.mark.parametrize('dtype', dtype_range) +@pytest.mark.parametrize('parity', [1, -1]) +def test_parity_hilbert(N, bounds, dealias, dtype, parity): + """Test Hilbert transform in Parity bases.""" + c, d, b, x = build_parity(N, bounds, dealias, parity, dtype) + f = d.Field(bases=b) + x0 = bounds[0] + L = bounds[1] - bounds[0] + k = 4 * np.pi / L + if parity == 1: + f['g'] = 1 + np.cos(k*(x-x0)) + g = d3.HilbertTransform(f, c).evaluate() + assert np.allclose(g['g'], np.sin(k*(x-x0))) + elif parity == -1: + f['g'] = np.sin(k*(x-x0)) + g = d3.HilbertTransform(f, c).evaluate() + assert np.allclose(g['g'], -np.cos(k*(x-x0))) @pytest.mark.parametrize('N', N_range) @@ -58,12 +198,38 @@ def test_fourier_interpolate(N, bounds, dealias, dtype): """Test interpolation in Fourier basis.""" c, d, b, x = build_fourier(N, bounds, dealias, dtype) f = d.Field(bases=b) - k = 4 * np.pi / (bounds[1] - bounds[0]) - f['g'] = 1 + np.sin(k*x+0.1) + L = bounds[1] - bounds[0] + k = 4 * np.pi / L + f0 = lambda x: 1 + np.sin(k*x+0.1) + f['g'] = f0(x) + results = [] + for p in [bounds[0], bounds[1], bounds[0] + L*np.random.rand()]: + g = d3.Interpolate(f, c, p).evaluate() + results.append(np.allclose(g['g'], f0(p))) + assert all(results) + + +@pytest.mark.parametrize('N', N_range) +@pytest.mark.parametrize('bounds', bounds_range) +@pytest.mark.parametrize('dealias', dealias_range) +@pytest.mark.parametrize('dtype', dtype_range) +@pytest.mark.parametrize('parity', [1, -1]) +def test_parity_interpolate(N, bounds, dealias, dtype, parity): + """Test interpolation in Parity bases.""" + c, d, b, x = build_parity(N, bounds, dealias, parity, dtype) + f = d.Field(bases=b) + x0 = bounds[0] + L = bounds[1] - bounds[0] + k = 3 * np.pi / L + if parity == 1: + f0 = lambda x: 1 + np.cos(k*(x-x0)) + elif parity == -1: + f0 = lambda x: np.sin(k*(x-x0)) + f['g'] = f0(x) results = [] for p in [bounds[0], bounds[1], bounds[0] + (bounds[1] - bounds[0]) * np.random.rand()]: g = d3.Interpolate(f, c, p).evaluate() - results.append(np.allclose(g['g'], 1 + np.sin(k*p+0.1))) + results.append(np.allclose(g['g'], f0(p))) assert all(results) @@ -75,10 +241,33 @@ def test_fourier_integrate(N, bounds, dealias, dtype): """Test integration in Fourier basis.""" c, d, b, x = build_fourier(N, bounds, dealias, dtype) f = d.Field(bases=b) - k = 4 * np.pi / (bounds[1] - bounds[0]) + L = bounds[1] - bounds[0] + k = 4 * np.pi / L f['g'] = 1 + np.sin(k*x+0.1) g = d3.Integrate(f, c).evaluate() - assert np.allclose(g['g'], bounds[1] - bounds[0]) + assert np.allclose(g['g'], L) + + +@pytest.mark.parametrize('N', N_range) +@pytest.mark.parametrize('bounds', bounds_range) +@pytest.mark.parametrize('dealias', dealias_range) +@pytest.mark.parametrize('dtype', dtype_range) +@pytest.mark.parametrize('parity', [1, -1]) +def test_parity_integrate(N, bounds, dealias, dtype, parity): + """Test integration in Parity bases.""" + c, d, b, x = build_parity(N, bounds, dealias, parity, dtype) + f = d.Field(bases=b) + x0 = bounds[0] + L = bounds[1] - bounds[0] + k = 3 * np.pi / L + if parity == 1: + f['g'] = 1 + np.cos(k*(x-x0)) + g = d3.Integrate(f, c).evaluate() + assert np.allclose(g['g'], L) + elif parity == -1: + f['g'] = np.sin(k*(x-x0)) + g = d3.Integrate(f, c).evaluate() + assert np.allclose(g['g'], L*2/3/np.pi) @pytest.mark.parametrize('N', N_range) @@ -89,8 +278,30 @@ def test_fourier_average(N, bounds, dealias, dtype): """Test averaging in Fourier basis.""" c, d, b, x = build_fourier(N, bounds, dealias, dtype) f = d.Field(bases=b) - k = 4 * np.pi / (bounds[1] - bounds[0]) + L = bounds[1] - bounds[0] + k = 4 * np.pi / L f['g'] = 1 + np.sin(k*x+0.1) g = d3.Average(f, c).evaluate() assert np.allclose(g['g'], 1) + +@pytest.mark.parametrize('N', N_range) +@pytest.mark.parametrize('bounds', bounds_range) +@pytest.mark.parametrize('dealias', dealias_range) +@pytest.mark.parametrize('dtype', dtype_range) +@pytest.mark.parametrize('parity', [1, -1]) +def test_parity_average(N, bounds, dealias, dtype, parity): + """Test averaging in Parity bases.""" + c, d, b, x = build_parity(N, bounds, dealias, parity, dtype) + f = d.Field(bases=b) + x0 = bounds[0] + L = bounds[1] - bounds[0] + k = 3 * np.pi / L + if parity == 1: + f['g'] = 1 + np.cos(k*(x-x0)) + g = d3.Average(f, c).evaluate() + assert np.allclose(g['g'], 1) + elif parity == -1: + f['g'] = np.sin(k*(x-x0)) + g = d3.Average(f, c).evaluate() + assert np.allclose(g['g'], 2/3/np.pi) diff --git a/dedalus/tests/test_transforms.py b/dedalus/tests/test_transforms.py index c576ae95..df478109 100644 --- a/dedalus/tests/test_transforms.py +++ b/dedalus/tests/test_transforms.py @@ -57,6 +57,50 @@ def test_real_fourier_libraries_forward(N, dealias, dtype, library): assert np.allclose(u_mat['c'], u_lib['c']) +@pytest.mark.parametrize('N', [16]) +@pytest.mark.parametrize('parity', [1, -1]) +@pytest.mark.parametrize('dealias', [0.5, 1, 1.5]) +@pytest.mark.parametrize('dtype', [np.float64, np.complex128]) +@pytest.mark.parametrize('library', ['scipy', 'fftw']) +def test_parity_libraries_backward(N, parity, dealias, dtype, library): + """Tests that fast real Fourier transforms match matrix transforms.""" + c = coords.Coordinate('x') + d = distributor.Distributor([c]) + # Matrix + b_mat = basis.Parity(c, size=N, bounds=(0, 2*np.pi), parity=parity, dealias=dealias, library='matrix') + u_mat = field.Field(dist=d, bases=(b_mat,), dtype=dtype) + u_mat.preset_scales(dealias) + u_mat['c'] = np.random.randn(N) + # Library + b_lib = basis.Parity(c, size=N, bounds=(0, 2*np.pi), parity=parity, dealias=dealias, library=library) + u_lib = field.Field(dist=d, bases=(b_lib,), dtype=dtype) + u_lib.preset_scales(dealias) + u_lib['c'] = u_mat['c'] + assert np.allclose(u_mat['g'], u_lib['g']) + + +@pytest.mark.parametrize('N', [16]) +@pytest.mark.parametrize('parity', [1, -1]) +@pytest.mark.parametrize('dealias', [0.5, 1, 1.5]) +@pytest.mark.parametrize('dtype', [np.float64, np.complex128]) +@pytest.mark.parametrize('library', ['scipy', 'fftw']) +def test_parity_libraries_forward(N, parity, dealias, dtype, library): + """Tests that fast real Fourier transforms match matrix transforms.""" + c = coords.Coordinate('x') + d = distributor.Distributor([c]) + # Matrix + b_mat = basis.Parity(c, size=N, bounds=(0, 2*np.pi), parity=parity, dealias=dealias, library='matrix') + u_mat = field.Field(dist=d, bases=(b_mat,), dtype=dtype) + u_mat.preset_scales(dealias) + u_mat['g'] = np.random.randn(int(np.ceil(dealias * N))) + # Library + b_lib = basis.Parity(c, size=N, bounds=(0, 2*np.pi), parity=parity, dealias=dealias, library=library) + u_lib = field.Field(dist=d, bases=(b_lib,), dtype=dtype) + u_lib.preset_scales(dealias) + u_lib['g'] = u_mat['g'] + assert np.allclose(u_mat['c'], u_lib['c']) + + @pytest.mark.parametrize('N', N_range) @pytest.mark.parametrize('dealias', dealias_range) def test_CF_scalar_roundtrip(N, dealias): diff --git a/examples/ivp_2d_stress_free_rayleigh_benard/plot_snapshots.py b/examples/ivp_2d_stress_free_rayleigh_benard/plot_snapshots.py new file mode 100644 index 00000000..8fdca5da --- /dev/null +++ b/examples/ivp_2d_stress_free_rayleigh_benard/plot_snapshots.py @@ -0,0 +1,79 @@ +""" +Plot 2D cartesian snapshots. + +Usage: + plot_snapshots.py ... [--output=] + +Options: + --output= Output directory [default: ./frames] + +""" + +import h5py +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from dedalus.extras import plot_tools + + +def main(filename, start, count, output): + """Save plot of specified tasks for given range of analysis writes.""" + + # Plot settings + tasks = ['buoyancy', 'vorticity', 'ux', 'uz'] + scale = 1.5 + dpi = 200 + title_func = lambda sim_time: 't = {:.3f}'.format(sim_time) + savename_func = lambda write: 'write_{:06}.png'.format(write) + + # Layout + nrows, ncols = 4, 1 + image = plot_tools.Box(4, 1) + pad = plot_tools.Frame(0.3, 0, 0, 0) + margin = plot_tools.Frame(0.2, 0.1, 0, 0) + + # Create multifigure + mfig = plot_tools.MultiFigure(nrows, ncols, image, pad, margin, scale) + fig = mfig.figure + + # Plot writes + with h5py.File(filename, mode='r') as file: + for index in range(start, start+count): + for n, task in enumerate(tasks): + # Build subfigure axes + i, j = divmod(n, ncols) + axes = mfig.add_axes(i, j, [0, 0, 1, 1]) + # Call 3D plotting helper, slicing in time + dset = file['tasks'][task] + plot_tools.plot_bot_3d(dset, 0, index, axes=axes, title=task, even_scale=True, visible_axes=False) + # Add time title + title = title_func(file['scales/sim_time'][index]) + title_height = 1 - 0.5 * mfig.margin.top / mfig.fig.y + fig.suptitle(title, x=0.44, y=title_height, ha='left') + # Save figure + savename = savename_func(file['scales/write_number'][index]) + savepath = output.joinpath(savename) + fig.savefig(str(savepath), dpi=dpi) + fig.clear() + plt.close(fig) + + +if __name__ == "__main__": + + import pathlib + from docopt import docopt + from dedalus.tools import logging + from dedalus.tools import post + from dedalus.tools.parallel import Sync + + args = docopt(__doc__) + + output_path = pathlib.Path(args['--output']).absolute() + # Create output directory if needed + with Sync() as sync: + if sync.comm.rank == 0: + if not output_path.exists(): + output_path.mkdir() + post.visit_writes(args[''], main, output=output_path) + diff --git a/examples/ivp_2d_stress_free_rayleigh_benard/stress_free_rbc.py b/examples/ivp_2d_stress_free_rayleigh_benard/stress_free_rbc.py new file mode 100644 index 00000000..786254ed --- /dev/null +++ b/examples/ivp_2d_stress_free_rayleigh_benard/stress_free_rbc.py @@ -0,0 +1,114 @@ +""" +Dedalus script simulating 2D horizontally-periodic Rayleigh-Benard +convection with stress-free boundary conditions using Sine/Cosine +bases in z. This script demonstrates solving a 2D Cartesian initial +value problem. It can be ran serially or in parallel, and uses the +built-in analysis framework to save data snapshots to HDF5 files. The +`plot_snapshots.py` script can be used to produce plots from the saved +data. It should take about 5 cpu-minutes to run. + +The problem is non-dimensionalized using the box height and freefall time, so +the resulting thermal diffusivity and viscosity are related to the Prandtl +and Rayleigh numbers as: + + kappa = (Rayleigh * Prandtl)**(-1/2) + nu = (Rayleigh / Prandtl)**(-1/2) + +Note that unlike the no-slip Cheybshev example, here we solve for the +buoyancy *perturbation* instead of the total buoyancy field. + +To run and plot using e.g. 4 processes: + $ mpiexec -n 4 python3 rayleigh_benard.py + $ mpiexec -n 4 python3 plot_snapshots.py snapshots/*.h5 + +""" + +import numpy as np +import dedalus.public as d3 +import logging +logger = logging.getLogger(__name__) + + +# Parameters +Lx, Lz = 4, 1 +Nx, Nz = 256, 64 +Rayleigh = 1e5 +Prandtl = 1 +dealias = 3/2 +stop_sim_time = 40 +timestepper = d3.RK222 +timestep = 1e-2 +max_timestep = 0.125 +dtype = np.float64 + +# Bases +# cx = d3.Coordinate('x') +# cz = d3.Coordinate('z') +# coords = d3.DirectProduct(cx, cz) +coords = d3.CartesianCoordinates('x', 'z') +cx, cz = coords.coords +dist = d3.Distributor(coords, dtype=dtype) +xbasis = d3.RealFourier(cx, size=Nx, bounds=(0, Lx), dealias=dealias) +zobasis = d3.OddParity(cz, size=Nz, bounds=(0, Lz), dealias=dealias) +zebasis = d3.EvenParity(cz, size=Nz, bounds=(0, Lz), dealias=dealias) + +# Fields +p = dist.Field(name='p', bases=(xbasis,zebasis)) +b = dist.Field(name='b', bases=(xbasis,zobasis)) +u = dist.VectorField(coords, name='u', bases=(xbasis,zebasis)) +tau_p = dist.Field(name='tau_p') + +# Substitutions +kappa = (Rayleigh * Prandtl)**(-1/2) +nu = (Rayleigh / Prandtl)**(-1/2) +x, z = dist.local_grids(xbasis, zebasis) +ex = dist.VectorField(coords, bases=zebasis.clone_with(size=1)) +ex['g'][0] = 1 +ez = dist.VectorField(coords, bases=zobasis.clone_with(size=1)) +ez['g'][1] = 1 + +# Problem +problem = d3.IVP([p, b, u, tau_p], namespace=locals()) +problem.add_equation("div(u) + tau_p = 0") +problem.add_equation("dt(b) - kappa*lap(b) = - u@grad(b) + ez@u") +problem.add_equation("dt(u) - nu*lap(u) + grad(p) = - u@grad(u) + b*ez") +problem.add_equation("integ(p) = 0") # Pressure gauge + +# Solver +solver = problem.build_solver(timestepper) +solver.stop_sim_time = stop_sim_time + +# Initial conditions +b.fill_random('g', seed=42, distribution='normal', scale=1e-3) # Random noise +b['g'] *= z * (Lz - z) # Damp noise at walls + +# Analysis +snapshots = solver.evaluator.add_file_handler('snapshots', sim_dt=0.25, max_writes=50) +snapshots.add_task(b, name='buoyancy') +snapshots.add_task(ex@u, name='ux') +snapshots.add_task(ez@u, name='uz') +snapshots.add_task(-d3.div(d3.skew(u)), name='vorticity') + +# CFL +#CFL = d3.CFL(solver, initial_dt=max_timestep, cadence=10, safety=0.5, threshold=0.05, +# max_change=1.5, min_change=0.5, max_dt=max_timestep) +#CFL.add_velocity(u) + +# Flow properties +flow = d3.GlobalFlowProperty(solver, cadence=10) +flow.add_property(np.sqrt(u@u)/nu, name='Re') + +# Main loop +try: + logger.info('Starting main loop') + while solver.proceed: + #timestep = CFL.compute_timestep() + solver.step(timestep) + if (solver.iteration-1) % 10 == 0: + max_Re = flow.max('Re') + logger.info('Iteration=%i, Time=%e, dt=%e, max(Re)=%f' %(solver.iteration, solver.sim_time, timestep, max_Re)) +except: + logger.error('Exception raised, triggering end of main loop.') + raise +finally: + solver.log_stats()