Skip to content

Commit 8b54083

Browse files
authored
Fix TransferManager reuse for mixed coefficients (#4600)
* Fix TransferManager reuse for mixed coefficients
1 parent 087d2d5 commit 8b54083

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

firedrake/variational_solver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,12 @@ def solve(self, bounds=None):
332332
problem = self._problem
333333
forms = (problem.F, problem.J, problem.Jp)
334334
coefficients = utils.unique(chain.from_iterable(form.coefficients() for form in forms if form is not None))
335-
spaces = chain.from_iterable(c.function_space() for c in coefficients)
335+
coefficients += problem.u.subfunctions
336336
solution_dm = self.snes.getDM()
337337
# Grab the unique DMs for this problem
338338
problem_dms = []
339-
for V in spaces:
340-
dm = V.dm
339+
for c in coefficients:
340+
dm = c.function_space().dm
341341
if dm == solution_dm:
342342
# Make sure the solution dm is visited last
343343
continue

tests/firedrake/multigrid/test_transfer_manager.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,20 @@ def test_transfer_manager_dat_version_cache(action, transfer_op, spaces):
134134
raise ValueError(f"Unrecognized action {action}")
135135

136136

137-
@pytest.mark.parametrize("family, degree", [("CG", 1), ("R", 0)])
138-
def test_cached_transfer(family, degree):
137+
@pytest.mark.parametrize("family, degree, coefficient", [
138+
("CG", 1, "repeated"),
139+
("CG", 1, "mixed"),
140+
("CG", 1, "bcs"),
141+
("R", 0, "repeated"),
142+
])
143+
def test_cached_transfer(family, degree, coefficient):
139144
# Test that we can properly reuse transfers within solve
140-
sp = {"mat_type": "matfree",
141-
"pc_type": "mg",
142-
"mg_coarse_pc_type": "none",
143-
"mg_levels_pc_type": "none"}
145+
sp = {
146+
"mat_type": "matfree",
147+
"pc_type": "mg",
148+
"mg_levels_pc_type": "none",
149+
"mg_coarse_pc_type": "none",
150+
}
144151

145152
base = UnitSquareMesh(1, 1)
146153
hierarchy = MeshHierarchy(base, 3)
@@ -149,13 +156,30 @@ def test_cached_transfer(family, degree):
149156
V = FunctionSpace(mesh, family, degree)
150157
u = Function(V)
151158

152-
R1 = FunctionSpace(mesh, "R", 0)
153-
R2 = FunctionSpace(mesh, "R", 0)
154-
c1 = Function(R1).assign(1)
155-
c2 = Function(R2).assign(1)
159+
bcs = None
160+
if coefficient == "mixed":
161+
R = FunctionSpace(mesh, "R", 0)
162+
R2 = R * R
163+
c = Function(R2).assign(1)
164+
c1 = c.subfunctions[0]
165+
c2 = c[1]
166+
elif coefficient == "repeated":
167+
R1 = FunctionSpace(mesh, "R", 0)
168+
R2 = FunctionSpace(mesh, "R", 0)
169+
c1 = Function(R1).assign(1)
170+
c2 = Function(R2).assign(1)
171+
elif coefficient == "bcs":
172+
c1 = 1
173+
c2 = 1
174+
R = FunctionSpace(mesh, "R", 0)
175+
R2 = R * R
176+
g = Function(R2).assign(1)
177+
bcs = [DirichletBC(V, g[0], (1, 2)), DirichletBC(V, g[1], (3, 4))]
178+
else:
179+
raise ValueError(f"Unrecognized coefficient type {coefficient}")
156180

157181
F = inner(u - 1, (c1 + c2)*TestFunction(V)) * dx
158-
problem = NonlinearVariationalProblem(F, u)
182+
problem = NonlinearVariationalProblem(F, u, bcs=bcs)
159183
solver = NonlinearVariationalSolver(problem, solver_parameters=sp)
160184

161185
transfer = TransferManager()

0 commit comments

Comments
 (0)