From cd599101e09cd1ae253eb5a027bb91bb72f74a03 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 17 Jan 2024 16:09:27 +0100
Subject: [PATCH] first sketch of translation context and iteration domain

---
 src/pystencils/nbackend/arrays.py             | 150 ++++++++----------
 src/pystencils/nbackend/ast/kernelfunction.py |   8 +-
 .../nbackend/{ast => }/constraints.py         |   4 +-
 .../nbackend/jit/cpu_extension_module.py      |   4 +-
 .../nbackend/translation/__init__.py          |   0
 .../nbackend/translation/context.py           |  24 ++-
 .../nbackend/translation/field_array_pair.py  |  21 +++
 .../nbackend/translation/iteration_domain.py  | 130 +++++++++++++++
 src/pystencils/nbackend/typed_expressions.py  |   3 +-
 tests/nbackend/test_basic_printing.py         |   4 +-
 tests/nbackend/test_cpujit.py                 |   8 +-
 tests/nbackend/test_expressions.py            |  16 +-
 12 files changed, 267 insertions(+), 105 deletions(-)
 rename src/pystencils/nbackend/{ast => }/constraints.py (88%)
 create mode 100644 src/pystencils/nbackend/translation/__init__.py
 create mode 100644 src/pystencils/nbackend/translation/field_array_pair.py
 create mode 100644 src/pystencils/nbackend/translation/iteration_domain.py

diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py
index fd416a20c..e206b28fd 100644
--- a/src/pystencils/nbackend/arrays.py
+++ b/src/pystencils/nbackend/arrays.py
@@ -31,8 +31,8 @@ all occurences of the shape and stride variables with their constant value:
 
 ```
 constraints = (
-    [PsParamConstraint(s.eq(f)) for s, f in zip(arr.shape, fixed_size)] 
-    + [PsParamConstraint(s.eq(f)) for s, f in zip(arr.strides, fixed_strides)]
+    [PsKernelConstraint(s.eq(f)) for s, f in zip(arr.shape, fixed_size)] 
+    + [PsKernelConstraint(s.eq(f)) for s, f in zip(arr.strides, fixed_strides)]
 )
 
 kernel_function.add_constraints(*constraints)
@@ -43,6 +43,8 @@ kernel_function.add_constraints(*constraints)
 
 from __future__ import annotations
 
+from types import EllipsisType
+
 from abc import ABC
 
 import pymbolic.primitives as pb
@@ -56,78 +58,94 @@ from .types import (
     constify,
 )
 
-from .typed_expressions import PsTypedVariable, ExprOrConstant
+from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant
 
 
 class PsLinearizedArray:
-    """N-dimensional contiguous array"""
+    """Class to model N-dimensional contiguous arrays.
+    
+    Memory Layout, Shape and Strides
+    --------------------------------
+
+    The memory layout of an array is defined by its shape and strides.
+    Both shape and stride entries may either be constants or special variables associated with
+    exactly one array.
+
+    Shape and strides may be specified at construction in the following way.
+    For constant entries, their value must be given as an integer.
+    For variable shape entries and strides, the Ellipsis `...` must be passed instead.
+    Internally, the passed `index_dtype` will be used to create typed constants (`PsTypedConstant`)
+    and variables (`PsArrayShapeVar` and `PsArrayStrideVar`) from the passed values.
+    """
 
     def __init__(
         self,
         name: str,
-        element_type: PsScalarType,
-        dim: int,
+        element_type: PsAbstractType,
+        shape: tuple[int | EllipsisType, ...],
+        strides: tuple[int | EllipsisType, ...],
         index_dtype: PsIntegerType = PsSignedIntegerType(64),
     ):
         self._name = name
+        self._element_type = element_type
+        self._index_dtype = index_dtype
 
-        self._shape = tuple(
-            PsArrayShapeVar(self, d, constify(index_dtype)) for d in range(dim)
-        )
-        self._strides = tuple(
-            PsArrayStrideVar(self, d, constify(index_dtype)) for d in range(dim)
+        if len(shape) != len(strides):
+            raise ValueError("Shape and stride tuples must have the same length")
+
+        self._shape: tuple[PsArrayShapeVar | PsTypedConstant, ...] = tuple(
+            (
+                PsArrayShapeVar(self, i, index_dtype)
+                if s == Ellipsis
+                else PsTypedConstant(s, index_dtype)
+            )
+            for i, s in enumerate(shape)
         )
 
-        self._element_type = element_type
-        self._dim = dim
-        self._index_dtype = index_dtype
+        self._strides: tuple[PsArrayStrideVar | PsTypedConstant, ...] = tuple(
+            (
+                PsArrayStrideVar(self, i, index_dtype)
+                if s == Ellipsis
+                else PsTypedConstant(s, index_dtype)
+            )
+            for i, s in enumerate(strides)
+        )
 
     @property
     def name(self):
         return self._name
 
     @property
-    def shape(self):
+    def shape(self) -> tuple[PsArrayShapeVar | PsTypedConstant, ...]:
         return self._shape
 
     @property
-    def strides(self):
+    def strides(self) -> tuple[PsArrayStrideVar | PsTypedConstant, ...]:
         return self._strides
 
-    @property
-    def dim(self):
-        return self._dim
-
     @property
     def element_type(self):
         return self._element_type
+    
+    def _hashable_contents(self):
+        """Contents by which to compare two instances of `PsLinearizedArray`.
+        
+        Since equality checks on shape and stride variables internally check equality of their associated arrays,
+        if these variables would occur in here, an infinite recursion would follow.
+        Hence they are filtered and replaced by the ellipsis.
+        """
+        shape_clean = tuple((s if isinstance(s, PsTypedConstant) else ...) for s in self._shape)
+        strides_clean = tuple((s if isinstance(s, PsTypedConstant) else ...) for s in self._strides)
+        return (self._name, self._element_type, self._index_dtype, shape_clean, strides_clean)
 
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, PsLinearizedArray):
             return False
 
-        return (
-            self._name,
-            self._element_type,
-            self._dim,
-            self._index_dtype,
-        ) == (
-            other._name,
-            other._element_type,
-            other._dim,
-            other._index_dtype,
-        )
+        return self._hashable_contents() == other._hashable_contents()
 
     def __hash__(self) -> int:
-        return hash(
-            (
-                self._name,
-                self._element_type,
-                self._dim,
-                self._index_dtype,
-            )
-        )
-
+        return hash(self._hashable_contents())
 
 class PsArrayAssocVar(PsTypedVariable, ABC):
     """A variable that is associated to an array.
@@ -166,6 +184,11 @@ class PsArrayBasePointer(PsArrayAssocVar):
 
 
 class PsArrayShapeVar(PsArrayAssocVar):
+    """Variable that represents an array's shape in one coordinate.
+    
+    Do not instantiate this class yourself, but only use its instances
+    as provided by `PsLinearizedArray.shape`.
+    """
     init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
     __match_args__ = ("array", "coordinate", "dtype")
 
@@ -183,6 +206,11 @@ class PsArrayShapeVar(PsArrayAssocVar):
 
 
 class PsArrayStrideVar(PsArrayAssocVar):
+    """Variable that represents an array's stride in one coordinate.
+    
+    Do not instantiate this class yourself, but only use its instances
+    as provided by `PsLinearizedArray.strides`.
+    """
     init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
     __match_args__ = ("array", "coordinate", "dtype")
 
@@ -217,45 +245,3 @@ class PsArrayAccess(pb.Subscript):
     def dtype(self) -> PsAbstractType:
         """Data type of this expression, i.e. the element type of the underlying array"""
         return self._base_ptr.array.element_type
-
-
-# class PsIterationDomain:
-#     """A factory for arrays spanning a given iteration domain."""
-
-#     def __init__(
-#         self,
-#         id: str,
-#         dim: int | None = None,
-#         fixed_shape: tuple[int, ...] | None = None,
-#         index_dtype: PsIntegerType = PsSignedIntegerType(64),
-#     ):
-#         if fixed_shape is not None:
-#             if dim is not None and len(fixed_shape) != dim:
-#                 raise ValueError(
-#                     "If both `dim` and `fixed_shape` are specified, `fixed_shape` must have exactly `dim` entries."
-#                 )
-
-#             shape = tuple(PsTypedConstant(s, index_dtype) for s in fixed_shape)
-#         elif dim is not None:
-#             shape = tuple(
-#                 PsTypedVariable(f"{id}_shape_{d}", index_dtype) for d in range(dim)
-#             )
-#         else:
-#             raise ValueError("Either `fixed_shape` or `dim` must be specified.")
-
-#         self._domain_shape: tuple[VarOrConstant, ...] = shape
-#         self._index_dtype = index_dtype
-
-#         self._archetype_array: PsLinearizedArray | None = None
-
-#         self._constraints: list[PsParamConstraint] = []
-
-#     @property
-#     def dim(self) -> int:
-#         return len(self._domain_shape)
-
-#     @property
-#     def shape(self) -> tuple[VarOrConstant, ...]:
-#         return self._domain_shape
-
-#     def create_array(self, ghost_layers: int = 0):
diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py
index 060947a82..9aecb16ea 100644
--- a/src/pystencils/nbackend/ast/kernelfunction.py
+++ b/src/pystencils/nbackend/ast/kernelfunction.py
@@ -6,7 +6,7 @@ from dataclasses import dataclass
 from pymbolic.mapper.dependency import DependencyMapper
 
 from .nodes import PsAstNode, PsBlock, failing_cast
-from .constraints import PsParamConstraint
+from ..constraints import PsKernelConstraint
 from ..typed_expressions import PsTypedVariable
 from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar
 from ..exceptions import PsInternalCompilerError
@@ -26,7 +26,7 @@ class PsKernelParametersSpec:
 
     params: tuple[PsTypedVariable, ...]
     arrays: tuple[PsLinearizedArray, ...]
-    constraints: tuple[PsParamConstraint, ...]
+    constraints: tuple[PsKernelConstraint, ...]
 
     def params_for_array(self, arr: PsLinearizedArray):
         def pred(p: PsTypedVariable):
@@ -71,7 +71,7 @@ class PsKernelFunction(PsAstNode):
         self._target = target
         self._name = name
 
-        self._constraints: list[PsParamConstraint] = []
+        self._constraints: list[PsKernelConstraint] = []
 
     @property
     def target(self) -> Target:
@@ -120,7 +120,7 @@ class PsKernelFunction(PsAstNode):
             raise IndexError(f"Child index out of bounds: {idx}")
         self._body = failing_cast(PsBlock, c)
 
-    def add_constraints(self, *constraints: PsParamConstraint):
+    def add_constraints(self, *constraints: PsKernelConstraint):
         self._constraints += constraints
 
     def get_parameters(self) -> PsKernelParametersSpec:
diff --git a/src/pystencils/nbackend/ast/constraints.py b/src/pystencils/nbackend/constraints.py
similarity index 88%
rename from src/pystencils/nbackend/ast/constraints.py
rename to src/pystencils/nbackend/constraints.py
index d11fe1195..0cda3f4dc 100644
--- a/src/pystencils/nbackend/ast/constraints.py
+++ b/src/pystencils/nbackend/constraints.py
@@ -4,11 +4,11 @@ import pymbolic.primitives as pb
 from pymbolic.mapper.c_code import CCodeMapper
 from pymbolic.mapper.dependency import DependencyMapper
 
-from ..typed_expressions import PsTypedVariable
+from .typed_expressions import PsTypedVariable
 
 
 @dataclass
-class PsParamConstraint:
+class PsKernelConstraint:
     condition: pb.Comparison
     message: str = ""
 
diff --git a/src/pystencils/nbackend/jit/cpu_extension_module.py b/src/pystencils/nbackend/jit/cpu_extension_module.py
index f07172e3d..3b67fa45f 100644
--- a/src/pystencils/nbackend/jit/cpu_extension_module.py
+++ b/src/pystencils/nbackend/jit/cpu_extension_module.py
@@ -11,7 +11,7 @@ import numpy as np
 
 from ..exceptions import PsInternalCompilerError
 from ..ast import PsKernelFunction
-from ..ast.constraints import PsParamConstraint
+from ..constraints import PsKernelConstraint
 from ..typed_expressions import PsTypedVariable
 from ..arrays import (
     PsLinearizedArray,
@@ -285,7 +285,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
             case _:
                 assert False, "Invalid variable encountered."
 
-    def check_constraint(self, constraint: PsParamConstraint):
+    def check_constraint(self, constraint: PsKernelConstraint):
         variables = constraint.get_variables()
 
         for var in variables:
diff --git a/src/pystencils/nbackend/translation/__init__.py b/src/pystencils/nbackend/translation/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/pystencils/nbackend/translation/context.py b/src/pystencils/nbackend/translation/context.py
index 199315579..cc9cfc0fc 100644
--- a/src/pystencils/nbackend/translation/context.py
+++ b/src/pystencils/nbackend/translation/context.py
@@ -1,3 +1,9 @@
+from ...field import Field
+from ..arrays import PsLinearizedArray, PsArrayBasePointer
+from ..types import PsIntegerType
+from ..constraints import PsKernelConstraint
+
+from .iteration_domain import PsIterationDomain
 
 class PsTranslationContext:
     """The `PsTranslationContext` manages the translation process from the SymPy frontend
@@ -27,7 +33,6 @@ class PsTranslationContext:
       Domain fields can only be accessed by relative offsets, and therefore must always
       be associated with an *iteration domain* that provides a spatial index tuple.
       All domain fields associated with the same domain must have the same spatial shape, modulo ghost layers.
-      A field and its array may be associated with multiple iteration domains.
     - `FieldType.INDEXED` are 1D arrays of index structures. They must be accessed by a single running index.
       If there is at least one indexed field present there must also exist an index source for that field
       (loop or device indexing).
@@ -36,6 +41,21 @@ class PsTranslationContext:
       Within a domain, a buffer may be either written to or read from, never both.
 
 
+    In the translator, frontend fields and backend arrays are managed together using the `PsFieldArrayPair` class.
+    """
+
+    def __init__(self, index_dtype: PsIntegerType):
+        self._index_dtype = index_dtype
+        self._constraints: list[PsKernelConstraint] = []
+
+    @property
+    def index_dtype(self) -> PsIntegerType:
+        return self._index_dtype
     
+    def add_constraints(self, *constraints: PsKernelConstraint):
+        self._constraints += constraints
+
+    @property
+    def constraints(self) -> tuple[PsKernelConstraint, ...]:
+        return tuple(self._constraints)
 
-    """
diff --git a/src/pystencils/nbackend/translation/field_array_pair.py b/src/pystencils/nbackend/translation/field_array_pair.py
new file mode 100644
index 000000000..720b5c1c7
--- /dev/null
+++ b/src/pystencils/nbackend/translation/field_array_pair.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+
+from ...field import Field
+from ..arrays import PsLinearizedArray, PsArrayBasePointer
+from ..types import PsIntegerType
+from ..constraints import PsKernelConstraint
+
+from .iteration_domain import PsIterationDomain
+
+@dataclass
+class PsFieldArrayPair:
+    field: Field
+    array: PsLinearizedArray
+    base_ptr: PsArrayBasePointer
+
+
+@dataclass
+class PsDomainFieldArrayPair(PsFieldArrayPair):
+    ghost_layers: int
+    interior_base_ptr: PsArrayBasePointer
+    domain: PsIterationDomain
diff --git a/src/pystencils/nbackend/translation/iteration_domain.py b/src/pystencils/nbackend/translation/iteration_domain.py
new file mode 100644
index 000000000..990a4ff67
--- /dev/null
+++ b/src/pystencils/nbackend/translation/iteration_domain.py
@@ -0,0 +1,130 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, cast
+from types import EllipsisType
+
+from ...field import Field
+from ...typing import TypedSymbol, BasicType
+from ..arrays import PsLinearizedArray, PsArrayBasePointer
+from ..types.quick import make_type
+from ..typed_expressions import PsTypedVariable, PsTypedConstant, VarOrConstant
+from .field_array_pair import PsDomainFieldArrayPair
+
+if TYPE_CHECKING:
+    from .context import PsTranslationContext
+
+class PsIterationDomain:
+    """Represents the n-dimensonal spatial iteration domain of a pystencils kernel.
+    
+    Domain Shape
+    ------------
+
+    A domain may have either constant or variable, n-dimensional shape, where n = 1, 2, 3.
+    If the shape is variable, the domain object manages variables for each shape entry.
+
+    The domain provides index variables for each dimension which may be used to access fields
+    associated with the domain.
+    In the kernel, these index variables must be provided by some index source.
+    Index sources differ between two major types of domains: full and sparse domains.
+
+    In a full domain, it is guaranteed that each interior point is processed by the kernel.
+    The index source may therefore be a full n-fold loop nest, or a device index calculation.
+
+    In a sparse domain, the iteration is controlled by an index vector, which acts as the index
+    source.
+
+    Arrays
+    ------
+
+    Any number of domain arrays may be associated with each domain.
+    Each array is annotated with a number of ghost layers for each spatial coordinate.
+
+    ### Shape Compatibility
+
+    When an array is associated with a domain, it must be ensured that the array's shape
+    is compatible with the domain.
+    The first n shape entries are considered the array's spatial shape.
+    These spatial shapes, after subtracting ghost layers, must all be equal, and are further
+    constrained by a constant domain shape.
+    For each spatial coordinate, shape compatibility is ensured as described by the following table.
+
+    |                           |  Constant Array Shape       |   Variable Array Shape |
+    |---------------------------|-----------------------------|------------------------|
+    | **Constant Domain Shape** | Compile-Time Equality Check |  Kernel Constraints    |
+    | **Variable Domain Shape** | Invalid, Compiler Error     |  Kernel Constraints    |
+
+    ### Base Pointers and Array Accesses
+
+    In the kernel's public interface, each array is represented at least through its base pointer,
+    which represents the starting address of the array's data in memory.
+    Since the iteration domain models arrays as being surrounded by ghost layers, it provides for each
+    array a second, *interior* base pointer, which points to the first interior point after skipping the
+    ghost layers, e.g. in three dimensions with one index dimension:
+
+    ```
+    addr(interior_base_ptr[0, 0, 0, 0]) == addr(base_ptr[gls, gls, gls, 0])
+    ```
+
+    To access domain arrays using the domain's index variables, the interior base pointer should be used,
+    since the domain index variables always count up from zero.
+
+    """
+
+    def __init__(self, ctx: PsTranslationContext, shape: tuple[int | EllipsisType, ...]):
+        self._ctx = ctx
+        
+        if len(shape) == 0:
+            raise ValueError("Domain shape must be at least one-dimensional.")
+        
+        if len(shape) > 3:
+            raise ValueError("Iteration domain can be at most three-dimensional.")
+        
+        self._shape: tuple[VarOrConstant, ...] = tuple(
+            (
+                PsTypedVariable(f"domain_size_{i}", self._ctx.index_dtype)
+                if s == Ellipsis
+                else PsTypedConstant(s, self._ctx.index_dtype)
+            )
+            for i, s in enumerate(shape)
+        )
+
+        self._archetype_field: PsDomainFieldArrayPair | None = None
+        self._fields: dict[str, PsDomainFieldArrayPair] = dict()
+
+    @property
+    def shape(self) -> tuple[VarOrConstant, ...]:
+        return self._shape
+    
+    def add_field(self, field: Field, ghost_layers: int) -> PsDomainFieldArrayPair:
+        arr_shape = tuple(
+            (Ellipsis if isinstance(s, TypedSymbol) else s) # TODO: Field should also use ellipsis
+            for s in field.shape
+        )
+
+        arr_strides = tuple(
+            (Ellipsis if isinstance(s, TypedSymbol) else s) # TODO: Field should also use ellipsis
+            for s in field.strides
+        )
+
+        # TODO: frontend should use new type system
+        element_type = make_type(cast(BasicType, field.dtype).numpy_dtype.type) 
+
+        arr = PsLinearizedArray(field.name, element_type, arr_shape, arr_strides, self._ctx.index_dtype)
+
+        fa_pair = PsDomainFieldArrayPair(
+            field=field,
+            array=arr,
+            base_ptr=PsArrayBasePointer("arr_data", arr),
+            ghost_layers=ghost_layers,
+            interior_base_ptr=PsArrayBasePointer("arr_interior_data", arr),
+            domain=self
+        )
+        
+        #   Check shape compatibility
+        #   TODO
+        for domain_s, field_s in zip(self.shape, field.shape):
+            if isinstance(domain_s, PsTypedConstant):
+                pass
+
+        raise NotImplementedError()
+
diff --git a/src/pystencils/nbackend/typed_expressions.py b/src/pystencils/nbackend/typed_expressions.py
index 5bfd0fcb1..b33114426 100644
--- a/src/pystencils/nbackend/typed_expressions.py
+++ b/src/pystencils/nbackend/typed_expressions.py
@@ -206,8 +206,7 @@ class PsTypedConstant:
             return PsTypedConstant(rem, self._dtype)
 
     def __neg__(self):
-        minus_one = PsTypedConstant(-1, self._dtype)
-        return pb.Product((minus_one, self))
+        return PsTypedConstant(- self._value, self._dtype)
 
     def __bool__(self):
         return bool(self._value)
diff --git a/tests/nbackend/test_basic_printing.py b/tests/nbackend/test_basic_printing.py
index ba2f7770d..8d9fc6483 100644
--- a/tests/nbackend/test_basic_printing.py
+++ b/tests/nbackend/test_basic_printing.py
@@ -10,7 +10,7 @@ from pystencils.nbackend.emission import CPrinter
 
 def test_basic_kernel():
 
-    u_arr = PsLinearizedArray("u", Fp(64), 1)
+    u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, ))
     u_size = u_arr.shape[0]
     u_base = PsArrayBasePointer("u_data", u_arr)
 
@@ -40,5 +40,5 @@ def test_basic_kernel():
 
     assert code.find("(" + params_str + ")") >= 0
     
-    assert code.find("u_data[ctr] = u_data[ctr - 1] + u_data[ctr + 1];") >= 0
+    assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr + -1];") >= 0
 
diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py
index b77ad6fff..6c2a453c7 100644
--- a/tests/nbackend/test_cpujit.py
+++ b/tests/nbackend/test_cpujit.py
@@ -3,7 +3,7 @@ import pytest
 from pystencils import Target
 
 from pystencils.nbackend.ast import *
-from pystencils.nbackend.ast.constraints import PsParamConstraint
+from pystencils.nbackend.constraints import PsKernelConstraint
 from pystencils.nbackend.typed_expressions import *
 from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess
 from pystencils.nbackend.types.quick import *
@@ -15,8 +15,8 @@ from pystencils.cpu.cpujit import compile_and_load
 def test_pairwise_addition():
     idx_type = SInt(64)
 
-    u = PsLinearizedArray("u", Fp(64, const=True), 2, index_dtype=idx_type)
-    v = PsLinearizedArray("v", Fp(64), 2, index_dtype=idx_type)
+    u = PsLinearizedArray("u", Fp(64, const=True), (..., ...), (..., ...), index_dtype=idx_type)
+    v = PsLinearizedArray("v", Fp(64), (..., ...), (..., ...), index_dtype=idx_type)
 
     u_data = PsArrayBasePointer("u_data", u)
     v_data = PsArrayBasePointer("v_data", v)
@@ -42,7 +42,7 @@ def test_pairwise_addition():
 
     func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
 
-    sizes_constraint = PsParamConstraint(
+    sizes_constraint = PsKernelConstraint(
         u.shape[0].eq(2 * v.shape[0]),
         "Array `u` must have twice the length of array `v`"
     )
diff --git a/tests/nbackend/test_expressions.py b/tests/nbackend/test_expressions.py
index b3485b267..6c24a6442 100644
--- a/tests/nbackend/test_expressions.py
+++ b/tests/nbackend/test_expressions.py
@@ -8,15 +8,18 @@ def test_variable_equality():
     var2 = PsTypedVariable("x", Fp(32))
     assert var1 == var2
 
-    arr = PsLinearizedArray("arr", Fp(64), 3)
+    shape = (..., ..., ...)
+    strides = (..., ..., ...)
+
+    arr = PsLinearizedArray("arr", Fp(64), shape, strides)
     bp1 = PsArrayBasePointer("arr_data", arr)
     bp2 = PsArrayBasePointer("arr_data", arr)
     assert bp1 == bp2
 
-    arr1 = PsLinearizedArray("arr", Fp(64), 3)
+    arr1 = PsLinearizedArray("arr", Fp(64), shape, strides)
     bp1 = PsArrayBasePointer("arr_data", arr1)
 
-    arr2 = PsLinearizedArray("arr", Fp(64), 3)
+    arr2 = PsLinearizedArray("arr", Fp(64), shape, strides)
     bp2 = PsArrayBasePointer("arr_data", arr2)
     assert bp1 == bp2
 
@@ -28,6 +31,9 @@ def test_variable_equality():
 
 
 def test_variable_inequality():
+    shape = (..., ..., ...)
+    strides = (..., ..., ...)
+
     var1 = PsTypedVariable("x", Fp(32))
     var2 = PsTypedVariable("x", Fp(64))
     assert var1 != var2
@@ -37,10 +43,10 @@ def test_variable_inequality():
     assert var1 != var2
 
     #   Arrays 
-    arr1 = PsLinearizedArray("arr", Fp(64), 3)
+    arr1 = PsLinearizedArray("arr", Fp(64), shape, strides)
     bp1 = PsArrayBasePointer("arr_data", arr1)
 
-    arr2 = PsLinearizedArray("arr", Fp(32), 3)
+    arr2 = PsLinearizedArray("arr", Fp(32), shape, strides)
     bp2 = PsArrayBasePointer("arr_data", arr2)
     assert bp1 != bp2
 
-- 
GitLab