Skip to content
2 changes: 1 addition & 1 deletion firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
if rank > 2:
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
interpolator = get_interpolator(expr)
return interpolator.assemble(tensor=tensor, bcs=bcs)
return interpolator.assemble(tensor=tensor, bcs=bcs, mat_type=self._mat_type, sub_mat_type=self._sub_mat_type)
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
return tensor.assign(expr)
elif tensor and isinstance(expr, ufl.ZeroBaseForm):
Expand Down
245 changes: 189 additions & 56 deletions firedrake/interpolation.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion firedrake/preconditioners/pmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,8 @@ def prolongation_matrix_aij(Vc, Vf, Vc_bcs=(), Vf_bcs=()):
Vc = Vc.function_space()
bcs = Vc_bcs + Vf_bcs
interp = firedrake.interpolate(firedrake.TrialFunction(Vc), Vf)
mat = firedrake.assemble(interp, bcs=bcs)
mat_type = "nest" if len(Vc) > 1 or len(Vf) > 1 else None
mat = firedrake.assemble(interp, bcs=bcs, mat_type=mat_type)
return mat.petscmat


Expand Down
11 changes: 6 additions & 5 deletions tests/firedrake/regression/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,8 @@ def test_interpolate_logical_not():


@pytest.mark.parametrize("mode", ("forward", "adjoint"))
def test_mixed_matrix(mode):
@pytest.mark.parametrize("mat_type", (None, "nest"))
def test_mixed_matrix(mode, mat_type):
nx = 3
mesh = UnitSquareMesh(nx, nx)

Expand All @@ -537,11 +538,11 @@ def test_mixed_matrix(mode):

if mode == "forward":
I = Interpolate(TrialFunction(Z), TestFunction(W.dual()))
a = assemble(I)
a = assemble(I, mat_type=mat_type)
assert a.arguments()[0].function_space() == W.dual()
assert a.arguments()[1].function_space() == Z
assert a.petscmat.getSize() == (W.dim(), Z.dim())
assert a.petscmat.getType() == "nest"
assert a.petscmat.getType() == (mat_type if mat_type else "seqaij")

u = Function(Z)
u.subfunctions[0].sub(0).assign(1)
Expand All @@ -550,11 +551,11 @@ def test_mixed_matrix(mode):
result_matfree = assemble(Interpolate(u, TestFunction(W.dual())))
elif mode == "adjoint":
I = Interpolate(TestFunction(Z), TrialFunction(W.dual()))
a = assemble(I)
a = assemble(I, mat_type=mat_type)
assert a.arguments()[1].function_space() == W.dual()
assert a.arguments()[0].function_space() == Z
assert a.petscmat.getSize() == (Z.dim(), W.dim())
assert a.petscmat.getType() == "nest"
assert a.petscmat.getType() == (mat_type if mat_type else "seqaij")

u = Function(W.dual())
u.subfunctions[0].assign(1)
Expand Down
4 changes: 2 additions & 2 deletions tests/firedrake/regression/test_interpolate_cross_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,9 @@ def test_interpolate_matrix_cross_mesh():
f_at_points2 = assemble(interpolate(f, P0DG))
assert np.allclose(f_at_points.dat.data_ro, f_at_points2.dat.data_ro)
# To get the points in the correct order in V we interpolate into vom.input_ordering
# We pass matfree=False which constructs the permutation matrix instead of using SFs
# We pass mat_type='aij' which constructs the permutation matrix instead of using SFs
P0DG_io = FunctionSpace(vom.input_ordering, "DG", 0)
B = assemble(interpolate(TrialFunction(P0DG), P0DG_io, matfree=False))
B = assemble(interpolate(TrialFunction(P0DG), P0DG_io), mat_type='aij')
f_at_points_correct_order = assemble(B @ f_at_points)
f_at_points_correct_order2 = assemble(interpolate(f_at_points, P0DG_io))
assert np.allclose(f_at_points_correct_order.dat.data_ro, f_at_points_correct_order2.dat.data_ro)
Expand Down
173 changes: 173 additions & 0 deletions tests/firedrake/regression/test_interpolator_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from firedrake import *
from firedrake.interpolation import (
MixedInterpolator, SameMeshInterpolator, CrossMeshInterpolator,
get_interpolator, VomOntoVomInterpolator,
)
import pytest


def params():
params = []
for mat_type in [None, "aij"]:
params.append(pytest.param(mat_type, None, id=f"mat_type={mat_type}"))
for sub_mat_type in [None, "aij", "baij"]:
params.append(pytest.param("nest", sub_mat_type, id=f"nest_sub_mat_type={sub_mat_type}"))
return params


@pytest.mark.parallel([1, 2])
@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
@pytest.mark.parametrize("mat_type", [None, "aij", "baij"], ids=lambda v: f"mat_type={v}")
def test_same_mesh_mattype(value_shape, mat_type):
if COMM_WORLD.size > 1:
prefix = "mpi"
else:
prefix = "seq"
mesh = UnitSquareMesh(4, 4)
if value_shape == "scalar":
fs_type = FunctionSpace
else:
fs_type = VectorFunctionSpace
V1 = fs_type(mesh, "CG", 1)
V2 = fs_type(mesh, "CG", 2)

u = TrialFunction(V1)

interp = interpolate(u, V2)
assert isinstance(get_interpolator(interp), SameMeshInterpolator)
res = assemble(interp, mat_type=mat_type)

if value_shape == "scalar":
# Always seqaij for scalar
assert res.petscmat.type == prefix + "aij"
else:
# Defaults to seqaij
assert res.petscmat.type == prefix + (mat_type if mat_type else "aij")

with pytest.raises(NotImplementedError):
# MatNest only implemented for interpolation between MixedFunctionSpaces
assemble(interp, mat_type="nest")


@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
@pytest.mark.parametrize("mat_type", [None, "aij"], ids=lambda v: f"mat_type={v}")
def test_cross_mesh_mattype(value_shape, mat_type):
mesh1 = UnitSquareMesh(1, 1)
mesh2 = UnitSquareMesh(1, 1)
if value_shape == "scalar":
fs_type = FunctionSpace
else:
fs_type = VectorFunctionSpace
V1 = fs_type(mesh1, "CG", 1)
V2 = fs_type(mesh2, "CG", 1)

u = TrialFunction(V1)

interp = interpolate(u, V2)
assert isinstance(get_interpolator(interp), CrossMeshInterpolator)
res = assemble(interp, mat_type=mat_type)

# only aij for cross-mesh
assert res.petscmat.type == "seqaij"


@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
@pytest.mark.parametrize("mat_type", [None, "aij", "baij", "matfree"], ids=lambda v: f"mat_type={v}")
def test_vomtovom_mattype(value_shape, mat_type):
mesh = UnitSquareMesh(1, 1)
points = [[0.1, 0.1]]
vom = VertexOnlyMesh(mesh, points)
if value_shape == "scalar":
fs_type = FunctionSpace
else:
fs_type = VectorFunctionSpace
P0DG = fs_type(vom, "DG", 0)
P0DG_io = fs_type(vom.input_ordering, "DG", 0)

u = TrialFunction(P0DG)
interp = interpolate(u, P0DG_io)
assert isinstance(get_interpolator(interp), VomOntoVomInterpolator)
res = assemble(interp, mat_type=mat_type)
if not mat_type or mat_type == "matfree":
assert res.petscmat.type == "python"
else:
if value_shape == "scalar":
# Always seqaij for scalar
assert res.petscmat.type == "seqaij"
else:
# Defaults to seqaij
assert res.petscmat.type == "seq" + (mat_type if mat_type else "aij")


@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
@pytest.mark.parametrize("mat_type", [None, "aij", "baij"], ids=lambda v: f"mat_type={v}")
def test_point_eval_mattype(value_shape, mat_type):
mesh = UnitSquareMesh(1, 1)
points = [[0.1, 0.1], [0.5, 0.5], [0.9, 0.9]]
vom = VertexOnlyMesh(mesh, points)
if value_shape == "scalar":
fs_type = FunctionSpace
else:
fs_type = VectorFunctionSpace
P0DG = fs_type(vom, "DG", 0)
V = fs_type(mesh, "CG", 1)

u = TrialFunction(V)
interp = interpolate(u, P0DG)
assert isinstance(get_interpolator(interp), SameMeshInterpolator)
res = assemble(interp, mat_type=mat_type)

if value_shape == "scalar":
# Always seqaij for scalar
assert res.petscmat.type == "seqaij"
else:
# Defaults to seqaij
assert res.petscmat.type == "seq" + (mat_type if mat_type else "aij")


@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
@pytest.mark.parametrize("mat_type,sub_mat_type", params())
def test_mixed_same_mesh_mattype(value_shape, mat_type, sub_mat_type):
mesh = UnitSquareMesh(1, 1)
if value_shape == "scalar":
fs_type = FunctionSpace
else:
fs_type = VectorFunctionSpace
V1 = fs_type(mesh, "CG", 1)
V2 = fs_type(mesh, "CG", 2)
V3 = fs_type(mesh, "CG", 3)
V4 = fs_type(mesh, "CG", 4)

W = V1 * V2
U = V3 * V4

w = TrialFunction(W)
w0, w1 = split(w)
if value_shape == "scalar":
expr = as_vector([w0 + w1, w0 + w1])
else:
w00, w01 = split(w0)
w10, w11 = split(w1)
expr = as_vector([w00 + w10, w00 + w10, w01 + w11, w01 + w11])
interp = interpolate(expr, U)
assert isinstance(get_interpolator(interp), MixedInterpolator)
res = assemble(interp, mat_type=mat_type, sub_mat_type=sub_mat_type)
if not mat_type or mat_type == "aij":
# Defaults to seqaij
assert res.petscmat.type == "seqaij"
else:
assert res.petscmat.type == "nest"
for (i, j) in [(0, 0), (0, 1), (1, 0), (1, 1)]:
sub_mat = res.petscmat.getNestSubMatrix(i, j)
if value_shape == "scalar":
# Always seqaij for scalar
assert sub_mat.type == "seqaij"
else:
# matnest sub_mat_type defaults to baij
assert sub_mat.type == "seq" + (sub_mat_type if sub_mat_type else "baij")

with pytest.raises(NotImplementedError):
assemble(interp, mat_type="baij")

with pytest.raises(NotImplementedError):
assemble(interp, mat_type="matfree")
8 changes: 4 additions & 4 deletions tests/firedrake/vertexonly/test_vertex_only_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def functionspace_tests(vm):
assert np.allclose(h.dat.data_ro_with_halos[idxs_to_include], np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension), axis=1))
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == -1)
# Using permutation matrix
perm_mat = assemble(interpolate(TrialFunction(V), W, matfree=False))
perm_mat = assemble(interpolate(TrialFunction(V), W), mat_type="aij")
h2 = assemble(perm_mat @ g)
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
h2 = assemble(interpolate(g, W))
Expand Down Expand Up @@ -214,7 +214,7 @@ def vectorfunctionspace_tests(vm):
assert np.allclose(h.dat.data_ro[idxs_to_include], 2*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == -1)
# Using permutation matrix
perm_mat = assemble(interpolate(TrialFunction(V), W, matfree=False))
perm_mat = assemble(interpolate(TrialFunction(V), W), mat_type="aij")
h2 = assemble(perm_mat @ g)
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
# check other interpolation APIs work identically
Expand Down Expand Up @@ -369,9 +369,9 @@ def test_tensorfs_permutation(tensorfs_and_expr):
f = Function(V)
f.interpolate(expr)
f_in_W = assemble(interpolate(f, W))
python_mat = assemble(interpolate(TrialFunction(V), W, matfree=False))
python_mat = assemble(interpolate(TrialFunction(V), W), mat_type="matfree")
f_in_W_2 = assemble(python_mat @ f)
assert np.allclose(f_in_W.dat.data_ro, f_in_W_2.dat.data_ro)
petsc_mat = assemble(interpolate(TrialFunction(V), W, matfree=True))
petsc_mat = assemble(interpolate(TrialFunction(V), W), mat_type="aij")
f_in_W_petsc = assemble(petsc_mat @ f)
assert np.allclose(f_in_W.dat.data_ro, f_in_W_petsc.dat.data_ro)