Commit a30d5181 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'address_of-function' into 'master'

Address of SymPy-Function `address_of`

See merge request pycodegen/pystencils!1
parents 1754ef27 9f79445e
...@@ -5,7 +5,11 @@ from typing import Set ...@@ -5,7 +5,11 @@ from typing import Set
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
from pystencils.cpu.vectorization import vec_any, vec_all from pystencils.cpu.vectorization import vec_any, vec_all
from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt from pystencils.data_types import (PointerType, VectorType, address_of,
cast_func, create_type, reinterpret_cast_func,
get_type_of_expression,
vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
try: try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter from sympy.printing.ccode import C99CodePrinter as CCodePrinter
...@@ -15,8 +19,6 @@ except ImportError: ...@@ -15,8 +19,6 @@ except ImportError:
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_ceil, int_div, int_power_of_2 bitwise_or, modulo_ceil, int_div, int_power_of_2
from pystencils.astnodes import Node, KernelFunction from pystencils.astnodes import Node, KernelFunction
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
vector_memory_access, reinterpret_cast_func
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
...@@ -276,6 +278,9 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -276,6 +278,9 @@ class CustomSympyPrinter(CCodePrinter):
if isinstance(expr, reinterpret_cast_func): if isinstance(expr, reinterpret_cast_func):
arg, data_type = expr.args arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
elif isinstance(expr, address_of):
assert len(expr.args) == 1, "address_of must only have one argument"
return "&(%s)" % self._print(expr.args[0])
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
arg, data_type = expr.args arg, data_type = expr.args
if isinstance(arg, sp.Number): if isinstance(arg, sp.Number):
......
...@@ -14,6 +14,33 @@ from pystencils.utils import all_equal ...@@ -14,6 +14,33 @@ from pystencils.utils import all_equal
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
# noinspection PyPep8Naming
class address_of(sp.Function):
is_Atom = True
def __new__(cls, arg):
obj = sp.Function.__new__(cls, arg)
return obj
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
@property
def dtype(self):
if hasattr(self.args[0], 'dtype'):
return PointerType(self.args[0].dtype, restrict=True)
else:
return PointerType('void', restrict=True)
# noinspection PyPep8Naming # noinspection PyPep8Naming
class cast_func(sp.Function): class cast_func(sp.Function):
is_Atom = True is_Atom = True
......
"""
Test of pystencils.data_types.address_of
"""
from pystencils.data_types import address_of, cast_func, PointerType
import pystencils
from pystencils.simp.simplifications import sympy_cse
import sympy
def test_address_of():
x, y = pystencils.fields('x,y: int64[2d]')
s = pystencils.TypedSymbol('s', PointerType('int64'))
assignments = pystencils.AssignmentCollection({
s: address_of(x[0, 0]),
y[0, 0]: cast_func(s, 'int64')
}, {})
ast = pystencils.create_kernel(assignments)
code = pystencils.show_code(ast)
print(code)
assignments = pystencils.AssignmentCollection({
y[0, 0]: cast_func(address_of(x[0, 0]), 'int64')
}, {})
ast = pystencils.create_kernel(assignments)
code = pystencils.show_code(ast)
print(code)
def test_address_of_with_cse():
x, y = pystencils.fields('x,y: int64[2d]')
s = pystencils.TypedSymbol('s', PointerType('int64'))
assignments = pystencils.AssignmentCollection({
y[0, 0]: cast_func(address_of(x[0, 0]), 'int64'),
x[0, 0]: cast_func(address_of(x[0, 0]), 'int64') + 1
}, {})
ast = pystencils.create_kernel(assignments)
code = pystencils.show_code(ast)
assignments_cse = sympy_cse(assignments)
ast = pystencils.create_kernel(assignments_cse)
code = pystencils.show_code(ast)
print(code)
def main():
test_address_of()
test_address_of_with_cse()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment