From bcd147bbb7284642d187fe3713634f1a54670b06 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 16 Jan 2024 18:51:17 +0100
Subject: [PATCH] first compilation and execution test

---
 src/pystencils/nbackend/arrays.py             | 39 ++++++++-
 src/pystencils/nbackend/ast/kernelfunction.py | 17 ++--
 .../nbackend/jit/cpu_extension_module.py      | 49 ++++++++++-
 src/pystencils/nbackend/typed_expressions.py  | 16 +++-
 tests/nbackend/test_basic_printing.py         |  2 +-
 tests/nbackend/test_cpujit.py                 | 82 +++++++++++++++++++
 tests/nbackend/test_expressions.py            | 46 +++++++++++
 7 files changed, 235 insertions(+), 16 deletions(-)
 create mode 100644 tests/nbackend/test_cpujit.py
 create mode 100644 tests/nbackend/test_expressions.py

diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py
index 368c2940b..5cab54abc 100644
--- a/src/pystencils/nbackend/arrays.py
+++ b/src/pystencils/nbackend/arrays.py
@@ -56,7 +56,7 @@ from .types import (
     constify,
 )
 
-from .typed_expressions import PsTypedVariable, PsTypedConstant
+from .typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
 
 
 class PsLinearizedArray:
@@ -86,7 +86,9 @@ class PsLinearizedArray:
         if offsets is None:
             offsets = (0,) * dim
 
+        self._dim = dim
         self._offsets = tuple(PsTypedConstant(o, index_dtype) for o in offsets)
+        self._index_dtype = index_dtype
 
     @property
     def name(self):
@@ -100,6 +102,10 @@ class PsLinearizedArray:
     def strides(self):
         return self._strides
 
+    @property
+    def dim(self):
+        return self._dim
+
     @property
     def element_type(self):
         return self._element_type
@@ -108,6 +114,35 @@ class PsLinearizedArray:
     def offsets(self) -> tuple[PsTypedConstant, ...]:
         return self._offsets
 
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsLinearizedArray):
+            return False
+
+        return (
+            self._name,
+            self._element_type,
+            self._dim,
+            self._offsets,
+            self._index_dtype,
+        ) == (
+            other._name,
+            other._element_type,
+            other._dim,
+            other._offsets,
+            other._index_dtype,
+        )
+
+    def __hash__(self) -> int:
+        return hash(
+            (
+                self._name,
+                self._element_type,
+                self._dim,
+                self._offsets,
+                self._index_dtype,
+            )
+        )
+
 
 class PsArrayAssocVar(PsTypedVariable, ABC):
     """A variable that is associated to an array.
@@ -180,7 +215,7 @@ class PsArrayStrideVar(PsArrayAssocVar):
 
 
 class PsArrayAccess(pb.Subscript):
-    def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression):
+    def __init__(self, base_ptr: PsArrayBasePointer, index: ExprOrConstant):
         super(PsArrayAccess, self).__init__(base_ptr, index)
         self._base_ptr = base_ptr
         self._index = index
diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py
index 4aa963c1b..060947a82 100644
--- a/src/pystencils/nbackend/ast/kernelfunction.py
+++ b/src/pystencils/nbackend/ast/kernelfunction.py
@@ -16,7 +16,7 @@ from ...enums import Target
 @dataclass
 class PsKernelParametersSpec:
     """Specification of a kernel function's parameters.
-    
+
     Contains:
         - Verbatim parameter list, a list of `PsTypedVariables`
         - List of Arrays used in the kernel, in canonical order
@@ -31,9 +31,9 @@ class PsKernelParametersSpec:
     def params_for_array(self, arr: PsLinearizedArray):
         def pred(p: PsTypedVariable):
             return isinstance(p, PsArrayAssocVar) and p.array == arr
-        
+
         return tuple(filter(pred, self.params))
-    
+
     def __post_init__(self):
         dep_mapper = DependencyMapper(False, False, False, False)
 
@@ -59,7 +59,7 @@ class PsKernelParametersSpec:
 
 class PsKernelFunction(PsAstNode):
     """A pystencils kernel function.
-    
+
     Objects of this class represent a full pystencils kernel and should provide all information required for
     export, compilation, and inclusion of the kernel into a runtime system.
     """
@@ -98,13 +98,12 @@ class PsKernelFunction(PsAstNode):
     def function_name(self) -> str:
         """For backward compatibility."""
         return self._name
-    
+
     @property
     def instruction_set(self) -> str | None:
         """For backward compatibility"""
         return None
 
-
     def num_children(self) -> int:
         return 1
 
@@ -136,8 +135,10 @@ class PsKernelFunction(PsAstNode):
         params_list = sorted(params_set, key=lambda p: p.name)
 
         arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer))
-        return PsKernelParametersSpec(tuple(params_list), tuple(arrays), tuple(self._constraints))
-    
+        return PsKernelParametersSpec(
+            tuple(params_list), tuple(arrays), tuple(self._constraints)
+        )
+
     def get_required_headers(self) -> set[str]:
         #   TODO: Headers from types, vectorizer, ...
         return set()
diff --git a/src/pystencils/nbackend/jit/cpu_extension_module.py b/src/pystencils/nbackend/jit/cpu_extension_module.py
index c70c45a39..f07172e3d 100644
--- a/src/pystencils/nbackend/jit/cpu_extension_module.py
+++ b/src/pystencils/nbackend/jit/cpu_extension_module.py
@@ -4,10 +4,11 @@ from typing import Any
 
 from os import path
 import hashlib
-
 from itertools import chain
 from textwrap import indent
 
+import numpy as np
+
 from ..exceptions import PsInternalCompilerError
 from ..ast import PsKernelFunction
 from ..ast.constraints import PsParamConstraint
@@ -19,7 +20,13 @@ from ..arrays import (
     PsArrayShapeVar,
     PsArrayStrideVar,
 )
-from ..types import PsAbstractType
+from ..types import (
+    PsAbstractType,
+    PsScalarType,
+    PsUnsignedIntegerType,
+    PsSignedIntegerType,
+    PsIeeeFloatType,
+)
 from ..types.quick import Fp, SInt, UInt
 from ..emission import emit_code
 
@@ -150,6 +157,13 @@ if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '
 Py_buffer buffer_{name};
 int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
 if (buffer_{name}_res == -1) {{ return NULL; }}
+"""
+
+    TMPL_CHECK_ARRAY_TYPE = """
+if(!({cond})) {{ 
+    PyErr_SetString(PyExc_TypeError, "Wrong {what} of array {name}. Expected {expected}"); 
+    return NULL; 
+}}
 """
 
     KWCHECK = """
@@ -185,10 +199,38 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
                     f"Don't know how to cast Python objects to {dtype}"
                 )
 
+    def _type_char(self, dtype: PsScalarType) -> str | None:
+        if isinstance(
+            dtype, (PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType)
+        ):
+            np_dtype = dtype.NUMPY_TYPES[dtype.width]
+            return np.dtype(np_dtype).char
+        else:
+            return None
+
     def extract_array(self, arr: PsLinearizedArray) -> str:
         """Adds an array, and returns the name of the underlying Py_Buffer."""
         if arr not in self._array_extractions:
             extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=arr.name)
+
+            #   Check array type
+            type_char = self._type_char(arr.element_type)
+            if type_char is not None:
+                dtype_cond = f"buffer_{arr.name}.format[0] == '{type_char}'"
+                extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
+                    cond=dtype_cond,
+                    what="data type",
+                    name=arr.name,
+                    expected=str(arr.element_type),
+                )
+
+            #   Check item size
+            itemsize = arr.element_type.itemsize
+            item_size_cond = f"buffer_{arr.name}.itemsize == {itemsize}"
+            extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
+                cond=item_size_cond, what="itemsize", name=arr.name, expected=itemsize
+            )
+
             self._array_buffers[arr] = f"buffer_{arr.name}"
             self._array_extractions[arr] = extraction_code
 
@@ -219,8 +261,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
                 case PsArrayShapeVar():
                     coord = variable.coordinate
                     code = (
-                        f"{variable.dtype} {variable.name} = "
-                        f"{buffer}.shape[{coord}] / {arr.element_type.itemsize};"
+                        f"{variable.dtype} {variable.name} = {buffer}.shape[{coord}];"
                     )
                 case PsArrayStrideVar():
                     coord = variable.coordinate
diff --git a/src/pystencils/nbackend/typed_expressions.py b/src/pystencils/nbackend/typed_expressions.py
index 657eaf4f0..5bfd0fcb1 100644
--- a/src/pystencils/nbackend/typed_expressions.py
+++ b/src/pystencils/nbackend/typed_expressions.py
@@ -13,7 +13,6 @@ from .types import (
 
 
 class PsTypedVariable(pb.Variable):
-
     init_arg_names: tuple[str, ...] = ("name", "dtype")
 
     __match_args__ = ("name", "dtype")
@@ -130,18 +129,27 @@ class PsTypedConstant:
             return v
 
     def __add__(self, other: Any):
+        if isinstance(other, pb.Expression):  # let pymbolic handle this case
+            return NotImplemented
+
         return PsTypedConstant(self._value + self._fix(other)._value, self._dtype)
 
     def __radd__(self, other: Any):
         return PsTypedConstant(self._rfix(other)._value + self._value, self._dtype)
 
     def __mul__(self, other: Any):
+        if isinstance(other, pb.Expression):  # let pymbolic handle this case
+            return NotImplemented
+
         return PsTypedConstant(self._value * self._fix(other)._value, self._dtype)
 
     def __rmul__(self, other: Any):
         return PsTypedConstant(self._rfix(other)._value * self._value, self._dtype)
 
     def __sub__(self, other: Any):
+        if isinstance(other, pb.Expression):  # let pymbolic handle this case
+            return NotImplemented
+
         return PsTypedConstant(self._value - self._fix(other)._value, self._dtype)
 
     def __rsub__(self, other: Any):
@@ -156,6 +164,9 @@ class PsTypedConstant:
         return quotient, rem
 
     def __truediv__(self, other: Any):
+        if isinstance(other, pb.Expression):  # let pymbolic handle this case
+            return NotImplemented
+
         if self._dtype.is_float():
             return PsTypedConstant(self._value / self._fix(other)._value, self._dtype)
         elif self._dtype.is_uint():
@@ -183,6 +194,9 @@ class PsTypedConstant:
             return NotImplemented
 
     def __mod__(self, other: Any):
+        if isinstance(other, pb.Expression):  # let pymbolic handle this case
+            return NotImplemented
+
         if self._dtype.is_uint():
             return PsTypedConstant(self._value % self._fix(other)._value, self._dtype)
         else:
diff --git a/tests/nbackend/test_basic_printing.py b/tests/nbackend/test_basic_printing.py
index adf7b8d2e..ba2f7770d 100644
--- a/tests/nbackend/test_basic_printing.py
+++ b/tests/nbackend/test_basic_printing.py
@@ -10,8 +10,8 @@ from pystencils.nbackend.emission import CPrinter
 
 def test_basic_kernel():
 
-    u_size = PsTypedVariable("u_length", UInt(32, True))
     u_arr = PsLinearizedArray("u", Fp(64), 1)
+    u_size = u_arr.shape[0]
     u_base = PsArrayBasePointer("u_data", u_arr)
 
     loop_ctr = PsTypedVariable("ctr", UInt(32))
diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py
new file mode 100644
index 000000000..b77ad6fff
--- /dev/null
+++ b/tests/nbackend/test_cpujit.py
@@ -0,0 +1,82 @@
+import pytest
+
+from pystencils import Target
+
+from pystencils.nbackend.ast import *
+from pystencils.nbackend.ast.constraints import PsParamConstraint
+from pystencils.nbackend.typed_expressions import *
+from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess
+from pystencils.nbackend.types.quick import *
+
+import numpy as np
+
+from pystencils.cpu.cpujit import compile_and_load
+
+def test_pairwise_addition():
+    idx_type = SInt(64)
+
+    u = PsLinearizedArray("u", Fp(64, const=True), 2, index_dtype=idx_type)
+    v = PsLinearizedArray("v", Fp(64), 2, index_dtype=idx_type)
+
+    u_data = PsArrayBasePointer("u_data", u)
+    v_data = PsArrayBasePointer("v_data", v)
+
+    loop_ctr = PsTypedVariable("ctr", idx_type)
+    
+    zero = PsTypedConstant(0, idx_type)
+    one = PsTypedConstant(1, idx_type)
+    two = PsTypedConstant(2, idx_type)
+
+    update = PsAssignment(
+        PsLvalueExpr(PsArrayAccess(v_data, loop_ctr)),
+        PsExpression(PsArrayAccess(u_data, two * loop_ctr) + PsArrayAccess(u_data, two * loop_ctr + one))
+    )
+
+    loop = PsLoop(
+        PsSymbolExpr(loop_ctr),
+        PsExpression(zero),
+        PsExpression(v.shape[0]),
+        PsExpression(one),
+        PsBlock([update])
+    )
+
+    func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
+
+    sizes_constraint = PsParamConstraint(
+        u.shape[0].eq(2 * v.shape[0]),
+        "Array `u` must have twice the length of array `v`"
+    )
+
+    func.add_constraints(sizes_constraint)
+
+    kernel = compile_and_load(func)
+
+    #   Positive case
+    N = 21
+    u_arr = np.arange(2 * N, dtype=np.float64)
+    v_arr = np.zeros((N,), dtype=np.float64)
+
+    assert u_arr.shape[0] == 2 * v_arr.shape[0]
+
+    kernel(u=u_arr, v=v_arr)
+
+    v_expected = np.zeros_like(v_arr)
+    for i in range(N):
+        v_expected[i] = u_arr[2 * i] + u_arr[2*i + 1]
+
+    np.testing.assert_allclose(v_arr, v_expected)
+
+    #   Negative case - mismatched array sizes
+    u_arr = np.zeros((N + 2,), dtype=np.float64)
+    v_arr = np.zeros((N,), dtype=np.float64)
+
+    with pytest.raises(ValueError):
+        kernel(u=u_arr, v=v_arr)
+
+    #   Negative case - mismatched types
+    u_arr = np.arange(2 * N, dtype=np.float64)
+    v_arr = np.zeros((N,), dtype=np.float32)
+
+    with pytest.raises(TypeError):
+        kernel(u=u_arr, v=v_arr)
+    
diff --git a/tests/nbackend/test_expressions.py b/tests/nbackend/test_expressions.py
new file mode 100644
index 000000000..b3485b267
--- /dev/null
+++ b/tests/nbackend/test_expressions.py
@@ -0,0 +1,46 @@
+from pystencils.nbackend.typed_expressions import PsTypedVariable
+from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayShapeVar, PsArrayStrideVar
+
+from pystencils.nbackend.types.quick import *
+
+def test_variable_equality():
+    var1 = PsTypedVariable("x", Fp(32))
+    var2 = PsTypedVariable("x", Fp(32))
+    assert var1 == var2
+
+    arr = PsLinearizedArray("arr", Fp(64), 3)
+    bp1 = PsArrayBasePointer("arr_data", arr)
+    bp2 = PsArrayBasePointer("arr_data", arr)
+    assert bp1 == bp2
+
+    arr1 = PsLinearizedArray("arr", Fp(64), 3)
+    bp1 = PsArrayBasePointer("arr_data", arr1)
+
+    arr2 = PsLinearizedArray("arr", Fp(64), 3)
+    bp2 = PsArrayBasePointer("arr_data", arr2)
+    assert bp1 == bp2
+
+    for v1, v2 in zip(arr1.shape, arr2.shape):
+        assert v1 == v2
+
+    for v1, v2 in zip(arr1.strides, arr2.strides):
+        assert v1 == v2
+
+
+def test_variable_inequality():
+    var1 = PsTypedVariable("x", Fp(32))
+    var2 = PsTypedVariable("x", Fp(64))
+    assert var1 != var2
+
+    var1 = PsTypedVariable("x", Fp(32, True))
+    var2 = PsTypedVariable("x", Fp(32, False))
+    assert var1 != var2
+
+    #   Arrays 
+    arr1 = PsLinearizedArray("arr", Fp(64), 3)
+    bp1 = PsArrayBasePointer("arr_data", arr1)
+
+    arr2 = PsLinearizedArray("arr", Fp(32), 3)
+    bp2 = PsArrayBasePointer("arr_data", arr2)
+    assert bp1 != bp2
+
-- 
GitLab