From e5e1a95cb0b374da86d4650a9680545322e09f72 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 26 Jan 2024 21:06:12 +0100
Subject: [PATCH] add freeze and typify unit tests. various minor fixes

---
 src/pystencils/nbackend/arrays.py             |  3 +
 src/pystencils/nbackend/ast/nodes.py          | 46 +++++++++++--
 .../nbackend/kernelcreation/__init__.py       | 20 ++++++
 .../nbackend/kernelcreation/context.py        | 17 +++--
 .../nbackend/kernelcreation/freeze.py         | 52 ++++++++------
 .../nbackend/kernelcreation/kernelcreation.py |  7 --
 .../nbackend/kernelcreation/typification.py   |  8 +--
 src/pystencils/nbackend/typed_expressions.py  |  6 ++
 src/pystencils/nbackend/types/__init__.py     |  3 +
 src/pystencils/nbackend/types/basic_types.py  |  5 +-
 src/pystencils/nbackend/types/parsing.py      |  2 +-
 tests/nbackend/test_freeze.py                 | 67 +++++++++++++++++++
 tests/nbackend/test_typification.py           | 46 +++++++++++++
 13 files changed, 232 insertions(+), 50 deletions(-)
 create mode 100644 tests/nbackend/test_freeze.py
 create mode 100644 tests/nbackend/test_typification.py

diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py
index 49375209e..b04552001 100644
--- a/src/pystencils/nbackend/arrays.py
+++ b/src/pystencils/nbackend/arrays.py
@@ -156,6 +156,9 @@ class PsLinearizedArray:
 
     def __hash__(self) -> int:
         return hash(self._hashable_contents())
+    
+    def __repr__(self) -> str:
+        return f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])"
 
 
 class PsArrayAssocVar(PsTypedVariable, ABC):
diff --git a/src/pystencils/nbackend/ast/nodes.py b/src/pystencils/nbackend/ast/nodes.py
index 5a20e5835..865fa6dec 100644
--- a/src/pystencils/nbackend/ast/nodes.py
+++ b/src/pystencils/nbackend/ast/nodes.py
@@ -2,9 +2,11 @@ from __future__ import annotations
 from typing import Sequence, Iterable, cast, TypeAlias
 from types import NoneType
 
+from pymbolic.primitives import Variable
+
 from abc import ABC, abstractmethod
 
-from ..typed_expressions import PsTypedVariable, ExprOrConstant
+from ..typed_expressions import ExprOrConstant
 from ..arrays import PsArrayAccess
 from .util import failing_cast
 
@@ -35,6 +37,15 @@ class PsAstNode(ABC):
     def set_child(self, idx: int, c: PsAstNode):
         ...
 
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsAstNode):
+            return False
+        
+        return type(self) is type(other) and self.children == other.children
+    
+    def __hash__(self) -> int:
+        return hash((type(self), self.children))
+
 
 class PsBlock(PsAstNode):
     __match_args__ = ("statements",)
@@ -56,6 +67,10 @@ class PsBlock(PsAstNode):
     def statements(self, stm: Sequence[PsAstNode]):
         self._statements = list(stm)
 
+    def __repr__(self) -> str:
+        contents = ", ".join(repr(c) for c in self.children)
+        return f"PsBlock( {contents} )"
+
 
 class PsLeafNode(PsAstNode):
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -81,12 +96,23 @@ class PsExpression(PsLeafNode):
     def expression(self, expr: ExprOrConstant):
         self._expr = expr
 
+    def __repr__(self) -> str:
+        return repr(self._expr)
+    
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsExpression):
+            return False
+        return type(self) is type(other) and self._expr == other._expr
+    
+    def __hash__(self) -> int:
+        return hash((type(self), self._expr))
+
 
 class PsLvalueExpr(PsExpression):
     """Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment"""
 
     def __init__(self, expr: PsLvalue):
-        if not isinstance(expr, (PsTypedVariable, PsArrayAccess)):
+        if not isinstance(expr, (Variable, PsArrayAccess)):
             raise TypeError("Expression was not a valid lvalue")
 
         super(PsLvalueExpr, self).__init__(expr)
@@ -97,19 +123,19 @@ class PsSymbolExpr(PsLvalueExpr):
 
     __match_args__ = ("symbol",)
 
-    def __init__(self, symbol: PsTypedVariable):
+    def __init__(self, symbol: Variable):
         super().__init__(symbol)
 
     @property
-    def symbol(self) -> PsTypedVariable:
-        return cast(PsTypedVariable, self._expr)
+    def symbol(self) -> Variable:
+        return cast(Variable, self._expr)
 
     @symbol.setter
-    def symbol(self, symbol: PsTypedVariable):
+    def symbol(self, symbol: Variable):
         self._expr = symbol
 
 
-PsLvalue: TypeAlias = PsTypedVariable | PsArrayAccess
+PsLvalue: TypeAlias = Variable | PsArrayAccess
 """Types of expressions that may occur on the left-hand side of assignments."""
 
 
@@ -151,6 +177,9 @@ class PsAssignment(PsAstNode):
         else:
             assert False, "unreachable code"
 
+    def __repr__(self) -> str:
+        return f"PsAssignment({repr(self._lhs)}, {repr(self._rhs)})"
+
 
 class PsDeclaration(PsAssignment):
     __match_args__ = (
@@ -186,6 +215,9 @@ class PsDeclaration(PsAssignment):
         else:
             assert False, "unreachable code"
 
+    def __repr__(self) -> str:
+        return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})"
+
 
 class PsLoop(PsAstNode):
     __match_args__ = ("counter", "start", "stop", "step", "body")
diff --git a/src/pystencils/nbackend/kernelcreation/__init__.py b/src/pystencils/nbackend/kernelcreation/__init__.py
index e69de29bb..110acb818 100644
--- a/src/pystencils/nbackend/kernelcreation/__init__.py
+++ b/src/pystencils/nbackend/kernelcreation/__init__.py
@@ -0,0 +1,20 @@
+from .options import KernelCreationOptions
+from .kernelcreation import create_kernel
+
+from .context import KernelCreationContext
+from .analysis import KernelAnalysis
+from .freeze import FreezeExpressions
+from .typification import Typifier
+
+from .iteration_space import FullIterationSpace, SparseIterationSpace
+
+__all__ = [
+    "KernelCreationOptions",
+    "create_kernel",
+    "KernelCreationContext",
+    "KernelAnalysis",
+    "FreezeExpressions",
+    "Typifier",
+    "FullIterationSpace",
+    "SparseIterationSpace",
+]
diff --git a/src/pystencils/nbackend/kernelcreation/context.py b/src/pystencils/nbackend/kernelcreation/context.py
index 40cd2448f..7e4fad9ba 100644
--- a/src/pystencils/nbackend/kernelcreation/context.py
+++ b/src/pystencils/nbackend/kernelcreation/context.py
@@ -1,6 +1,5 @@
 from __future__ import annotations
 from typing import cast
-from dataclasses import dataclass
 
 
 from ...field import Field, FieldType
@@ -16,12 +15,12 @@ from .options import KernelCreationOptions
 from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace
 
 
-@dataclass
 class FieldsInKernel:
-    domain_fields: set[Field] = set()
-    index_fields: set[Field] = set()
-    custom_fields: set[Field] = set()
-    buffer_fields: set[Field] = set()
+    def __init__(self) -> None:
+        self.domain_fields: set[Field] = set()
+        self.index_fields: set[Field] = set()
+        self.custom_fields: set[Field] = set()
+        self.buffer_fields: set[Field] = set()
 
 
 class KernelCreationContext:
@@ -70,6 +69,8 @@ class KernelCreationContext:
     def constraints(self) -> tuple[PsKernelConstraint, ...]:
         return tuple(self._constraints)
 
+    #   Fields and Arrays
+
     @property
     def fields(self) -> FieldsInKernel:
         return self._fields_collection
@@ -113,7 +114,9 @@ class KernelCreationContext:
 
             self._arrays[field] = arr
 
-        return self._arrays[field]
+        return self._arrays[field]    
+
+    #   Iteration Space
 
     def set_iteration_space(self, ispace: IterationSpace):
         if self._ispace is not None:
diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py
index e9fc6ccd7..ef4ad8ea3 100644
--- a/src/pystencils/nbackend/kernelcreation/freeze.py
+++ b/src/pystencils/nbackend/kernelcreation/freeze.py
@@ -2,11 +2,18 @@ import pymbolic.primitives as pb
 from pymbolic.interop.sympy import SympyToPymbolicMapper
 
 from ...field import Field, FieldType
+from ...typing import BasicType
 
 from .context import KernelCreationContext
 
-from ..ast.nodes import PsAssignment
-from ..types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
+from ..ast.nodes import (
+    PsAssignment,
+    PsDeclaration,
+    PsSymbolExpr,
+    PsLvalueExpr,
+    PsExpression,
+)
+from ..types import constify, make_type
 from ..typed_expressions import PsTypedVariable
 from ..arrays import PsArrayAccess
 
@@ -18,19 +25,21 @@ class FreezeExpressions(SympyToPymbolicMapper):
     def map_Assignment(self, expr):  # noqa
         lhs = self.rec(expr.lhs)
         rhs = self.rec(expr.rhs)
-        return PsAssignment(lhs, rhs)
-
-    def map_BasicType(self, expr):
-        width = expr.numpy_dtype.itemsize * 8
-        const = expr.const
-        if expr.is_float():
-            return PsIeeeFloatType(width, const)
-        elif expr.is_uint():
-            return PsUnsignedIntegerType(width, const)
-        elif expr.is_int():
-            return PsSignedIntegerType(width, const)
+
+        if isinstance(lhs, pb.Variable):
+            return PsDeclaration(PsSymbolExpr(lhs), PsExpression(rhs))
+        elif isinstance(lhs, PsArrayAccess):
+            return PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs))
+        else:
+            assert False, "That should not have happened."
+
+    def map_BasicType(self, expr: BasicType):
+        #   TODO: This should not be necessary; the frontend should use the new type system.
+        dtype = make_type(expr.numpy_dtype.type)
+        if expr.const:
+            return constify(dtype)
         else:
-            raise NotImplementedError("Data type not supported.")
+            return dtype
 
     def map_FieldShapeSymbol(self, expr):
         dtype = self.rec(expr.dtype)
@@ -53,7 +62,10 @@ class FreezeExpressions(SympyToPymbolicMapper):
                 case FieldType.GENERIC:
                     #   Add the iteration counters
                     offsets = [
-                        i + o for i, o in zip(self._ctx.get_iteration_space().spatial_indices, offsets)
+                        i + o
+                        for i, o in zip(
+                            self._ctx.get_iteration_space().spatial_indices, offsets
+                        )
                     ]
                 case FieldType.INDEXED:
                     # flake8: noqa
@@ -68,11 +80,11 @@ class FreezeExpressions(SympyToPymbolicMapper):
                         f"Cannot translate accesses to field type {unknown} yet."
                     )
 
-        index = pb.Sum(
-            tuple(
-                idx * stride
-                for idx, stride in zip(offsets + indices, array.strides, strict=True)
-            )
+        summands = tuple(
+            idx * stride
+            for idx, stride in zip(offsets + indices, array.strides, strict=True)
         )
 
+        index = summands[0] if len(summands) == 1 else pb.Sum(summands)
+
         return PsArrayAccess(ptr, index)
diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py
index 732b03459..617a152fa 100644
--- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py
+++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py
@@ -19,14 +19,11 @@ from .iteration_space import (
 
 
 def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions):
-    #   1. Prepare context
     ctx = KernelCreationContext(options)
 
-    #   2. Check kernel constraints and collect knowledge
     analysis = KernelAnalysis(ctx)
     analysis(assignments)
 
-    #   3. Create iteration space
     ispace: IterationSpace = (
         create_sparse_iteration_space(ctx, assignments)
         if len(ctx.fields.index_fields) > 0
@@ -35,13 +32,9 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
 
     ctx.set_iteration_space(ispace)
 
-    #   4. Freeze assignments
-    #   This call is the same for both domain and indexed kernels
     freeze = FreezeExpressions(ctx)
     kernel_body: PsBlock = freeze(assignments)
 
-    #   5. Typify
-    #   Also the same for both types of kernels
     typify = Typifier(ctx)
     kernel_body = typify(kernel_body)
 
diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py
index a34213623..f914fd0ce 100644
--- a/src/pystencils/nbackend/kernelcreation/typification.py
+++ b/src/pystencils/nbackend/kernelcreation/typification.py
@@ -105,9 +105,9 @@ class Typifier(Mapper):
     def map_array_access(
         self, access: PsArrayAccess, target_type: PsNumericType | None
     ) -> tuple[PsArrayAccess, PsNumericType]:
-        self._check_target_type(access, access.array.element_type, target_type)
+        self._check_target_type(access, access.dtype, target_type)
         index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype)
-        return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, access.array.element_type)
+        return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, access.dtype)
 
     #   Arithmetic Expressions
 
@@ -116,7 +116,7 @@ class Typifier(Mapper):
         expr: pb.Expression,
         args: Sequence[Any],
         target_type: PsNumericType | None,
-    ) -> tuple[Sequence[ExprOrConstant], PsNumericType]:
+    ) -> tuple[tuple[ExprOrConstant], PsNumericType]:
         """Typify all arguments of a multi-argument expression with the same type."""
         new_args = [None] * len(args)
         common_type: PsNumericType | None = None
@@ -134,7 +134,7 @@ class Typifier(Mapper):
 
         assert common_type is not None
 
-        return cast(Sequence[ExprOrConstant], new_args), common_type
+        return cast(tuple[ExprOrConstant], tuple(new_args)), common_type
 
     def map_sum(
         self, expr: pb.Sum, target_type: PsNumericType | None
diff --git a/src/pystencils/nbackend/typed_expressions.py b/src/pystencils/nbackend/typed_expressions.py
index 94aa75cf4..2b1f3f17d 100644
--- a/src/pystencils/nbackend/typed_expressions.py
+++ b/src/pystencils/nbackend/typed_expressions.py
@@ -80,6 +80,8 @@ class PsTypedConstant:
     Usage of `//` and the pymbolic `FloorDiv` is illegal.
     """
 
+    __match_args__ = ("value", "dtype")
+
     @staticmethod
     def try_create(value: Any, dtype: PsNumericType):
         try:
@@ -100,6 +102,10 @@ class PsTypedConstant:
         self._dtype = constify(dtype)
         self._value = self._dtype.create_constant(value)
 
+    @property
+    def value(self) -> Any:
+        return self._value
+
     @property
     def dtype(self) -> PsNumericType:
         return self._dtype
diff --git a/src/pystencils/nbackend/types/__init__.py b/src/pystencils/nbackend/types/__init__.py
index 1f15c4516..c398aea9d 100644
--- a/src/pystencils/nbackend/types/__init__.py
+++ b/src/pystencils/nbackend/types/__init__.py
@@ -12,6 +12,8 @@ from .basic_types import (
     deconstify,
 )
 
+from .quick import make_type
+
 from .exception import PsTypeError
 
 __all__ = [
@@ -26,5 +28,6 @@ __all__ = [
     "PsIeeeFloatType",
     "constify",
     "deconstify",
+    "make_type",
     "PsTypeError",
 ]
diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py
index ad123148e..e6b918080 100644
--- a/src/pystencils/nbackend/types/basic_types.py
+++ b/src/pystencils/nbackend/types/basic_types.py
@@ -381,10 +381,7 @@ class PsIeeeFloatType(PsScalarType):
     def create_constant(self, value: Any) -> Any:
         np_type = self.NUMPY_TYPES[self._width]
 
-        if isinstance(value, int) and value in (0, 1, -1):
-            return np_type(value)
-
-        if isinstance(value, float):
+        if isinstance(value, int) or isinstance(value, float):
             return np_type(value)
 
         if isinstance(value, np_type):
diff --git a/src/pystencils/nbackend/types/parsing.py b/src/pystencils/nbackend/types/parsing.py
index 8a5e687aa..14db20a92 100644
--- a/src/pystencils/nbackend/types/parsing.py
+++ b/src/pystencils/nbackend/types/parsing.py
@@ -68,7 +68,7 @@ def parse_type_string(s: str) -> PsAbstractType:
                     raise ValueError(f"Could not parse token '{s}' as C type.")
 
         case _:
-            raise ValueError(f"Could not parse token '{s}`' as C type.")
+            raise ValueError(f"Could not parse token '{s}' as C type.")
 
 
 def parse_type_name(typename: str, const: bool):
diff --git a/tests/nbackend/test_freeze.py b/tests/nbackend/test_freeze.py
new file mode 100644
index 000000000..db8f4feb2
--- /dev/null
+++ b/tests/nbackend/test_freeze.py
@@ -0,0 +1,67 @@
+import sympy as sp
+import pymbolic.primitives as pb
+
+from pystencils import Assignment, fields
+
+from pystencils.nbackend.ast import (
+    PsAssignment,
+    PsDeclaration,
+    PsExpression,
+    PsSymbolExpr,
+    PsLvalueExpr,
+)
+from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable
+from pystencils.nbackend.arrays import PsArrayAccess
+from pystencils.nbackend.kernelcreation import (
+    KernelCreationOptions,
+    KernelCreationContext,
+    FreezeExpressions,
+    FullIterationSpace,
+)
+
+
+def test_freeze_simple():
+    options = KernelCreationOptions()
+    ctx = KernelCreationContext(options)
+    freeze = FreezeExpressions(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+    asm = Assignment(z, 2 * x + y)
+
+    fasm = freeze(asm)
+
+    pb_x, pb_y, pb_z = pb.variables("x y z")
+
+    assert fasm == PsDeclaration(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x))
+    assert fasm != PsAssignment(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x))
+
+
+def test_freeze_fields():
+    options = KernelCreationOptions()
+    ctx = KernelCreationContext(options)
+
+    start = PsTypedConstant(0, ctx.index_dtype)
+    stop = PsTypedConstant(42, ctx.index_dtype)
+    step = PsTypedConstant(1, ctx.index_dtype)
+    counter = PsTypedVariable("ctr", ctx.index_dtype)
+    ispace = FullIterationSpace(
+        ctx, [FullIterationSpace.Dimension(start, stop, step, counter)]
+    )
+    ctx.set_iteration_space(ispace)
+
+    freeze = FreezeExpressions(ctx)
+
+    f, g = fields("f, g : [1D]")
+    asm = Assignment(f.center(0), g.center(0))
+
+    f_arr = ctx.get_array(f)
+    g_arr = ctx.get_array(g)
+
+    fasm = freeze(asm)
+
+    lhs = PsArrayAccess(f_arr.base_pointer, counter * f_arr.strides[0])
+    rhs = PsArrayAccess(g_arr.base_pointer, counter * g_arr.strides[0])
+
+    should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs))
+
+    assert fasm == should
diff --git a/tests/nbackend/test_typification.py b/tests/nbackend/test_typification.py
new file mode 100644
index 000000000..f9e8ab517
--- /dev/null
+++ b/tests/nbackend/test_typification.py
@@ -0,0 +1,46 @@
+import pytest
+import sympy as sp
+import pymbolic.primitives as pb
+
+from pystencils import Assignment
+
+from pystencils.nbackend.ast import PsDeclaration
+from pystencils.nbackend.types import constify
+from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable
+from pystencils.nbackend.kernelcreation.options import KernelCreationOptions
+from pystencils.nbackend.kernelcreation.context import KernelCreationContext
+from pystencils.nbackend.kernelcreation.freeze import FreezeExpressions
+from pystencils.nbackend.kernelcreation.typification import Typifier
+
+
+def test_typify_simple():
+    options = KernelCreationOptions()
+    ctx = KernelCreationContext(options)
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+    asm = Assignment(z, 2 * x + y)
+
+    fasm = freeze(asm)
+    fasm = typify(fasm)
+
+    assert isinstance(fasm, PsDeclaration)
+
+    def check(expr):
+        match expr:
+            case PsTypedConstant(value, dtype):
+                assert value == 2
+                assert dtype == constify(ctx.options.default_dtype)
+            case PsTypedVariable(name, dtype):
+                assert name in "xyz"
+                assert dtype == ctx.options.default_dtype
+            case pb.Variable:
+                pytest.fail("Encountered untyped variable")
+            case pb.Sum(cs) | pb.Product(cs):
+                [check(c) for c in cs]
+            case _:
+                pytest.fail("Non-exhaustive pattern matcher.")
+
+    check(fasm.lhs.expression)
+    check(fasm.rhs.expression)
-- 
GitLab