diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 1f8da30b5a..9d3239427a 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -341,3 +341,15 @@ class CofunctionMixin(FunctionMixin): def _ad_dot(self, other): return firedrake.assemble(firedrake.action(self, other)) + + @classmethod + def _ad_init_object(cls, obj): + from firedrake import Cofunction + return Cofunction(obj.function_space()).assign(obj) + + def _ad_init_zero(self, dual=False): + from firedrake import Function, Cofunction + if dual: + return Function(self.function_space().dual()) + else: + return Cofunction(self.function_space()) diff --git a/tests/firedrake/adjoint/test_optimisation.py b/tests/firedrake/adjoint/test_optimisation.py index 2189a7882b..3073f68eb2 100644 --- a/tests/firedrake/adjoint/test_optimisation.py +++ b/tests/firedrake/adjoint/test_optimisation.py @@ -179,6 +179,34 @@ def test_tao_simple_inversion(minimize, riesz_representation): assert_allclose(x.dat.data, source_ref.dat.data, rtol=1e-2) +@pytest.mark.parametrize("minimize", [minimize_tao_lmvm, + pytest.param(minimize_tao_nls, marks=pytest.mark.xfail)]) +@pytest.mark.parametrize("riesz_representation", ["L2", "H1"]) +@pytest.mark.skipcomplex +def test_tao_cofunction_control(minimize, riesz_representation): + """Test inversion of source term in helmholtz eqn using TAO.""" + mesh = UnitIntervalMesh(10) + V = FunctionSpace(mesh, "CG", 1) + source_ref = Function(V) + x = SpatialCoordinate(mesh) + source_ref.interpolate(cos(pi*x**2)) + + # compute reference solution + with stop_annotating(): + u_ref = _simple_helmholz_model(V, source_ref) + + # now rerun annotated model with zero source + source = Cofunction(V.dual()) + c = Control(source, riesz_map=riesz_representation) + u = _simple_helmholz_model(V, source.riesz_representation(riesz_representation)) + + J = assemble(1e6 * (u - u_ref)**2*dx) + rf = ReducedFunctional(J, c) + + x = minimize(rf).riesz_representation(riesz_representation) + assert_allclose(x.dat.data, source_ref.dat.data, rtol=1e-2) + + class TransformType(Enum): PRIMAL = auto() DUAL = auto()