From 57da00b47c2666497aae54a94a3563c2218199ff Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 10 Mar 2024 13:37:56 +0100
Subject: [PATCH] introduce basic array handling

---
 src/pystencils/backend/ast/expressions.py     | 26 ++++++++
 src/pystencils/backend/emission.py            | 30 ++++++++--
 .../backend/kernelcreation/freeze.py          | 18 +++++-
 .../backend/kernelcreation/typification.py    | 60 ++++++++++++++++++-
 .../backend/platforms/generic_cpu.py          |  4 +-
 src/pystencils/types/__init__.py              |  4 ++
 src/pystencils/types/basic_types.py           | 48 ++++++++++++---
 src/pystencils/types/quick.py                 |  4 ++
 8 files changed, 176 insertions(+), 18 deletions(-)

diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index a2c13548a..7c24b55b2 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -161,6 +161,9 @@ class PsSubscript(PsLvalueExpr):
             case 1:
                 self.index = failing_cast(PsExpression, c)
 
+    def __str__(self) -> str:
+        return f"Subscript({self._base})[{self._index}]"
+
 
 class PsArrayAccess(PsSubscript):
     __match_args__ = ("base_ptr", "index")
@@ -484,3 +487,26 @@ class PsDiv(PsBinOp):
     #  python_operator not implemented because can't unambigously decide
     #  between intdiv and truediv
     pass
+
+
+class PsArrayInitList(PsExpression):
+    __match_args__ = ("items",)
+
+    def __init__(self, items: Sequence[PsExpression]):
+        self._items = list(items)
+
+    @property
+    def items(self) -> list[PsExpression]:
+        return self._items
+
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return tuple(self._items)
+
+    def set_child(self, idx: int, c: PsAstNode):
+        self._items[idx] = failing_cast(PsExpression, c)
+
+    def clone(self) -> PsExpression:
+        return PsArrayInitList([expr.clone() for expr in self._items])
+
+    def __repr__(self) -> str:
+        return f"PsArrayInitList({repr(self._items)})"
diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py
index 054cd9b44..190f4f001 100644
--- a/src/pystencils/backend/emission.py
+++ b/src/pystencils/backend/emission.py
@@ -28,9 +28,11 @@ from .ast.expressions import (
     PsDeref,
     PsAddressOf,
     PsCast,
+    PsArrayInitList,
 )
 
-from ..types import PsScalarType
+from .symbols import PsSymbol
+from ..types import PsScalarType, PsArrayType
 
 from .kernelfunction import KernelFunction
 
@@ -153,12 +155,10 @@ class CAstPrinter:
 
             case PsDeclaration(lhs, rhs):
                 lhs_symb = lhs.symbol
-                lhs_dtype = lhs_symb.get_dtype()
+                lhs_code = self._symbol_decl(lhs_symb)
                 rhs_code = self.visit(rhs, pc)
 
-                return pc.indent(
-                    f"{lhs_dtype.c_string()} {lhs_symb.name} = {rhs_code};"
-                )
+                return pc.indent(f"{lhs_code} = {rhs_code};")
 
             case PsAssignment(lhs, rhs):
                 lhs_code = self.visit(lhs, pc)
@@ -281,9 +281,29 @@ class CAstPrinter:
                 type_str = target_type.c_string()
                 return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast)
 
+            case PsArrayInitList(items):
+                pc.push_op(Ops.Weakest, LR.Middle)
+                items_str = ", ".join(self.visit(item, pc) for item in items)
+                pc.pop_op()
+                return "{ " + items_str + " }"
+
             case _:
                 raise NotImplementedError(f"Don't know how to print {node}")
 
+    def _symbol_decl(self, symb: PsSymbol):
+        dtype = symb.get_dtype()
+
+        array_dims = []
+        while isinstance(dtype, PsArrayType):
+            array_dims.append(dtype.length)
+            dtype = dtype.base_type
+
+        code = f"{dtype.c_string()} {symb.name}"
+        for d in array_dims:
+            code += f"[{str(d) if d is not None else ''}]"
+
+        return code
+
     def _char_and_op(self, node: PsBinOp) -> tuple[str, Ops]:
         match node:
             case PsAdd():
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index a13f21ae2..e0a5b130f 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -24,6 +24,8 @@ from ..ast.expressions import (
     PsLookup,
     PsCall,
     PsConstantExpr,
+    PsArrayInitList,
+    PsSubscript,
 )
 
 from ..constants import PsConstant
@@ -86,7 +88,7 @@ class FreezeExpressions:
             raise FreezeError(f"Don't know how to freeze {obj}")
 
     def visit_expr(self, expr: sp.Basic):
-        if not isinstance(expr, sp.Expr):
+        if not isinstance(expr, (sp.Expr, sp.Tuple)):
             raise FreezeError(f"Cannot freeze {expr} to an expression")
         return cast(PsExpression, self.visit(expr))
 
@@ -182,7 +184,7 @@ class FreezeExpressions:
     def map_Integer(self, expr: sp.Integer) -> PsConstantExpr:
         value = int(expr)
         return PsConstantExpr(PsConstant(value))
-    
+
     def map_Float(self, expr: sp.Float) -> PsConstantExpr:
         value = float(expr)  # TODO: check accuracy of evaluation
         return PsConstantExpr(PsConstant(value))
@@ -197,6 +199,18 @@ class FreezeExpressions:
         symb = self._ctx.get_symbol(expr.name, dtype)
         return PsSymbolExpr(symb)
 
+    def map_Tuple(self, expr: sp.Tuple) -> PsArrayInitList:
+        items = [self.visit_expr(item) for item in expr]
+        return PsArrayInitList(items)
+
+    def map_Indexed(self, expr: sp.Indexed) -> PsSubscript:
+        assert isinstance(expr.base, sp.IndexedBase)
+        base = self.visit_expr(expr.base.label)
+        subscript = PsSubscript(base, self.visit_expr(expr.indices[0]))
+        for idx in expr.indices[1:]:
+            subscript = PsSubscript(subscript, self.visit_expr(idx))
+        return subscript
+
     def map_Access(self, access: Field.Access):
         field = access.field
         array = self._ctx.get_array(field)
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index ef04a617d..cc526e895 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -3,15 +3,25 @@ from __future__ import annotations
 from typing import TypeVar
 
 from .context import KernelCreationContext
-from ...types import PsType, PsNumericType, PsStructType, PsIntegerType, deconstify
+from ...types import (
+    PsType,
+    PsNumericType,
+    PsStructType,
+    PsIntegerType,
+    PsArrayType,
+    PsSubscriptableType,
+    deconstify,
+)
 from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment
 from ..ast.expressions import (
     PsSymbolExpr,
     PsConstantExpr,
     PsBinOp,
     PsArrayAccess,
+    PsSubscript,
     PsLookup,
     PsCall,
+    PsArrayInitList,
 )
 from ..functions import PsMathFunction
 
@@ -192,6 +202,26 @@ class Typifier:
                         f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
                     )
 
+            case PsSubscript(arr, idx):
+                arr_tc = TypeContext()
+                self.visit_expr(arr, arr_tc)
+
+                if not isinstance(arr_tc.target_type, PsSubscriptableType):
+                    raise TypificationError(
+                        "Type of subscript base is not subscriptable."
+                    )
+
+                tc.apply_and_check(expr, arr_tc.target_type.base_type)
+
+                index_tc = TypeContext()
+                self.visit_expr(idx, index_tc)
+                if index_tc.target_type is None:
+                    index_tc.apply_and_check(idx, self._ctx.index_dtype)
+                elif not isinstance(index_tc.target_type, PsIntegerType):
+                    raise TypificationError(
+                        f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}"
+                    )
+
             case PsLookup(aggr, member_name):
                 aggr_tc = TypeContext(None)
                 self.visit_expr(aggr, aggr_tc)
@@ -199,7 +229,7 @@ class Typifier:
 
                 if not isinstance(aggr_type, PsStructType):
                     raise TypificationError(
-                        "Aggregate type of lookup was not a struct type."
+                        "Aggregate type of lookup is not a struct type."
                     )
 
                 member = aggr_type.find_member(member_name)
@@ -224,5 +254,31 @@ class Typifier:
                             f"Don't know how to typify calls to {function}"
                         )
 
+            case PsArrayInitList(items):
+                items_tc = TypeContext()
+                for item in items:
+                    self.visit_expr(item, items_tc)
+
+                if items_tc.target_type is None:
+                    if tc.target_type is None:
+                        raise TypificationError(f"Unable to infer type of array {expr}")
+                    elif not isinstance(tc.target_type, PsArrayType):
+                        raise TypificationError(
+                            f"Cannot apply type {tc.target_type} to an array initializer."
+                        )
+                    elif (
+                        tc.target_type.length is not None
+                        and tc.target_type.length != len(items)
+                    ):
+                        raise TypificationError(
+                            "Array size mismatch: Cannot typify initializer list with "
+                            f"{len(items)} items as {tc.target_type}"
+                        )
+                    else:
+                        items_tc.apply_and_check(expr, tc.target_type.base_type)
+                else:
+                    arr_type = PsArrayType(items_tc.target_type, len(items))
+                    tc.apply_and_check(expr, arr_type)
+
             case _:
                 raise NotImplementedError(f"Can't typify {expr}")
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index 6b19a88ff..8c53a9f16 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -6,7 +6,7 @@ from .platform import Platform
 from ..kernelcreation.iteration_space import (
     IterationSpace,
     FullIterationSpace,
-    SparseIterationSpace
+    SparseIterationSpace,
 )
 
 from ..constants import PsConstant
@@ -43,7 +43,7 @@ class GenericCpu(Platform):
     def _create_domain_loops(
         self, body: PsBlock, ispace: FullIterationSpace
     ) -> PsBlock:
-        
+
         dimensions = ispace.dimensions
 
         #   Determine loop order by permuting dimensions
diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py
index 0c475c8ac..140151b2b 100644
--- a/src/pystencils/types/__init__.py
+++ b/src/pystencils/types/__init__.py
@@ -5,7 +5,9 @@ from .basic_types import (
     PsNumericType,
     PsScalarType,
     PsVectorType,
+    PsSubscriptableType,
     PsPointerType,
+    PsArrayType,
     PsBoolType,
     PsIntegerType,
     PsUnsignedIntegerType,
@@ -23,7 +25,9 @@ __all__ = [
     "PsType",
     "PsCustomType",
     "PsStructType",
+    "PsSubscriptableType",
     "PsPointerType",
+    "PsArrayType",
     "PsNumericType",
     "PsScalarType",
     "PsVectorType",
diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py
index 55055205f..4648c3210 100644
--- a/src/pystencils/types/basic_types.py
+++ b/src/pystencils/types/basic_types.py
@@ -112,8 +112,20 @@ class PsCustomType(PsType):
         return f"CustomType( {self.name}, const={self.const} )"
 
 
+class PsSubscriptableType(PsType, ABC):
+    __match_args__ = ("base_type",)
+
+    def __init__(self, base_type: PsType, const: bool = False):
+        super().__init__(const)
+        self._base_type = base_type
+
+    @property
+    def base_type(self) -> PsType:
+        return self._base_type
+
+
 @final
-class PsPointerType(PsType):
+class PsPointerType(PsSubscriptableType):
     """Class to model C pointer types."""
 
     __match_args__ = ("base_type",)
@@ -121,14 +133,9 @@ class PsPointerType(PsType):
     def __init__(
         self, base_type: PsType, const: bool = False, restrict: bool = True
     ):
-        super().__init__(const)
-        self._base_type = base_type
+        super().__init__(base_type, const)
         self._restrict = restrict
 
-    @property
-    def base_type(self) -> PsType:
-        return self._base_type
-
     @property
     def restrict(self) -> bool:
         return self._restrict
@@ -148,6 +155,33 @@ class PsPointerType(PsType):
 
     def __repr__(self) -> str:
         return f"PsPointerType( {repr(self.base_type)}, const={self.const} )"
+    
+
+class PsArrayType(PsSubscriptableType):
+    """Class that models one-dimensional C arrays"""
+
+    def __init__(self, base_type: PsType, length: int | None = None, const: bool = False):
+        self._length = length
+        super().__init__(base_type, const)
+
+    @property
+    def length(self) -> int | None:
+        return self._length
+    
+    def c_string(self) -> str:
+        return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]"
+    
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsArrayType):
+            return False
+        
+        return self._base_equal(other) and self._base_type == other._base_type and self._length == other._length
+    
+    def __hash__(self) -> int:
+        return hash(("PsArrayType", self._base_type, self._length, self._const))
+
+    def __repr__(self) -> str:
+        return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})"
 
 
 class PsStructType(PsType):
diff --git a/src/pystencils/types/quick.py b/src/pystencils/types/quick.py
index 24b6968c7..60ec516f1 100644
--- a/src/pystencils/types/quick.py
+++ b/src/pystencils/types/quick.py
@@ -15,6 +15,7 @@ from .basic_types import (
     PsScalarType,
     PsBoolType,
     PsPointerType,
+    PsArrayType,
     PsIntegerType,
     PsUnsignedIntegerType,
     PsSignedIntegerType,
@@ -73,6 +74,9 @@ Scalar = PsScalarType
 Ptr = PsPointerType
 """`Ptr(t)` matches `PsPointerType(base_type=t)`"""
 
+Arr = PsArrayType
+"""`Arr(t, s)` matches PsArrayType(base_type=t, size=s)"""
+
 Bool = PsBoolType
 """Bool() matches PsBoolType()"""
 
-- 
GitLab