Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demos/boussinesq/boussinesq.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ implements a boundary condition that fixes a field at a single point. ::

# Take the basis function with the largest abs value at bc_point
v = TestFunction(V)
F = assemble(Interpolate(inner(v, v), Fvom))
F = assemble(interpolate(inner(v, v), Fvom))
with F.dat.vec as Fvec:
max_index, _ = Fvec.max()
nodes = V.dof_dset.lgmap.applyInverse([max_index])
Expand Down
2 changes: 1 addition & 1 deletion demos/multicomponent/multicomponent.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ mathematically valid to do this)::

# Take the basis function with the largest abs value at bc_point
v = TestFunction(V)
F = assemble(Interpolate(inner(v, v), Fvom))
F = assemble(interpolate(inner(v, v), Fvom))
with F.dat.vec as Fvec:
max_index, _ = Fvec.max()
nodes = V.dof_dset.lgmap.applyInverse([max_index])
Expand Down
6 changes: 3 additions & 3 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def interpolate(self,
Parameters
----------
expression
A dual UFL expression to interpolate.
A UFL BaseForm to adjoint interpolate.
ad_block_tag
An optional string for tagging the resulting assemble
block on the Pyadjoint tape.
Expand All @@ -353,9 +353,9 @@ def interpolate(self,
firedrake.cofunction.Cofunction
Returns `self`
"""
from firedrake import interpolation, assemble
from firedrake import interpolate, assemble
v, = self.arguments()
interp = interpolation.Interpolate(v, expression, **kwargs)
interp = interpolate(v, expression, **kwargs)
return assemble(interp, tensor=self, ad_block_tag=ad_block_tag)

@property
Expand Down
10 changes: 5 additions & 5 deletions firedrake/external_operators/point_expr_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import firedrake.ufl_expr as ufl_expr
from firedrake.assemble import assemble
from firedrake.interpolation import Interpolate
from firedrake.interpolation import interpolate
from firedrake.external_operators import AbstractExternalOperator, assemble_method


Expand Down Expand Up @@ -58,7 +58,7 @@ def assemble_operator(self, *args, **kwargs):
V = self.function_space()
expr = as_ufl(self.expr(*self.ufl_operands))
if len(V) < 2:
interp = Interpolate(expr, self.function_space())
interp = interpolate(expr, self.function_space())
return assemble(interp)
# Interpolation of UFL expressions for mixed functions is not yet supported
# -> `Function.assign` might be enough in some cases.
Expand All @@ -72,7 +72,7 @@ def assemble_operator(self, *args, **kwargs):
def assemble_Jacobian_action(self, *args, **kwargs):
V = self.function_space()
expr = as_ufl(self.expr(*self.ufl_operands))
interp = Interpolate(expr, V)
interp = interpolate(expr, V)

u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1]
w = self.argument_slots()[-1]
Expand All @@ -83,7 +83,7 @@ def assemble_Jacobian_action(self, *args, **kwargs):
def assemble_Jacobian(self, *args, assembly_opts, **kwargs):
V = self.function_space()
expr = as_ufl(self.expr(*self.ufl_operands))
interp = Interpolate(expr, V)
interp = interpolate(expr, V)

u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1]
jac = ufl_expr.derivative(interp, u)
Expand All @@ -99,7 +99,7 @@ def assemble_Jacobian_adjoint(self, *args, assembly_opts, **kwargs):
def assemble_Jacobian_adjoint_action(self, *args, **kwargs):
V = self.function_space()
expr = as_ufl(self.expr(*self.ufl_operands))
interp = Interpolate(expr, V)
interp = interpolate(expr, V)

u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1]
ustar = self.argument_slots()[0]
Expand Down
6 changes: 3 additions & 3 deletions firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,9 @@ def interpolate(self,
firedrake.function.Function
Returns `self`
"""
from firedrake import interpolation, assemble
from firedrake import interpolate, assemble
V = self.function_space()
interp = interpolation.Interpolate(expression, V, **kwargs)
interp = interpolate(expression, V, **kwargs)
return assemble(interp, tensor=self, ad_block_tag=ad_block_tag)

def zero(self, subset=None):
Expand Down Expand Up @@ -715,7 +715,7 @@ def __init__(self, domain, point):
self.point = point

def __str__(self):
return "domain %s does not contain point %s" % (self.domain, self.point)
return f"Domain {self.domain} does not contain point {self.point}"


class PointEvaluator:
Expand Down
67 changes: 20 additions & 47 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import abc
import warnings
from collections.abc import Iterable
from typing import Literal
from functools import partial, singledispatch
from typing import Hashable
from typing import Hashable, Literal

import FIAT
import ufl
import finat.ufl
from ufl.algorithms import extract_arguments, extract_coefficients, replace
from ufl.algorithms import extract_arguments, extract_coefficients
from ufl.domain import as_domain, extract_unique_domain

from pyop2 import op2
Expand All @@ -25,13 +24,11 @@
import finat

import firedrake
import firedrake.bcs
from firedrake import tsfc_interface, utils, functionspaceimpl
from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology
from firedrake.petsc import PETSc
from firedrake.halo import _get_mtype as get_dat_mpi_type
from firedrake.cofunction import Cofunction
from mpi4py import MPI

from pyadjoint import stop_annotating, no_annotations
Expand All @@ -48,7 +45,7 @@

class Interpolate(ufl.Interpolate):

def __init__(self, expr, v,
def __init__(self, expr, V,
Copy link
Contributor

@pbrubeck pbrubeck Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very unrelated, but I think that a much more friendly interface is to allow either or both left and right arguments to be a primal FunctionSpace.

Right now we do this under the hood

Interpolate(Function(V1), V2) -> Interpolate(Function(V1), Argument(V2.dual(), 0))

It'd be reasonable to have a similar shortcut for the adjoint. When the left argument is a FunctionSpace, we would then automatically create the Argument for it.

Interpolate(V1, Cofunction(V2.dual())) -> Interpolate(Argument(V1, 0), Cofunction(V2.dual()))

And supplying two FunctionSpaces is a perfectly natural interface:

Interpolate(V1, V2) -> Interpolate(Argument(V1, 1), Argument(V2.dual(), 0))

Of course we need to arbitrarily decide who gets the lowest number, the more intuitive numbering that produces the forward Interpolation is to go from right to left.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts @dham ?

subset=None,
access=None,
allow_missing_dofs=False,
Expand All @@ -60,7 +57,7 @@ def __init__(self, expr, v,
----------
expr : ufl.core.expr.Expr or ufl.BaseForm
The UFL expression to interpolate.
v : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument
V : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument
The function space to interpolate into or the coargument defined
on the dual of the function space to interpolate into.
subset : pyop2.types.set.Subset
Expand Down Expand Up @@ -95,20 +92,18 @@ def __init__(self, expr, v,
between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
and reduce operations.
"""
# Check function space
expr = ufl.as_ufl(expr)
if isinstance(v, functionspaceimpl.WithGeometry):
expr_args = extract_arguments(expr)
is_adjoint = len(expr_args) and expr_args[0].number() == 0
v = Argument(v.dual(), 1 if is_adjoint else 0)
if isinstance(V, functionspaceimpl.WithGeometry):
expr_args = expr.arguments()[1:] if isinstance(expr, ufl.BaseForm) else extract_arguments(expr)
expr_arg_numbers = {arg.number() for arg in expr_args}
# Need to create a Firedrake Argument so that it has a .function_space() method
V = Argument(V.dual(), 1 if expr_arg_numbers == {0} else 0)

V = v.arguments()[0].function_space()
if len(expr.ufl_shape) != len(V.value_shape):
raise RuntimeError(f'Rank mismatch: Expression rank {len(expr.ufl_shape)}, FunctionSpace rank {len(V.value_shape)}')
target_shape = V.arguments()[0].function_space().value_shape
if expr.ufl_shape != target_shape:
raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {target_shape}.")

if expr.ufl_shape != V.value_shape:
raise RuntimeError('Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {V.value_shape}')
super().__init__(expr, v)
super().__init__(expr, V)

# -- Interpolate data (e.g. `subset` or `access`) -- #
self.interp_data = {"subset": subset,
Expand Down Expand Up @@ -174,32 +169,10 @@ def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, def
reduction (hence using MIN will compute the MIN between the
existing values and any new values).
"""
if isinstance(V, (Cofunction, Coargument)):
dual_arg = V
elif isinstance(V, ufl.BaseForm):
rank = len(V.arguments())
if rank == 1:
dual_arg = V
else:
raise TypeError(f"Expected a one-form, provided form had {rank} arguments")
elif isinstance(V, functionspaceimpl.WithGeometry):
dual_arg = Coargument(V.dual(), 0)
expr_args = extract_arguments(ufl.as_ufl(expr))
if expr_args and expr_args[0].number() == 0:
warnings.warn("Passing argument numbered 0 in expression for forward interpolation is deprecated. "
"Use a TrialFunction in the expression.")
v, = expr_args
expr = replace(expr, {v: v.reconstruct(number=1)})
else:
raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}")

interp = Interpolate(expr, dual_arg,
subset=subset, access=access,
allow_missing_dofs=allow_missing_dofs,
default_missing_val=default_missing_val,
matfree=matfree)

return interp
return Interpolate(
expr, V, subset=subset, access=access, allow_missing_dofs=allow_missing_dofs,
default_missing_val=default_missing_val, matfree=matfree
)


class Interpolator(abc.ABC):
Expand Down Expand Up @@ -528,7 +501,7 @@ def __init__(

from firedrake.assemble import assemble
V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element)
f_dest_node_coords = Interpolate(dest_mesh.coordinates, V_dest_vec)
f_dest_node_coords = interpolate(dest_mesh.coordinates, V_dest_vec)
f_dest_node_coords = assemble(f_dest_node_coords)
dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, dest_mesh_gdim)
try:
Expand All @@ -553,15 +526,15 @@ def __init__(
else:
fs_type = partial(firedrake.TensorFunctionSpace, shape=shape)
P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0)
self.point_eval_interpolate = Interpolate(self.expr_renumbered, P0DG_vom)
self.point_eval_interpolate = interpolate(self.expr_renumbered, P0DG_vom)
# The parallel decomposition of the nodes of V_dest in the DESTINATION
# mesh (dest_mesh) is retrieved using the input_ordering attribute of the
# VOM. This again is an interpolation operation, which, under the hood
# is a PETSc SF reduce.
P0DG_vom_i_o = fs_type(
self.vom_dest_node_coords_in_src_mesh.input_ordering, "DG", 0
)
self.to_input_ordering_interpolate = Interpolate(
self.to_input_ordering_interpolate = interpolate(
firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o
)
# The P0DG function outputted by the above interpolation has the
Expand Down
2 changes: 1 addition & 1 deletion firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4274,7 +4274,7 @@ def _parent_mesh_embedding(
# nessesary, to other processes.
P0DG = functionspace.FunctionSpace(parent_mesh, "DG", 0)
with stop_annotating():
visible_ranks = interpolation.Interpolate(
visible_ranks = interpolation.interpolate(
constant.Constant(parent_mesh.comm.rank), P0DG
)
visible_ranks = assemble(visible_ranks).dat.data_ro_with_halos.real
Expand Down
2 changes: 1 addition & 1 deletion firedrake/mg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def physical_node_locations(V):
Vc = V.collapse().reconstruct(element=finat.ufl.VectorElement(element, dim=mesh.geometric_dimension))

# FIXME: This is unsafe for DG coordinates and CG target spaces.
locations = firedrake.assemble(firedrake.Interpolate(firedrake.SpatialCoordinate(mesh), Vc))
locations = firedrake.assemble(firedrake.interpolate(firedrake.SpatialCoordinate(mesh), Vc))
return cache.setdefault(key, locations)


Expand Down
4 changes: 2 additions & 2 deletions firedrake/preconditioners/gtmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from firedrake.petsc import PETSc
from firedrake.preconditioners.base import PCBase
from firedrake.parameters import parameters
from firedrake.interpolation import Interpolate
from firedrake.interpolation import interpolate
from firedrake.solving_utils import _SNESContext
from firedrake.matrix_free.operators import ImplicitMatrixContext
import firedrake.dmhooks as dmhooks
Expand Down Expand Up @@ -155,7 +155,7 @@ def initialize(self, pc):
# Create interpolation matrix from coarse space to fine space
fine_space = ctx.J.arguments()[0].function_space()
coarse_test, coarse_trial = coarse_operator.arguments()
interp = assemble(Interpolate(coarse_trial, fine_space))
interp = assemble(interpolate(coarse_trial, fine_space))
interp_petscmat = interp.petscmat
restr_petscmat = appctx.get("restriction_matrix", None)

Expand Down
6 changes: 3 additions & 3 deletions firedrake/preconditioners/hypre_ads.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from firedrake.preconditioners.base import PCBase
from firedrake.petsc import PETSc
from firedrake.function import Function
from firedrake.ufl_expr import TestFunction
from firedrake.ufl_expr import TrialFunction
from firedrake.dmhooks import get_function_space
from firedrake.preconditioners.hypre_ams import chop
from firedrake.interpolation import interpolate
Expand Down Expand Up @@ -31,12 +31,12 @@ def initialize(self, obj):
NC1 = V.reconstruct(family="N1curl" if mesh.ufl_cell().is_simplex else "NCE", degree=1)
G_callback = appctx.get("get_gradient", None)
if G_callback is None:
G = chop(assemble(interpolate(grad(TestFunction(P1)), NC1)).petscmat)
G = chop(assemble(interpolate(grad(TrialFunction(P1)), NC1)).petscmat)
else:
G = G_callback(P1, NC1)
C_callback = appctx.get("get_curl", None)
if C_callback is None:
C = chop(assemble(interpolate(curl(TestFunction(NC1)), V)).petscmat)
C = chop(assemble(interpolate(curl(TrialFunction(NC1)), V)).petscmat)
else:
C = C_callback(NC1, V)

Expand Down
4 changes: 2 additions & 2 deletions firedrake/preconditioners/hypre_ams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from firedrake.preconditioners.base import PCBase
from firedrake.petsc import PETSc
from firedrake.function import Function
from firedrake.ufl_expr import TestFunction
from firedrake.ufl_expr import TrialFunction
from firedrake.dmhooks import get_function_space
from firedrake.utils import complex_mode
from firedrake.interpolation import interpolate
Expand Down Expand Up @@ -51,7 +51,7 @@ def initialize(self, obj):
P1 = V.reconstruct(family="Lagrange", degree=1)
G_callback = appctx.get("get_gradient", None)
if G_callback is None:
G = chop(assemble(interpolate(grad(TestFunction(P1)), V)).petscmat)
G = chop(assemble(interpolate(grad(TrialFunction(P1)), V)).petscmat)
else:
G = G_callback(P1, V)

Expand Down
4 changes: 2 additions & 2 deletions firedrake/preconditioners/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from firedrake.solving_utils import _SNESContext
from firedrake.utils import cached_property, complex_mode, IntType
from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx
from firedrake.interpolation import Interpolate
from firedrake.interpolation import interpolate
from firedrake.ufl_expr import extract_domains

from collections import namedtuple
Expand Down Expand Up @@ -668,7 +668,7 @@ def sort_entities(self, dm, axis, dir, ndiv=None, divisions=None):
# with access descriptor MAX to define a consistent opinion
# about where the vertices are.
CGk = V.reconstruct(family="Lagrange")
coordinates = assemble(Interpolate(coordinates, CGk, access=op2.MAX))
coordinates = assemble(interpolate(coordinates, CGk, access=op2.MAX))

select = partial(select_entity, dm=dm, exclude="pyop2_ghost")
entities = [(p, self.coords(dm, p, coordinates)) for p in
Expand Down
12 changes: 6 additions & 6 deletions firedrake/pyplot/mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import mpl_toolkits.mplot3d
from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection
from math import factorial
from firedrake import (Interpolate, sqrt, inner, Function, SpatialCoordinate,
from firedrake import (interpolate, sqrt, inner, Function, SpatialCoordinate,
FunctionSpace, VectorFunctionSpace, PointNotInDomainError,
Constant, assemble, dx)
from firedrake.mesh import MeshGeometry
Expand Down Expand Up @@ -120,7 +120,7 @@ def triplot(mesh, axes=None, interior_kw={}, boundary_kw={}):
if element.degree() != 1:
# Interpolate to piecewise linear.
V = VectorFunctionSpace(mesh, element.family(), 1)
coordinates = assemble(Interpolate(coordinates, V))
coordinates = assemble(interpolate(coordinates, V))

coords = toreal(coordinates.dat.data_ro_with_halos, "real")
result = []
Expand Down Expand Up @@ -215,7 +215,7 @@ def _plot_2d_field(method_name, function, *args, complex_component="real", **kwa
if len(function.ufl_shape) == 1:
element = function.ufl_element().sub_elements[0]
Q = FunctionSpace(mesh, element)
function = assemble(Interpolate(sqrt(inner(function, function)), Q))
function = assemble(interpolate(sqrt(inner(function, function)), Q))

num_sample_points = kwargs.pop("num_sample_points", 10)
function_plotter = FunctionPlotter(mesh, num_sample_points)
Expand Down Expand Up @@ -326,7 +326,7 @@ def trisurf(function, *args, complex_component="real", **kwargs):
if len(function.ufl_shape) == 1:
element = function.ufl_element().sub_elements[0]
Q = FunctionSpace(mesh, element)
function = assemble(Interpolate(sqrt(inner(function, function)), Q))
function = assemble(interpolate(sqrt(inner(function, function)), Q))

num_sample_points = kwargs.pop("num_sample_points", 10)
function_plotter = FunctionPlotter(mesh, num_sample_points)
Expand Down Expand Up @@ -355,7 +355,7 @@ def quiver(function, *, complex_component="real", **kwargs):

coords = toreal(extract_unique_domain(function).coordinates.dat.data_ro, "real")
V = extract_unique_domain(function).coordinates.function_space()
function_interp = assemble(Interpolate(function, V))
function_interp = assemble(interpolate(function, V))
vals = toreal(function_interp.dat.data_ro, complex_component)
C = np.linalg.norm(vals, axis=1)
return axes.quiver(*(coords.T), *(vals.T), C, **kwargs)
Expand Down Expand Up @@ -816,7 +816,7 @@ def _bezier_plot(function, axes, complex_component="real", **kwargs):
mesh = function.function_space().mesh()
if deg == 0:
V = FunctionSpace(mesh, "DG", 1)
interp = assemble(Interpolate(function, V))
interp = assemble(interpolate(function, V))
return _bezier_plot(interp, axes, complex_component=complex_component,
**kwargs)
y_vals = _bezier_calculate_points(function)
Expand Down
Loading
Loading