diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 92a6080c73389a815163157efefac2186aeee09e..8eb44766199e2aaad990bab167fc9267bc4017c9 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -5,7 +5,11 @@ from typing import Set from sympy.printing.ccode import C89CodePrinter from pystencils.cpu.vectorization import vec_any, vec_all -from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt +from pystencils.data_types import (PointerType, VectorType, address_of, + cast_func, create_type, reinterpret_cast_func, + get_type_of_expression, + vector_memory_access) +from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter @@ -15,8 +19,6 @@ except ImportError: from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ bitwise_or, modulo_ceil, int_div, int_power_of_2 from pystencils.astnodes import Node, KernelFunction -from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ - vector_memory_access, reinterpret_cast_func __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] @@ -276,6 +278,9 @@ class CustomSympyPrinter(CCodePrinter): if isinstance(expr, reinterpret_cast_func): arg, data_type = expr.args return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) + elif isinstance(expr, address_of): + assert len(expr.args) == 1, "address_of must only have one argument" + return "&(%s)" % self._print(expr.args[0]) elif isinstance(expr, cast_func): arg, data_type = expr.args if isinstance(arg, sp.Number): diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 7bdc9d340664fa20178ec941cc9e5305d99fd02c..93ef5c1c798df3b0316fdd4232462e13a9415538 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -14,6 +14,33 @@ from pystencils.utils import all_equal from sympy.logic.boolalg import Boolean +# noinspection PyPep8Naming +class address_of(sp.Function): + is_Atom = True + + def __new__(cls, arg): + obj = sp.Function.__new__(cls, arg) + return obj + + @property + def canonical(self): + if hasattr(self.args[0], 'canonical'): + return self.args[0].canonical + else: + raise NotImplementedError() + + @property + def is_commutative(self): + return self.args[0].is_commutative + + @property + def dtype(self): + if hasattr(self.args[0], 'dtype'): + return PointerType(self.args[0].dtype, restrict=True) + else: + return PointerType('void', restrict=True) + + # noinspection PyPep8Naming class cast_func(sp.Function): is_Atom = True diff --git a/pystencils_tests/test_address_of.py b/pystencils_tests/test_address_of.py new file mode 100644 index 0000000000000000000000000000000000000000..8de48e2bb0783dc66452a0bf31a7a7529d8d8292 --- /dev/null +++ b/pystencils_tests/test_address_of.py @@ -0,0 +1,58 @@ + +""" +Test of pystencils.data_types.address_of +""" + +from pystencils.data_types import address_of, cast_func, PointerType +import pystencils +from pystencils.simp.simplifications import sympy_cse +import sympy + + +def test_address_of(): + x, y = pystencils.fields('x,y: int64[2d]') + s = pystencils.TypedSymbol('s', PointerType('int64')) + + assignments = pystencils.AssignmentCollection({ + s: address_of(x[0, 0]), + y[0, 0]: cast_func(s, 'int64') + }, {}) + + ast = pystencils.create_kernel(assignments) + code = pystencils.show_code(ast) + print(code) + + assignments = pystencils.AssignmentCollection({ + y[0, 0]: cast_func(address_of(x[0, 0]), 'int64') + }, {}) + + ast = pystencils.create_kernel(assignments) + code = pystencils.show_code(ast) + print(code) + + +def test_address_of_with_cse(): + x, y = pystencils.fields('x,y: int64[2d]') + s = pystencils.TypedSymbol('s', PointerType('int64')) + + assignments = pystencils.AssignmentCollection({ + y[0, 0]: cast_func(address_of(x[0, 0]), 'int64'), + x[0, 0]: cast_func(address_of(x[0, 0]), 'int64') + 1 + }, {}) + + ast = pystencils.create_kernel(assignments) + code = pystencils.show_code(ast) + assignments_cse = sympy_cse(assignments) + + ast = pystencils.create_kernel(assignments_cse) + code = pystencils.show_code(ast) + print(code) + + +def main(): + test_address_of() + test_address_of_with_cse() + + +if __name__ == '__main__': + main()