Commit 8e63c9ff authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add DestructuringBindingsForFieldClass to use pystencils kernels in a more C++-ish way

DestructuringBindingsForFieldClass defines all field-related variables
in its subordinated block.
However, it leaves a TypedSymbol of type 'Field' for each field
undefined.
By that trick we can generate kernels that accept structs as
kernelparameters.
Either to include a pystencils specific Field struct of the following
definition:

```cpp
template<DTYPE_T, DIMENSION>
struct Field
{
    DTYPE_T* data;
    std::array<DTYPE_T, DIMENSION> shape;
    std::array<DTYPE_T, DIMENSION> stride;
}

or to be able to destructure user defined types like `pybind11::array`,
`at::Tensor`, `tensorflow::Tensor`

```
parent 8c4a6f1e
from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type
from pystencils.field import Field from pystencils.field import Field
from pystencils.data_types import TypedSymbol, create_type, cast_func from pystencils.kernelparameters import (FieldPointerSymbol, FieldShapeSymbol,
from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol FieldStrideSymbol)
from pystencils.sympyextensions import fast_subs from pystencils.sympyextensions import fast_subs
from typing import List, Set, Optional, Union, Any, Sequence
NodeOrExpr = Union['Node', sp.Expr] NodeOrExpr = Union['Node', sp.Expr]
...@@ -130,6 +133,7 @@ class KernelFunction(Node): ...@@ -130,6 +133,7 @@ class KernelFunction(Node):
defined in pystencils.kernelparameters. defined in pystencils.kernelparameters.
If the parameter is related to one or multiple fields, these fields are referenced in the fields property. If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
""" """
def __init__(self, symbol, fields): def __init__(self, symbol, fields):
self.symbol = symbol # type: TypedSymbol self.symbol = symbol # type: TypedSymbol
self.fields = fields # type: Sequence[Field] self.fields = fields # type: Sequence[Field]
...@@ -582,6 +586,7 @@ class TemporaryMemoryAllocation(Node): ...@@ -582,6 +586,7 @@ class TemporaryMemoryAllocation(Node):
size: number of elements to allocate size: number of elements to allocate
align_offset: the align_offset's element is aligned align_offset: the align_offset's element is aligned
""" """
def __init__(self, typed_symbol: TypedSymbol, size, align_offset): def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
super(TemporaryMemoryAllocation, self).__init__(parent=None) super(TemporaryMemoryAllocation, self).__init__(parent=None)
self.symbol = typed_symbol self.symbol = typed_symbol
...@@ -639,3 +644,51 @@ class TemporaryMemoryFree(Node): ...@@ -639,3 +644,51 @@ class TemporaryMemoryFree(Node):
def early_out(condition): def early_out(condition):
from pystencils.cpu.vectorization import vec_all from pystencils.cpu.vectorization import vec_all
return Conditional(vec_all(condition), Block([SkipIteration()])) return Conditional(vec_all(condition), Block([SkipIteration()]))
class DestructuringBindingsForFieldClass(Node):
"""
Defines all variables needed for describing a field (shape, pointer, strides)
"""
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "data",
FieldShapeSymbol: "shape",
FieldStrideSymbol: "stride"
}
CLASS_NAME = "Field"
def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__()
self.headers = ['<Field.h>']
self.body = body
@property
def args(self) -> List[NodeOrExpr]:
"""Returns all arguments/children of this node."""
return set()
@property
def symbols_defined(self) -> Set[sp.Symbol]:
"""Set of symbols which are defined by this node."""
undefined_field_symbols = {s for s in self.body.undefined_symbols
if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))}
return undefined_field_symbols
@property
def undefined_symbols(self) -> Set[sp.Symbol]:
undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME + '&') for f in corresponding_field_names} | \
(self.body.undefined_symbols - undefined_field_symbols)
def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
self.body.subs(subs_dict)
@property
def func(self):
return self.__class__
def atoms(self, arg_type) -> Set[Any]:
return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)}
import sympy as sp
from collections import namedtuple from collections import namedtuple
from sympy.core import S
from typing import Set from typing import Set
import jinja2
import sympy as sp
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
from pystencils.cpu.vectorization import vec_any, vec_all from pystencils.astnodes import (DestructuringBindingsForFieldClass,
KernelFunction, Node)
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (PointerType, VectorType, address_of, from pystencils.data_types import (PointerType, VectorType, address_of,
cast_func, create_type, reinterpret_cast_func, cast_func, create_type,
get_type_of_expression, get_type_of_expression,
vector_memory_access) reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
fast_sqrt)
from pystencils.integer_functions import (bit_shift_left, bit_shift_right,
bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
try: try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError: except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_ceil, int_div, int_power_of_2
from pystencils.astnodes import Node, KernelFunction
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
...@@ -255,6 +261,29 @@ class CBackend: ...@@ -255,6 +261,29 @@ class CBackend:
result += "else " + false_block result += "else " + false_block
return result return result
def _print_DestructuringBindingsForFieldClass(self, node: Node):
# Define all undefined symbols
undefined_field_symbols = node.symbols_defined
destructuring_bindings = ["%s = %s.%s%s;" %
(u.name,
u.field_name if hasattr(u, 'field_name') else u.field_names[0],
DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__],
"" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate))
for u in undefined_field_symbols
]
template = jinja2.Template(
"""{
{% for binding in bindings -%}
{{ binding | indent(3) }}
{% endfor -%}
{{ block | indent(3) }}
}
""")
code = template.render(bindings=destructuring_bindings,
block=self._print(node.body))
return code
# ------------------------------------------ Helper function & classes ------------------------------------------------- # ------------------------------------------ Helper function & classes -------------------------------------------------
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy
import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass
def test_destructuring_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments)
print(pystencils.show_code(ast))
ast.body = DestructuringBindingsForFieldClass(ast.body)
print(pystencils.show_code(ast))
def main():
test_destructuring_field_class()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment