diff --git a/firedrake/assemble.py b/firedrake/assemble.py index ca5836be95..e495995897 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -503,6 +503,14 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args): with lhs.dat.vec_ro as x, rhs.dat.vec_ro as y: res = x.dot(y) return res + elif isinstance(rhs, matrix.MatrixBase): + # Compute action(Cofunc, Mat) => Mat^* @ Cofunc + petsc_mat = rhs.petscmat + (_, col) = rhs.arguments() + res = tensor if tensor else firedrake.Function(col.function_space().dual()) + with lhs.dat.vec_ro as v_vec, res.dat.vec as res_vec: + petsc_mat.multHermitian(v_vec, res_vec) + return res else: raise TypeError("Incompatible RHS for Action.") else: @@ -830,6 +838,12 @@ def restructure_base_form(expr, visited=None): # Replace arguments return ufl.replace(right, replace_map) + # Action(Adjoint(A), w*) -> Action(w*, A) + if isinstance(left, ufl.Adjoint) and not isinstance(right, firedrake.Function) and is_rank_1(right): + # TODO: ufl.action(Coefficient, Form) currently fails. When it is fixed, we can remove the + # `not isinstance(right, firedrake.Function)` check. + return ufl.action(right, left.form()) + # -- Case (4) -- # if isinstance(expr, ufl.Adjoint) and isinstance(expr.form(), ufl.core.base_form_operator.BaseFormOperator): B = expr.form() diff --git a/tests/firedrake/regression/test_interp_dual.py b/tests/firedrake/regression/test_interp_dual.py index 000b1023b2..b0ce971a95 100644 --- a/tests/firedrake/regression/test_interp_dual.py +++ b/tests/firedrake/regression/test_interp_dual.py @@ -2,6 +2,7 @@ import numpy as np from firedrake import * from firedrake.utils import complex_mode +from firedrake.matrix import MatrixBase import ufl @@ -352,3 +353,45 @@ def test_interp_dual_mixed(source_space, target_space): assert result is tensor for x, y, in zip(result.subfunctions, expected.subfunctions): assert np.allclose(x.dat.data_ro, y.dat.data_ro) + + +def test_assemble_action_adjoint(V1, V2): + u = TrialFunction(V1) + + a = interpolate(u, V2) # V1 x V2^* -> R, equiv. V1 -> V2 + assert a.arguments() == (TestFunction(V2.dual()), TrialFunction(V1)) + + f_form = inner(1, TestFunction(V2)) * dx + + for f in (f_form, assemble(f_form)): + expr = action(adjoint(assemble(a)), f) + assert isinstance(expr, Action) + res = assemble(expr) + assert isinstance(res, Cofunction) + assert res.function_space() == V1.dual() + + expr2 = action(f, a) # This simplifies into an Interpolate + assert isinstance(expr2, Interpolate) + res2 = assemble(expr2) + assert isinstance(res2, Cofunction) + assert res2.function_space() == V1.dual() + assert np.allclose(res.dat.data, res2.dat.data) + + A = assemble(a) + assert isinstance(A, MatrixBase) + + # This doesn't explicitly assemble the adjoint of A, but uses multHermitian + expr3 = action(f, A) + assert isinstance(expr3, Action) + res3 = assemble(expr3) + assert isinstance(res3, Cofunction) + assert res3.function_space() == V1.dual() + assert np.allclose(res.dat.data, res3.dat.data) + + # This is simplified into action(f, A) to avoid explicit assembly of adjoint(A) + expr4 = action(adjoint(A), f) + assert isinstance(expr4, Action) + res4 = assemble(expr4) + assert isinstance(res4, Cofunction) + assert res4.function_space() == V1.dual() + assert np.allclose(res.dat.data, res4.dat.data)