diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 92a6080c73389a815163157efefac2186aeee09e..8eb44766199e2aaad990bab167fc9267bc4017c9 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -5,7 +5,11 @@ from typing import Set from sympy.printing.ccode import C89CodePrinter from pystencils.cpu.vectorization import vec_any, vec_all -from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt +from pystencils.data_types import (PointerType, VectorType, address_of, + cast_func, create_type, reinterpret_cast_func, + get_type_of_expression, + vector_memory_access) +from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter @@ -15,8 +19,6 @@ except ImportError: from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ bitwise_or, modulo_ceil, int_div, int_power_of_2 from pystencils.astnodes import Node, KernelFunction -from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ - vector_memory_access, reinterpret_cast_func __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] @@ -276,6 +278,9 @@ class CustomSympyPrinter(CCodePrinter): if isinstance(expr, reinterpret_cast_func): arg, data_type = expr.args return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) + elif isinstance(expr, address_of): + assert len(expr.args) == 1, "address_of must only have one argument" + return "&(%s)" % self._print(expr.args[0]) elif isinstance(expr, cast_func): arg, data_type = expr.args if isinstance(arg, sp.Number):