Skip to content

Commit b788010

Browse files
authored
Make assembly of interpolation matrices respect mat_type (#4749)
* pass mat_type and sub_mat_type into Interpolator.assemble() * remove matfree kwarg from interpolateoptions
1 parent f20c148 commit b788010

File tree

7 files changed

+377
-69
lines changed

7 files changed

+377
-69
lines changed

firedrake/assemble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
615615
if rank > 2:
616616
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
617617
interpolator = get_interpolator(expr)
618-
return interpolator.assemble(tensor=tensor, bcs=bcs)
618+
return interpolator.assemble(tensor=tensor, bcs=bcs, mat_type=self._mat_type, sub_mat_type=self._sub_mat_type)
619619
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
620620
return tensor.assign(expr)
621621
elif tensor and isinstance(expr, ufl.ZeroBaseForm):

firedrake/interpolation.py

Lines changed: 189 additions & 56 deletions
Large diffs are not rendered by default.

firedrake/preconditioners/pmg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1543,7 +1543,8 @@ def prolongation_matrix_aij(Vc, Vf, Vc_bcs=(), Vf_bcs=()):
15431543
Vc = Vc.function_space()
15441544
bcs = Vc_bcs + Vf_bcs
15451545
interp = firedrake.interpolate(firedrake.TrialFunction(Vc), Vf)
1546-
mat = firedrake.assemble(interp, bcs=bcs)
1546+
mat_type = "nest" if len(Vc) > 1 or len(Vf) > 1 else None
1547+
mat = firedrake.assemble(interp, bcs=bcs, mat_type=mat_type)
15471548
return mat.petscmat
15481549

15491550

tests/firedrake/regression/test_interpolate.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,8 @@ def test_interpolate_logical_not():
523523

524524

525525
@pytest.mark.parametrize("mode", ("forward", "adjoint"))
526-
def test_mixed_matrix(mode):
526+
@pytest.mark.parametrize("mat_type", (None, "nest"))
527+
def test_mixed_matrix(mode, mat_type):
527528
nx = 3
528529
mesh = UnitSquareMesh(nx, nx)
529530

@@ -537,11 +538,11 @@ def test_mixed_matrix(mode):
537538

538539
if mode == "forward":
539540
I = Interpolate(TrialFunction(Z), TestFunction(W.dual()))
540-
a = assemble(I)
541+
a = assemble(I, mat_type=mat_type)
541542
assert a.arguments()[0].function_space() == W.dual()
542543
assert a.arguments()[1].function_space() == Z
543544
assert a.petscmat.getSize() == (W.dim(), Z.dim())
544-
assert a.petscmat.getType() == "nest"
545+
assert a.petscmat.getType() == (mat_type if mat_type else "seqaij")
545546

546547
u = Function(Z)
547548
u.subfunctions[0].sub(0).assign(1)
@@ -550,11 +551,11 @@ def test_mixed_matrix(mode):
550551
result_matfree = assemble(Interpolate(u, TestFunction(W.dual())))
551552
elif mode == "adjoint":
552553
I = Interpolate(TestFunction(Z), TrialFunction(W.dual()))
553-
a = assemble(I)
554+
a = assemble(I, mat_type=mat_type)
554555
assert a.arguments()[1].function_space() == W.dual()
555556
assert a.arguments()[0].function_space() == Z
556557
assert a.petscmat.getSize() == (Z.dim(), W.dim())
557-
assert a.petscmat.getType() == "nest"
558+
assert a.petscmat.getType() == (mat_type if mat_type else "seqaij")
558559

559560
u = Function(W.dual())
560561
u.subfunctions[0].assign(1)

tests/firedrake/regression/test_interpolate_cross_mesh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,9 @@ def test_interpolate_matrix_cross_mesh():
666666
f_at_points2 = assemble(interpolate(f, P0DG))
667667
assert np.allclose(f_at_points.dat.data_ro, f_at_points2.dat.data_ro)
668668
# To get the points in the correct order in V we interpolate into vom.input_ordering
669-
# We pass matfree=False which constructs the permutation matrix instead of using SFs
669+
# We pass mat_type='aij' which constructs the permutation matrix instead of using SFs
670670
P0DG_io = FunctionSpace(vom.input_ordering, "DG", 0)
671-
B = assemble(interpolate(TrialFunction(P0DG), P0DG_io, matfree=False))
671+
B = assemble(interpolate(TrialFunction(P0DG), P0DG_io), mat_type='aij')
672672
f_at_points_correct_order = assemble(B @ f_at_points)
673673
f_at_points_correct_order2 = assemble(interpolate(f_at_points, P0DG_io))
674674
assert np.allclose(f_at_points_correct_order.dat.data_ro, f_at_points_correct_order2.dat.data_ro)
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from firedrake import *
2+
from firedrake.interpolation import (
3+
MixedInterpolator, SameMeshInterpolator, CrossMeshInterpolator,
4+
get_interpolator, VomOntoVomInterpolator,
5+
)
6+
import pytest
7+
8+
9+
def params():
10+
params = []
11+
for mat_type in [None, "aij"]:
12+
params.append(pytest.param(mat_type, None, id=f"mat_type={mat_type}"))
13+
for sub_mat_type in [None, "aij", "baij"]:
14+
params.append(pytest.param("nest", sub_mat_type, id=f"nest_sub_mat_type={sub_mat_type}"))
15+
return params
16+
17+
18+
@pytest.mark.parallel([1, 2])
19+
@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
20+
@pytest.mark.parametrize("mat_type", [None, "aij", "baij"], ids=lambda v: f"mat_type={v}")
21+
def test_same_mesh_mattype(value_shape, mat_type):
22+
if COMM_WORLD.size > 1:
23+
prefix = "mpi"
24+
else:
25+
prefix = "seq"
26+
mesh = UnitSquareMesh(4, 4)
27+
if value_shape == "scalar":
28+
fs_type = FunctionSpace
29+
else:
30+
fs_type = VectorFunctionSpace
31+
V1 = fs_type(mesh, "CG", 1)
32+
V2 = fs_type(mesh, "CG", 2)
33+
34+
u = TrialFunction(V1)
35+
36+
interp = interpolate(u, V2)
37+
assert isinstance(get_interpolator(interp), SameMeshInterpolator)
38+
res = assemble(interp, mat_type=mat_type)
39+
40+
if value_shape == "scalar":
41+
# Always seqaij for scalar
42+
assert res.petscmat.type == prefix + "aij"
43+
else:
44+
# Defaults to seqaij
45+
assert res.petscmat.type == prefix + (mat_type if mat_type else "aij")
46+
47+
with pytest.raises(NotImplementedError):
48+
# MatNest only implemented for interpolation between MixedFunctionSpaces
49+
assemble(interp, mat_type="nest")
50+
51+
52+
@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
53+
@pytest.mark.parametrize("mat_type", [None, "aij"], ids=lambda v: f"mat_type={v}")
54+
def test_cross_mesh_mattype(value_shape, mat_type):
55+
mesh1 = UnitSquareMesh(1, 1)
56+
mesh2 = UnitSquareMesh(1, 1)
57+
if value_shape == "scalar":
58+
fs_type = FunctionSpace
59+
else:
60+
fs_type = VectorFunctionSpace
61+
V1 = fs_type(mesh1, "CG", 1)
62+
V2 = fs_type(mesh2, "CG", 1)
63+
64+
u = TrialFunction(V1)
65+
66+
interp = interpolate(u, V2)
67+
assert isinstance(get_interpolator(interp), CrossMeshInterpolator)
68+
res = assemble(interp, mat_type=mat_type)
69+
70+
# only aij for cross-mesh
71+
assert res.petscmat.type == "seqaij"
72+
73+
74+
@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
75+
@pytest.mark.parametrize("mat_type", [None, "aij", "baij", "matfree"], ids=lambda v: f"mat_type={v}")
76+
def test_vomtovom_mattype(value_shape, mat_type):
77+
mesh = UnitSquareMesh(1, 1)
78+
points = [[0.1, 0.1]]
79+
vom = VertexOnlyMesh(mesh, points)
80+
if value_shape == "scalar":
81+
fs_type = FunctionSpace
82+
else:
83+
fs_type = VectorFunctionSpace
84+
P0DG = fs_type(vom, "DG", 0)
85+
P0DG_io = fs_type(vom.input_ordering, "DG", 0)
86+
87+
u = TrialFunction(P0DG)
88+
interp = interpolate(u, P0DG_io)
89+
assert isinstance(get_interpolator(interp), VomOntoVomInterpolator)
90+
res = assemble(interp, mat_type=mat_type)
91+
if not mat_type or mat_type == "matfree":
92+
assert res.petscmat.type == "python"
93+
else:
94+
if value_shape == "scalar":
95+
# Always seqaij for scalar
96+
assert res.petscmat.type == "seqaij"
97+
else:
98+
# Defaults to seqaij
99+
assert res.petscmat.type == "seq" + (mat_type if mat_type else "aij")
100+
101+
102+
@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
103+
@pytest.mark.parametrize("mat_type", [None, "aij", "baij"], ids=lambda v: f"mat_type={v}")
104+
def test_point_eval_mattype(value_shape, mat_type):
105+
mesh = UnitSquareMesh(1, 1)
106+
points = [[0.1, 0.1], [0.5, 0.5], [0.9, 0.9]]
107+
vom = VertexOnlyMesh(mesh, points)
108+
if value_shape == "scalar":
109+
fs_type = FunctionSpace
110+
else:
111+
fs_type = VectorFunctionSpace
112+
P0DG = fs_type(vom, "DG", 0)
113+
V = fs_type(mesh, "CG", 1)
114+
115+
u = TrialFunction(V)
116+
interp = interpolate(u, P0DG)
117+
assert isinstance(get_interpolator(interp), SameMeshInterpolator)
118+
res = assemble(interp, mat_type=mat_type)
119+
120+
if value_shape == "scalar":
121+
# Always seqaij for scalar
122+
assert res.petscmat.type == "seqaij"
123+
else:
124+
# Defaults to seqaij
125+
assert res.petscmat.type == "seq" + (mat_type if mat_type else "aij")
126+
127+
128+
@pytest.mark.parametrize("value_shape", ["scalar", "vector"], ids=lambda v: f"fs_type={v}")
129+
@pytest.mark.parametrize("mat_type,sub_mat_type", params())
130+
def test_mixed_same_mesh_mattype(value_shape, mat_type, sub_mat_type):
131+
mesh = UnitSquareMesh(1, 1)
132+
if value_shape == "scalar":
133+
fs_type = FunctionSpace
134+
else:
135+
fs_type = VectorFunctionSpace
136+
V1 = fs_type(mesh, "CG", 1)
137+
V2 = fs_type(mesh, "CG", 2)
138+
V3 = fs_type(mesh, "CG", 3)
139+
V4 = fs_type(mesh, "CG", 4)
140+
141+
W = V1 * V2
142+
U = V3 * V4
143+
144+
w = TrialFunction(W)
145+
w0, w1 = split(w)
146+
if value_shape == "scalar":
147+
expr = as_vector([w0 + w1, w0 + w1])
148+
else:
149+
w00, w01 = split(w0)
150+
w10, w11 = split(w1)
151+
expr = as_vector([w00 + w10, w00 + w10, w01 + w11, w01 + w11])
152+
interp = interpolate(expr, U)
153+
assert isinstance(get_interpolator(interp), MixedInterpolator)
154+
res = assemble(interp, mat_type=mat_type, sub_mat_type=sub_mat_type)
155+
if not mat_type or mat_type == "aij":
156+
# Defaults to seqaij
157+
assert res.petscmat.type == "seqaij"
158+
else:
159+
assert res.petscmat.type == "nest"
160+
for (i, j) in [(0, 0), (0, 1), (1, 0), (1, 1)]:
161+
sub_mat = res.petscmat.getNestSubMatrix(i, j)
162+
if value_shape == "scalar":
163+
# Always seqaij for scalar
164+
assert sub_mat.type == "seqaij"
165+
else:
166+
# matnest sub_mat_type defaults to baij
167+
assert sub_mat.type == "seq" + (sub_mat_type if sub_mat_type else "baij")
168+
169+
with pytest.raises(NotImplementedError):
170+
assemble(interp, mat_type="baij")
171+
172+
with pytest.raises(NotImplementedError):
173+
assemble(interp, mat_type="matfree")

tests/firedrake/vertexonly/test_vertex_only_fs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def functionspace_tests(vm):
121121
assert np.allclose(h.dat.data_ro_with_halos[idxs_to_include], np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension), axis=1))
122122
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == -1)
123123
# Using permutation matrix
124-
perm_mat = assemble(interpolate(TrialFunction(V), W, matfree=False))
124+
perm_mat = assemble(interpolate(TrialFunction(V), W), mat_type="aij")
125125
h2 = assemble(perm_mat @ g)
126126
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
127127
h2 = assemble(interpolate(g, W))
@@ -214,7 +214,7 @@ def vectorfunctionspace_tests(vm):
214214
assert np.allclose(h.dat.data_ro[idxs_to_include], 2*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])
215215
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == -1)
216216
# Using permutation matrix
217-
perm_mat = assemble(interpolate(TrialFunction(V), W, matfree=False))
217+
perm_mat = assemble(interpolate(TrialFunction(V), W), mat_type="aij")
218218
h2 = assemble(perm_mat @ g)
219219
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
220220
# check other interpolation APIs work identically
@@ -369,9 +369,9 @@ def test_tensorfs_permutation(tensorfs_and_expr):
369369
f = Function(V)
370370
f.interpolate(expr)
371371
f_in_W = assemble(interpolate(f, W))
372-
python_mat = assemble(interpolate(TrialFunction(V), W, matfree=False))
372+
python_mat = assemble(interpolate(TrialFunction(V), W), mat_type="matfree")
373373
f_in_W_2 = assemble(python_mat @ f)
374374
assert np.allclose(f_in_W.dat.data_ro, f_in_W_2.dat.data_ro)
375-
petsc_mat = assemble(interpolate(TrialFunction(V), W, matfree=True))
375+
petsc_mat = assemble(interpolate(TrialFunction(V), W), mat_type="aij")
376376
f_in_W_petsc = assemble(petsc_mat @ f)
377377
assert np.allclose(f_in_W.dat.data_ro, f_in_W_petsc.dat.data_ro)

0 commit comments

Comments
 (0)