From bd2c7d881e8f1b6cfcd849acea319aa5bf3af634 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 31 Jan 2024 15:36:24 +0100 Subject: [PATCH] Extend struct support; add lookup freeze and typificaiton; add various tests --- src/pystencils/nbackend/emission.py | 8 ++- .../nbackend/kernelcreation/context.py | 7 +++ .../nbackend/kernelcreation/freeze.py | 37 ++++++++++++-- .../nbackend/kernelcreation/options.py | 10 +++- .../nbackend/kernelcreation/typification.py | 49 ++++++++++++++---- src/pystencils/nbackend/types/basic_types.py | 29 ++++++++--- tests/nbackend/__init__.py | 0 tests/nbackend/kernelcreation/__init__.py | 0 .../{ => kernelcreation}/test_freeze.py | 0 .../kernelcreation/test_index_kernels.py | 31 +++++++++++ tests/nbackend/kernelcreation/test_options.py | 28 ++++++++++ .../{ => kernelcreation}/test_typification.py | 27 +++++++++- tests/nbackend/types/__init__.py | 0 tests/nbackend/types/test_quick_types.py | 30 ----------- tests/nbackend/{ => types}/test_types.py | 51 ++++++++++++++++++- 15 files changed, 247 insertions(+), 60 deletions(-) create mode 100644 tests/nbackend/__init__.py create mode 100644 tests/nbackend/kernelcreation/__init__.py rename tests/nbackend/{ => kernelcreation}/test_freeze.py (100%) create mode 100644 tests/nbackend/kernelcreation/test_index_kernels.py create mode 100644 tests/nbackend/kernelcreation/test_options.py rename tests/nbackend/{ => kernelcreation}/test_typification.py (77%) create mode 100644 tests/nbackend/types/__init__.py delete mode 100644 tests/nbackend/types/test_quick_types.py rename tests/nbackend/{ => types}/test_types.py (56%) diff --git a/src/pystencils/nbackend/emission.py b/src/pystencils/nbackend/emission.py index f25070192..e47854506 100644 --- a/src/pystencils/nbackend/emission.py +++ b/src/pystencils/nbackend/emission.py @@ -13,6 +13,7 @@ from .ast import ( PsConditional, ) from .ast.kernelfunction import PsKernelFunction +from .typed_expressions import PsTypedVariable def emit_code(kernel: PsKernelFunction): @@ -42,7 +43,7 @@ class CPrinter: @visit.case(PsKernelFunction) def function(self, func: PsKernelFunction) -> str: params_spec = func.get_parameters() - params_str = ", ".join(f"{p.dtype} {p.name}" for p in params_spec.params) + params_str = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in params_spec.params) decl = f"FUNC_PREFIX void {func.name} ({params_str})" body = self.visit(func.body) return f"{decl}\n{body}" @@ -64,10 +65,11 @@ class CPrinter: @visit.case(PsDeclaration) def declaration(self, decl: PsDeclaration): lhs_symb = decl.declared_variable.symbol + assert isinstance(lhs_symb, PsTypedVariable) lhs_dtype = lhs_symb.dtype rhs_code = self.visit(decl.rhs) - return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};") + return self.indent(f"{lhs_dtype.c_string()} {lhs_symb.name} = {rhs_code};") @visit.case(PsAssignment) def assignment(self, asm: PsAssignment): @@ -78,6 +80,8 @@ class CPrinter: @visit.case(PsLoop) def loop(self, loop: PsLoop): ctr_symbol = loop.counter.symbol + assert isinstance(ctr_symbol, PsTypedVariable) + ctr = ctr_symbol.name start_code = self.visit(loop.start) stop_code = self.visit(loop.stop) diff --git a/src/pystencils/nbackend/kernelcreation/context.py b/src/pystencils/nbackend/kernelcreation/context.py index a77189331..29eeb0265 100644 --- a/src/pystencils/nbackend/kernelcreation/context.py +++ b/src/pystencils/nbackend/kernelcreation/context.py @@ -120,6 +120,13 @@ class KernelCreationContext: assert isinstance(field.dtype, (BasicType, StructType)) element_type = make_type(field.dtype.numpy_dtype) + # The frontend doesn't quite agree with itself on how to model + # fields with trivial index dimensions. Sometimes the index_shape is empty, + # sometimes its (1,). This is canonicalized here. + if not field.index_shape: + arr_shape += (1,) + arr_strides += (1,) + arr = PsLinearizedArray( field.name, element_type, arr_shape, arr_strides, self.index_dtype ) diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py index 22eb1dbf4..38f100e72 100644 --- a/src/pystencils/nbackend/kernelcreation/freeze.py +++ b/src/pystencils/nbackend/kernelcreation/freeze.py @@ -20,12 +20,16 @@ from ..ast.nodes import ( PsLvalueExpr, PsExpression, ) -from ..types import constify, make_type +from ..types import constify, make_type, PsStructType from ..typed_expressions import PsTypedVariable from ..arrays import PsArrayAccess from ..exceptions import PsInputError +class FreezeError(Exception): + """Signifies an error during expression freezing.""" + + class FreezeExpressions(SympyToPymbolicMapper): def __init__(self, ctx: KernelCreationContext): self._ctx = ctx @@ -85,7 +89,6 @@ class FreezeExpressions(SympyToPymbolicMapper): ptr = array.base_pointer offsets: list[pb.Expression] = [self.rec(o) for o in access.offsets] - indices: list[pb.Expression] = [self.rec(o) for o in access.index] if not access.is_absolute_access: match field.field_type: @@ -101,7 +104,7 @@ class FreezeExpressions(SympyToPymbolicMapper): # flake8: noqa sparse_ispace = self._ctx.get_sparse_iteration_space() # Add sparse iteration counter to offset - assert len(offsets) == 1 # must have been checked by the context + assert len(offsets) == 1 # must have been checked by the context offsets = [offsets[0] + sparse_ispace.sparse_counter] case FieldType.CUSTOM: raise ValueError("Custom fields support only absolute accesses.") @@ -110,6 +113,26 @@ class FreezeExpressions(SympyToPymbolicMapper): f"Cannot translate accesses to field type {unknown} yet." ) + # If the array type is a struct, accesses are modelled using strings + # In that case, the index is empty + if isinstance(array.element_type, PsStructType): + if isinstance(access.index, str): + struct_member_name = access.index + indices = [0] + elif len(access.index) == 1 and isinstance(access.index[0], str): + struct_member_name = access.index[0] + indices = [0] + else: + raise FreezeError( + f"Unsupported access into field with struct-type elements: {access}" + ) + else: + struct_member_name = None + indices = [self.rec(i) for i in access.index] + if not indices: + # For canonical representation, there must always be at least one index dimension + indices = [0] + summands = tuple( idx * stride for idx, stride in zip(offsets + indices, array.strides, strict=True) @@ -117,8 +140,12 @@ class FreezeExpressions(SympyToPymbolicMapper): index = summands[0] if len(summands) == 1 else pb.Sum(summands) - return PsArrayAccess(ptr, index) - + if struct_member_name is not None: + # Produce a pb.Lookup here, don't check yet if the member name is valid. That's the typifier's job. + return pb.Lookup(PsArrayAccess(ptr, index), struct_member_name) + else: + return PsArrayAccess(ptr, index) + def map_Function(self, func: sp.Function): """Map a SymPy function to a backend-supported function symbol. diff --git a/src/pystencils/nbackend/kernelcreation/options.py b/src/pystencils/nbackend/kernelcreation/options.py index 5f5028a94..53fbbd640 100644 --- a/src/pystencils/nbackend/kernelcreation/options.py +++ b/src/pystencils/nbackend/kernelcreation/options.py @@ -2,7 +2,7 @@ from typing import Sequence from dataclasses import dataclass from ...enums import Target -from ...field import Field +from ...field import Field, FieldType from ..exceptions import PsOptionsError from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType @@ -73,3 +73,11 @@ class KernelCreationOptions: "Parameters `iteration_slice`, `ghost_layers` and 'index_field` are mutually exclusive; " "at most one of them may be set." ) + + if ( + self.index_field is not None + and self.index_field.field_type != FieldType.INDEXED + ): + raise PsOptionsError( + "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`" + ) diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py index f1d299e28..c817bde3f 100644 --- a/src/pystencils/nbackend/kernelcreation/typification.py +++ b/src/pystencils/nbackend/kernelcreation/typification.py @@ -6,7 +6,7 @@ import pymbolic.primitives as pb from pymbolic.mapper import Mapper from .context import KernelCreationContext -from ..types import PsAbstractType, PsNumericType, deconstify +from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant from ..arrays import PsArrayAccess from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment @@ -24,9 +24,10 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) class UndeterminedType(PsNumericType): """Placeholder for types that could not yet be determined by the typifier. - + Instances of this class should never leave the typifier; it is an error if they do. """ + def create_constant(self, value: Any) -> Any: return None @@ -51,7 +52,7 @@ class UndeterminedType(PsNumericType): def __eq__(self, other: object) -> bool: self._err() - def _c_string(self) -> str: + def c_string(self) -> str: self._err() @@ -69,7 +70,7 @@ class DeferredTypedConstant(PsTypedConstant): class TypeContext: - def __init__(self, target_type: PsNumericType | None): + def __init__(self, target_type: PsAbstractType | None): self._target_type = deconstify(target_type) if target_type is not None else None self._deferred_constants: list[DeferredTypedConstant] = [] @@ -78,18 +79,28 @@ class TypeContext: dc = DeferredTypedConstant(value) self._deferred_constants.append(dc) return dc + elif not isinstance(self._target_type, PsNumericType): + raise TypificationError( + f"Can't typify constant with non-numeric type {self._target_type}" + ) else: return PsTypedConstant(value, self._target_type) - def apply(self, target_type: PsNumericType): + def apply(self, target_type: PsAbstractType): assert self._target_type is None, "Type context was already resolved" self._target_type = deconstify(target_type) + for dc in self._deferred_constants: + if not isinstance(self._target_type, PsNumericType): + raise TypificationError( + f"Can't typify constant with non-numeric type {self._target_type}" + ) dc.resolve(self._target_type) + self._deferred_constants = [] @property - def target_type(self) -> PsNumericType | None: + def target_type(self) -> PsAbstractType | None: return self._target_type @@ -194,15 +205,32 @@ class Typifier(Mapper): return tc.make_constant(value) - # Array Access + # Array Accesses and Lookups def map_array_access(self, access: PsArrayAccess, tc: TypeContext) -> PsArrayAccess: self._apply_target_type(access, access.dtype, tc) - index, _ = self.rec( + index = self.rec( access.index_tuple[0], TypeContext(self._ctx.options.index_dtype) ) return PsArrayAccess(access.base_ptr, index) + def map_lookup(self, lookup: pb.Lookup, tc: TypeContext) -> pb.Lookup: + aggr_tc = TypeContext(None) + aggregate = self.rec(lookup.aggregate, aggr_tc) + aggr_type = aggr_tc.target_type + + if not isinstance(aggr_type, PsStructType): + raise TypificationError("Aggregate type of lookup was not a struct type.") + + member = aggr_type.get_member(lookup.name) + if member is None: + raise TypificationError( + f"Aggregate of type {aggr_type} does not have a member {member}." + ) + + self._apply_target_type(lookup, member.dtype, tc) + return pb.Lookup(aggregate, member.name) + # Arithmetic Expressions def map_sum(self, expr: pb.Sum, tc: TypeContext) -> pb.Sum: @@ -210,7 +238,7 @@ class Typifier(Mapper): def map_product(self, expr: pb.Product, tc: TypeContext) -> pb.Product: return pb.Product(tuple(self.rec(c, tc) for c in expr.children)) - + # Functions def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call: @@ -218,14 +246,13 @@ class Typifier(Mapper): TODO: Figure out how to describe function signatures """ raise NotImplementedError() - + # Internals def _apply_target_type( self, expr: ExprOrConstant, expr_type: PsAbstractType, tc: TypeContext ): if tc.target_type is None: - assert isinstance(expr_type, PsNumericType) tc.apply(expr_type) elif deconstify(expr_type) != tc.target_type: raise TypificationError( diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py index bb27e3493..ce27c6fb3 100644 --- a/src/pystencils/nbackend/types/basic_types.py +++ b/src/pystencils/nbackend/types/basic_types.py @@ -67,7 +67,7 @@ class PsAbstractType(ABC): return "const " if self._const else "" @abstractmethod - def _c_string(self) -> str: + def c_string(self) -> str: ... # ------------------------------------------------------------------------------------------- @@ -79,7 +79,7 @@ class PsAbstractType(ABC): ... def __str__(self) -> str: - return self._c_string() + return self.c_string() @abstractmethod def __hash__(self) -> int: @@ -107,7 +107,7 @@ class PsCustomType(PsAbstractType): def __hash__(self) -> int: return hash(("PsCustomType", self._name, self._const)) - def _c_string(self) -> str: + def c_string(self) -> str: return f"{self._const_string()} {self._name}" def __repr__(self) -> str: @@ -143,8 +143,8 @@ class PsPointerType(PsAbstractType): def __hash__(self) -> int: return hash(("PsPointerType", self._base_type, self._restrict, self._const)) - def _c_string(self) -> str: - base_str = self._base_type._c_string() + def c_string(self) -> str: + base_str = self._base_type.c_string() restrict_str = " RESTRICT" if self._restrict else "" return f"{base_str} *{restrict_str} {self._const_string()}" @@ -189,6 +189,13 @@ class PsStructType(PsAbstractType): def members(self) -> tuple[PsStructType.Member, ...]: return self._members + def get_member(self, member_name: str) -> PsStructType.Member | None: + """Find a member by name""" + for m in self._members: + if m.name == member_name: + return m + return None + @property def name(self) -> str: if self._name is None: @@ -206,12 +213,18 @@ class PsStructType(PsAbstractType): members = [(m.name, m.dtype.numpy_dtype) for m in self._members] return np.dtype(members) - def _c_string(self) -> str: + def c_string(self) -> str: if self._name is None: raise PsInternalCompilerError( "Cannot retrieve C string for anonymous struct type" ) return self._name + + def __str__(self) -> str: + if self._name is None: + return "<anonymous>" + else: + return self._name def __eq__(self, other: object) -> bool: if not isinstance(other, PsStructType): @@ -359,7 +372,7 @@ class PsIntegerType(PsScalarType, ABC): def __hash__(self) -> int: return hash(("PsIntegerType", self._width, self._signed, self._const)) - def _c_string(self) -> str: + def c_string(self) -> str: prefix = "" if self._signed else "u" return f"{self._const_string()}{prefix}int{self._width}_t" @@ -499,7 +512,7 @@ class PsIeeeFloatType(PsScalarType): def __hash__(self) -> int: return hash(("PsIeeeFloatType", self._width, self._const)) - def _c_string(self) -> str: + def c_string(self) -> str: match self._width: case 16: return f"{self._const_string()}half" diff --git a/tests/nbackend/__init__.py b/tests/nbackend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/nbackend/kernelcreation/__init__.py b/tests/nbackend/kernelcreation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/nbackend/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py similarity index 100% rename from tests/nbackend/test_freeze.py rename to tests/nbackend/kernelcreation/test_freeze.py diff --git a/tests/nbackend/kernelcreation/test_index_kernels.py b/tests/nbackend/kernelcreation/test_index_kernels.py new file mode 100644 index 000000000..fde27632e --- /dev/null +++ b/tests/nbackend/kernelcreation/test_index_kernels.py @@ -0,0 +1,31 @@ +import pytest + +import sympy as sp +import numpy as np + +from pystencils import Assignment, Field, FieldType, AssignmentCollection +from pystencils.nbackend.kernelcreation import create_kernel, KernelCreationOptions +from pystencils.cpu.cpujit import compile_and_load + +def test_indexed_kernel(): + arr = np.zeros((3, 4)) + dtype = np.dtype([('x', int), ('y', int), ('value', arr.dtype)]) + index_arr = np.zeros((3,), dtype=dtype) + index_arr[0] = (0, 2, 3.0) + index_arr[1] = (1, 3, 42.0) + index_arr[2] = (2, 1, 5.0) + + index_field = Field.create_from_numpy_array('index', index_arr, field_type=FieldType.INDEXED) + normal_field = Field.create_from_numpy_array('f', arr) + update_rule = AssignmentCollection([ + Assignment(normal_field[0, 0], index_field('value')) + ]) + + options = KernelCreationOptions(index_field=index_field) + ast = create_kernel(update_rule, options) + kernel = compile_and_load(ast) + + kernel(f=arr, index=index_arr) + + for i in range(index_arr.shape[0]): + np.testing.assert_allclose(arr[index_arr[i]['x'], index_arr[i]['y']], index_arr[i]['value'], atol=1e-13) diff --git a/tests/nbackend/kernelcreation/test_options.py b/tests/nbackend/kernelcreation/test_options.py new file mode 100644 index 000000000..7f26288d4 --- /dev/null +++ b/tests/nbackend/kernelcreation/test_options.py @@ -0,0 +1,28 @@ +import pytest + +from pystencils.field import Field, FieldType +from pystencils.nbackend.types.quick import * +from pystencils.nbackend.kernelcreation.options import ( + KernelCreationOptions, + PsOptionsError, +) + + +def test_invalid_iteration_region_options(): + idx_field = Field.create_generic( + "idx", spatial_dimensions=1, field_type=FieldType.INDEXED + ) + with pytest.raises(PsOptionsError): + KernelCreationOptions( + ghost_layers=2, iteration_slice=(slice(1, -1), slice(1, -1)) + ) + with pytest.raises(PsOptionsError): + KernelCreationOptions(ghost_layers=2, index_field=idx_field) + + +def test_index_field_options(): + with pytest.raises(PsOptionsError): + idx_field = Field.create_generic( + "idx", spatial_dimensions=1, field_type=FieldType.GENERIC + ) + KernelCreationOptions(index_field=idx_field) diff --git a/tests/nbackend/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py similarity index 77% rename from tests/nbackend/test_typification.py rename to tests/nbackend/kernelcreation/test_typification.py index ae477fe19..e5e88b2f6 100644 --- a/tests/nbackend/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -3,10 +3,11 @@ import sympy as sp import numpy as np import pymbolic.primitives as pb -from pystencils import Assignment, TypedSymbol +from pystencils import Assignment, TypedSymbol, Field, FieldType from pystencils.nbackend.ast import PsDeclaration -from pystencils.nbackend.types import constify, make_numeric_type +from pystencils.nbackend.types import constify, deconstify, PsStructType +from pystencils.nbackend.types.quick import * from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable from pystencils.nbackend.kernelcreation.options import KernelCreationOptions from pystencils.nbackend.kernelcreation.context import KernelCreationContext @@ -45,6 +46,28 @@ def test_typify_simple(): check(fasm.rhs.expression) +def test_typify_structs(): + options = KernelCreationOptions(default_dtype=Fp(32)) + ctx = KernelCreationContext(options) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + np_struct = np.dtype([("size", np.uint32), ("data", np.float32)]) + f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM) + x = sp.Symbol("x") + + # Good + asm = Assignment(x, f.absolute_access((0,), "data")) + fasm = freeze(asm) + fasm = typify(fasm) + + # Bad + asm = Assignment(x, f.absolute_access((0,), "size")) + fasm = freeze(asm) + with pytest.raises(TypificationError): + fasm = typify(fasm) + + def test_contextual_typing(): options = KernelCreationOptions() ctx = KernelCreationContext(options) diff --git a/tests/nbackend/types/__init__.py b/tests/nbackend/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/nbackend/types/test_quick_types.py b/tests/nbackend/types/test_quick_types.py deleted file mode 100644 index f45bf565d..000000000 --- a/tests/nbackend/types/test_quick_types.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -from pystencils.nbackend.types.quick import * - - -def test_parsing_positive(): - assert make_type("const uint32_t * restrict") == Ptr(UInt(32, const=True), restrict=True) - assert make_type("float * * const") == Ptr(Ptr(Fp(32)), const=True) - assert make_type("uint16 * const") == Ptr(UInt(16), const=True) - assert make_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True) - -def test_parsing_negative(): - bad_specs = [ - "const notatype * const", - "cnost uint32_t", - "uint45_t", - "int", # plain ints are ambiguous - "float float", - "double * int", - "bool" - ] - - for spec in bad_specs: - with pytest.raises(ValueError): - make_type(spec) - -def test_numpy(): - import numpy as np - assert make_type(np.single) == make_type(np.float32) == PsIeeeFloatType(32) - assert make_type(float) == make_type(np.double) == make_type(np.float64) == PsIeeeFloatType(64) - assert make_type(int) == make_type(np.int64) == PsSignedIntegerType(64) diff --git a/tests/nbackend/test_types.py b/tests/nbackend/types/test_types.py similarity index 56% rename from tests/nbackend/test_types.py rename to tests/nbackend/types/test_types.py index ba5746222..082b39205 100644 --- a/tests/nbackend/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -6,6 +6,54 @@ from pystencils.nbackend.types import * from pystencils.nbackend.types.quick import * +@pytest.mark.parametrize("Type", [PsSignedIntegerType, PsUnsignedIntegerType, PsIeeeFloatType]) +def test_widths(Type): + for width in Type.SUPPORTED_WIDTHS: + assert Type(width).width == width + + for width in (1, 9, 33, 63): + with pytest.raises(ValueError): + Type(width) + + +def test_parsing_positive(): + assert make_type("const uint32_t * restrict") == Ptr( + UInt(32, const=True), restrict=True + ) + assert make_type("float * * const") == Ptr(Ptr(Fp(32)), const=True) + assert make_type("uint16 * const") == Ptr(UInt(16), const=True) + assert make_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True) + + +def test_parsing_negative(): + bad_specs = [ + "const notatype * const", + "cnost uint32_t", + "uint45_t", + "int", # plain ints are ambiguous + "float float", + "double * int", + "bool", + ] + + for spec in bad_specs: + with pytest.raises(ValueError): + make_type(spec) + + +def test_numpy(): + import numpy as np + + assert make_type(np.single) == make_type(np.float32) == PsIeeeFloatType(32) + assert ( + make_type(float) + == make_type(np.double) + == make_type(np.float64) + == PsIeeeFloatType(64) + ) + assert make_type(int) == make_type(np.int64) == PsSignedIntegerType(64) + + @pytest.mark.parametrize( "numpy_type", [ @@ -68,5 +116,6 @@ def test_struct_types(): ) assert t.anonymous + assert str(t) == "<anonymous>" with pytest.raises(PsInternalCompilerError): - str(t) + t.c_string() -- GitLab