From 10f8f119784caff63fee557fa9f1fa867a60cfa6 Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Fri, 21 Jan 2022 12:12:28 +0100 Subject: [PATCH] Fix indexed kernels --- pystencils/backends/cbackend.py | 2 +- pystencils/field.py | 14 +++++++++++--- pystencils/typing/cast_functions.py | 5 +++-- pystencils/typing/leaf_typing.py | 1 + pystencils_tests/test_indexed_kernels.py | 2 ++ pystencils_tests/test_random.py | 13 ++++++++----- 6 files changed, 26 insertions(+), 11 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index a8a3dbc8d..3057f6870 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -470,7 +470,7 @@ class CustomSympyPrinter(CCodePrinter): else: return f'fabs({self._print(expr.args[0])})' - def _print_Type(self, node): + def _print_AbstractType(self, node): return str(node) def _print_Function(self, expr): diff --git a/pystencils/field.py b/pystencils/field.py index 4a29a1be2..ceaeb82ac 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -13,7 +13,7 @@ from sympy.core.cache import cacheit import pystencils from pystencils.alignedarray import aligned_empty -from pystencils.typing import StructType, TypedSymbol, create_type +from pystencils.typing import StructType, TypedSymbol, BasicType, create_type from pystencils.typing.typed_sympy import FieldShapeSymbol, FieldStrideSymbol from pystencils.stencil import ( direction_string_to_offset, inverse_direction, offset_to_direction_string) @@ -673,7 +673,11 @@ class Field: if superscript is not None: symbol_name += "^" + superscript - obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype) + if dtype: + obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype) + else: + obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype) + obj._field = field obj._offsets = [] for o in offsets: @@ -716,7 +720,11 @@ class Field: if len(idx) != self.field.index_dimensions: raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}") - return Field.Access(self.field, self._offsets, idx, dtype=self.dtype) + if len(idx) == 1 and isinstance(idx[0], str): + dtype = BasicType(self.field.dtype.numpy_dtype[idx[0]]) + return Field.Access(self.field, self._offsets, idx, dtype=dtype) + else: + return Field.Access(self.field, self._offsets, idx, dtype=self.dtype) def __getitem__(self, *idx): return self.__call__(*idx) diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py index b70c8bb48..8200e9697 100644 --- a/pystencils/typing/cast_functions.py +++ b/pystencils/typing/cast_functions.py @@ -15,8 +15,9 @@ class CastFunc(sp.Function): pass expr, dtype, *other_args = args - # If we have two consecutive casts, throw the inner one away - if isinstance(expr, CastFunc): + # If we have two consecutive casts, throw the inner one away. + # This optimisation is only available for simple casts. Thus the == is intended here! + if expr.__class__ == CastFunc: expr = expr.args[0] if not isinstance(dtype, AbstractType): dtype = BasicType(dtype) diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index f92a0db73..b620c9c7e 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -113,6 +113,7 @@ class TypeAdder: # TOOO: check the access if isinstance(expr, Field.Access): + # TODO if Struct, look at the reinterpreted dtype return expr, expr.dtype elif isinstance(expr, TypedSymbol): return expr, expr.dtype diff --git a/pystencils_tests/test_indexed_kernels.py b/pystencils_tests/test_indexed_kernels.py index 87b24b354..fa06a8f16 100644 --- a/pystencils_tests/test_indexed_kernels.py +++ b/pystencils_tests/test_indexed_kernels.py @@ -1,4 +1,5 @@ import numpy as np +import pystencils as ps from pystencils import Assignment, Field, CreateKernelConfig, create_kernel, Target @@ -18,6 +19,7 @@ def test_indexed_kernel(): ast = create_kernel([update_rule], config=config) kernel = ast.compile() kernel(f=arr, index=index_arr) + code = ps.get_code_str(kernel) for i in range(index_arr.shape[0]): np.testing.assert_allclose(arr[index_arr[i]['x'], index_arr[i]['y']], index_arr[i]['value'], atol=1e-13) diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index 2310f4c1d..5d839da2f 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -4,6 +4,7 @@ import pytest import pystencils as ps from pystencils.astnodes import SympyAssignment +from pystencils.node_collection import NodeCollection from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets from pystencils.cpu.cpujit import get_compiler_config @@ -163,13 +164,15 @@ def test_rng_symbol(vectorized): dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target=Target.CPU) f = dh.add_array("f", values_per_cell=2 * dh.dim, alignment=True) - ac = ps.AssignmentCollection([ps.Assignment(f(i), 0) for i in range(f.shape[-1])]) - rng_symbol_gen = random_symbol(ac.subexpressions, dim=dh.dim) + nc = NodeCollection([SympyAssignment(f(i), 0) for i in range(f.shape[-1])]) + subexpressions = [] + rng_symbol_gen = random_symbol(subexpressions, dim=dh.dim) for i in range(f.shape[-1]): - ac.main_assignments[i] = ps.Assignment(ac.main_assignments[i].lhs, next(rng_symbol_gen)) - symbols = [a.rhs for a in ac.main_assignments] + nc.all_assignments[i] = SympyAssignment(nc.all_assignments[i].lhs, next(rng_symbol_gen)) + symbols = [a.rhs for a in nc.all_assignments] + [nc.all_assignments.insert(0, subexpression) for subexpression in subexpressions] assert len(symbols) == f.shape[-1] and len(set(symbols)) == f.shape[-1] - ps.create_kernel(ac, target=dh.default_target, cpu_vectorize_info=cpu_vectorize_info).compile() + ps.create_kernel(nc, target=dh.default_target, cpu_vectorize_info=cpu_vectorize_info).compile() @pytest.mark.parametrize('vectorized', (False, True)) -- GitLab