From e09f0646550ba4e262d66266bd422670052f57b2 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 18 Jul 2019 10:13:11 +0200 Subject: [PATCH] Fixup for DestructuringBindingsForFieldClass - rename header Field.h is not a unique name in waLBerla context - add PyStencilsField.h - bindings were lacking data type --- pystencils/astnodes.py | 8 ++--- pystencils/backends/cbackend.py | 20 ++++++------- pystencils/include/PyStencilsField.h | 19 ++++++++++++ .../test_destructuring_field_class.py | 30 ++++++++++++++++++- 4 files changed, 62 insertions(+), 15 deletions(-) create mode 100644 pystencils/include/PyStencilsField.h diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index a6215923d..661cabd12 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -670,10 +670,10 @@ class DestructuringBindingsForFieldClass(Node): """ CLASS_TO_MEMBER_DICT = { FieldPointerSymbol: "data", - FieldShapeSymbol: "shape", - FieldStrideSymbol: "stride" + FieldShapeSymbol: "shape[%i]", + FieldStrideSymbol: "stride[%i]" } - CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>") + CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>") @property def fields_accessed(self) -> Set['ResolvedFieldAccess']: @@ -682,7 +682,7 @@ class DestructuringBindingsForFieldClass(Node): def __init__(self, body): super(DestructuringBindingsForFieldClass, self).__init__() - self.headers = ['<Field.h>'] + self.headers = ['<PyStencilsField.h>'] self.body = body @property diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index c0b1a763d..965b5d84b 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -6,15 +6,14 @@ import sympy as sp from sympy.core import S from sympy.printing.ccode import C89CodePrinter -from pystencils.astnodes import DestructuringBindingsForFieldClass, KernelFunction, Node +from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import ( - PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, - reinterpret_cast_func, vector_memory_access) + PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func, + vector_memory_access) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.integer_functions import ( - bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, - int_div, int_power_of_2, modulo_ceil) + bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) from pystencils.kernelparameters import FieldPointerSymbol try: @@ -260,14 +259,15 @@ class CBackend: result += "else " + false_block return result - def _print_DestructuringBindingsForFieldClass(self, node: Node): + def _print_DestructuringBindingsForFieldClass(self, node): # Define all undefined symbols undefined_field_symbols = node.symbols_defined - destructuring_bindings = ["%s = %s.%s%s;" % - (u.name, + destructuring_bindings = ["%s %s = %s.%s;" % + (u.dtype, + u.name, u.field_name if hasattr(u, 'field_name') else u.field_names[0], - DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__], - "" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate)) + node.CLASS_TO_MEMBER_DICT[u.__class__] % + (() if type(u) == FieldPointerSymbol else (u.coordinate,))) for u in undefined_field_symbols ] destructuring_bindings.sort() # only for code aesthetics diff --git a/pystencils/include/PyStencilsField.h b/pystencils/include/PyStencilsField.h new file mode 100644 index 000000000..3055cae23 --- /dev/null +++ b/pystencils/include/PyStencilsField.h @@ -0,0 +1,19 @@ +#pragma once + +extern "C++" { +#ifdef __CUDA_ARCH__ +template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField { + DTYPE_T *data; + DTYPE_T shape[DIMENSION]; + DTYPE_T stride[DIMENSION]; +}; +#else +#include <array> + +template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField { + DTYPE_T *data; + std::array<DTYPE_T, DIMENSION> shape; + std::array<DTYPE_T, DIMENSION> stride; +}; +#endif +} diff --git a/pystencils_tests/test_destructuring_field_class.py b/pystencils_tests/test_destructuring_field_class.py index aca1516cb..6e1eefff9 100644 --- a/pystencils_tests/test_destructuring_field_class.py +++ b/pystencils_tests/test_destructuring_field_class.py @@ -1,7 +1,11 @@ import sympy +import jinja2 + import pystencils from pystencils.astnodes import DestructuringBindingsForFieldClass +from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol + def test_destructuring_field_class(): @@ -10,15 +14,39 @@ def test_destructuring_field_class(): normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) - ast = pystencils.create_kernel(normal_assignments) + ast = pystencils.create_kernel(normal_assignments, target='gpu') print(pystencils.show_code(ast)) ast.body = DestructuringBindingsForFieldClass(ast.body) print(pystencils.show_code(ast)) + ast.compile() + + +class DestructuringEmojiClass(DestructuringBindingsForFieldClass): + CLASS_TO_MEMBER_DICT = { + FieldPointerSymbol: "🥶", + FieldShapeSymbol: "😳_%i", + FieldStrideSymbol: "🥵_%i" + } + CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>") + def __init__(self, node): + super().__init__(node) + self.headers = [] + + +def test_destructuring_alternative_field_class(): + z, x, y = pystencils.fields("z, y, x: [2d]") + normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( + z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + + ast = pystencils.create_kernel(normal_assignments, target='gpu') + ast.body = DestructuringEmojiClass(ast.body) + print(pystencils.show_code(ast)) def main(): test_destructuring_field_class() + test_destructuring_alternative_field_class() if __name__ == '__main__': -- GitLab