Commit 8e63c9ff authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add DestructuringBindingsForFieldClass to use pystencils kernels in a more C++-ish way

DestructuringBindingsForFieldClass defines all field-related variables
in its subordinated block.
However, it leaves a TypedSymbol of type 'Field' for each field
undefined.
By that trick we can generate kernels that accept structs as
kernelparameters.
Either to include a pystencils specific Field struct of the following
definition:

```cpp
template<DTYPE_T, DIMENSION>
struct Field
{
    DTYPE_T* data;
    std::array<DTYPE_T, DIMENSION> shape;
    std::array<DTYPE_T, DIMENSION> stride;
}

or to be able to destructure user defined types like `pybind11::array`,
`at::Tensor`, `tensorflow::Tensor`

```
parent 8c4a6f1e
from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type
from pystencils.field import Field
from pystencils.data_types import TypedSymbol, create_type, cast_func
from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol
from pystencils.kernelparameters import (FieldPointerSymbol, FieldShapeSymbol,
FieldStrideSymbol)
from pystencils.sympyextensions import fast_subs
from typing import List, Set, Optional, Union, Any, Sequence
NodeOrExpr = Union['Node', sp.Expr]
......@@ -130,6 +133,7 @@ class KernelFunction(Node):
defined in pystencils.kernelparameters.
If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
"""
def __init__(self, symbol, fields):
self.symbol = symbol # type: TypedSymbol
self.fields = fields # type: Sequence[Field]
......@@ -582,6 +586,7 @@ class TemporaryMemoryAllocation(Node):
size: number of elements to allocate
align_offset: the align_offset's element is aligned
"""
def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
super(TemporaryMemoryAllocation, self).__init__(parent=None)
self.symbol = typed_symbol
......@@ -639,3 +644,51 @@ class TemporaryMemoryFree(Node):
def early_out(condition):
from pystencils.cpu.vectorization import vec_all
return Conditional(vec_all(condition), Block([SkipIteration()]))
class DestructuringBindingsForFieldClass(Node):
"""
Defines all variables needed for describing a field (shape, pointer, strides)
"""
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "data",
FieldShapeSymbol: "shape",
FieldStrideSymbol: "stride"
}
CLASS_NAME = "Field"
def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__()
self.headers = ['<Field.h>']
self.body = body
@property
def args(self) -> List[NodeOrExpr]:
"""Returns all arguments/children of this node."""
return set()
@property
def symbols_defined(self) -> Set[sp.Symbol]:
"""Set of symbols which are defined by this node."""
undefined_field_symbols = {s for s in self.body.undefined_symbols
if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))}
return undefined_field_symbols
@property
def undefined_symbols(self) -> Set[sp.Symbol]:
undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME + '&') for f in corresponding_field_names} | \
(self.body.undefined_symbols - undefined_field_symbols)
def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
self.body.subs(subs_dict)
@property
def func(self):
return self.__class__
def atoms(self, arg_type) -> Set[Any]:
return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)}
import sympy as sp
from collections import namedtuple
from sympy.core import S
from typing import Set
import jinja2
import sympy as sp
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
from pystencils.cpu.vectorization import vec_any, vec_all
from pystencils.astnodes import (DestructuringBindingsForFieldClass,
KernelFunction, Node)
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (PointerType, VectorType, address_of,
cast_func, create_type, reinterpret_cast_func,
cast_func, create_type,
get_type_of_expression,
vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
fast_sqrt)
from pystencils.integer_functions import (bit_shift_left, bit_shift_right,
bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_ceil, int_div, int_power_of_2
from pystencils.astnodes import Node, KernelFunction
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
......@@ -255,6 +261,29 @@ class CBackend:
result += "else " + false_block
return result
def _print_DestructuringBindingsForFieldClass(self, node: Node):
# Define all undefined symbols
undefined_field_symbols = node.symbols_defined
destructuring_bindings = ["%s = %s.%s%s;" %
(u.name,
u.field_name if hasattr(u, 'field_name') else u.field_names[0],
DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__],
"" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate))
for u in undefined_field_symbols
]
template = jinja2.Template(
"""{
{% for binding in bindings -%}
{{ binding | indent(3) }}
{% endfor -%}
{{ block | indent(3) }}
}
""")
code = template.render(bindings=destructuring_bindings,
block=self._print(node.body))
return code
# ------------------------------------------ Helper function & classes -------------------------------------------------
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy
import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass
def test_destructuring_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments)
print(pystencils.show_code(ast))
ast.body = DestructuringBindingsForFieldClass(ast.body)
print(pystencils.show_code(ast))
def main():
test_destructuring_field_class()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment