diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index a0ca26fdd9efb9be827ef5ffbf2deed9d5a79a1f..67ad6869fa6b98b9b807940a742afb7d9544097a 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -676,59 +676,5 @@ def early_out(condition): return Conditional(vec_all(condition), Block([SkipIteration()])) -class DestructuringBindingsForFieldClass(Node): - """ - Defines all variables needed for describing a field (shape, pointer, strides) - """ - CLASS_TO_MEMBER_DICT = { - FieldPointerSymbol: "data", - FieldShapeSymbol: "shape[%i]", - FieldStrideSymbol: "stride[%i]" - } - CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>" - - @property - def fields_accessed(self) -> Set['ResolvedFieldAccess']: - """Set of Field instances: fields which are accessed inside this kernel function""" - return set(o.field for o in self.atoms(ResolvedFieldAccess)) - - def __init__(self, body): - super(DestructuringBindingsForFieldClass, self).__init__() - self.headers = ['<PyStencilsField.h>'] - self.body = body - - @property - def args(self) -> List[NodeOrExpr]: - """Returns all arguments/children of this node.""" - return set() - - @property - def symbols_defined(self) -> Set[sp.Symbol]: - """Set of symbols which are defined by this node.""" - undefined_field_symbols = {s for s in self.body.undefined_symbols - if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))} - return undefined_field_symbols - - @property - def undefined_symbols(self) -> Set[sp.Symbol]: - field_map = {f.name: f for f in self.fields_accessed} - undefined_field_symbols = self.symbols_defined - corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} - corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} - return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') - for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols) - - def subs(self, subs_dict) -> None: - """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" - self.body.subs(subs_dict) - - @property - def func(self): - return self.__class__ - - def atoms(self, arg_type) -> Set[Any]: - return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)} - - def get_dummy_symbol(dtype='bool'): return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype)) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 197e21d8140f82b0f4d92aad47975f79885884fc..ec5a930fe9cba4fbe71b8cb5ff3086d4be216073 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -15,7 +15,6 @@ from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqr from pystencils.integer_functions import ( bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) -from pystencils.kernelparameters import FieldPointerSymbol try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter @@ -276,24 +275,6 @@ class CBackend: result += "else " + false_block return result - def _print_DestructuringBindingsForFieldClass(self, node): - # Define all undefined symbols - undefined_field_symbols = node.symbols_defined - destructuring_bindings = ["%s %s = %s.%s;" % - (u.dtype, - u.name, - u.field_name if hasattr(u, 'field_name') else u.field_names[0], - node.CLASS_TO_MEMBER_DICT[u.__class__] % - (() if type(u) == FieldPointerSymbol else (u.coordinate,))) - for u in undefined_field_symbols - ] - destructuring_bindings.sort() # only for code aesthetics - return "{\n" + self._indent + \ - ("\n" + self._indent).join(destructuring_bindings) + \ - "\n" + self._indent + \ - ("\n" + self._indent).join(self._print(node.body).splitlines()) + \ - "\n}" - # ------------------------------------------ Helper function & classes ------------------------------------------------- diff --git a/pystencils_tests/test_destructuring_field_class.py b/pystencils_tests/test_destructuring_field_class.py deleted file mode 100644 index 4b9faf63277e7c00f469b10ea59f7d629baae41e..0000000000000000000000000000000000000000 --- a/pystencils_tests/test_destructuring_field_class.py +++ /dev/null @@ -1,43 +0,0 @@ -import sympy - -import pystencils -from pystencils.astnodes import DestructuringBindingsForFieldClass -from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol - - -def test_destructuring_field_class(): - z, x, y = pystencils.fields("z, y, x: [2d]") - - normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( - z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) - - ast = pystencils.create_kernel(normal_assignments, target='gpu') - print(pystencils.show_code(ast)) - - ast.body = DestructuringBindingsForFieldClass(ast.body) - print(pystencils.show_code(ast)) - ast.compile() - - -class DestructuringEmojiClass(DestructuringBindingsForFieldClass): - CLASS_TO_MEMBER_DICT = { - FieldPointerSymbol: "🥶", - FieldShapeSymbol: "😳_%i", - FieldStrideSymbol: "🥵_%i" - } - CLASS_NAME_TEMPLATE = "🤯<{dtype}, {ndim}>" - - def __init__(self, node): - super().__init__(node) - self.headers = [] - - -def test_destructuring_alternative_field_class(): - z, x, y = pystencils.fields("z, y, x: [2d]") - - normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( - z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) - - ast = pystencils.create_kernel(normal_assignments, target='gpu') - ast.body = DestructuringEmojiClass(ast.body) - print(pystencils.show_code(ast))