Commit 61800b73 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils type system: distinction between static and reinterpret cast

parent 2d925329
......@@ -12,7 +12,7 @@ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift
bitwise_or, modulo_ceil
from pystencils.astnodes import Node, KernelFunction
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
vector_memory_access
vector_memory_access, reinterpret_cast_func
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
......@@ -251,7 +251,10 @@ class CustomSympyPrinter(CCodePrinter):
}
if hasattr(expr, 'to_c'):
return expr.to_c(self._print)
if isinstance(expr, cast_func):
if isinstance(expr, reinterpret_cast_func):
arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
elif isinstance(expr, cast_func):
arg, data_type = expr.args
if isinstance(arg, sp.Number):
return self._typed_number(arg, data_type)
......
......@@ -56,6 +56,11 @@ class vector_memory_access(cast_func):
nargs = (4,)
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
pass
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
@property
......
......@@ -10,8 +10,8 @@ from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment
from pystencils.assignment_collection.nestedscopes import NestedScopes
from pystencils.field import Field, FieldType
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.slicing import normalize_slice
import pystencils.astnodes as ast
......@@ -427,7 +427,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if isinstance(get_base_type(field_access.field.dtype), StructType):
new_type = field_access.field.dtype.get_element_type(field_access.index[0])
result = cast_func(result, new_type)
result = reinterpret_cast_func(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment