@@ -304,9 +304,6 @@ def assemble(
304304 The function, cofunction, matrix, or scalar resulting from the
305305 interpolation.
306306 """
307- # Set default mat_types if not provided and check it's valid
308- mat_type = mat_type or "aij"
309- sub_mat_type = sub_mat_type or "baij"
310307 self ._check_mat_type (mat_type )
311308
312309 result = self ._get_callable (tensor = tensor , bcs = bcs , mat_type = mat_type , sub_mat_type = sub_mat_type )()
@@ -324,7 +321,7 @@ def assemble(
324321
325322 def _check_mat_type (
326323 self ,
327- mat_type : Literal ["aij" , "baij" , "nest" , "matfree" ],
324+ mat_type : Literal ["aij" , "baij" , "nest" , "matfree" ] | None ,
328325 ) -> None :
329326 """Check that the given mat_type is valid for this Interpolator.
330327
@@ -530,6 +527,7 @@ def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None)
530527 from firedrake .assemble import assemble
531528 if bcs :
532529 raise NotImplementedError ("bcs not implemented for cross-mesh interpolation." )
530+ mat_type = mat_type or "aij"
533531
534532 # self.ufl_interpolate.function_space() is None in the 0-form case
535533 V_dest = self .ufl_interpolate .function_space () or self .target_space
@@ -603,7 +601,7 @@ def callable() -> Function | Number:
603601
604602 @property
605603 def _allowed_mat_types (self ):
606- return {"aij" }
604+ return {"aij" , None }
607605
608606
609607class SameMeshInterpolator (Interpolator ):
@@ -709,6 +707,7 @@ def _get_monolithic_sparsity(self, mat_type: Literal["aij", "baij"]) -> op2.Spar
709707 return sparsity
710708
711709 def _get_callable (self , tensor = None , bcs = None , mat_type = None , sub_mat_type = None ):
710+ mat_type = mat_type or "aij"
712711 if (isinstance (tensor , Cofunction ) and isinstance (self .dual_arg , Cofunction )) and set (tensor .dat ).intersection (set (self .dual_arg .dat )):
713712 # adjoint one-form case: we need an empty tensor, so if it shares dats with
714713 # the dual_arg we cannot use it directly, so we store it
@@ -766,7 +765,7 @@ def callable() -> Function | Cofunction | PETSc.Mat | Number:
766765
767766 @property
768767 def _allowed_mat_types (self ):
769- return {"aij" , "baij" }
768+ return {"aij" , "baij" , None }
770769
771770
772771class VomOntoVomInterpolator (SameMeshInterpolator ):
@@ -777,7 +776,7 @@ def __init__(self, expr: Interpolate):
777776 def _get_callable (self , tensor = None , bcs = None , mat_type = None , sub_mat_type = None ):
778777 if bcs :
779778 raise NotImplementedError ("bcs not implemented for vom-to-vom interpolation." )
780-
779+ mat_type = mat_type or "matfree"
781780 self .mat = VomOntoVomMat (self , mat_type = mat_type )
782781 if self .rank == 1 :
783782 f = tensor or self ._get_tensor (mat_type )
@@ -816,7 +815,7 @@ def callable() -> PETSc.Mat:
816815
817816 @property
818817 def _allowed_mat_types (self ):
819- return {"aij" , "baij" , "matfree" }
818+ return {"aij" , "baij" , "matfree" , None }
820819
821820
822821@known_pyop2_safe
@@ -1697,6 +1696,8 @@ def _build_aij(
16971696 return matnest .convert ("aij" )
16981697
16991698 def _get_callable (self , tensor = None , bcs = None , mat_type = None , sub_mat_type = None ):
1699+ mat_type = mat_type or "aij"
1700+ sub_mat_type = sub_mat_type or "baij"
17001701 Isub = self ._get_sub_interpolators (bcs = bcs )
17011702 V_dest = self .ufl_interpolate .function_space () or self .target_space
17021703 f = tensor or Function (V_dest )
@@ -1720,4 +1721,4 @@ def callable() -> Number:
17201721
17211722 @property
17221723 def _allowed_mat_types (self ):
1723- return {"aij" , "nest" }
1724+ return {"aij" , "nest" , None }
0 commit comments