Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace
from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key
from firedrake.interpolation import get_interpolator
from firedrake.petsc import PETSc
from firedrake.slate import slac, slate
from firedrake.slate.slac.kernel_builder import CellFacetKernelArg, LayerCountKernelArg
Expand Down Expand Up @@ -613,17 +614,8 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
rank = len(expr.arguments())
if rank > 2:
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
# Get the target space
V = v.function_space().dual()

# Get the interpolator
interp_data = expr.interp_data.copy()
default_missing_val = interp_data.pop('default_missing_val', None)
if rank == 1 and isinstance(tensor, firedrake.Function):
V = tensor
interpolator = firedrake.Interpolator(expr, V, bcs=bcs, **interp_data)
# Assembly
return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val)
interpolator = get_interpolator(expr)
return interpolator.assemble(tensor=tensor, bcs=bcs)
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
return tensor.assign(expr)
elif tensor and isinstance(expr, ufl.ZeroBaseForm):
Expand Down
14 changes: 7 additions & 7 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# A module implementing strong (Dirichlet) boundary conditions.
import numpy as np

import functools
from functools import partial, reduce
import itertools

import ufl
Expand Down Expand Up @@ -167,7 +167,7 @@ def hermite_stride(bcnodes):
# Edge conditions have only been tested with Lagrange elements.
# Need to expand the list.
bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss)))
bcnodes1 = functools.reduce(np.intersect1d, bcnodes1)
bcnodes1 = reduce(np.intersect1d, bcnodes1)
bcnodes.append(bcnodes1)
return np.concatenate(bcnodes)

Expand Down Expand Up @@ -359,11 +359,11 @@ def function_arg(self, g):
raise RuntimeError(f"Provided boundary value {g} does not match shape of space")
try:
self._function_arg = firedrake.Function(V)
# Use `Interpolator` instead of assembling an `Interpolate` form
# as the expression compilation needs to happen at this stage to
# determine if we should use interpolation or projection
# -> e.g. interpolation may not be supported for the element.
self._function_arg_update = firedrake.Interpolator(g, self._function_arg)._interpolate
interpolator = firedrake.get_interpolator(firedrake.interpolate(g, V))
# Call this here to check if the element supports interpolation
# TODO: It's probably better to have a more explicit way of checking this
interpolator._get_callable()
self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg)
except (NotImplementedError, AttributeError):
# Element doesn't implement interpolation
self._function_arg = firedrake.Function(V).project(g)
Expand Down
1,957 changes: 898 additions & 1,059 deletions firedrake/interpolation.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions firedrake/preconditioners/hiptmair.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from firedrake.preconditioners.hypre_ams import chop
from firedrake.preconditioners.facet_split import restrict
from firedrake.parameters import parameters
from firedrake.interpolation import Interpolator
from firedrake.interpolation import interpolate
from ufl.algorithms.ad import expand_derivatives
import firedrake.dmhooks as dmhooks
import firedrake.utils as utils
Expand Down Expand Up @@ -202,7 +202,7 @@ def coarsen(self, pc):

coarse_space_bcs = tuple(coarse_space_bcs)
if G_callback is None:
interp_petscmat = chop(Interpolator(dminus(trial), V, bcs=bcs + coarse_space_bcs).callable().handle)
interp_petscmat = chop(assemble(interpolate(dminus(trial), V), bcs=bcs + coarse_space_bcs).mat())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not assemble(...).petscmat?

else:
interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs)

Expand Down
6 changes: 3 additions & 3 deletions firedrake/preconditioners/pmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,14 +1248,14 @@ def _kernels(self):
return self._build_custom_interpolators()

def _build_native_interpolators(self):
from firedrake.interpolation import interpolate, Interpolator
P = Interpolator(interpolate(self.uc, self.Vf), self.Vf)
from firedrake.interpolation import interpolate, get_interpolator
P = get_interpolator(interpolate(self.uc, self.Vf))
prolong = partial(P.assemble, tensor=self.uf)

rf = firedrake.Function(self.Vf.dual(), val=self.uf.dat)
rc = firedrake.Function(self.Vc.dual(), val=self.uc.dat)
vc = firedrake.TestFunction(self.Vc)
R = Interpolator(interpolate(vc, rf), self.Vf)
R = get_interpolator(interpolate(vc, rf))
restrict = partial(R.assemble, tensor=rc)
return prolong, restrict

Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake/regression/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def test_interpolator_reuse(family, degree, mode):
u = Function(V.dual())
expr = interpolate(TestFunction(V), u)

I = Interpolator(expr, V)
I = get_interpolator(expr)

for k in range(3):
u.assign(rg.uniform(u.function_space()))
Expand Down
38 changes: 37 additions & 1 deletion tests/firedrake/regression/test_interpolate_cross_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,18 @@ def test_exact_refinement():
expr_in_V_fine = x**2 + y**2 + 1
f_fine = Function(V_fine).interpolate(expr_in_V_fine)

# Build interpolation matrices in both directions
coarse_to_fine = assemble(interpolate(TrialFunction(V_coarse), V_fine))
coarse_to_fine_adjoint = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual())))

# If we now interpolate f_coarse into V_fine we should get a function
# which has no interpolation error versus f_fine because we were able to
# exactly represent expr_in_V_coarse in V_coarse and V_coarse is a subset
# of V_fine
f_coarse_on_fine = assemble(interpolate(f_coarse, V_fine))
assert np.allclose(f_coarse_on_fine.dat.data_ro, f_fine.dat.data_ro)
f_coarse_on_fine_mat = assemble(coarse_to_fine @ f_coarse)
assert np.allclose(f_coarse_on_fine_mat.dat.data_ro, f_fine.dat.data_ro)

# Adjoint interpolation takes us from V_fine^* to V_coarse^* so we should
# also get an exact result here.
Expand All @@ -354,6 +360,10 @@ def test_exact_refinement():
assert np.allclose(
cofunction_fine_on_coarse.dat.data_ro, cofunction_coarse.dat.data_ro
)
cofunction_fine_on_coarse_mat = assemble(action(coarse_to_fine_adjoint, cofunction_fine))
assert np.allclose(
cofunction_fine_on_coarse_mat.dat.data_ro, cofunction_coarse.dat.data_ro
)

# Now we test with expressions which are NOT exactly representable in the
# function spaces by introducing a cube term. This can't be represented
Expand Down Expand Up @@ -550,7 +560,7 @@ def test_missing_dofs():
V_src = FunctionSpace(m_src, "CG", 2)
V_dest = FunctionSpace(m_dest, "CG", 3)
with pytest.raises(DofNotDefinedError):
Interpolator(TestFunction(V_src), V_dest)
assemble(interpolate(TrialFunction(V_src), V_dest))
f_src = Function(V_src).interpolate(expr)
f_dest = assemble(interpolate(f_src, V_dest, allow_missing_dofs=True))
dest_eval = PointEvaluator(m_dest, coords)
Expand Down Expand Up @@ -680,6 +690,32 @@ def test_interpolate_matrix_cross_mesh():
f_interp2.dat.data_wo[:] = f_at_points_correct_order3.dat.data_ro[:]
assert np.allclose(f_interp2.dat.data_ro, g.dat.data_ro)

interp_mat2 = assemble(interpolate(TrialFunction(U), V))
assert interp_mat2.arguments() == (TestFunction(V.dual()), TrialFunction(U))
f_interp3 = assemble(interp_mat2 @ f)
assert f_interp3.function_space() == V
assert np.allclose(f_interp3.dat.data_ro, g.dat.data_ro)


@pytest.mark.parallel([1, 3])
def test_interpolate_matrix_cross_mesh_adjoint():
mesh_fine = UnitSquareMesh(4, 4)
mesh_coarse = UnitSquareMesh(2, 2)

V_coarse = FunctionSpace(mesh_coarse, "CG", 1)
V_fine = FunctionSpace(mesh_fine, "CG", 1)

cofunc_fine = assemble(conj(TestFunction(V_fine)) * dx)

interp = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual())))
cofunc_coarse = assemble(Action(interp, cofunc_fine))
assert interp.arguments() == (TestFunction(V_coarse), TrialFunction(V_fine.dual()))
assert cofunc_coarse.function_space() == V_coarse.dual()

# Compare cofunc_fine with direct interpolation
cofunc_coarse_direct = assemble(conj(TestFunction(V_coarse)) * dx)
assert np.allclose(cofunc_coarse.dat.data_ro, cofunc_coarse_direct.dat.data_ro)


@pytest.mark.parallel([2, 3, 4])
def test_voting_algorithm_edgecases():
Expand Down
42 changes: 19 additions & 23 deletions tests/firedrake/vertexonly/test_vertex_only_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def pseudo_random_coords(size):

# Function Space Generation Tests

def functionspace_tests(vm, petsc_raises):
def functionspace_tests(vm):
# Prep
num_cells = len(vm.coordinates.dat.data_ro)
num_cells_mpi_global = MPI.COMM_WORLD.allreduce(num_cells, op=MPI.SUM)
Expand Down Expand Up @@ -144,27 +144,25 @@ def functionspace_tests(vm, petsc_raises):
h_star = h.riesz_representation(riesz_map="l2")
g = assemble(interpolate(TestFunction(V), h_star))
assert np.allclose(g.dat.data_ro_with_halos, np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))
with petsc_raises(NotImplementedError):
# Can't use adjoint on interpolates with expressions yet
g2 = assemble(interpolate(2 * TestFunction(V), h_star))
assert np.allclose(g2.dat.data_ro_with_halos, 2*np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))

g2 = assemble(interpolate(2 * TestFunction(V), h_star))
assert np.allclose(g2.dat.data_ro_with_halos, 2*np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))

h_star = assemble(interpolate(TestFunction(W), g))
h = h_star.riesz_representation(riesz_map="l2")
assert np.allclose(h.dat.data_ro_with_halos[idxs_to_include], np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension), axis=1))
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == 0)
with petsc_raises(NotImplementedError):
# Can't use adjoint on interpolates with expressions yet
h2 = assemble(interpolate(2 * TestFunction(W), g))
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], 2*np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension), axis=1))

h2 = assemble(interpolate(2 * TestFunction(W), g))
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], 2*np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension), axis=1))

g = assemble(interpolate(h, V))
assert np.allclose(g.dat.data_ro_with_halos, np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))
g2 = assemble(interpolate(2 * h, V))
assert np.allclose(g2.dat.data_ro_with_halos, 2*np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))


def vectorfunctionspace_tests(vm, petsc_raises):
def vectorfunctionspace_tests(vm):
# Prep
gdim = vm.geometric_dimension
num_cells = len(vm.coordinates.dat.data_ro)
Expand Down Expand Up @@ -240,18 +238,16 @@ def vectorfunctionspace_tests(vm, petsc_raises):
h_star = h.riesz_representation(riesz_map="l2")
g = assemble(interpolate(TestFunction(V), h_star))
assert np.allclose(g.dat.data_ro_with_halos, 2*vm.coordinates.dat.data_ro_with_halos)
with petsc_raises(NotImplementedError):
# Can't use adjoint on interpolate with expressions yet
g2 = assemble(interpolate(2 * TestFunction(V), h_star))
assert np.allclose(g2.dat.data_ro_with_halos, 4*vm.coordinates.dat.data_ro_with_halos)

g2 = assemble(interpolate(2 * TestFunction(V), h_star))
assert np.allclose(g2.dat.data_ro_with_halos, 4*vm.coordinates.dat.data_ro_with_halos)

h_star = assemble(interpolate(TestFunction(W), g))
assert np.allclose(h_star.dat.data_ro[idxs_to_include], 2*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])
assert np.all(h_star.dat.data_ro_with_halos[~idxs_to_include] == 0)
with petsc_raises(NotImplementedError):
# Can't use adjoint on interpolate with expressions yet
h2 = assemble(interpolate(2 * TestFunction(W), g))
assert np.allclose(h2.dat.data_ro[idxs_to_include], 4*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])

h2 = assemble(interpolate(2 * TestFunction(W), g))
assert np.allclose(h2.dat.data_ro[idxs_to_include], 4*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])

h = h_star.riesz_representation(riesz_map="l2")
g = assemble(interpolate(h, V))
Expand All @@ -261,12 +257,12 @@ def vectorfunctionspace_tests(vm, petsc_raises):


@pytest.mark.parallel([1, 3])
def test_functionspaces(parentmesh, vertexcoords, petsc_raises):
def test_functionspaces(parentmesh, vertexcoords):
vm = VertexOnlyMesh(parentmesh, vertexcoords, missing_points_behaviour="ignore")
functionspace_tests(vm, petsc_raises)
vectorfunctionspace_tests(vm, petsc_raises)
functionspace_tests(vm.input_ordering, petsc_raises)
vectorfunctionspace_tests(vm.input_ordering, petsc_raises)
functionspace_tests(vm)
vectorfunctionspace_tests(vm)
functionspace_tests(vm.input_ordering)
vectorfunctionspace_tests(vm.input_ordering)


@pytest.mark.parallel(nprocs=2)
Expand Down
Loading