Skip to content

Commit e04c9fe

Browse files
committed
fixup
1 parent 3e28c5d commit e04c9fe

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

firedrake/interpolation.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,6 @@ def assemble(
304304
The function, cofunction, matrix, or scalar resulting from the
305305
interpolation.
306306
"""
307-
# Set default mat_types if not provided and check it's valid
308-
mat_type = mat_type or "aij"
309-
sub_mat_type = sub_mat_type or "baij"
310307
self._check_mat_type(mat_type)
311308

312309
result = self._get_callable(tensor=tensor, bcs=bcs, mat_type=mat_type, sub_mat_type=sub_mat_type)()
@@ -324,7 +321,7 @@ def assemble(
324321

325322
def _check_mat_type(
326323
self,
327-
mat_type: Literal["aij", "baij", "nest", "matfree"],
324+
mat_type: Literal["aij", "baij", "nest", "matfree"] | None,
328325
) -> None:
329326
"""Check that the given mat_type is valid for this Interpolator.
330327
@@ -530,6 +527,7 @@ def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None)
530527
from firedrake.assemble import assemble
531528
if bcs:
532529
raise NotImplementedError("bcs not implemented for cross-mesh interpolation.")
530+
mat_type = mat_type or "aij"
533531

534532
# self.ufl_interpolate.function_space() is None in the 0-form case
535533
V_dest = self.ufl_interpolate.function_space() or self.target_space
@@ -603,7 +601,7 @@ def callable() -> Function | Number:
603601

604602
@property
605603
def _allowed_mat_types(self):
606-
return {"aij"}
604+
return {"aij", None}
607605

608606

609607
class SameMeshInterpolator(Interpolator):
@@ -709,6 +707,7 @@ def _get_monolithic_sparsity(self, mat_type: Literal["aij", "baij"]) -> op2.Spar
709707
return sparsity
710708

711709
def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None):
710+
mat_type = mat_type or "aij"
712711
if (isinstance(tensor, Cofunction) and isinstance(self.dual_arg, Cofunction)) and set(tensor.dat).intersection(set(self.dual_arg.dat)):
713712
# adjoint one-form case: we need an empty tensor, so if it shares dats with
714713
# the dual_arg we cannot use it directly, so we store it
@@ -766,7 +765,7 @@ def callable() -> Function | Cofunction | PETSc.Mat | Number:
766765

767766
@property
768767
def _allowed_mat_types(self):
769-
return {"aij", "baij"}
768+
return {"aij", "baij", None}
770769

771770

772771
class VomOntoVomInterpolator(SameMeshInterpolator):
@@ -777,7 +776,7 @@ def __init__(self, expr: Interpolate):
777776
def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None):
778777
if bcs:
779778
raise NotImplementedError("bcs not implemented for vom-to-vom interpolation.")
780-
779+
mat_type = mat_type or "matfree"
781780
self.mat = VomOntoVomMat(self, mat_type=mat_type)
782781
if self.rank == 1:
783782
f = tensor or self._get_tensor(mat_type)
@@ -816,7 +815,7 @@ def callable() -> PETSc.Mat:
816815

817816
@property
818817
def _allowed_mat_types(self):
819-
return {"aij", "baij", "matfree"}
818+
return {"aij", "baij", "matfree", None}
820819

821820

822821
@known_pyop2_safe
@@ -1697,6 +1696,8 @@ def _build_aij(
16971696
return matnest.convert("aij")
16981697

16991698
def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None):
1699+
mat_type = mat_type or "aij"
1700+
sub_mat_type = sub_mat_type or "baij"
17001701
Isub = self._get_sub_interpolators(bcs=bcs)
17011702
V_dest = self.ufl_interpolate.function_space() or self.target_space
17021703
f = tensor or Function(V_dest)
@@ -1720,4 +1721,4 @@ def callable() -> Number:
17201721

17211722
@property
17221723
def _allowed_mat_types(self):
1723-
return {"aij", "nest"}
1724+
return {"aij", "nest", None}

tests/firedrake/regression/test_interpolator_types.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def test_vomtovom_mattype(value_shape, mat_type):
8888
interp = interpolate(u, P0DG_io)
8989
assert isinstance(get_interpolator(interp), VomOntoVomInterpolator)
9090
res = assemble(interp, mat_type=mat_type)
91-
if not mat_type or mat_type == "aij":
92-
assert res.petscmat.type == "seqaij"
93-
elif mat_type == "matfree":
91+
if not mat_type or mat_type == "matfree":
9492
assert res.petscmat.type == "python"
9593
else:
9694
if value_shape == "scalar":

0 commit comments

Comments
 (0)