From 8e63c9ff64eaaba4cedbc246b5d200d6f3e2c8de Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 5 Jul 2019 19:39:49 +0200 Subject: [PATCH] Add DestructuringBindingsForFieldClass to use pystencils kernels in a more C++-ish way DestructuringBindingsForFieldClass defines all field-related variables in its subordinated block. However, it leaves a TypedSymbol of type 'Field' for each field undefined. By that trick we can generate kernels that accept structs as kernelparameters. Either to include a pystencils specific Field struct of the following definition: ```cpp template<DTYPE_T, DIMENSION> struct Field { DTYPE_T* data; std::array<DTYPE_T, DIMENSION> shape; std::array<DTYPE_T, DIMENSION> stride; } or to be able to destructure user defined types like `pybind11::array`, `at::Tensor`, `tensorflow::Tensor` ``` --- pystencils/astnodes.py | 59 ++++++++++++++++++- pystencils/backends/cbackend.py | 47 ++++++++++++--- .../test_destructuring_field_class.py | 34 +++++++++++ 3 files changed, 128 insertions(+), 12 deletions(-) create mode 100644 pystencils_tests/test_destructuring_field_class.py diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index b1ed610..a924c85 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -1,9 +1,12 @@ +from typing import Any, List, Optional, Sequence, Set, Union + import sympy as sp + +from pystencils.data_types import TypedSymbol, cast_func, create_type from pystencils.field import Field -from pystencils.data_types import TypedSymbol, create_type, cast_func -from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol +from pystencils.kernelparameters import (FieldPointerSymbol, FieldShapeSymbol, + FieldStrideSymbol) from pystencils.sympyextensions import fast_subs -from typing import List, Set, Optional, Union, Any, Sequence NodeOrExpr = Union['Node', sp.Expr] @@ -130,6 +133,7 @@ class KernelFunction(Node): defined in pystencils.kernelparameters. If the parameter is related to one or multiple fields, these fields are referenced in the fields property. """ + def __init__(self, symbol, fields): self.symbol = symbol # type: TypedSymbol self.fields = fields # type: Sequence[Field] @@ -582,6 +586,7 @@ class TemporaryMemoryAllocation(Node): size: number of elements to allocate align_offset: the align_offset's element is aligned """ + def __init__(self, typed_symbol: TypedSymbol, size, align_offset): super(TemporaryMemoryAllocation, self).__init__(parent=None) self.symbol = typed_symbol @@ -639,3 +644,51 @@ class TemporaryMemoryFree(Node): def early_out(condition): from pystencils.cpu.vectorization import vec_all return Conditional(vec_all(condition), Block([SkipIteration()])) + + +class DestructuringBindingsForFieldClass(Node): + """ + Defines all variables needed for describing a field (shape, pointer, strides) + """ + CLASS_TO_MEMBER_DICT = { + FieldPointerSymbol: "data", + FieldShapeSymbol: "shape", + FieldStrideSymbol: "stride" + } + CLASS_NAME = "Field" + + def __init__(self, body): + super(DestructuringBindingsForFieldClass, self).__init__() + self.headers = ['<Field.h>'] + self.body = body + + @property + def args(self) -> List[NodeOrExpr]: + """Returns all arguments/children of this node.""" + return set() + + @property + def symbols_defined(self) -> Set[sp.Symbol]: + """Set of symbols which are defined by this node.""" + undefined_field_symbols = {s for s in self.body.undefined_symbols + if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))} + return undefined_field_symbols + + @property + def undefined_symbols(self) -> Set[sp.Symbol]: + undefined_field_symbols = self.symbols_defined + corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} + corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} + return {TypedSymbol(f, self.CLASS_NAME + '&') for f in corresponding_field_names} | \ + (self.body.undefined_symbols - undefined_field_symbols) + + def subs(self, subs_dict) -> None: + """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" + self.body.subs(subs_dict) + + @property + def func(self): + return self.__class__ + + def atoms(self, arg_type) -> Set[Any]: + return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)} diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 1c28e32..f95995c 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -1,24 +1,30 @@ -import sympy as sp from collections import namedtuple -from sympy.core import S from typing import Set + +import jinja2 +import sympy as sp +from sympy.core import S from sympy.printing.ccode import C89CodePrinter -from pystencils.cpu.vectorization import vec_any, vec_all +from pystencils.astnodes import (DestructuringBindingsForFieldClass, + KernelFunction, Node) +from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import (PointerType, VectorType, address_of, - cast_func, create_type, reinterpret_cast_func, + cast_func, create_type, get_type_of_expression, - vector_memory_access) -from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt + reinterpret_cast_func, vector_memory_access) +from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, + fast_sqrt) +from pystencils.integer_functions import (bit_shift_left, bit_shift_right, + bitwise_and, bitwise_or, bitwise_xor, + int_div, int_power_of_2, modulo_ceil) +from pystencils.kernelparameters import FieldPointerSymbol try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter except ImportError: from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 -from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ - bitwise_or, modulo_ceil, int_div, int_power_of_2 -from pystencils.astnodes import Node, KernelFunction __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] @@ -255,6 +261,29 @@ class CBackend: result += "else " + false_block return result + def _print_DestructuringBindingsForFieldClass(self, node: Node): + # Define all undefined symbols + undefined_field_symbols = node.symbols_defined + destructuring_bindings = ["%s = %s.%s%s;" % + (u.name, + u.field_name if hasattr(u, 'field_name') else u.field_names[0], + DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__], + "" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate)) + for u in undefined_field_symbols + ] + template = jinja2.Template( + """{ + {% for binding in bindings -%} + {{ binding | indent(3) }} + {% endfor -%} + {{ block | indent(3) }} +} + +""") + code = template.render(bindings=destructuring_bindings, + block=self._print(node.body)) + return code + # ------------------------------------------ Helper function & classes ------------------------------------------------- diff --git a/pystencils_tests/test_destructuring_field_class.py b/pystencils_tests/test_destructuring_field_class.py new file mode 100644 index 0000000..248963a --- /dev/null +++ b/pystencils_tests/test_destructuring_field_class.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" +import sympy + +import pystencils +from pystencils.astnodes import DestructuringBindingsForFieldClass + + +def test_destructuring_field_class(): + z, x, y = pystencils.fields("z, y, x: [2d]") + + normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( + z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + + ast = pystencils.create_kernel(normal_assignments) + print(pystencils.show_code(ast)) + + ast.body = DestructuringBindingsForFieldClass(ast.body) + print(pystencils.show_code(ast)) + + +def main(): + test_destructuring_field_class() + + +if __name__ == '__main__': + main() -- GitLab