diff --git a/backends/cbackend.py b/backends/cbackend.py index 83dfc48ff65aa88bee5abe2be26154f978c9dc01..54ef38149199d844d25561e7878d281875fc8ad1 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -12,7 +12,7 @@ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift bitwise_or, modulo_ceil from pystencils.astnodes import Node, KernelFunction from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ - vector_memory_access + vector_memory_access, reinterpret_cast_func __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] @@ -251,7 +251,10 @@ class CustomSympyPrinter(CCodePrinter): } if hasattr(expr, 'to_c'): return expr.to_c(self._print) - if isinstance(expr, cast_func): + if isinstance(expr, reinterpret_cast_func): + arg, data_type = expr.args + return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) + elif isinstance(expr, cast_func): arg, data_type = expr.args if isinstance(arg, sp.Number): return self._typed_number(arg, data_type) diff --git a/data_types.py b/data_types.py index fc1bd44e54b04998e424ae26bed42c6dc136b94f..56d19d4f3c88d6e2c4040d100ea47d787f12ad3d 100644 --- a/data_types.py +++ b/data_types.py @@ -56,6 +56,11 @@ class vector_memory_access(cast_func): nargs = (4,) +# noinspection PyPep8Naming +class reinterpret_cast_func(cast_func): + pass + + # noinspection PyPep8Naming class pointer_arithmetic_func(sp.Function, Boolean): @property diff --git a/transformations.py b/transformations.py index 8a2cf999259692bb104d6f94217acfd6b98b6a0c..fb48c4a738d3ace2b603f7f838a8698283af43cc 100644 --- a/transformations.py +++ b/transformations.py @@ -10,8 +10,8 @@ from sympy.tensor import IndexedBase from pystencils.assignment import Assignment from pystencils.assignment_collection.nestedscopes import NestedScopes from pystencils.field import Field, FieldType -from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \ - pointer_arithmetic_func, get_type_of_expression, collate_types, create_type +from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \ + cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type from pystencils.kernelparameters import FieldPointerSymbol from pystencils.slicing import normalize_slice import pystencils.astnodes as ast @@ -427,7 +427,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), if isinstance(get_base_type(field_access.field.dtype), StructType): new_type = field_access.field.dtype.get_element_type(field_access.index[0]) - result = cast_func(result, new_type) + result = reinterpret_cast_func(result, new_type) return visit_sympy_expr(result, enclosing_block, sympy_assignment) else: