diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index dd72c4312..bda23a2bf 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -2016,6 +2016,22 @@ "endColumn": 25, "lineCount": 1 } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 34, + "endColumn": 44, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 56, + "endColumn": 66, + "lineCount": 1 + } } ], "./pytato/array.py": [ diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 00e36ce32..13a9bfd81 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -27,17 +27,21 @@ """ from collections import defaultdict -from typing import TYPE_CHECKING, Any, overload +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast, overload from orderedsets import FrozenOrderedSet from typing_extensions import Never, Self, override from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper +from pytools import product from pytato.array import ( Array, + ArrayOrScalar, Concatenate, + DataWrapper, DictOfNamedArrays, Einsum, IndexBase, @@ -45,17 +49,28 @@ IndexRemappingBase, InputArgumentBase, NamedArray, + Placeholder, ShapeType, Stack, ) +from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import Call, FunctionDefinition, NamedCallResult +from pytato.scalar_expr import ( + FlopCounter as ScalarFlopCounter, +) +from pytato.tags import ImplStored from pytato.transform import ( ArrayOrNames, + ArrayOrNamesOrFunctionDef, + ArrayOrNamesTc, CachedWalkMapper, CombineMapper, Mapper, VisitKeyT, + map_and_copy, ) +from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.utils import has_taggable_materialization, is_materialized if TYPE_CHECKING: @@ -63,7 +78,6 @@ import pytools.tag - from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.loopy import LoopyCall __doc__ = """ @@ -87,6 +101,13 @@ .. autoclass:: TagCountMapper .. autofunction:: get_num_tags_of_type + +.. autoclass:: UndefinedOpFlopCountError +.. autofunction:: get_default_op_name_to_num_flops +.. autofunction:: get_num_flops +.. autofunction:: get_materialized_node_flop_counts +.. autoclass:: UnmaterializedNodeFlopCounts +.. autofunction:: get_unmaterialized_node_flop_counts """ @@ -342,7 +363,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: class ListOfDirectPredecessorsGetter( Mapper[ - list[ArrayOrNames | FunctionDefinition], + list[ArrayOrNamesOrFunctionDef], list[ArrayOrNames], []]): """ @@ -425,8 +446,8 @@ def map_distributed_send_ref_holder(self, return [expr.send.data, expr.passthrough_data] def map_call( - self, expr: Call) -> list[ArrayOrNames | FunctionDefinition]: - result: list[ArrayOrNames | FunctionDefinition] = [] + self, expr: Call) -> list[ArrayOrNamesOrFunctionDef]: + result: list[ArrayOrNamesOrFunctionDef] = [] if self.include_functions: result.append(expr.function) result += list(expr.bindings.values()) @@ -463,7 +484,7 @@ def __init__(self, *, include_functions: bool = False) -> None: @overload def __call__( self, expr: ArrayOrNames - ) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]: + ) -> FrozenOrderedSet[ArrayOrNamesOrFunctionDef]: ... @overload @@ -472,9 +493,9 @@ def __call__(self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]: def __call__( self, - expr: ArrayOrNames | FunctionDefinition, + expr: ArrayOrNamesOrFunctionDef, ) -> ( - FrozenOrderedSet[ArrayOrNames | FunctionDefinition] + FrozenOrderedSet[ArrayOrNamesOrFunctionDef] | FrozenOrderedSet[ArrayOrNames]): """Get the direct predecessors of *expr*.""" return FrozenOrderedSet(self._pred_getter(expr)) @@ -523,7 +544,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: _visited_functions=self._visited_functions) @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_type_counts[type(expr)] += 1 @@ -586,7 +607,7 @@ def __init__(self, _visited_functions: set[Any] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) self.expr_multiplicity_counts: \ - dict[ArrayOrNames | FunctionDefinition, int] = defaultdict(int) + dict[ArrayOrNamesOrFunctionDef, int] = defaultdict(int) @override def get_cache_key(self, expr: ArrayOrNames) -> int: @@ -599,13 +620,13 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: return id(expr) @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_multiplicity_counts[expr] += 1 def get_node_multiplicities( - outputs: ArrayOrNames) -> dict[ArrayOrNames | FunctionDefinition, int]: + outputs: ArrayOrNames) -> dict[ArrayOrNamesOrFunctionDef, int]: """ Returns the multiplicity per `expr`. """ @@ -642,7 +663,7 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: return id(expr) @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if isinstance(expr, Call): self.count += 1 @@ -755,4 +776,384 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # }}} + +# {{{ flop counting + +@dataclass +class UndefinedOpFlopCountError(ValueError): + op_name: str + + +class _PerEntryFlopCounter(CombineMapper[int, Never]): + def __init__(self, op_name_to_num_flops: Mapping[str, int]) -> None: + super().__init__() + self.scalar_flop_counter: ScalarFlopCounter = ScalarFlopCounter( + op_name_to_num_flops) + self.node_to_nflops: dict[Array, int] = {} + + @override + def combine(self, *args: int) -> int: + return sum(args) + + def _get_own_flop_count(self, expr: Array) -> int: + if isinstance( + expr, + ( + DataWrapper, + Placeholder, + NamedArray, + DistributedRecv, + DistributedSendRefHolder)): + return 0 + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + if not isinstance(nflops, int): + # Restricting to numerical result here because the flop counters that use + # this mapper subsequently multiply the result by things that are + # potentially arrays (e.g., shape components), and arrays and scalar + # expressions are not interoperable + from pytato.scalar_expr import OpFlops, OpFlopsCollector + op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops) + if op_flops: + raise UndefinedOpFlopCountError(next(iter(op_flops)).op) + else: + raise AssertionError + return nflops + + @override + def rec(self, expr: ArrayOrNames) -> int: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: int + if isinstance(expr, Array) and not is_materialized(expr): + result = ( + self._get_own_flop_count(expr) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + + cast("int", Mapper.rec(self, expr))) + else: + result = 0 + if isinstance(expr, Array): + self.node_to_nflops[expr] = result + return self._cache_add(inputs, result) + + +class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the number of floating point operations of each materialized + expression in a DAG. + + .. note:: + + Flops from nodes inside function calls are accumulated onto the corresponding + call node. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + ) -> None: + super().__init__() + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: + if not is_materialized(expr): + return + assert isinstance(expr, Array) + if has_taggable_materialization(expr): + unmaterialized_expr = expr.without_tags(ImplStored()) + self.materialized_node_to_nflops[expr] = ( + product(expr.shape) + * self._per_entry_flop_counter(unmaterialized_expr)) + else: + self.materialized_node_to_nflops[expr] = 0 + + +class _UnmaterializedSubexpressionUseCounter(CombineMapper[dict[Array, int], Never]): + @override + def combine(self, *args: dict[Array, int]) -> dict[Array, int]: + result: dict[Array, int] = defaultdict(int) + for arg in args: + for ary, nuses in arg.items(): + result[ary] += nuses + return result + + @override + def rec(self, expr: ArrayOrNames) -> dict[Array, int]: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: dict[Array, int] + if isinstance(expr, Array) and not is_materialized(expr): + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = self.combine( + {expr: 1}, cast("dict[Array, int]", Mapper.rec(self, expr))) + else: + result = {} + return self._cache_add(inputs, result) + + +@dataclass +class UnmaterializedNodeFlopCounts: + """ + Floating point operation counts for an unmaterialized node. See + :func:`get_unmaterialized_node_flop_counts` for details. + """ + materialized_successor_to_contrib_nflops: dict[Array, ArrayOrScalar] + nflops_if_materialized: ArrayOrScalar + + +class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the accumulated number of floating point operations that each + unmaterialized expression contributes to materialized expressions in the DAG. + + .. note:: + + This mapper does not descend into functions. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int]) -> None: + super().__init__() + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.unmaterialized_node_to_flop_counts: \ + dict[Array, UnmaterializedNodeFlopCounts] = {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: + if not is_materialized(expr) or not has_taggable_materialization(expr): + return + assert isinstance(expr, Array) + unmaterialized_expr = expr.without_tags(ImplStored()) + subexpr_to_nuses = _UnmaterializedSubexpressionUseCounter()( + unmaterialized_expr) + del subexpr_to_nuses[unmaterialized_expr] + self._per_entry_flop_counter(unmaterialized_expr) + for subexpr, nuses in subexpr_to_nuses.items(): + per_entry_nflops = self._per_entry_flop_counter.node_to_nflops[subexpr] + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = product(subexpr.shape) * per_entry_nflops + flop_counts = UnmaterializedNodeFlopCounts({}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses * product(expr.shape) * per_entry_nflops) + + +# FIXME: Should this be added to normalize_outputs? +def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: + # Make sure outputs are materialized + if isinstance(expr, DictOfNamedArrays): + output_to_materialized_output: dict[Array, Array] = { + ary: ( + ary.tagged(ImplStored()) + if has_taggable_materialization(ary) + else ary) + for ary in expr._data.values()} + + def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames: + if not isinstance(ary, Array): + return ary + try: + return output_to_materialized_output[ary] + except KeyError: + return ary + + expr = map_and_copy(expr, replace_with_materialized) + + return expr + + +def get_default_op_name_to_num_flops() -> dict[str, int]: + """ + Returns a mapping from operator name to floating point operation count for + operators that are almost always a single flop. + """ + return { + "+": 1, + "*": 1, + "==": 1, + "!=": 1, + "<": 1, + ">": 1, + "<=": 1, + ">=": 1, + "min": 1, + "max": 1} + + +def get_num_flops( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> ArrayOrScalar: + """ + Count the total number of floating point operations in the DAG *expr*. + + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. The total flop count is the sum + over all materialized nodes. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. + + .. note:: + + Does not support functions. Function calls must be inlined before calling. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return sum(fc.materialized_node_to_nflops.values()) + + +def get_materialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, ArrayOrScalar]: + """ + Returns a dictionary mapping materialized nodes in DAG *expr* to their floating + point operation count. + + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. + + .. note:: + + Does not support functions. Function calls must be inlined before calling. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return fc.materialized_node_to_nflops + + +def get_unmaterialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, UnmaterializedNodeFlopCounts]: + """ + Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a + :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count + information. + + The :class:`UnmaterializedNodeFlopCounts` instance for each unmaterialized node + (i.e., a node that can be tagged with :class:`pytato.tags.ImplStored` but isn't) + contains `materialized_successor_to_contrib_nflops` and `nflops_if_materialized` + attributes. The former is a mapping from each materialized successor of the + unmaterialized node to the number of flops the node contributes to evaluating + that successor (this includes flops from the predecessors of the unmaterialized + node). The latter is the number of flops that would be required to evaluate the + unmaterialized node if it was materialized instead. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. + + .. note:: + + Does not support functions. Function calls must be inlined before calling. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = UnmaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return fc.unmaterialized_node_to_flop_counts + +# }}} + + # vim: fdm=marker diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index c7defba6f..69bb9d6d5 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -55,7 +55,11 @@ DistributedGraphPartition, PartId, ) -from pytato.transform import ArrayOrNames, CachedWalkMapper +from pytato.transform import ( + ArrayOrNames, + ArrayOrNamesOrFunctionDef, + CachedWalkMapper, +) logger = logging.getLogger(__name__) @@ -68,7 +72,6 @@ import numpy as np from pytato.distributed.nodes import CommTagType, DistributedRecv - from pytato.function import FunctionDefinition # {{{ data structures @@ -156,7 +159,7 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @override - def visit(self, expr: ArrayOrNames | FunctionDefinition) -> bool: + def visit(self, expr: ArrayOrNamesOrFunctionDef) -> bool: super().visit(expr) if isinstance(expr, ArrayOrNames): self.seen_nodes.add(expr) diff --git a/pytato/equality.py b/pytato/equality.py index 837530f02..1976be49d 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -63,6 +63,8 @@ ArrayOrNames = Array | AbstractResultWithNamedArrays +ArrayOrNamesOrFunctionDef = \ + Array | AbstractResultWithNamedArrays | FunctionDefinition # {{{ EqualityComparer @@ -87,7 +89,7 @@ def __init__(self) -> None: # Uses the same cache for both arrays and functions self._cache: dict[tuple[int, int], bool] = {} - def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: object) -> bool: + def rec(self, expr1: ArrayOrNamesOrFunctionDef, expr2: object) -> bool: # These cases are simple enough that they don't need to be cached if expr1 is expr2: return True @@ -119,7 +121,7 @@ def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: object) -> bool: self._cache[cache_key] = result return result - def __call__(self, expr1: ArrayOrNames | FunctionDefinition, expr2: object) -> bool: + def __call__(self, expr1: ArrayOrNamesOrFunctionDef, expr2: object) -> bool: return self.rec(expr1, expr2) def handle_unsupported_array(self, expr1: Array, diff --git a/pytato/reductions.py b/pytato/reductions.py index 6efa45ac2..a229328c8 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -34,6 +34,7 @@ import numpy as np from constantdict import constantdict +from typing_extensions import override import pymbolic.primitives as prim from pymbolic import ArithmeticExpression @@ -80,22 +81,27 @@ class _NoValue: class ReductionOperation(ABC): """ + .. automethod:: scalar_op_name .. automethod:: neutral_element .. automethod:: __hash__ .. automethod:: __eq__ """ + @classmethod + @abstractmethod + def scalar_op_name(cls) -> str: + ... @abstractmethod def neutral_element(self, dtype: np.dtype[Any]) -> Any: - pass + ... @abstractmethod def __hash__(self) -> int: - pass + ... @abstractmethod def __eq__(self, other: object) -> bool: - pass + ... class _StatelessReductionOperation(ReductionOperation): @@ -110,16 +116,31 @@ def __eq__(self, other: object) -> bool: class SumReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "+" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 0 class ProductReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "*" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 1 class MaxReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "max" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("-inf")) @@ -130,6 +151,11 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class MinReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "min" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("inf")) @@ -140,11 +166,21 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class AllReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "or" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(True) class AnyReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "and" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(False) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index e59cbdce8..e7a796faa 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -44,6 +44,7 @@ import re from collections.abc import Iterable, Mapping, Set as AbstractSet +from functools import reduce from typing import ( TYPE_CHECKING, Any, @@ -57,6 +58,7 @@ import pymbolic.primitives as prim from pymbolic import ArithmeticExpression, Bool, Expression, expr_dataclass from pymbolic.mapper import ( + Collector, CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, P, @@ -70,6 +72,7 @@ ) from pymbolic.mapper.distributor import DistributeMapper as DistributeMapperBase from pymbolic.mapper.evaluator import EvaluationMapper as EvaluationMapperBase +from pymbolic.mapper.flop_counter import FlopCounterBase from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase from pymbolic.mapper.substitutor import SubstitutionMapper as SubstitutionMapperBase from pymbolic.typing import Integer @@ -241,6 +244,140 @@ def map_type_cast(self, inner_str = self.rec(expr.inner_expr, PREC_NONE, *args, **kwargs) return f"cast({expr.dtype}, {inner_str})" + +class InputGatherer(Collector[str, []]): + @override + def map_variable(self, expr: prim.Variable) -> set[str]: + return {expr.name} + + +class FlopCounter(FlopCounterBase): + op_name_to_num_flops: dict[str, ArithmeticExpression] + + def __init__( + self, + op_name_to_num_flops: Mapping[str, ArithmeticExpression] | None = None): + super().__init__() + if op_name_to_num_flops: + self.op_name_to_num_flops = dict(op_name_to_num_flops) + else: + self.op_name_to_num_flops = {} + + def _get_op_nflops(self, name: str) -> ArithmeticExpression: + try: + return self.op_name_to_num_flops[name] + except KeyError: + result = OpFlops(name) + self.op_name_to_num_flops[name] = result + return result + + @override + def map_call(self, expr: prim.Call) -> ArithmeticExpression: + assert isinstance(expr.function, prim.Variable) + return ( + self._get_op_nflops(expr.function.name) + + sum(self.rec(child) for child in expr.parameters)) + + @override + def map_subscript(self, expr: prim.Subscript) -> ArithmeticExpression: + # Assume calculations inside subscripts are performed on non-floats + return 0 + + @override + def map_sum(self, expr: prim.Sum) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + if expr.children: + return ( + self._get_op_nflops("+") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_product(self, expr: prim.Product) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("*") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_quotient(self, expr: prim.Quotient) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + return ( + self._get_op_nflops("/") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_floor_div(self, expr: prim.FloorDiv) -> ArithmeticExpression: + return ( + self._get_op_nflops("//") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_power(self, expr: prim.Power) -> ArithmeticExpression: + if isinstance(expr.exponent, int): + if expr.exponent >= 0: + return ( + expr.exponent * self._get_op_nflops("*") + + self.rec(expr.base)) + else: + return ( + self._get_op_nflops("/") + + expr.exponent * self._get_op_nflops("*") + + self.rec(expr.base)) + else: + return ( + self._get_op_nflops("**") + + self.rec(expr.base) + + self.rec(expr.exponent)) + + @override + def map_comparison(self, expr: prim.Comparison) -> ArithmeticExpression: + return ( + self._get_op_nflops(expr.operator) + + self.rec(expr.left) + + self.rec(expr.right)) + + @override + def map_if(self, expr: prim.If) -> ArithmeticExpression: + return ( + self.rec(expr.condition) + + self.rec(expr.then) + + self.rec(expr.else_)) + + @override + def map_max(self, expr: prim.Max) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("max") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_min(self, expr: prim.Min) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("min") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_nan(self, expr: prim.NaN) -> ArithmeticExpression: + return 0 + + def map_reduce(self, expr: Reduce) -> ArithmeticExpression: + result = self.rec(expr.inner_expr) + nflops_op = self._get_op_nflops(expr.op.scalar_op_name()) + for lower_bd, upper_bd in expr.bounds.values(): + nops = upper_bd - lower_bd + result = result * nops + nflops_op * (nops-1) + + return result + # }}} @@ -343,9 +480,42 @@ class TypeCast(ExpressionBase): dtype: np.dtype[Any] inner_expr: ScalarExpression + +@expr_dataclass() +class OpFlops(prim.AlgebraicLeaf): + """ + Placeholder flop count for an operator. + + .. autoattribute:: op + """ + op: str + # }}} +class OpFlopsCollector(CombineMapper[frozenset[OpFlops], []]): + """ + Constructs a :class:`frozenset` containing all instances of + :class:`pytato.scalar_expr.OpFlops` found in a scalar expression. + """ + @override + def combine( + self, values: Iterable[frozenset[OpFlops]]) -> frozenset[OpFlops]: + return reduce( + lambda x, y: x.union(y), + values, + cast("frozenset[OpFlops]", frozenset())) + + @override + def map_algebraic_leaf( + self, expr: prim.AlgebraicLeaf) -> frozenset[OpFlops]: + return frozenset([expr]) if isinstance(expr, OpFlops) else frozenset() + + @override + def map_constant(self, expr: object) -> frozenset[OpFlops]: + return frozenset() + + class InductionVariableCollector(CombineMapper[AbstractSet[str], []]): def combine(self, values: Iterable[AbstractSet[str]]) -> frozenset[str]: from functools import reduce diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 0ed13a63f..27823a746 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -89,6 +89,8 @@ ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays +ArrayOrNamesOrFunctionDef: TypeAlias = \ + Array | AbstractResultWithNamedArrays | FunctionDefinition ArrayOrNamesTc = TypeVar("ArrayOrNamesTc", Array, AbstractResultWithNamedArrays, DictOfNamedArrays) ArrayOrNamesOrFunctionDefTc = TypeVar("ArrayOrNamesOrFunctionDefTc", @@ -150,6 +152,7 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. class:: ArrayOrNames +.. class:: ArrayOrNamesOrFunctionDef .. class:: ArrayOrNamesTc @@ -307,7 +310,7 @@ def __call__( def __call__( self, - expr: ArrayOrNames | FunctionDefinition, + expr: ArrayOrNamesOrFunctionDef, *args: P.args, **kwargs: P.kwargs) -> ResultT | FunctionResultT: """Handle the mapping of *expr*.""" @@ -1569,7 +1572,7 @@ def clone_for_callee( return type(self)() def visit( - self, expr: ArrayOrNames | FunctionDefinition, + self, expr: ArrayOrNamesOrFunctionDef, *args: P.args, **kwargs: P.kwargs) -> bool: """ If this method returns *True*, *expr* is traversed during the walk. @@ -1579,7 +1582,7 @@ def visit( return True def post_visit( - self, expr: ArrayOrNames | FunctionDefinition, + self, expr: ArrayOrNamesOrFunctionDef, *args: P.args, **kwargs: P.kwargs) -> None: """ Callback after *expr* has been traversed. @@ -1841,7 +1844,7 @@ def __init__( def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if isinstance(expr, Array): self.topological_order.append(expr) diff --git a/pytato/utils.py b/pytato/utils.py index 19dc08ef7..333929d8b 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -61,7 +61,7 @@ ScalarExpression, TypeCast, ) -from pytato.transform import CachedMapper +from pytato.transform import ArrayOrNamesOrFunctionDef, CachedMapper if TYPE_CHECKING: @@ -80,6 +80,8 @@ .. autofunction:: dim_to_index_lambda_components .. autofunction:: get_common_dtype_of_ary_or_scalars .. autofunction:: get_einsum_subscript_str +.. autofunction:: is_materialized +.. autofunction:: has_taggable_materialization References ^^^^^^^^^^ @@ -735,4 +737,40 @@ def get_einsum_specification(expr: Einsum) -> str: for i in range(expr.ndim)) return f"{','.join(input_specs)}->{output_spec}" + + +def is_materialized(expr: ArrayOrNamesOrFunctionDef) -> bool: + """Returns *True* if *expr* is materialized.""" + from pytato.array import InputArgumentBase + from pytato.distributed.nodes import DistributedRecv + from pytato.tags import ImplStored + return ( + ( + isinstance(expr, Array) + and bool(expr.tags_of_type(ImplStored))) + or isinstance( + expr, + ( + InputArgumentBase, + DistributedRecv))) + + +def has_taggable_materialization(expr: ArrayOrNamesOrFunctionDef) -> bool: + """ + Returns *True* if *expr* uses the :class:`pytato.tags.ImplStored` tag to + determine whether or not it is materialized. + """ + from pytato.array import InputArgumentBase, NamedArray + from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder + return ( + isinstance(expr, Array) + and not isinstance( + expr, + ( + InputArgumentBase, + DistributedRecv, + NamedArray, + DistributedSendRefHolder))) + + # vim: fdm=marker diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 5b39f83bf..6dcab1d1b 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -59,7 +59,12 @@ ) from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.tags import FunctionIdentifier -from pytato.transform import ArrayOrNames, CachedMapper, InputGatherer +from pytato.transform import ( + ArrayOrNames, + ArrayOrNamesOrFunctionDef, + CachedMapper, + InputGatherer, +) if TYPE_CHECKING: @@ -160,7 +165,7 @@ def emit_subgraph(sg: _SubgraphTree) -> None: class _DotNodeInfo: title: str fields: dict[str, Any] - edges: dict[str, ArrayOrNames | FunctionDefinition] + edges: dict[str, ArrayOrNamesOrFunctionDef] def stringify_tags(tags: frozenset[Tag | None]) -> str: @@ -193,7 +198,7 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: "non_equality_tags": expr.non_equality_tags, } - edges: dict[str, ArrayOrNames | FunctionDefinition] = {} + edges: dict[str, ArrayOrNamesOrFunctionDef] = {} return _DotNodeInfo(title, fields, edges) # type-ignore-reason: incompatible with supertype @@ -297,7 +302,7 @@ def map_einsum(self, expr: Einsum) -> None: self.node_to_dot[expr] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: - edges: dict[str, ArrayOrNames | FunctionDefinition] = {} + edges: dict[str, ArrayOrNamesOrFunctionDef] = {} for name, val in expr._data.items(): edges[name] = val self.rec(val) @@ -308,7 +313,7 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: edges=edges) def map_loopy_call(self, expr: LoopyCall) -> None: - edges: dict[str, ArrayOrNames | FunctionDefinition] = {} + edges: dict[str, ArrayOrNamesOrFunctionDef] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): edges[name] = arg diff --git a/test/test_pytato.py b/test/test_pytato.py index 8746810ae..56716132a 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -770,6 +770,272 @@ def test_large_dag_with_duplicates_count(): dag, count_duplicates=False) +def test_scalar_flop_count(): + from pytato.scalar_expr import FlopCounter + fc = FlopCounter({ + "+": 1, + "*": 1, + "/": 4, + "//": 4, + "%": 4, + "**": 8, + "<": 1, + "min": 1, + "max": 1, + "f": 32}) + + import pymbolic.primitives as prim + from pymbolic import Variable + + x = Variable("x") + y = Variable("y") + + assert fc(Variable("f")(x)) == 32 + + assert fc(x[0]) == 0 + + assert fc(x + 2) == 1 + assert fc(2 + y) == 1 + assert fc(x + y) == 1 + + assert fc(prim.Sum((2, x, y))) == 2 + + assert fc(x - 2) == 1 + assert fc(2 - y) == 2 + assert fc(x - y) == 2 + + assert fc(x * 2) == 1 + assert fc(2 * y) == 1 + assert fc(x * y) == 1 + + assert fc(prim.Product((2, x, y))) == 2 + + assert fc(x.or_(y)) == 0 + assert fc(x.and_(y)) == 0 + + assert fc(x / 2) == 4 + assert fc(2 / y) == 4 + assert fc(x / y) == 4 + + assert fc(x // 2) == 4 + + assert fc(x % 2) == 0 + + assert fc(x ** 3) == 3 + assert fc(x ** 0.3) == 8 + + assert fc(x.lt(y)) == 1 + + assert fc(prim.If(x, x, y)) == 0 + + assert fc(prim.Min((2, x, y))) == 2 + assert fc(prim.Max((2, x, y))) == 2 + + from constantdict import constantdict + + from pytato.reductions import SumReductionOperation + from pytato.scalar_expr import Reduce + + assert fc(Reduce(x, SumReductionOperation(), constantdict({"_0": (0, 10)}))) == 9 + + +def test_flop_count(): + from pytato.analysis import ( + UndefinedOpFlopCountError, + get_default_op_name_to_num_flops, + get_num_flops, + ) + from pytato.tags import ImplStored + + # {{{ basic expression + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = u - v + + # expr[i, j] = 2*(x[i, j] + y[i, j]) + (-1)*3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*6 + + # }}} + + # {{{ expression with operators that don't have default flop counts + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + expr = pt.cmath.exp(x / y) + + with pytest.raises(UndefinedOpFlopCountError): + get_num_flops(expr) + + op_name_to_num_flops = get_default_op_name_to_num_flops() + op_name_to_num_flops.update({ + "/": 4, + "pytato.c99.exp": 8}) + + assert get_num_flops(expr, op_name_to_num_flops) == 40*12 + + # }}} + + # {{{ multiple expressions + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = pt.make_dict_of_named_arrays({"u": u, "v": v}) + + # expr["u"][i, j] = 2*(x[i, j] + y[i, j]) + # expr["v"][i, j] = 3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*4 + + # }}} + + # {{{ subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # expr[i, j] = 2*(x[2*i, j] + y[2*i, j]) + (-1)*3*(x[2*i, j] + y[2*i, j]) + assert get_num_flops(expr) == 20*6 + + # }}} + + # {{{ materialized array + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert get_num_flops(expr) == 40 + 40*4 + + # }}} + + # {{{ materialized array and subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[2*i, j] + (-1)*3*z[2*i, j] + assert get_num_flops(expr) == 40 + 20*4 + + # }}} + + # {{{ einsum + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->ijk", x, y) + + # expr[i, j, k] = x[i, j, k] * y[j, k] + assert get_num_flops(expr) == 24 + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->i", x, y) + + # expr[i] = sum(sum(x[i, j, k] * y[j, k], j), k) + assert get_num_flops(expr) == 2*(4 * (3*1 + 2) + 3) + + # }}} + + +def test_materialized_node_flop_counts(): + from pytato.analysis import get_materialized_node_flop_counts + from pytato.tags import ImplStored + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + materialized_node_to_flop_count = get_materialized_node_flop_counts(expr) + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert len(materialized_node_to_flop_count) == 4 + assert x in materialized_node_to_flop_count + assert y in materialized_node_to_flop_count + assert z in materialized_node_to_flop_count + assert expr.tagged(ImplStored()) in materialized_node_to_flop_count + assert materialized_node_to_flop_count[x] == 0 + assert materialized_node_to_flop_count[y] == 0 + assert materialized_node_to_flop_count[z] == 40 + assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4 + + +def test_unmaterialized_node_flop_counts(): + from pytato.analysis import get_unmaterialized_node_flop_counts + from pytato.tags import ImplStored + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + # Make a reduction over a bunch of expressions that reference z + z = x + y + w = [i*z for i in range(1, 11)] + s = [w[0]] + for w_i in w[1:-1]: + s.append(s[-1] + w_i) + expr = s[-1] + w[-1] + + unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) + + materialized_expr = expr.tagged(ImplStored()) + + # Everything except expr stays unmaterialized + assert len(unmaterialized_node_to_flop_counts) == 1 + 10 + 8 + assert z in unmaterialized_node_to_flop_counts + assert all(w_i in unmaterialized_node_to_flop_counts for w_i in w) + assert all(s_i in unmaterialized_node_to_flop_counts for s_i in s) + flop_counts = unmaterialized_node_to_flop_counts[z] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*10 + assert flop_counts.nflops_if_materialized == 40 + for w_i in w: + flop_counts = unmaterialized_node_to_flop_counts[w_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*2 + assert flop_counts.nflops_if_materialized == 40*2 + for i, s_i in enumerate(s): + flop_counts = unmaterialized_node_to_flop_counts[s_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*2*(i+1) + 40*i + assert flop_counts.nflops_if_materialized == 40*2*(i+1) + 40*i + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4))