diff --git a/firedrake/adjoint_utils/blocks/__init__.py b/firedrake/adjoint_utils/blocks/__init__.py index bf83b896cc..d293d4b5d9 100644 --- a/firedrake/adjoint_utils/blocks/__init__.py +++ b/firedrake/adjoint_utils/blocks/__init__.py @@ -1,5 +1,5 @@ from .assembly import AssembleBlock # NOQA F401 -from .solving import GenericSolveBlock, SolveLinearSystemBlock, \ +from .solving import CachedSolverBlock, GenericSolveBlock, SolveLinearSystemBlock, \ ProjectBlock, SupermeshProjectBlock, SolveVarFormBlock, \ NonlinearVariationalSolveBlock # NOQA F401 from .function import FunctionAssignBlock, FunctionMergeBlock, \ diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 2c1fb4f876..40913e774a 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -25,10 +25,293 @@ def extract_subfunction(u, V): return u -class Solver(Enum): +class SolverType(Enum): """Enum for solver types.""" FORWARD = 0 ADJOINT = 1 + TLM = 2 + HESSIAN = 3 + + +FORWARD = SolverType.FORWARD +ADJOINT = SolverType.ADJOINT +TLM = SolverType.TLM +HESSIAN = SolverType.HESSIAN + + +class CachedSolverBlock(Block): + def __init__(self, func, bcs, cached_solvers, + is_linear, replaced_dependencies, + tlm_rhs, replaced_tlms, tlm_dFdm_forms, + adj_rhs, adj_dFdm_forms, adj_residual, + adj_sol, adj2_sol, tlm_output, + d2Fdu2_form, d2Fdmdu_forms, + dFdm_adj2_forms, d2Fdm2_adj_forms, + d2Fdudm_forms, + ad_block_tag=None): + super().__init__(ad_block_tag=ad_block_tag) + + self.func = func + self.bcs = bcs + self.cached_solvers = cached_solvers + self.replaced_dependencies = replaced_dependencies + self.is_linear = is_linear + + self.tlm_rhs = tlm_rhs + self.replaced_tlms = replaced_tlms + self.tlm_dFdm_forms = tlm_dFdm_forms + + self.adj_rhs = adj_rhs + self.adj_dFdm_forms = adj_dFdm_forms + self.adj_residual = adj_residual + + self.adj_sol = adj_sol + self.adj_sol_buf = adj_sol.copy(deepcopy=True) + self.adj2_sol = adj2_sol + self.tlm_output = tlm_output + self.d2Fdu2_form = d2Fdu2_form + self.d2Fdmdu_forms = d2Fdmdu_forms + self.dFdm_adj2_forms = dFdm_adj2_forms + self.d2Fdm2_adj_forms = d2Fdm2_adj_forms + self.d2Fdudm_forms = d2Fdudm_forms + + def _coefficient_dependencies(self, dependencies=None): + dependencies = dependencies or self.get_dependencies() + return dependencies[:len(self.replaced_dependencies)] + + def _bc_dependencies(self, dependencies=None): + dependencies = dependencies or self.get_dependencies() + if len(self.bcs) > 0: + return dependencies[-len(self.bcs):] + else: + return [] + + def update_dependencies(self, use_output=False): + """Update all dependencies of the forward solve. + """ + # Update the coefficients in the form. + # Use the fact that zip will use the shorter length. + for replaced_dep, dep in zip(self.replaced_dependencies, + self._coefficient_dependencies()): + replaced_dep.assign(dep.saved_output) + + # 1. For forward recomputation the unknown Function should use + # the incoming value of the dependency as the initial guess. + # 2. For the adjoint, TLM, and Hessian, the unknown Function + # should use the computed value so that the linearised + # Jacobian is correct. + if use_output: + output = self.get_outputs()[0].saved_output + self.cached_solvers[FORWARD]._problem.u.assign(output) + + # Update the boundary conditions + for replaced_dep, dep in zip(self.bcs, self._bc_dependencies()): + replaced_dep.set_value(dep.saved_output.function_arg) + + def update_tlm_dependencies(self): + """Update all dependencies of the tlm solve. + """ + for replaced_dep, dep in zip(self.replaced_tlms, + self._coefficient_dependencies()): + if dep.output == self.func and not self.is_linear: + continue + if dep.tlm_value is None: # This dependency doesn't depend on the controls + continue + replaced_dep.assign(dep.tlm_value) + + for replaced_dep, dep in zip(self.bcs, self._bc_dependencies()): + if dep.tlm_value is None: # This dependency doesn't depend on the controls + bc_val = 0 + else: + bc_val = dep.tlm_value.function_arg + replaced_dep.set_value(bc_val) + + def update_adj_dependencies(self): + # TODO: Anything to do here? + pass + + def update_hessian_dependencies(self): + # TODO: Anything else to do here? + self.update_tlm_dependencies() + + def _compute_boundary(self, relevant_dependencies): + return any(isinstance(dep.output, firedrake.DirichletBC) + for _, dep in relevant_dependencies) + + def prepare_recompute_component(self, inputs, relevant_outputs): + return + + def recompute_component(self, inputs, block_variable, idx, prepared): + self.update_dependencies(use_output=False) + + solver = self.cached_solvers[FORWARD] + solver.solve() + result = solver._problem.u.copy(deepcopy=True) + + # Possibly checkpoint the result for the adjoint solve later. + if isinstance(block_variable.checkpoint, firedrake.Function): + result = block_variable.checkpoint.assign(result) + + return maybe_disk_checkpoint(result) + + def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): + return + + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): + self.update_dependencies(use_output=True) + self.update_tlm_dependencies() + + # Assemble the rhs of (dF/du)(du/dm) = -dF/dm + self.tlm_rhs.zero() + for dFdm, dep in zip(self.tlm_dFdm_forms, self.get_dependencies()): + if dep.tlm_value is None: # This dependency doesn't depend on the controls + continue + if dep.output is self.func and not self.is_linear: # Can't compute dependence on initial guess + continue + self.tlm_rhs += firedrake.assemble(dFdm) + + # Solve for dudm + solver = self.cached_solvers[TLM] + solver._problem.u.zero() + solver.solve() + result = solver._problem.u.copy(deepcopy=True) + return result + + def solve_adj_equation(self, rhs, compute_boundary): + for bc in self.bcs: + bc.homogenize() + + solver = self.cached_solvers[ADJOINT] + adj_sol = solver._problem.u + + self.adj_rhs.assign(rhs) + adj_sol.zero() + + solver.solve() + + if compute_boundary: + adj_sol_bc = firedrake.assemble(self.adj_residual) + adj_sol_bc = adj_sol_bc.riesz_representation("l2") + else: + adj_sol_bc = None + + return adj_sol, adj_sol_bc + + def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): + self.update_dependencies(use_output=True) + self.update_adj_dependencies() + + dJdu = adj_inputs[0] + + compute_boundary = self._compute_boundary(relevant_dependencies) + + adj_sol, adj_sol_bc = self.solve_adj_equation(dJdu, compute_boundary) + + # store adj_sol for Hessian computation later. + # self.adj_sol is shared between all blocks that this NLVS + # generates so we can't store it there. Instead store it + # in self.adj_sol_buf which is owned by this block only. + self.adj_sol_buf.assign(adj_sol) + + prepared = { + "adj_sol": adj_sol.copy(deepcopy=True), + "adj_sol_bc": adj_sol_bc + } + return prepared + + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): + if block_variable.output == self.func and not self.is_linear: + return None + + if isinstance(block_variable.output, firedrake.DirichletBC): + bc = block_variable.output + adj_sol_bc = prepared["adj_sol_bc"] + return bc.reconstruct( + g=extract_subfunction(adj_sol_bc, bc.function_space()) + ) + + # assemble sensititivy comment + dFdm = firedrake.assemble(self.adj_dFdm_forms[idx]) + + return dFdm + + def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): + self.adj_sol.assign(self.adj_sol_buf) + self.update_dependencies(use_output=True) + self.update_hessian_dependencies() + + hessian_input = hessian_inputs[0] + tlm_output = self.get_outputs()[0].tlm_value + + if hessian_input is None: + return + if tlm_output is None: + return + + # 1. Assemble rhs + + # hessian input contribution + hessian_rhs = hessian_input.copy(deepcopy=True) + + # tlm_output contribution + self.tlm_output.assign(tlm_output) + hessian_rhs -= firedrake.assemble(self.d2Fdu2_form) + + # tlm_input contribution + for d2Fdmdu, dep in zip(self.d2Fdmdu_forms, + self._coefficient_dependencies()): + if dep.tlm_value is None: # This dependency doesn't depend on the controls + continue + if dep.output is self.func and not self.is_linear: # Can't compute dependence on initial guess + continue + if len(d2Fdmdu.integrals()) > 0: + hessian_rhs -= firedrake.assemble(d2Fdmdu) + + # 2. Solve adjoint system + compute_boundary = self._compute_boundary(relevant_dependencies) + adj2_sol, adj2_sol_bc = self.solve_adj_equation(hessian_rhs, compute_boundary) + + self.adj2_sol.assign(adj2_sol) + + prepared = { + "adj2_sol": adj2_sol.copy(deepcopy=True), + "adj2_sol_bc": adj2_sol_bc, + } + + return prepared + + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): + m = block_variable.output + + if m is self.func and not self.is_linear: + return None + + if isinstance(m, firedrake.DirichletBC): + bc = block_variable.output + adj2_sol_bc = prepared["adj2_sol_bc"] + return bc.reconstruct( + g=extract_subfunction(adj2_sol_bc, bc.function_space()) + ) + + relevant_d2Fdm2_forms = [] + for i, dep in relevant_dependencies: + if i >= len(self._coefficient_dependencies()): + continue + if dep.tlm_value is None: + continue + if dep.output is self.func and not self.is_linear: + continue + relevant_d2Fdm2_forms.append(self.d2Fdm2_adj_forms[idx][i]) + + hessian_output = 0 + + for form in (self.d2Fdudm_forms[idx], + self.dFdm_adj2_forms[idx], + *relevant_d2Fdm2_forms): + if not form.empty(): + hessian_output += firedrake.assemble(-form) + + return hessian_output class GenericSolveBlock(Block): @@ -56,6 +339,8 @@ def __init__(self, lhs, rhs, func, bcs, *args, **kwargs): # Solution function self.func = func self.function_space = self.func.function_space() + # Storage for adjoint solution of this block + self.adj_state_buf = func.copy(deepcopy=True) # Boundary conditions self.bcs = [] if bcs is not None: @@ -188,6 +473,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): dFdu_form, dJdu, compute_bdy ) self.adj_state = adj_sol + self.adj_state_buf.assign(adj_sol) if self.adj_cb is not None: self.adj_cb(adj_sol) if self.adj_bdy_cb is not None and compute_bdy: @@ -199,31 +485,6 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): r["adj_sol_bdy"] = adj_sol_bdy return r - def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs): - return firedrake.assemble(dFdu_adj_form, **kwargs) - - def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): - dJdu_copy = dJdu.copy() - # Homogenize and apply boundary conditions on adj_dFdu. - bcs = self._homogenize_bcs() - dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs) - - adj_sol = firedrake.Function(self.function_space) - firedrake.solve( - dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs - ) - - adj_sol_bdy = None - if compute_bdy: - adj_sol_bdy = self._compute_adj_bdy( - adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy) - - return adj_sol, adj_sol_bdy - - def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu): - adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol)) - return adj_sol_bdy.riesz_representation("l2") - def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): if not self.linear and self.func == block_variable.output: @@ -272,7 +533,36 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs) return dFdm + def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs): + return firedrake.assemble(dFdu_adj_form, **kwargs) + + def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): + dJdu_copy = dJdu.copy() + # Homogenize and apply boundary conditions on adj_dFdu. + bcs = self._homogenize_bcs() + dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs) + + adj_sol = firedrake.Function(self.function_space) + firedrake.solve( + dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs + ) + + adj_sol_bdy = None + if compute_bdy: + adj_sol_bdy = self._compute_adj_bdy( + adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy) + + return adj_sol, adj_sol_bdy + + def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu): + adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol)) + return adj_sol_bdy.riesz_representation("l2") + def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): + pass + + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, + prepared=None): fwd_block_variable = self.get_outputs()[0] u = fwd_block_variable.output @@ -284,16 +574,6 @@ def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): fwd_block_variable.saved_output, firedrake.TrialFunction(u.function_space()) ) - - return { - "form": F_form, - "dFdu": dFdu - } - - def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, - prepared=None): - F_form = prepared["form"] - dFdu = prepared["dFdu"] V = self.get_outputs()[idx].output.function_space() bcs = [] @@ -330,10 +610,11 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, dFdm = ufl.algorithms.expand_derivatives(dFdm) dFdm = firedrake.assemble(dFdm) dudm = firedrake.Function(V) - return self._assemble_and_solve_tlm_eq( + result = self._assemble_and_solve_tlm_eq( firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs), dFdm, dudm, bcs ) + return result def _assemble_and_solve_tlm_eq(self, dFdu, dFdm, dudm, bcs): return self._assembled_solve(dFdu, dFdm, dudm, bcs) @@ -365,7 +646,10 @@ def _assemble_soa_eq_rhs(self, dFdu_form, adj_sol, hessian_input, d2Fdu2): elif not isinstance(c, firedrake.DirichletBC): dFdu_adj = firedrake.action(firedrake.adjoint(dFdu_form), adj_sol) - b_form += firedrake.derivative(dFdu_adj, c_rep, tlm_input) + # b_form += firedrake.derivative(dFdu_adj, c_rep, tlm_input) + bo_form = ufl.algorithms.expand_derivatives( + firedrake.derivative(dFdu_adj, c_rep, tlm_input)) + b_form += bo_form b_form = ufl.algorithms.expand_derivatives(b_form) if len(b_form.integrals()) > 0: @@ -393,6 +677,8 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, hessian_input = hessian_inputs[0] tlm_output = fwd_block_variable.tlm_value + self.adj_state = self.adj_state_buf.copy(deepcopy=True) + if hessian_input is None: return @@ -421,6 +707,7 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, r["adj_sol2_bdy"] = adj_sol2_bdy r["form"] = F_form r["adj_sol"] = adj_sol + return r def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, @@ -569,6 +856,12 @@ def solve_init_params(self, args, kwargs, varform): ) self.adj_kwargs.pop("appctx", None) + if hasattr(self, "tlm_args") and len(self.tlm_args) <= 0: + self.tlm_args = self.adj_args + + if hasattr(self, "tlm_kwargs") and len(self.tlm_kwargs) <= 0: + self.tlm_kwargs = self.adj_kwargs.copy() + solver_params = kwargs.get("solver_parameters", None) if solver_params is not None and "mat_type" in solver_params: self.assemble_kwargs["mat_type"] = solver_params["mat_type"] @@ -619,6 +912,8 @@ def __init__(self, equation, func, bcs, adj_cache, problem_J, self.problem_J = problem_J self.solver_kwargs = solver_kwargs + self.adj_state_buf = func.copy(deepcopy=True) + super().__init__(lhs, rhs, func, bcs, **{**solver_kwargs, **kwargs}) if self.problem_J is not None: @@ -656,12 +951,12 @@ def _adjoint_solve(self, dJdu, compute_bdy): and self._ad_solvers["update_adjoint"] ): # Update left hand side of the adjoint equation. - self._ad_solver_replace_forms(Solver.ADJOINT) + self._ad_solver_replace_forms(SolverType.ADJOINT) self._ad_solvers["adjoint_lvs"].invalidate_jacobian() self._ad_solvers["update_adjoint"] = False elif not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian: # Update left hand side of the adjoint equation. - self._ad_solver_replace_forms(Solver.ADJOINT) + self._ad_solver_replace_forms(SolverType.ADJOINT) # Update the right hand side of the adjoint equation. # problem.F._component[1] is the right hand side of the adjoint. @@ -679,7 +974,7 @@ def _adjoint_solve(self, dJdu, compute_bdy): return u_sol, adj_sol_bdy def _ad_assign_map(self, form, solver): - if solver == Solver.FORWARD: + if solver == SolverType.FORWARD: count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map else: count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map @@ -697,7 +992,7 @@ def _ad_assign_map(self, form, solver): block_variable.saved_output if ( - solver == Solver.ADJOINT + solver == SolverType.ADJOINT and not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian ): block_variable = self.get_outputs()[0] @@ -712,8 +1007,8 @@ def _ad_assign_coefficients(self, form, solver): for coeff, value in assign_map.items(): coeff.assign(value) - def _ad_solver_replace_forms(self, solver=Solver.FORWARD): - if solver == Solver.FORWARD: + def _ad_solver_replace_forms(self, solver=SolverType.FORWARD): + if solver == SolverType.FORWARD: problem = self._ad_solvers["forward_nlvs"]._problem self._ad_assign_coefficients(problem.F, solver) self._ad_assign_coefficients(problem.J, solver) @@ -727,6 +1022,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): ) adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy) self.adj_state = adj_sol + self.adj_state_buf.assign(adj_sol) if self.adj_cb is not None: self.adj_cb(adj_sol) if self.adj_bdy_cb is not None and compute_bdy: diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index d1b0af22ca..6bbe65d8f3 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -1,9 +1,13 @@ import copy from functools import wraps from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations -from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock -from firedrake.ufl_expr import derivative, adjoint -from ufl import replace +from firedrake.adjoint_utils.blocks import ( + NonlinearVariationalSolveBlock, CachedSolverBlock) +from firedrake.adjoint_utils.blocks.solving import solve_init_params +from firedrake.ufl_expr import derivative, adjoint, action +from ufl import replace, Action +from ufl.algorithms import expand_derivatives +from types import SimpleNamespace class NonlinearVariationalProblemMixin: @@ -52,10 +56,367 @@ def wrapper(self, problem, *args, **kwargs): "recompute_count": 0} self._ad_adj_cache = {} + self._ad_solver_cache = {} + + # process args/kwargs for cached solvers + self._ad_args_kwargs = SimpleNamespace( + forward_args=kwargs.pop("forward_args", []), + forward_kwargs=kwargs.pop("forward_kwargs", {}), + adj_args=kwargs.pop("adj_args", []), + adj_kwargs=kwargs.pop("adj_kwargs", {}), + tlm_args=kwargs.pop("tlm_args", []), + tlm_kwargs=kwargs.pop("tlm_kwargs", {}), + assemble_kwargs={} + ) + solve_init_params(self._ad_args_kwargs, + args, kwargs, varform=True) + return wrapper + def _ad_cache_forward_solver(self): + from firedrake import ( + DirichletBC, + NonlinearVariationalProblem, + NonlinearVariationalSolver) + from firedrake.adjoint_utils.blocks.solving import FORWARD + + problem = self._ad_problem + + # Build a new form so that we can update the coefficient + # values without affecting user code. + # We do this by copying all coefficients in the form and + # symbolically replacing the old values with the new. + F = problem.F + replace_map = {} + for old_coeff in F.coefficients(): + new_coeff = old_coeff.copy(deepcopy=True) + replace_map[old_coeff] = new_coeff + + # We need a handle to the new Function being + # solved for so that we can create an NLVS. + Fnew = replace(F, replace_map) + unew = replace_map[problem.u] + + # We also need to "replace" all the bcs in + # the new NLVS so we can modify those values + # without affecting user code. + # Note that ``DirichletBC.reconstruct`` will + # return ``self`` if V, g, and sub_domain are + # all unchanged, so we need to explicitly + # instantiate a new object. + bcs = problem.bcs + bcs_new = [ + DirichletBC(V=bc.function_space(), + g=bc.function_arg, + sub_domain=bc.sub_domain) + for bc in bcs + ] + + # This NLVS will be used to recompute the solve. + # TODO: solver_parameters + nlvp = NonlinearVariationalProblem(Fnew, unew, bcs=bcs_new) + nlvs = NonlinearVariationalSolver( + nlvp, + *self._ad_args_kwargs.forward_args, + **self._ad_args_kwargs.forward_kwargs) + + # The original coefficients will be added as + # dependencies to all solve blocks. + # The block need handles to the newly created + # objects to update their values when recomputing. + self._ad_dependencies_to_add = (*replace_map.keys(), *bcs) + self._ad_replaced_dependencies = tuple(replace_map.values()) + self._ad_bcs = bcs_new + self._ad_solver_cache[FORWARD] = nlvs + + def _ad_cache_tlm_solver(self): + from firedrake import ( + Function, Cofunction, derivative, TrialFunction, + LinearVariationalProblem, LinearVariationalSolver) + from firedrake.adjoint_utils.blocks.solving import FORWARD, TLM + + # If we build the TLM form from the cached + # forward solve form then we can update exactly + # the same coefficients/boundary conditions. + nlvp = self._ad_solver_cache[FORWARD]._problem + + F = nlvp.F + u = nlvp.u + V = u.function_space() + + # We need gradient of output/input i.e. du/dm. + # We know F(u; m) = 0 and _total_ dF/dm = 0. + # Then for the _partial_ derivatives: + # (dF/du)*(du/dm) + dF/dm = 0 so we calculate: + # (dF/du)*(du/dm) = -dF/dm + dFdu = derivative(F, u, TrialFunction(V)) + dFdm = Cofunction(V.dual()) + dudm = Function(V) + + self._ad_dFdu = dFdu + + # Reuse the same bcs as the forward problem. + # TODO: Think about if we should use new bcs. + # TODO: solver_parameters + lvp = LinearVariationalProblem(dFdu, dFdm, dudm, bcs=self._ad_bcs) + lvs = LinearVariationalSolver( + lvp, + *self._ad_args_kwargs.tlm_args, + **self._ad_args_kwargs.tlm_kwargs) + + self._ad_solver_cache[TLM] = lvs + self._ad_tlm_rhs = dFdm + + # Do all the symbolic work for calculating dF/dm up front + # so we only pay for the numeric calculations at run time. + replaced_tlms = [] + dFdm_tlm_forms = [] + for m in self._ad_replaced_dependencies: + mtlm = m.copy(deepcopy=True) + replaced_tlms.append(mtlm) + + dFdm = derivative(-F, m, mtlm) + # TODO: Do we need expand_derivatives here? If so, why? + dFdm = expand_derivatives(dFdm) + dFdm_tlm_forms.append(dFdm) + + # We'll need to update the replaced_tlm + # values and assemble the dFdm forms + self._ad_tlm_dFdm_forms = dFdm_tlm_forms + self._ad_replaced_tlms = replaced_tlms + + def _ad_cache_adj_solver(self): + from firedrake import ( + Function, Cofunction, TrialFunction, Argument, + LinearVariationalProblem, LinearVariationalSolver) + from firedrake.adjoint_utils.blocks.solving import FORWARD, ADJOINT + + # If we build the adjoint form from the cached + # forward solve form then we can update exactly + # the same coefficients/boundary conditions. + nlvp = self._ad_solver_cache[FORWARD]._problem + + F = nlvp.F + u = nlvp.u + V = u.function_space() + + # TODO: rewrite for adjoint not TLM + # We need gradient of output/input i.e. du/dm. + # We know F(u; m) = 0 and _total_ dF/dm = 0. + # Then for the _partial_ derivatives: + # (dF/du)*(du/dm) + dF/dm = 0 so we calculate: + # (dF/du)*(du/dm) = -dF/dm + dFdu = self._ad_dFdu + try: + dFdu_adj = adjoint(dFdu) + except ValueError: + # Try again without expanding derivatives, + # as dFdu might have been simplied to an empty Form + dFdu_adj = adjoint(dFdu, derivatives_expanded=True) + + self._ad_dFdu_adj = dFdu_adj + + # This will be the rhs of the adjoint problem + dJdu = Cofunction(V.dual()) + adj_sol = Function(V) + + # Reuse the same bcs as the forward problem. + # TODO: Think about if we should use new bcs. + # TODO: solver_parameters + lvp = LinearVariationalProblem(dFdu_adj, dJdu, adj_sol, bcs=self._ad_bcs) + lvs = LinearVariationalSolver( + lvp, + *self._ad_args_kwargs.adj_args, + **self._ad_args_kwargs.adj_kwargs) + + self._ad_solver_cache[ADJOINT] = lvs + self._ad_adj_rhs = dJdu + + # Do all the symbolic work for calculating dJ/du up front + # so we only pay for the numeric calculations at run time. + dFdm_adj_forms = [] + for m in self._ad_replaced_dependencies: + # Action of adjoint solution on dFdm + # TODO: Which of the two implementations should we use? + dFdm = derivative(-F, m, TrialFunction(m.function_space())) + + # 1. from previous cached implementation + dFdm = adjoint(dFdm) + if isinstance(dFdm, Argument): + # Corner case. Should be fixed more permanently upstream in UFL. + # See: https://github.com/FEniCS/ufl/issues/395 + dFdm = Action(dFdm, adj_sol) + else: + dFdm = dFdm * adj_sol + + # 2. from GenericSolveBlock + # if isinstance(dFdm, ufl.Form): + # dFdm = adjoint(dFdm) + # dFdm = action(dFdm, adj_sol) + # else: + # dFdm = dFdm(adj_sol) + + dFdm_adj_forms.append(dFdm) + + # To calculate the adjoint component of each DirichletBC + # we'll need the residual of the adjoint equation without + # any DirichletBC using the solution calculated with + # homogeneous DirichletBCs. + self._ad_adj_residual = dJdu - action(dFdu_adj, adj_sol) + + # We'll need to assemble these forms to calculate + # the adj_component for each dependency. + self._ad_adj_dFdm_forms = dFdm_adj_forms + + def _ad_cache_hessian_solver(self): + from firedrake import ( + Function, TestFunction) + from firedrake.adjoint_utils.blocks.solving import FORWARD + + nlvp = self._ad_solver_cache[FORWARD]._problem + F = nlvp.F + u = nlvp.u + V = u.function_space() + + # 1. Forms to calculate rhs of Hessian solve + + # Calculate d^2F/du^2 * du/dm * dm + # where dm is direction for tlm action so du/dm * dm is tlm output + dFdu = self._ad_dFdu + tlm_output = Function(V) + d2Fdu2 = derivative(dFdu, u, tlm_output) + # print() + # print(f"{dFdu = }") + # print() + # print(f"{d2Fdu2 = }") + # print() + d2Fdu2 = expand_derivatives(d2Fdu2) + + self._ad_tlm_output = tlm_output + + adj_sol = Function(V) + self._ad_adj_sol = adj_sol + + # Contribution from tlm_output + if len(d2Fdu2.integrals()) > 0: + d2Fdu2_form = action(adjoint(d2Fdu2), adj_sol) + else: + d2Fdu2_form = d2Fdu2 + self._ad_d2Fdu2_form = d2Fdu2_form + + # Contributions from each tlm_input + dFdu_adj = action(self._ad_dFdu_adj, adj_sol) + d2Fdmdu_forms = [] + for m, dm in zip(self._ad_replaced_dependencies, + self._ad_replaced_tlms): + d2Fdmdu = expand_derivatives( + derivative(dFdu_adj, m, dm)) + + d2Fdmdu_forms.append(d2Fdmdu) + + self._ad_d2Fdmdu_forms = d2Fdmdu_forms + + # 2. Forms to calculate contribution from each control + adj2_sol = Function(V) + self._ad_adj2_sol = adj2_sol + + Fadj = action(F, adj_sol) + Fadj2 = action(F, adj2_sol) + + dFdm_adj2_forms = [] + d2Fdm2_adj_forms = [] + d2Fdudm_forms = [] + for m in self._ad_replaced_dependencies: + dm = TestFunction(m.function_space()) + dFdm_adj2 = expand_derivatives( + derivative(Fadj2, m, dm)) + + dFdm_adj2_forms.append(dFdm_adj2) + + dFdm_adj = derivative(Fadj, m, dm) + + d2Fdudm = expand_derivatives( + derivative(dFdm_adj, u, tlm_output)) + + d2Fdudm_forms.append(d2Fdudm) + + d2Fdm2_adj_forms_k = [] + for m2, dm2 in zip(self._ad_replaced_dependencies, + self._ad_replaced_tlms): + d2Fdm2_adj = expand_derivatives( + derivative(dFdm_adj, m2, dm2)) + d2Fdm2_adj_forms_k.append(d2Fdm2_adj) + + d2Fdm2_adj_forms.append(d2Fdm2_adj_forms_k) + + self._ad_dFdm_adj2_forms = dFdm_adj2_forms + self._ad_d2Fdm2_adj_forms = d2Fdm2_adj_forms + self._ad_d2Fdudm_forms = d2Fdudm_forms + @staticmethod def _ad_annotate_solve(solve): + @wraps(solve) + def wrapper(self, **kwargs): + """To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the + Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic + for the purposes of the adjoint computation (such as projecting fields to other function spaces + for the purposes of visualisation).""" + annotate = annotate_tape(kwargs) + if annotate: + if kwargs.pop("bounds", None) is not None: + raise ValueError( + "MissingMathsError: we do not know how to differentiate through a variational inequality") + + if len(self._ad_solver_cache) == 0: + self._ad_cache_forward_solver() + self._ad_cache_tlm_solver() + self._ad_cache_adj_solver() + self._ad_cache_hessian_solver() + + block = CachedSolverBlock(self._ad_problem.u, + self._ad_bcs, + self._ad_solver_cache, + self._ad_problem.is_linear, + self._ad_replaced_dependencies, + + self._ad_tlm_rhs, + self._ad_replaced_tlms, + self._ad_tlm_dFdm_forms, + + self._ad_adj_rhs, + self._ad_adj_dFdm_forms, + self._ad_adj_residual, + + self._ad_adj_sol, + self._ad_adj2_sol, + self._ad_tlm_output, + self._ad_d2Fdu2_form, + self._ad_d2Fdmdu_forms, + self._ad_dFdm_adj2_forms, + self._ad_d2Fdm2_adj_forms, + self._ad_d2Fdudm_forms, + + ad_block_tag=self.ad_block_tag) + + for dep in self._ad_dependencies_to_add: + block.add_dependency(dep, no_duplicates=True) + # mesh = self._ad_problem.u.function_space().mesh() + # block.add_dependency(mesh, no_duplicates=True) + + get_working_tape().add_block(block) + + with stop_annotating(): + out = solve(self, **kwargs) + + if annotate: + block.add_output(self._ad_problem._ad_u.create_block_variable()) + + return out + + return wrapper + + @staticmethod + def _ad_annotate_solve_old(solve): @wraps(solve) def wrapper(self, **kwargs): """To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the diff --git a/tests/firedrake/adjoint/test_assemble.py b/tests/firedrake/adjoint/test_assemble.py index b10d68cfcc..9044fe7ab8 100644 --- a/tests/firedrake/adjoint/test_assemble.py +++ b/tests/firedrake/adjoint/test_assemble.py @@ -89,7 +89,7 @@ def test_assemble_1_forms_tlm(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) v = TestFunction(V) - f = Function(V).assign(1) + f = Function(V).assign(1.) w1 = assemble(inner(f, v) * dx) w2 = assemble(inner(f**2, v) * dx) @@ -101,9 +101,11 @@ def test_assemble_1_forms_tlm(rg): h = rg.uniform(V) g = f.copy(deepcopy=True) - f.block_variable.tlm_value = h - tape.evaluate_tlm() - assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9) + Jhat(g) + assert (taylor_test(Jhat, g, h, dJdm=Jhat.tlm(h)) > 1.9) + # f.block_variable.tlm_value = h + # tape.evaluate_tlm() + # assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9) @pytest.mark.skipcomplex diff --git a/tests/firedrake/adjoint/test_hessian.py b/tests/firedrake/adjoint/test_hessian.py index 294680a554..79142bf046 100644 --- a/tests/firedrake/adjoint/test_hessian.py +++ b/tests/firedrake/adjoint/test_hessian.py @@ -37,7 +37,7 @@ def test_simple_solve(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(2) + f = Function(V).assign(2.) u = TrialFunction(V) v = TestFunction(V) @@ -76,10 +76,10 @@ def test_mixed_derivatives(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(2) + f = Function(V).assign(2.) control_f = Control(f) - g = Function(V).assign(3) + g = Function(V).assign(3.) control_g = Control(g) u = TrialFunction(V) @@ -126,7 +126,7 @@ def test_function(rg): R = FunctionSpace(mesh, "R", 0) c = Function(R, val=4) control_c = Control(c) - f = Function(V).assign(3) + f = Function(V).assign(3.) control_f = Control(f) u = Function(V) @@ -139,14 +139,13 @@ def test_function(rg): J = assemble(c ** 2 * u ** 2 * dx) Jhat = ReducedFunctional(J, [control_c, control_f]) - dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True) # Step direction for derivatives and convergence test h_c = Function(R, val=1.0) h_f = rg.uniform(V, 0, 10) # Total derivative - dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True) + dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True) dJdm = assemble(dJdc * h_c * dx + dJdf * h_f * dx) # Hessian @@ -163,7 +162,7 @@ def test_nonlinear(rg): mesh = UnitSquareMesh(10, 10) V = FunctionSpace(mesh, "Lagrange", 1) R = FunctionSpace(mesh, "R", 0) - f = Function(V).assign(5) + f = Function(V).assign(5.) u = Function(V) v = TestFunction(V) @@ -176,6 +175,7 @@ def test_nonlinear(rg): Jhat = ReducedFunctional(J, Control(f)) h = rg.uniform(V, 0, 10) + g = f.copy(deepcopy=True) J.block_variable.adj_value = 1.0 f.block_variable.tlm_value = h @@ -186,8 +186,6 @@ def test_nonlinear(rg): J.block_variable.hessian_value = 0 tape.evaluate_hessian() - g = f.copy(deepcopy=True) - dJdm = J.block_variable.tlm_value Hm = f.block_variable.hessian_value.dat.inner(h.dat) assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.8 @@ -201,11 +199,11 @@ def test_dirichlet(rg): mesh = UnitSquareMesh(10, 10) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(30) + f = Function(V).assign(30.) u = Function(V) v = TestFunction(V) - c = Function(V).assign(1) + c = Function(V).assign(1.) bc = DirichletBC(V, c, "on_boundary") F = inner(grad(u), grad(v)) * dx + u**4*v*dx - f**2 * v * dx @@ -238,24 +236,24 @@ def test_dirichlet(rg): def test_burgers(solve_type, rg): tape = Tape() set_working_tape(tape) - n = 100 - mesh = UnitIntervalMesh(n) - V = FunctionSpace(mesh, "CG", 2) + nx = 50 + nt = 5 + mesh = UnitIntervalMesh(nx) + V = FunctionSpace(mesh, "CG", 1) - def Dt(u, u_, timestep): - return (u - u_)/timestep + def Dt(u, u_, dt): + return (u - u_)/dt x, = SpatialCoordinate(mesh) - pr = project(sin(2*pi*x), V, annotate=False) - ic = Function(V).assign(pr) + ic = Function(V).project(sin(2*pi*x)) - u_ = Function(V) - u = Function(V) + u_ = Function(V).assign(ic) + u = Function(V).assign(ic) v = TestFunction(V) - nu = Constant(0.0001) + nu = Constant(1/100) - timestep = Constant(1.0/n) + dt = Constant(1/nx) params = { 'snes_rtol': 1e-10, @@ -263,10 +261,9 @@ def Dt(u, u_, timestep): 'pc_type': 'lu', } - F = (Dt(u, ic, timestep)*v + F = (Dt(u, u_, dt)*v + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx bc = DirichletBC(V, 0.0, "on_boundary") - t = 0.0 if solve_type == "nlvs": use_nlvs = True @@ -280,39 +277,23 @@ def Dt(u, u_, timestep): NonlinearVariationalProblem(F, u), solver_parameters=params) - if use_nlvs: - solver.solve() - else: - solve(F == 0, u, bc, solver_parameters=params) - u_.assign(u) - t += float(timestep) - - F = (Dt(u, u_, timestep)*v - + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx - - end = 0.2 - while (t <= end): + for _ in range(nt): if use_nlvs: solver.solve() else: solve(F == 0, u, bc, solver_parameters=params) u_.assign(u) - t += float(timestep) - J = assemble(u_*u_*dx + ic*ic*dx) Jhat = ReducedFunctional(J, Control(ic)) + h = rg.uniform(V) g = ic.copy(deepcopy=True) - J.block_variable.adj_value = 1.0 - ic.block_variable.tlm_value = h - tape.evaluate_adj() - tape.evaluate_tlm() - J.block_variable.hessian_value = 0 - tape.evaluate_hessian() - - dJdm = J.block_variable.tlm_value - Hm = ic.block_variable.hessian_value.dat.inner(h.dat) - assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.9 + taylor = taylor_to_dict(Jhat, g, h) + from pprint import pprint + pprint(taylor) + assert min(taylor['R0']['Rate']) > 0.95, taylor['R0'] + assert min(taylor['R1']['Rate']) > 1.95, taylor['R1'] + assert min(taylor['R2']['Rate']) > 2.95, taylor['R2'] diff --git a/tests/firedrake/adjoint/test_nlvs.py b/tests/firedrake/adjoint/test_nlvs.py new file mode 100644 index 0000000000..fdb96e3da1 --- /dev/null +++ b/tests/firedrake/adjoint/test_nlvs.py @@ -0,0 +1,191 @@ +import pytest + +from firedrake import * +from firedrake.adjoint import * + + +@pytest.fixture(autouse=True) +def handle_taping(): + yield + tape = get_working_tape() + tape.clear_tape() + + +@pytest.fixture(autouse=True, scope="module") +def handle_annotation(): + if not annotate_tape(): + continue_annotation() + yield + # Ensure annotation is paused when we finish. + if annotate_tape(): + pause_annotation() + + +def forward(ic, dt, nt, bc_arg=None): + """Burgers equation solver.""" + V = ic.function_space() + + if bc_arg: + bc_val = bc_arg.copy(deepcopy=True) + bcs = [DirichletBC(V, bc_val, 1), + DirichletBC(V, 0, 2)] + else: + bcs = None + + nu = Function(dt.function_space()).assign(0.1) + + u0 = Function(V) + u1 = Function(V) + v = TestFunction(V) + + F = ((u1 - u0)*v + + dt*u1*u1.dx(0)*v + + dt*nu*u1.dx(0)*v.dx(0))*dx + + problem = NonlinearVariationalProblem(F, u1, bcs=bcs) + solver = NonlinearVariationalSolver(problem) + + u1.assign(ic) + + for i in range(nt): + u0.assign(u1) + solver.solve() + nu += dt + if bc_arg: + bc_val.assign(bc_val + dt/nt) + + J = assemble(u1*u1*dx) + return J + + +@pytest.mark.skipcomplex +@pytest.mark.parametrize("control_type", ["ic_control", + "dt_control", + "bc_control"]) +@pytest.mark.parametrize("bc_type", ["neumann_bc", + "dirichlet_bc"]) +def test_nlvs_adjoint(control_type, bc_type): + if control_type == 'bc_control' and bc_type == 'neumann_bc': + pytest.skip("Cannot use Neumann BCs as control") + + nx = 100000 + nt = 50 + + mesh = UnitIntervalMesh(nx) + x, = SpatialCoordinate(mesh) + + V = FunctionSpace(mesh, "CG", 1) + R = FunctionSpace(mesh, "R", 0) + + dt = Function(R).assign(1/nx) + ic = Function(V).interpolate(cos(2*pi*x)) + + dt0 = dt.copy(deepcopy=True) + ic0 = ic.copy(deepcopy=True) + + if bc_type == 'neumann_bc': + bc_arg = None + bc_arg0 = None + elif bc_type == 'dirichlet_bc': + bc_arg = Function(R).assign(1.) + bc_arg0 = bc_arg.copy(deepcopy=True) + else: + raise ValueError(f"Unrecognised {bc_type=}") + + if control_type == 'ic_control': + control = ic0 + elif control_type == 'dt_control': + control = dt0 + elif control_type == 'bc_control': + control = bc_arg0 + else: + raise ValueError(f"Unrecognised {control_type=}") + + PETSc.Sys.Print("record tape") + continue_annotation() + with set_working_tape() as tape: + J = forward(ic0, dt0, nt, bc_arg=bc_arg0) + Jhat = ReducedFunctional(J, Control(control), tape=tape) + pause_annotation() + + if control_type == 'ic_control': + m = Function(V).assign(0.5*ic) + h = Function(V).interpolate(-0.5*cos(4*pi*x)) + + ic2 = m.copy(deepcopy=True) + dt2 = dt + bc_arg2 = bc_arg + + elif control_type == 'dt_control': + m = Function(R).assign(0.05) + h = Function(R).assign(0.01) + + ic2 = ic + dt2 = m.copy(deepcopy=True) + bc_arg2 = bc_arg + + elif control_type == 'bc_control': + m = Function(R).assign(0.5) + h = Function(R).assign(-0.1) + + ic2 = ic + dt2 = dt + bc_arg2 = m.copy(deepcopy=True) + + from mpi4py import MPI + + # # recompute component + # PETSc.Sys.Print("recompute test") + # assert abs(Jhat(m) - forward(ic2, dt2, nt, bc_arg=bc_arg2)) < 1e-14 + + # # tlm + # PETSc.Sys.Print("tlm test") + # Jhat(m) + # assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 + + # # adjoint + # PETSc.Sys.Print("adjoint test") + # Jhat(m) + # assert taylor_test(Jhat, m, h) > 1.95 + + # # hessian + # PETSc.Sys.Print("hessian test") + # Jhat(m) + # taylor = taylor_to_dict(Jhat, m, h) + # from pprint import pprint + # pprint(taylor) + + # assert min(taylor['R0']['Rate']) > 0.95 + # assert min(taylor['R1']['Rate']) > 1.95 + # assert min(taylor['R2']['Rate']) > 2.95 + + for _ in range(3): + stime = MPI.Wtime() + Jhat(m) + etime = MPI.Wtime() + PETSc.Sys.Print(f"Recompute time: {etime - stime:.4f}") + + for _ in range(3): + stime = MPI.Wtime() + Jhat.derivative() + etime = MPI.Wtime() + PETSc.Sys.Print(f"Derivative time: {etime - stime:.4f}") + + for _ in range(3): + stime = MPI.Wtime() + Jhat.tlm(h) + etime = MPI.Wtime() + PETSc.Sys.Print(f"TLM time: {etime - stime:.4f}") + + for _ in range(3): + stime = MPI.Wtime() + Jhat.hessian(h, evaluate_tlm=False) + etime = MPI.Wtime() + PETSc.Sys.Print(f"Hessian time: {etime - stime:.4f}") + + +if __name__ == "__main__": + control_type = "ic_control" + bc_type = "neumann_bc" + PETSc.Sys.Print(f"{control_type=} | {bc_type=}") + test_nlvs_adjoint(control_type, bc_type) diff --git a/tests/firedrake/adjoint/test_solving.py b/tests/firedrake/adjoint/test_solving.py index 5ce9b120ce..7997fe39f5 100644 --- a/tests/firedrake/adjoint/test_solving.py +++ b/tests/firedrake/adjoint/test_solving.py @@ -33,7 +33,7 @@ def test_linear_problem(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) R = FunctionSpace(mesh, "R", 0) - f = Function(V).assign(1) + f = Function(V).assign(1.) u = TrialFunction(V) u_ = Function(V) @@ -60,7 +60,7 @@ def test_singular_linear_problem(rg): mesh = UnitSquareMesh(10, 10) V = FunctionSpace(mesh, "CG", 1) - f = Function(V).assign(1) + f = Function(V).assign(1.) u = TrialFunction(V) u_ = Function(V) @@ -85,7 +85,7 @@ def test_nonlinear_problem(pre_apply_bcs, rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) R = FunctionSpace(mesh, "R", 0) - f = Function(V).assign(1) + f = Function(V).assign(1.) u = Function(V) v = TestFunction(V) @@ -116,7 +116,7 @@ def test_mixed_boundary(rg): g1 = Constant(2) g2 = Constant(1) - f = Function(V).assign(10) + f = Function(V).assign(10.) def J(f): a = f*inner(grad(u), grad(v))*dx @@ -165,7 +165,7 @@ def xtest_wrt_function_dirichlet_boundary(): g1 = Constant(2) g2 = Constant(1) - f = Function(V).assign(10) + f = Function(V).assign(10.) def J(bc): a = inner(grad(u), grad(v))*dx @@ -195,7 +195,7 @@ def test_wrt_function_neumann_boundary(): g1 = Function(R, val=2) g2 = Function(R, val=1) - f = Function(V).assign(10) + f = Function(V).assign(10.) def J(g1): a = inner(grad(u), grad(v))*dx @@ -247,7 +247,7 @@ def test_wrt_constant_neumann_boundary(): g1 = Function(R, val=2) g2 = Function(R, val=1) - f = Function(V).assign(10) + f = Function(V).assign(10.) def J(g1): a = inner(grad(u), grad(v))*dx @@ -283,7 +283,7 @@ def test_time_dependent(): f = Function(R, val=1) def J(f): - u_1 = Function(V).assign(1) + u_1 = Function(V).assign(1.) a = u_1*u*v*dx + dt*f*inner(grad(u), grad(v))*dx L = u_1*v*dx @@ -340,7 +340,7 @@ def _test_adjoint_function_boundary(J, bc, f): set_working_tape(tape) V = f.function_space() - h = Function(V).assign(1) + h = Function(V).assign(1.) g = Function(V) eps_ = [0.4/2.0**i for i in range(4)] residuals = [] diff --git a/tests/firedrake/adjoint/test_tlm.py b/tests/firedrake/adjoint/test_tlm.py index 723f84e481..b9123ad9bd 100644 --- a/tests/firedrake/adjoint/test_tlm.py +++ b/tests/firedrake/adjoint/test_tlm.py @@ -36,7 +36,7 @@ def test_tlm_assemble(rg): set_working_tape(tape) mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(5) + f = Function(V).assign(5.) u = TrialFunction(V) v = TestFunction(V) @@ -66,7 +66,7 @@ def test_tlm_bc(): V = FunctionSpace(mesh, "Lagrange", 1) R = FunctionSpace(mesh, "R", 0) c = Function(R, val=1) - f = Function(V).assign(1) + f = Function(V).assign(1.) u = Function(V) v = TestFunction(V) @@ -88,8 +88,8 @@ def test_tlm_func(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - c = Function(V).assign(1) - f = Function(V).assign(1) + c = Function(V).assign(1.) + f = Function(V).assign(1.) u = Function(V) v = TestFunction(V) @@ -130,9 +130,9 @@ def test_time_dependent(solve_type, rg): # Some variables T = 0.5 dt = 0.1 - f = Function(V).assign(1) + f = Function(V).assign(1.) - u_1 = Function(V).assign(1) + u_1 = Function(V).assign(1.) control = Control(u_1) a = u_1 * u * v * dx + dt * f * inner(grad(u), grad(v)) * dx