diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index b1ed610e1a9262e407d65e3678d8b12d261845f4..2d3174a1a564bfd00633d36150d8119f7dacaec5 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -1,9 +1,13 @@ +from typing import Any, List, Optional, Sequence, Set, Union + +import jinja2 import sympy as sp + +from pystencils.data_types import TypedSymbol, cast_func, create_type from pystencils.field import Field -from pystencils.data_types import TypedSymbol, create_type, cast_func -from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol +from pystencils.kernelparameters import (FieldPointerSymbol, FieldShapeSymbol, + FieldStrideSymbol) from pystencils.sympyextensions import fast_subs -from typing import List, Set, Optional, Union, Any, Sequence NodeOrExpr = Union['Node', sp.Expr] @@ -130,6 +134,7 @@ class KernelFunction(Node): defined in pystencils.kernelparameters. If the parameter is related to one or multiple fields, these fields are referenced in the fields property. """ + def __init__(self, symbol, fields): self.symbol = symbol # type: TypedSymbol self.fields = fields # type: Sequence[Field] @@ -582,6 +587,7 @@ class TemporaryMemoryAllocation(Node): size: number of elements to allocate align_offset: the align_offset's element is aligned """ + def __init__(self, typed_symbol: TypedSymbol, size, align_offset): super(TemporaryMemoryAllocation, self).__init__(parent=None) self.symbol = typed_symbol @@ -639,3 +645,58 @@ class TemporaryMemoryFree(Node): def early_out(condition): from pystencils.cpu.vectorization import vec_all return Conditional(vec_all(condition), Block([SkipIteration()])) + + +class DestructuringBindingsForFieldClass(Node): + """ + Defines all variables needed for describing a field (shape, pointer, strides) + """ + CLASS_TO_MEMBER_DICT = { + FieldPointerSymbol: "data", + FieldShapeSymbol: "shape", + FieldStrideSymbol: "stride" + } + CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>") + + @property + def fields_accessed(self) -> Set['ResolvedFieldAccess']: + """Set of Field instances: fields which are accessed inside this kernel function""" + return set(o.field for o in self.atoms(ResolvedFieldAccess)) + + def __init__(self, body): + super(DestructuringBindingsForFieldClass, self).__init__() + self.headers = ['<Field.h>'] + self.body = body + + @property + def args(self) -> List[NodeOrExpr]: + """Returns all arguments/children of this node.""" + return set() + + @property + def symbols_defined(self) -> Set[sp.Symbol]: + """Set of symbols which are defined by this node.""" + undefined_field_symbols = {s for s in self.body.undefined_symbols + if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))} + return undefined_field_symbols + + @property + def undefined_symbols(self) -> Set[sp.Symbol]: + field_map = {f.name: f for f in self.fields_accessed} + undefined_field_symbols = self.symbols_defined + corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} + corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} + return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') + for f in corresponding_field_names} | \ + (self.body.undefined_symbols - undefined_field_symbols) + + def subs(self, subs_dict) -> None: + """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" + self.body.subs(subs_dict) + + @property + def func(self): + return self.__class__ + + def atoms(self, arg_type) -> Set[Any]: + return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)} diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 1c28e324981d7bc4f5ea3ee310ca4ede0ee4e5f8..7c4937d1f524fdd76299a0f86b97bae4377eab16 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -1,24 +1,30 @@ -import sympy as sp from collections import namedtuple -from sympy.core import S from typing import Set + +import jinja2 +import sympy as sp +from sympy.core import S from sympy.printing.ccode import C89CodePrinter -from pystencils.cpu.vectorization import vec_any, vec_all +from pystencils.astnodes import (DestructuringBindingsForFieldClass, + KernelFunction, Node) +from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import (PointerType, VectorType, address_of, - cast_func, create_type, reinterpret_cast_func, + cast_func, create_type, get_type_of_expression, - vector_memory_access) -from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt + reinterpret_cast_func, vector_memory_access) +from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, + fast_sqrt) +from pystencils.integer_functions import (bit_shift_left, bit_shift_right, + bitwise_and, bitwise_or, bitwise_xor, + int_div, int_power_of_2, modulo_ceil) +from pystencils.kernelparameters import FieldPointerSymbol try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter except ImportError: from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 -from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ - bitwise_or, modulo_ceil, int_div, int_power_of_2 -from pystencils.astnodes import Node, KernelFunction __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] @@ -255,6 +261,30 @@ class CBackend: result += "else " + false_block return result + def _print_DestructuringBindingsForFieldClass(self, node: Node): + # Define all undefined symbols + undefined_field_symbols = node.symbols_defined + destructuring_bindings = ["%s = %s.%s%s;" % + (u.name, + u.field_name if hasattr(u, 'field_name') else u.field_names[0], + DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__], + "" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate)) + for u in undefined_field_symbols + ] + destructuring_bindings.sort() # only for code aesthetics + template = jinja2.Template( + """{ + {% for binding in bindings -%} + {{ binding | indent(3) }} + {% endfor -%} + {{ block | indent(3) }} +} + +""") + code = template.render(bindings=destructuring_bindings, + block=self._print(node.body)) + return code + # ------------------------------------------ Helper function & classes ------------------------------------------------- diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 93ef5c1c798df3b0316fdd4232462e13a9415538..3f1a02c6833756985308282c49148fe0ecaa0153 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -108,7 +108,7 @@ class TypedSymbol(sp.Symbol): obj = super(TypedSymbol, cls).__xnew__(cls, name) try: obj._dtype = create_type(dtype) - except TypeError: + except (TypeError, ValueError): # on error keep the string obj._dtype = dtype return obj diff --git a/pystencils/field.py b/pystencils/field.py index c1720654867fe5b86ad321e1a1c6b78ec9b328f6..0c3af6dfa4edec15cb0f1cdff4822083eaacc017 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -306,6 +306,10 @@ class Field(AbstractField): def index_dimensions(self) -> int: return len(self.shape) - len(self._layout) + @property + def ndim(self) -> int: + return len(self.shape) + @property def layout(self): return self._layout diff --git a/pystencils_tests/test_destructuring_field_class.py b/pystencils_tests/test_destructuring_field_class.py new file mode 100644 index 0000000000000000000000000000000000000000..248963ae3aab3a0fb10e5965478e174ccecb39f8 --- /dev/null +++ b/pystencils_tests/test_destructuring_field_class.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" +import sympy + +import pystencils +from pystencils.astnodes import DestructuringBindingsForFieldClass + + +def test_destructuring_field_class(): + z, x, y = pystencils.fields("z, y, x: [2d]") + + normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( + z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + + ast = pystencils.create_kernel(normal_assignments) + print(pystencils.show_code(ast)) + + ast.body = DestructuringBindingsForFieldClass(ast.body) + print(pystencils.show_code(ast)) + + +def main(): + test_destructuring_field_class() + + +if __name__ == '__main__': + main()