From 57da00b47c2666497aae54a94a3563c2218199ff Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Sun, 10 Mar 2024 13:37:56 +0100 Subject: [PATCH] introduce basic array handling --- src/pystencils/backend/ast/expressions.py | 26 ++++++++ src/pystencils/backend/emission.py | 30 ++++++++-- .../backend/kernelcreation/freeze.py | 18 +++++- .../backend/kernelcreation/typification.py | 60 ++++++++++++++++++- .../backend/platforms/generic_cpu.py | 4 +- src/pystencils/types/__init__.py | 4 ++ src/pystencils/types/basic_types.py | 48 ++++++++++++--- src/pystencils/types/quick.py | 4 ++ 8 files changed, 176 insertions(+), 18 deletions(-) diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index a2c13548a..7c24b55b2 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -161,6 +161,9 @@ class PsSubscript(PsLvalueExpr): case 1: self.index = failing_cast(PsExpression, c) + def __str__(self) -> str: + return f"Subscript({self._base})[{self._index}]" + class PsArrayAccess(PsSubscript): __match_args__ = ("base_ptr", "index") @@ -484,3 +487,26 @@ class PsDiv(PsBinOp): # python_operator not implemented because can't unambigously decide # between intdiv and truediv pass + + +class PsArrayInitList(PsExpression): + __match_args__ = ("items",) + + def __init__(self, items: Sequence[PsExpression]): + self._items = list(items) + + @property + def items(self) -> list[PsExpression]: + return self._items + + def get_children(self) -> tuple[PsAstNode, ...]: + return tuple(self._items) + + def set_child(self, idx: int, c: PsAstNode): + self._items[idx] = failing_cast(PsExpression, c) + + def clone(self) -> PsExpression: + return PsArrayInitList([expr.clone() for expr in self._items]) + + def __repr__(self) -> str: + return f"PsArrayInitList({repr(self._items)})" diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 054cd9b44..190f4f001 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -28,9 +28,11 @@ from .ast.expressions import ( PsDeref, PsAddressOf, PsCast, + PsArrayInitList, ) -from ..types import PsScalarType +from .symbols import PsSymbol +from ..types import PsScalarType, PsArrayType from .kernelfunction import KernelFunction @@ -153,12 +155,10 @@ class CAstPrinter: case PsDeclaration(lhs, rhs): lhs_symb = lhs.symbol - lhs_dtype = lhs_symb.get_dtype() + lhs_code = self._symbol_decl(lhs_symb) rhs_code = self.visit(rhs, pc) - return pc.indent( - f"{lhs_dtype.c_string()} {lhs_symb.name} = {rhs_code};" - ) + return pc.indent(f"{lhs_code} = {rhs_code};") case PsAssignment(lhs, rhs): lhs_code = self.visit(lhs, pc) @@ -281,9 +281,29 @@ class CAstPrinter: type_str = target_type.c_string() return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast) + case PsArrayInitList(items): + pc.push_op(Ops.Weakest, LR.Middle) + items_str = ", ".join(self.visit(item, pc) for item in items) + pc.pop_op() + return "{ " + items_str + " }" + case _: raise NotImplementedError(f"Don't know how to print {node}") + def _symbol_decl(self, symb: PsSymbol): + dtype = symb.get_dtype() + + array_dims = [] + while isinstance(dtype, PsArrayType): + array_dims.append(dtype.length) + dtype = dtype.base_type + + code = f"{dtype.c_string()} {symb.name}" + for d in array_dims: + code += f"[{str(d) if d is not None else ''}]" + + return code + def _char_and_op(self, node: PsBinOp) -> tuple[str, Ops]: match node: case PsAdd(): diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index a13f21ae2..e0a5b130f 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -24,6 +24,8 @@ from ..ast.expressions import ( PsLookup, PsCall, PsConstantExpr, + PsArrayInitList, + PsSubscript, ) from ..constants import PsConstant @@ -86,7 +88,7 @@ class FreezeExpressions: raise FreezeError(f"Don't know how to freeze {obj}") def visit_expr(self, expr: sp.Basic): - if not isinstance(expr, sp.Expr): + if not isinstance(expr, (sp.Expr, sp.Tuple)): raise FreezeError(f"Cannot freeze {expr} to an expression") return cast(PsExpression, self.visit(expr)) @@ -182,7 +184,7 @@ class FreezeExpressions: def map_Integer(self, expr: sp.Integer) -> PsConstantExpr: value = int(expr) return PsConstantExpr(PsConstant(value)) - + def map_Float(self, expr: sp.Float) -> PsConstantExpr: value = float(expr) # TODO: check accuracy of evaluation return PsConstantExpr(PsConstant(value)) @@ -197,6 +199,18 @@ class FreezeExpressions: symb = self._ctx.get_symbol(expr.name, dtype) return PsSymbolExpr(symb) + def map_Tuple(self, expr: sp.Tuple) -> PsArrayInitList: + items = [self.visit_expr(item) for item in expr] + return PsArrayInitList(items) + + def map_Indexed(self, expr: sp.Indexed) -> PsSubscript: + assert isinstance(expr.base, sp.IndexedBase) + base = self.visit_expr(expr.base.label) + subscript = PsSubscript(base, self.visit_expr(expr.indices[0])) + for idx in expr.indices[1:]: + subscript = PsSubscript(subscript, self.visit_expr(idx)) + return subscript + def map_Access(self, access: Field.Access): field = access.field array = self._ctx.get_array(field) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index ef04a617d..cc526e895 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -3,15 +3,25 @@ from __future__ import annotations from typing import TypeVar from .context import KernelCreationContext -from ...types import PsType, PsNumericType, PsStructType, PsIntegerType, deconstify +from ...types import ( + PsType, + PsNumericType, + PsStructType, + PsIntegerType, + PsArrayType, + PsSubscriptableType, + deconstify, +) from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment from ..ast.expressions import ( PsSymbolExpr, PsConstantExpr, PsBinOp, PsArrayAccess, + PsSubscript, PsLookup, PsCall, + PsArrayInitList, ) from ..functions import PsMathFunction @@ -192,6 +202,26 @@ class Typifier: f"Array index is not of integer type: {idx} has type {index_tc.target_type}" ) + case PsSubscript(arr, idx): + arr_tc = TypeContext() + self.visit_expr(arr, arr_tc) + + if not isinstance(arr_tc.target_type, PsSubscriptableType): + raise TypificationError( + "Type of subscript base is not subscriptable." + ) + + tc.apply_and_check(expr, arr_tc.target_type.base_type) + + index_tc = TypeContext() + self.visit_expr(idx, index_tc) + if index_tc.target_type is None: + index_tc.apply_and_check(idx, self._ctx.index_dtype) + elif not isinstance(index_tc.target_type, PsIntegerType): + raise TypificationError( + f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}" + ) + case PsLookup(aggr, member_name): aggr_tc = TypeContext(None) self.visit_expr(aggr, aggr_tc) @@ -199,7 +229,7 @@ class Typifier: if not isinstance(aggr_type, PsStructType): raise TypificationError( - "Aggregate type of lookup was not a struct type." + "Aggregate type of lookup is not a struct type." ) member = aggr_type.find_member(member_name) @@ -224,5 +254,31 @@ class Typifier: f"Don't know how to typify calls to {function}" ) + case PsArrayInitList(items): + items_tc = TypeContext() + for item in items: + self.visit_expr(item, items_tc) + + if items_tc.target_type is None: + if tc.target_type is None: + raise TypificationError(f"Unable to infer type of array {expr}") + elif not isinstance(tc.target_type, PsArrayType): + raise TypificationError( + f"Cannot apply type {tc.target_type} to an array initializer." + ) + elif ( + tc.target_type.length is not None + and tc.target_type.length != len(items) + ): + raise TypificationError( + "Array size mismatch: Cannot typify initializer list with " + f"{len(items)} items as {tc.target_type}" + ) + else: + items_tc.apply_and_check(expr, tc.target_type.base_type) + else: + arr_type = PsArrayType(items_tc.target_type, len(items)) + tc.apply_and_check(expr, arr_type) + case _: raise NotImplementedError(f"Can't typify {expr}") diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 6b19a88ff..8c53a9f16 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -6,7 +6,7 @@ from .platform import Platform from ..kernelcreation.iteration_space import ( IterationSpace, FullIterationSpace, - SparseIterationSpace + SparseIterationSpace, ) from ..constants import PsConstant @@ -43,7 +43,7 @@ class GenericCpu(Platform): def _create_domain_loops( self, body: PsBlock, ispace: FullIterationSpace ) -> PsBlock: - + dimensions = ispace.dimensions # Determine loop order by permuting dimensions diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py index 0c475c8ac..140151b2b 100644 --- a/src/pystencils/types/__init__.py +++ b/src/pystencils/types/__init__.py @@ -5,7 +5,9 @@ from .basic_types import ( PsNumericType, PsScalarType, PsVectorType, + PsSubscriptableType, PsPointerType, + PsArrayType, PsBoolType, PsIntegerType, PsUnsignedIntegerType, @@ -23,7 +25,9 @@ __all__ = [ "PsType", "PsCustomType", "PsStructType", + "PsSubscriptableType", "PsPointerType", + "PsArrayType", "PsNumericType", "PsScalarType", "PsVectorType", diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py index 55055205f..4648c3210 100644 --- a/src/pystencils/types/basic_types.py +++ b/src/pystencils/types/basic_types.py @@ -112,8 +112,20 @@ class PsCustomType(PsType): return f"CustomType( {self.name}, const={self.const} )" +class PsSubscriptableType(PsType, ABC): + __match_args__ = ("base_type",) + + def __init__(self, base_type: PsType, const: bool = False): + super().__init__(const) + self._base_type = base_type + + @property + def base_type(self) -> PsType: + return self._base_type + + @final -class PsPointerType(PsType): +class PsPointerType(PsSubscriptableType): """Class to model C pointer types.""" __match_args__ = ("base_type",) @@ -121,14 +133,9 @@ class PsPointerType(PsType): def __init__( self, base_type: PsType, const: bool = False, restrict: bool = True ): - super().__init__(const) - self._base_type = base_type + super().__init__(base_type, const) self._restrict = restrict - @property - def base_type(self) -> PsType: - return self._base_type - @property def restrict(self) -> bool: return self._restrict @@ -148,6 +155,33 @@ class PsPointerType(PsType): def __repr__(self) -> str: return f"PsPointerType( {repr(self.base_type)}, const={self.const} )" + + +class PsArrayType(PsSubscriptableType): + """Class that models one-dimensional C arrays""" + + def __init__(self, base_type: PsType, length: int | None = None, const: bool = False): + self._length = length + super().__init__(base_type, const) + + @property + def length(self) -> int | None: + return self._length + + def c_string(self) -> str: + return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsArrayType): + return False + + return self._base_equal(other) and self._base_type == other._base_type and self._length == other._length + + def __hash__(self) -> int: + return hash(("PsArrayType", self._base_type, self._length, self._const)) + + def __repr__(self) -> str: + return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})" class PsStructType(PsType): diff --git a/src/pystencils/types/quick.py b/src/pystencils/types/quick.py index 24b6968c7..60ec516f1 100644 --- a/src/pystencils/types/quick.py +++ b/src/pystencils/types/quick.py @@ -15,6 +15,7 @@ from .basic_types import ( PsScalarType, PsBoolType, PsPointerType, + PsArrayType, PsIntegerType, PsUnsignedIntegerType, PsSignedIntegerType, @@ -73,6 +74,9 @@ Scalar = PsScalarType Ptr = PsPointerType """`Ptr(t)` matches `PsPointerType(base_type=t)`""" +Arr = PsArrayType +"""`Arr(t, s)` matches PsArrayType(base_type=t, size=s)""" + Bool = PsBoolType """Bool() matches PsBoolType()""" -- GitLab