diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index a6215923d21c32c20338eacf4330350f889aee43..661cabd12456fb11b6bc00f49d7af604930b35c6 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -670,10 +670,10 @@ class DestructuringBindingsForFieldClass(Node): """ CLASS_TO_MEMBER_DICT = { FieldPointerSymbol: "data", - FieldShapeSymbol: "shape", - FieldStrideSymbol: "stride" + FieldShapeSymbol: "shape[%i]", + FieldStrideSymbol: "stride[%i]" } - CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>") + CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>") @property def fields_accessed(self) -> Set['ResolvedFieldAccess']: @@ -682,7 +682,7 @@ class DestructuringBindingsForFieldClass(Node): def __init__(self, body): super(DestructuringBindingsForFieldClass, self).__init__() - self.headers = ['<Field.h>'] + self.headers = ['<PyStencilsField.h>'] self.body = body @property diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index c0b1a763d43e5f4a3f9d43f93c1f7951957efe89..965b5d84bdf19eddec05bdd2d3b9ac83182eaa58 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -6,15 +6,14 @@ import sympy as sp from sympy.core import S from sympy.printing.ccode import C89CodePrinter -from pystencils.astnodes import DestructuringBindingsForFieldClass, KernelFunction, Node +from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import ( - PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, - reinterpret_cast_func, vector_memory_access) + PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func, + vector_memory_access) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.integer_functions import ( - bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, - int_div, int_power_of_2, modulo_ceil) + bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) from pystencils.kernelparameters import FieldPointerSymbol try: @@ -260,14 +259,15 @@ class CBackend: result += "else " + false_block return result - def _print_DestructuringBindingsForFieldClass(self, node: Node): + def _print_DestructuringBindingsForFieldClass(self, node): # Define all undefined symbols undefined_field_symbols = node.symbols_defined - destructuring_bindings = ["%s = %s.%s%s;" % - (u.name, + destructuring_bindings = ["%s %s = %s.%s;" % + (u.dtype, + u.name, u.field_name if hasattr(u, 'field_name') else u.field_names[0], - DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__], - "" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate)) + node.CLASS_TO_MEMBER_DICT[u.__class__] % + (() if type(u) == FieldPointerSymbol else (u.coordinate,))) for u in undefined_field_symbols ] destructuring_bindings.sort() # only for code aesthetics diff --git a/pystencils/include/PyStencilsField.h b/pystencils/include/PyStencilsField.h new file mode 100644 index 0000000000000000000000000000000000000000..3055cae2365279e28fdcaab4353779b97ca27d35 --- /dev/null +++ b/pystencils/include/PyStencilsField.h @@ -0,0 +1,19 @@ +#pragma once + +extern "C++" { +#ifdef __CUDA_ARCH__ +template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField { + DTYPE_T *data; + DTYPE_T shape[DIMENSION]; + DTYPE_T stride[DIMENSION]; +}; +#else +#include <array> + +template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField { + DTYPE_T *data; + std::array<DTYPE_T, DIMENSION> shape; + std::array<DTYPE_T, DIMENSION> stride; +}; +#endif +} diff --git a/pystencils_tests/test_destructuring_field_class.py b/pystencils_tests/test_destructuring_field_class.py index aca1516cb5200f41e65a0d7648d8443fd3e4fb16..6e1eefff9e47cdee1e6c9508796886050bb0505a 100644 --- a/pystencils_tests/test_destructuring_field_class.py +++ b/pystencils_tests/test_destructuring_field_class.py @@ -1,7 +1,11 @@ import sympy +import jinja2 + import pystencils from pystencils.astnodes import DestructuringBindingsForFieldClass +from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol + def test_destructuring_field_class(): @@ -10,15 +14,39 @@ def test_destructuring_field_class(): normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) - ast = pystencils.create_kernel(normal_assignments) + ast = pystencils.create_kernel(normal_assignments, target='gpu') print(pystencils.show_code(ast)) ast.body = DestructuringBindingsForFieldClass(ast.body) print(pystencils.show_code(ast)) + ast.compile() + + +class DestructuringEmojiClass(DestructuringBindingsForFieldClass): + CLASS_TO_MEMBER_DICT = { + FieldPointerSymbol: "🥶", + FieldShapeSymbol: "😳_%i", + FieldStrideSymbol: "🥵_%i" + } + CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>") + def __init__(self, node): + super().__init__(node) + self.headers = [] + + +def test_destructuring_alternative_field_class(): + z, x, y = pystencils.fields("z, y, x: [2d]") + normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( + z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + + ast = pystencils.create_kernel(normal_assignments, target='gpu') + ast.body = DestructuringEmojiClass(ast.body) + print(pystencils.show_code(ast)) def main(): test_destructuring_field_class() + test_destructuring_alternative_field_class() if __name__ == '__main__':