From e5e1a95cb0b374da86d4650a9680545322e09f72 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 26 Jan 2024 21:06:12 +0100 Subject: [PATCH] add freeze and typify unit tests. various minor fixes --- src/pystencils/nbackend/arrays.py | 3 + src/pystencils/nbackend/ast/nodes.py | 46 +++++++++++-- .../nbackend/kernelcreation/__init__.py | 20 ++++++ .../nbackend/kernelcreation/context.py | 17 +++-- .../nbackend/kernelcreation/freeze.py | 52 ++++++++------ .../nbackend/kernelcreation/kernelcreation.py | 7 -- .../nbackend/kernelcreation/typification.py | 8 +-- src/pystencils/nbackend/typed_expressions.py | 6 ++ src/pystencils/nbackend/types/__init__.py | 3 + src/pystencils/nbackend/types/basic_types.py | 5 +- src/pystencils/nbackend/types/parsing.py | 2 +- tests/nbackend/test_freeze.py | 67 +++++++++++++++++++ tests/nbackend/test_typification.py | 46 +++++++++++++ 13 files changed, 232 insertions(+), 50 deletions(-) create mode 100644 tests/nbackend/test_freeze.py create mode 100644 tests/nbackend/test_typification.py diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py index 49375209e..b04552001 100644 --- a/src/pystencils/nbackend/arrays.py +++ b/src/pystencils/nbackend/arrays.py @@ -156,6 +156,9 @@ class PsLinearizedArray: def __hash__(self) -> int: return hash(self._hashable_contents()) + + def __repr__(self) -> str: + return f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])" class PsArrayAssocVar(PsTypedVariable, ABC): diff --git a/src/pystencils/nbackend/ast/nodes.py b/src/pystencils/nbackend/ast/nodes.py index 5a20e5835..865fa6dec 100644 --- a/src/pystencils/nbackend/ast/nodes.py +++ b/src/pystencils/nbackend/ast/nodes.py @@ -2,9 +2,11 @@ from __future__ import annotations from typing import Sequence, Iterable, cast, TypeAlias from types import NoneType +from pymbolic.primitives import Variable + from abc import ABC, abstractmethod -from ..typed_expressions import PsTypedVariable, ExprOrConstant +from ..typed_expressions import ExprOrConstant from ..arrays import PsArrayAccess from .util import failing_cast @@ -35,6 +37,15 @@ class PsAstNode(ABC): def set_child(self, idx: int, c: PsAstNode): ... + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsAstNode): + return False + + return type(self) is type(other) and self.children == other.children + + def __hash__(self) -> int: + return hash((type(self), self.children)) + class PsBlock(PsAstNode): __match_args__ = ("statements",) @@ -56,6 +67,10 @@ class PsBlock(PsAstNode): def statements(self, stm: Sequence[PsAstNode]): self._statements = list(stm) + def __repr__(self) -> str: + contents = ", ".join(repr(c) for c in self.children) + return f"PsBlock( {contents} )" + class PsLeafNode(PsAstNode): def get_children(self) -> tuple[PsAstNode, ...]: @@ -81,12 +96,23 @@ class PsExpression(PsLeafNode): def expression(self, expr: ExprOrConstant): self._expr = expr + def __repr__(self) -> str: + return repr(self._expr) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsExpression): + return False + return type(self) is type(other) and self._expr == other._expr + + def __hash__(self) -> int: + return hash((type(self), self._expr)) + class PsLvalueExpr(PsExpression): """Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment""" def __init__(self, expr: PsLvalue): - if not isinstance(expr, (PsTypedVariable, PsArrayAccess)): + if not isinstance(expr, (Variable, PsArrayAccess)): raise TypeError("Expression was not a valid lvalue") super(PsLvalueExpr, self).__init__(expr) @@ -97,19 +123,19 @@ class PsSymbolExpr(PsLvalueExpr): __match_args__ = ("symbol",) - def __init__(self, symbol: PsTypedVariable): + def __init__(self, symbol: Variable): super().__init__(symbol) @property - def symbol(self) -> PsTypedVariable: - return cast(PsTypedVariable, self._expr) + def symbol(self) -> Variable: + return cast(Variable, self._expr) @symbol.setter - def symbol(self, symbol: PsTypedVariable): + def symbol(self, symbol: Variable): self._expr = symbol -PsLvalue: TypeAlias = PsTypedVariable | PsArrayAccess +PsLvalue: TypeAlias = Variable | PsArrayAccess """Types of expressions that may occur on the left-hand side of assignments.""" @@ -151,6 +177,9 @@ class PsAssignment(PsAstNode): else: assert False, "unreachable code" + def __repr__(self) -> str: + return f"PsAssignment({repr(self._lhs)}, {repr(self._rhs)})" + class PsDeclaration(PsAssignment): __match_args__ = ( @@ -186,6 +215,9 @@ class PsDeclaration(PsAssignment): else: assert False, "unreachable code" + def __repr__(self) -> str: + return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})" + class PsLoop(PsAstNode): __match_args__ = ("counter", "start", "stop", "step", "body") diff --git a/src/pystencils/nbackend/kernelcreation/__init__.py b/src/pystencils/nbackend/kernelcreation/__init__.py index e69de29bb..110acb818 100644 --- a/src/pystencils/nbackend/kernelcreation/__init__.py +++ b/src/pystencils/nbackend/kernelcreation/__init__.py @@ -0,0 +1,20 @@ +from .options import KernelCreationOptions +from .kernelcreation import create_kernel + +from .context import KernelCreationContext +from .analysis import KernelAnalysis +from .freeze import FreezeExpressions +from .typification import Typifier + +from .iteration_space import FullIterationSpace, SparseIterationSpace + +__all__ = [ + "KernelCreationOptions", + "create_kernel", + "KernelCreationContext", + "KernelAnalysis", + "FreezeExpressions", + "Typifier", + "FullIterationSpace", + "SparseIterationSpace", +] diff --git a/src/pystencils/nbackend/kernelcreation/context.py b/src/pystencils/nbackend/kernelcreation/context.py index 40cd2448f..7e4fad9ba 100644 --- a/src/pystencils/nbackend/kernelcreation/context.py +++ b/src/pystencils/nbackend/kernelcreation/context.py @@ -1,6 +1,5 @@ from __future__ import annotations from typing import cast -from dataclasses import dataclass from ...field import Field, FieldType @@ -16,12 +15,12 @@ from .options import KernelCreationOptions from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace -@dataclass class FieldsInKernel: - domain_fields: set[Field] = set() - index_fields: set[Field] = set() - custom_fields: set[Field] = set() - buffer_fields: set[Field] = set() + def __init__(self) -> None: + self.domain_fields: set[Field] = set() + self.index_fields: set[Field] = set() + self.custom_fields: set[Field] = set() + self.buffer_fields: set[Field] = set() class KernelCreationContext: @@ -70,6 +69,8 @@ class KernelCreationContext: def constraints(self) -> tuple[PsKernelConstraint, ...]: return tuple(self._constraints) + # Fields and Arrays + @property def fields(self) -> FieldsInKernel: return self._fields_collection @@ -113,7 +114,9 @@ class KernelCreationContext: self._arrays[field] = arr - return self._arrays[field] + return self._arrays[field] + + # Iteration Space def set_iteration_space(self, ispace: IterationSpace): if self._ispace is not None: diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py index e9fc6ccd7..ef4ad8ea3 100644 --- a/src/pystencils/nbackend/kernelcreation/freeze.py +++ b/src/pystencils/nbackend/kernelcreation/freeze.py @@ -2,11 +2,18 @@ import pymbolic.primitives as pb from pymbolic.interop.sympy import SympyToPymbolicMapper from ...field import Field, FieldType +from ...typing import BasicType from .context import KernelCreationContext -from ..ast.nodes import PsAssignment -from ..types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType +from ..ast.nodes import ( + PsAssignment, + PsDeclaration, + PsSymbolExpr, + PsLvalueExpr, + PsExpression, +) +from ..types import constify, make_type from ..typed_expressions import PsTypedVariable from ..arrays import PsArrayAccess @@ -18,19 +25,21 @@ class FreezeExpressions(SympyToPymbolicMapper): def map_Assignment(self, expr): # noqa lhs = self.rec(expr.lhs) rhs = self.rec(expr.rhs) - return PsAssignment(lhs, rhs) - - def map_BasicType(self, expr): - width = expr.numpy_dtype.itemsize * 8 - const = expr.const - if expr.is_float(): - return PsIeeeFloatType(width, const) - elif expr.is_uint(): - return PsUnsignedIntegerType(width, const) - elif expr.is_int(): - return PsSignedIntegerType(width, const) + + if isinstance(lhs, pb.Variable): + return PsDeclaration(PsSymbolExpr(lhs), PsExpression(rhs)) + elif isinstance(lhs, PsArrayAccess): + return PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs)) + else: + assert False, "That should not have happened." + + def map_BasicType(self, expr: BasicType): + # TODO: This should not be necessary; the frontend should use the new type system. + dtype = make_type(expr.numpy_dtype.type) + if expr.const: + return constify(dtype) else: - raise NotImplementedError("Data type not supported.") + return dtype def map_FieldShapeSymbol(self, expr): dtype = self.rec(expr.dtype) @@ -53,7 +62,10 @@ class FreezeExpressions(SympyToPymbolicMapper): case FieldType.GENERIC: # Add the iteration counters offsets = [ - i + o for i, o in zip(self._ctx.get_iteration_space().spatial_indices, offsets) + i + o + for i, o in zip( + self._ctx.get_iteration_space().spatial_indices, offsets + ) ] case FieldType.INDEXED: # flake8: noqa @@ -68,11 +80,11 @@ class FreezeExpressions(SympyToPymbolicMapper): f"Cannot translate accesses to field type {unknown} yet." ) - index = pb.Sum( - tuple( - idx * stride - for idx, stride in zip(offsets + indices, array.strides, strict=True) - ) + summands = tuple( + idx * stride + for idx, stride in zip(offsets + indices, array.strides, strict=True) ) + index = summands[0] if len(summands) == 1 else pb.Sum(summands) + return PsArrayAccess(ptr, index) diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py index 732b03459..617a152fa 100644 --- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py +++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py @@ -19,14 +19,11 @@ from .iteration_space import ( def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions): - # 1. Prepare context ctx = KernelCreationContext(options) - # 2. Check kernel constraints and collect knowledge analysis = KernelAnalysis(ctx) analysis(assignments) - # 3. Create iteration space ispace: IterationSpace = ( create_sparse_iteration_space(ctx, assignments) if len(ctx.fields.index_fields) > 0 @@ -35,13 +32,9 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti ctx.set_iteration_space(ispace) - # 4. Freeze assignments - # This call is the same for both domain and indexed kernels freeze = FreezeExpressions(ctx) kernel_body: PsBlock = freeze(assignments) - # 5. Typify - # Also the same for both types of kernels typify = Typifier(ctx) kernel_body = typify(kernel_body) diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py index a34213623..f914fd0ce 100644 --- a/src/pystencils/nbackend/kernelcreation/typification.py +++ b/src/pystencils/nbackend/kernelcreation/typification.py @@ -105,9 +105,9 @@ class Typifier(Mapper): def map_array_access( self, access: PsArrayAccess, target_type: PsNumericType | None ) -> tuple[PsArrayAccess, PsNumericType]: - self._check_target_type(access, access.array.element_type, target_type) + self._check_target_type(access, access.dtype, target_type) index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype) - return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, access.array.element_type) + return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, access.dtype) # Arithmetic Expressions @@ -116,7 +116,7 @@ class Typifier(Mapper): expr: pb.Expression, args: Sequence[Any], target_type: PsNumericType | None, - ) -> tuple[Sequence[ExprOrConstant], PsNumericType]: + ) -> tuple[tuple[ExprOrConstant], PsNumericType]: """Typify all arguments of a multi-argument expression with the same type.""" new_args = [None] * len(args) common_type: PsNumericType | None = None @@ -134,7 +134,7 @@ class Typifier(Mapper): assert common_type is not None - return cast(Sequence[ExprOrConstant], new_args), common_type + return cast(tuple[ExprOrConstant], tuple(new_args)), common_type def map_sum( self, expr: pb.Sum, target_type: PsNumericType | None diff --git a/src/pystencils/nbackend/typed_expressions.py b/src/pystencils/nbackend/typed_expressions.py index 94aa75cf4..2b1f3f17d 100644 --- a/src/pystencils/nbackend/typed_expressions.py +++ b/src/pystencils/nbackend/typed_expressions.py @@ -80,6 +80,8 @@ class PsTypedConstant: Usage of `//` and the pymbolic `FloorDiv` is illegal. """ + __match_args__ = ("value", "dtype") + @staticmethod def try_create(value: Any, dtype: PsNumericType): try: @@ -100,6 +102,10 @@ class PsTypedConstant: self._dtype = constify(dtype) self._value = self._dtype.create_constant(value) + @property + def value(self) -> Any: + return self._value + @property def dtype(self) -> PsNumericType: return self._dtype diff --git a/src/pystencils/nbackend/types/__init__.py b/src/pystencils/nbackend/types/__init__.py index 1f15c4516..c398aea9d 100644 --- a/src/pystencils/nbackend/types/__init__.py +++ b/src/pystencils/nbackend/types/__init__.py @@ -12,6 +12,8 @@ from .basic_types import ( deconstify, ) +from .quick import make_type + from .exception import PsTypeError __all__ = [ @@ -26,5 +28,6 @@ __all__ = [ "PsIeeeFloatType", "constify", "deconstify", + "make_type", "PsTypeError", ] diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py index ad123148e..e6b918080 100644 --- a/src/pystencils/nbackend/types/basic_types.py +++ b/src/pystencils/nbackend/types/basic_types.py @@ -381,10 +381,7 @@ class PsIeeeFloatType(PsScalarType): def create_constant(self, value: Any) -> Any: np_type = self.NUMPY_TYPES[self._width] - if isinstance(value, int) and value in (0, 1, -1): - return np_type(value) - - if isinstance(value, float): + if isinstance(value, int) or isinstance(value, float): return np_type(value) if isinstance(value, np_type): diff --git a/src/pystencils/nbackend/types/parsing.py b/src/pystencils/nbackend/types/parsing.py index 8a5e687aa..14db20a92 100644 --- a/src/pystencils/nbackend/types/parsing.py +++ b/src/pystencils/nbackend/types/parsing.py @@ -68,7 +68,7 @@ def parse_type_string(s: str) -> PsAbstractType: raise ValueError(f"Could not parse token '{s}' as C type.") case _: - raise ValueError(f"Could not parse token '{s}`' as C type.") + raise ValueError(f"Could not parse token '{s}' as C type.") def parse_type_name(typename: str, const: bool): diff --git a/tests/nbackend/test_freeze.py b/tests/nbackend/test_freeze.py new file mode 100644 index 000000000..db8f4feb2 --- /dev/null +++ b/tests/nbackend/test_freeze.py @@ -0,0 +1,67 @@ +import sympy as sp +import pymbolic.primitives as pb + +from pystencils import Assignment, fields + +from pystencils.nbackend.ast import ( + PsAssignment, + PsDeclaration, + PsExpression, + PsSymbolExpr, + PsLvalueExpr, +) +from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable +from pystencils.nbackend.arrays import PsArrayAccess +from pystencils.nbackend.kernelcreation import ( + KernelCreationOptions, + KernelCreationContext, + FreezeExpressions, + FullIterationSpace, +) + + +def test_freeze_simple(): + options = KernelCreationOptions() + ctx = KernelCreationContext(options) + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + asm = Assignment(z, 2 * x + y) + + fasm = freeze(asm) + + pb_x, pb_y, pb_z = pb.variables("x y z") + + assert fasm == PsDeclaration(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x)) + assert fasm != PsAssignment(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x)) + + +def test_freeze_fields(): + options = KernelCreationOptions() + ctx = KernelCreationContext(options) + + start = PsTypedConstant(0, ctx.index_dtype) + stop = PsTypedConstant(42, ctx.index_dtype) + step = PsTypedConstant(1, ctx.index_dtype) + counter = PsTypedVariable("ctr", ctx.index_dtype) + ispace = FullIterationSpace( + ctx, [FullIterationSpace.Dimension(start, stop, step, counter)] + ) + ctx.set_iteration_space(ispace) + + freeze = FreezeExpressions(ctx) + + f, g = fields("f, g : [1D]") + asm = Assignment(f.center(0), g.center(0)) + + f_arr = ctx.get_array(f) + g_arr = ctx.get_array(g) + + fasm = freeze(asm) + + lhs = PsArrayAccess(f_arr.base_pointer, counter * f_arr.strides[0]) + rhs = PsArrayAccess(g_arr.base_pointer, counter * g_arr.strides[0]) + + should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs)) + + assert fasm == should diff --git a/tests/nbackend/test_typification.py b/tests/nbackend/test_typification.py new file mode 100644 index 000000000..f9e8ab517 --- /dev/null +++ b/tests/nbackend/test_typification.py @@ -0,0 +1,46 @@ +import pytest +import sympy as sp +import pymbolic.primitives as pb + +from pystencils import Assignment + +from pystencils.nbackend.ast import PsDeclaration +from pystencils.nbackend.types import constify +from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable +from pystencils.nbackend.kernelcreation.options import KernelCreationOptions +from pystencils.nbackend.kernelcreation.context import KernelCreationContext +from pystencils.nbackend.kernelcreation.freeze import FreezeExpressions +from pystencils.nbackend.kernelcreation.typification import Typifier + + +def test_typify_simple(): + options = KernelCreationOptions() + ctx = KernelCreationContext(options) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y, z = sp.symbols("x, y, z") + asm = Assignment(z, 2 * x + y) + + fasm = freeze(asm) + fasm = typify(fasm) + + assert isinstance(fasm, PsDeclaration) + + def check(expr): + match expr: + case PsTypedConstant(value, dtype): + assert value == 2 + assert dtype == constify(ctx.options.default_dtype) + case PsTypedVariable(name, dtype): + assert name in "xyz" + assert dtype == ctx.options.default_dtype + case pb.Variable: + pytest.fail("Encountered untyped variable") + case pb.Sum(cs) | pb.Product(cs): + [check(c) for c in cs] + case _: + pytest.fail("Non-exhaustive pattern matcher.") + + check(fasm.lhs.expression) + check(fasm.rhs.expression) -- GitLab