From 9f79445e6fd8d7b564570366d3fa8dcb831ea8a8 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 7 Jun 2019 09:38:02 +0200
Subject: [PATCH] Add test for `address_of`

---
 pystencils_tests/test_address_of.py | 58 +++++++++++++++++++++++++++++
 1 file changed, 58 insertions(+)
 create mode 100644 pystencils_tests/test_address_of.py

diff --git a/pystencils_tests/test_address_of.py b/pystencils_tests/test_address_of.py
new file mode 100644
index 000000000..8de48e2bb
--- /dev/null
+++ b/pystencils_tests/test_address_of.py
@@ -0,0 +1,58 @@
+
+"""
+Test of pystencils.data_types.address_of
+"""
+
+from pystencils.data_types import address_of, cast_func, PointerType
+import pystencils
+from pystencils.simp.simplifications import sympy_cse
+import sympy
+
+
+def test_address_of():
+    x, y = pystencils.fields('x,y: int64[2d]')
+    s = pystencils.TypedSymbol('s', PointerType('int64'))
+
+    assignments = pystencils.AssignmentCollection({
+        s: address_of(x[0, 0]),
+        y[0, 0]: cast_func(s, 'int64')
+    }, {})
+
+    ast = pystencils.create_kernel(assignments)
+    code = pystencils.show_code(ast)
+    print(code)
+
+    assignments = pystencils.AssignmentCollection({
+        y[0, 0]: cast_func(address_of(x[0, 0]), 'int64')
+    }, {})
+
+    ast = pystencils.create_kernel(assignments)
+    code = pystencils.show_code(ast)
+    print(code)
+
+
+def test_address_of_with_cse():
+    x, y = pystencils.fields('x,y: int64[2d]')
+    s = pystencils.TypedSymbol('s', PointerType('int64'))
+
+    assignments = pystencils.AssignmentCollection({
+        y[0, 0]: cast_func(address_of(x[0, 0]), 'int64'),
+        x[0, 0]: cast_func(address_of(x[0, 0]), 'int64') + 1
+    }, {})
+
+    ast = pystencils.create_kernel(assignments)
+    code = pystencils.show_code(ast)
+    assignments_cse = sympy_cse(assignments)
+
+    ast = pystencils.create_kernel(assignments_cse)
+    code = pystencils.show_code(ast)
+    print(code)
+
+
+def main():
+    test_address_of()
+    test_address_of_with_cse()
+
+
+if __name__ == '__main__':
+    main()
-- 
GitLab