diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index a8a3dbc8d467585ba71570d6c0df7d31a8738dcf..3057f68706a2d920e0927c8fa5553be2d56e9c2d 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -470,7 +470,7 @@ class CustomSympyPrinter(CCodePrinter): else: return f'fabs({self._print(expr.args[0])})' - def _print_Type(self, node): + def _print_AbstractType(self, node): return str(node) def _print_Function(self, expr): diff --git a/pystencils/field.py b/pystencils/field.py index 4a29a1be201cf34cc396e5da355dc77fdba23446..ceaeb82acc1a5eebfa6880fe5b8004ed832ee415 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -13,7 +13,7 @@ from sympy.core.cache import cacheit import pystencils from pystencils.alignedarray import aligned_empty -from pystencils.typing import StructType, TypedSymbol, create_type +from pystencils.typing import StructType, TypedSymbol, BasicType, create_type from pystencils.typing.typed_sympy import FieldShapeSymbol, FieldStrideSymbol from pystencils.stencil import ( direction_string_to_offset, inverse_direction, offset_to_direction_string) @@ -673,7 +673,11 @@ class Field: if superscript is not None: symbol_name += "^" + superscript - obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype) + if dtype: + obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype) + else: + obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype) + obj._field = field obj._offsets = [] for o in offsets: @@ -716,7 +720,11 @@ class Field: if len(idx) != self.field.index_dimensions: raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}") - return Field.Access(self.field, self._offsets, idx, dtype=self.dtype) + if len(idx) == 1 and isinstance(idx[0], str): + dtype = BasicType(self.field.dtype.numpy_dtype[idx[0]]) + return Field.Access(self.field, self._offsets, idx, dtype=dtype) + else: + return Field.Access(self.field, self._offsets, idx, dtype=self.dtype) def __getitem__(self, *idx): return self.__call__(*idx) diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py index b70c8bb483a3dfe7a25bb7f45191bebb871cf18c..8200e96974d1859d86a0315134eeac1c13d9877a 100644 --- a/pystencils/typing/cast_functions.py +++ b/pystencils/typing/cast_functions.py @@ -15,8 +15,9 @@ class CastFunc(sp.Function): pass expr, dtype, *other_args = args - # If we have two consecutive casts, throw the inner one away - if isinstance(expr, CastFunc): + # If we have two consecutive casts, throw the inner one away. + # This optimisation is only available for simple casts. Thus the == is intended here! + if expr.__class__ == CastFunc: expr = expr.args[0] if not isinstance(dtype, AbstractType): dtype = BasicType(dtype) diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index f92a0db73df6f787a715b35478c7f406d10e78b5..b620c9c7e1257881522db57b8a1f5aad138a8485 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -113,6 +113,7 @@ class TypeAdder: # TOOO: check the access if isinstance(expr, Field.Access): + # TODO if Struct, look at the reinterpreted dtype return expr, expr.dtype elif isinstance(expr, TypedSymbol): return expr, expr.dtype diff --git a/pystencils_tests/test_indexed_kernels.py b/pystencils_tests/test_indexed_kernels.py index 87b24b354accb1f5936660378078d8a9f70a94fa..fa06a8f166702b53519a398bc544fcdc30f5cc94 100644 --- a/pystencils_tests/test_indexed_kernels.py +++ b/pystencils_tests/test_indexed_kernels.py @@ -1,4 +1,5 @@ import numpy as np +import pystencils as ps from pystencils import Assignment, Field, CreateKernelConfig, create_kernel, Target @@ -18,6 +19,7 @@ def test_indexed_kernel(): ast = create_kernel([update_rule], config=config) kernel = ast.compile() kernel(f=arr, index=index_arr) + code = ps.get_code_str(kernel) for i in range(index_arr.shape[0]): np.testing.assert_allclose(arr[index_arr[i]['x'], index_arr[i]['y']], index_arr[i]['value'], atol=1e-13) diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index 2310f4c1dd06e33b1c2a9d0aa34e281f6a804252..5d839da2fff0830980851ef1b605a71b53b57dd5 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -4,6 +4,7 @@ import pytest import pystencils as ps from pystencils.astnodes import SympyAssignment +from pystencils.node_collection import NodeCollection from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets from pystencils.cpu.cpujit import get_compiler_config @@ -163,13 +164,15 @@ def test_rng_symbol(vectorized): dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target=Target.CPU) f = dh.add_array("f", values_per_cell=2 * dh.dim, alignment=True) - ac = ps.AssignmentCollection([ps.Assignment(f(i), 0) for i in range(f.shape[-1])]) - rng_symbol_gen = random_symbol(ac.subexpressions, dim=dh.dim) + nc = NodeCollection([SympyAssignment(f(i), 0) for i in range(f.shape[-1])]) + subexpressions = [] + rng_symbol_gen = random_symbol(subexpressions, dim=dh.dim) for i in range(f.shape[-1]): - ac.main_assignments[i] = ps.Assignment(ac.main_assignments[i].lhs, next(rng_symbol_gen)) - symbols = [a.rhs for a in ac.main_assignments] + nc.all_assignments[i] = SympyAssignment(nc.all_assignments[i].lhs, next(rng_symbol_gen)) + symbols = [a.rhs for a in nc.all_assignments] + [nc.all_assignments.insert(0, subexpression) for subexpression in subexpressions] assert len(symbols) == f.shape[-1] and len(set(symbols)) == f.shape[-1] - ps.create_kernel(ac, target=dh.default_target, cpu_vectorize_info=cpu_vectorize_info).compile() + ps.create_kernel(nc, target=dh.default_target, cpu_vectorize_info=cpu_vectorize_info).compile() @pytest.mark.parametrize('vectorized', (False, True))