Commit e09f0646 authored by Stephan Seitz's avatar Stephan Seitz Committed by Martin Bauer
Browse files

Fixup for DestructuringBindingsForFieldClass

- rename header Field.h is not a unique name in waLBerla context
- add PyStencilsField.h
- bindings were lacking data type
parent 3b291b02
...@@ -670,10 +670,10 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -670,10 +670,10 @@ class DestructuringBindingsForFieldClass(Node):
""" """
CLASS_TO_MEMBER_DICT = { CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "data", FieldPointerSymbol: "data",
FieldShapeSymbol: "shape", FieldShapeSymbol: "shape[%i]",
FieldStrideSymbol: "stride" FieldStrideSymbol: "stride[%i]"
} }
CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>") CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>")
@property @property
def fields_accessed(self) -> Set['ResolvedFieldAccess']: def fields_accessed(self) -> Set['ResolvedFieldAccess']:
...@@ -682,7 +682,7 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -682,7 +682,7 @@ class DestructuringBindingsForFieldClass(Node):
def __init__(self, body): def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__() super(DestructuringBindingsForFieldClass, self).__init__()
self.headers = ['<Field.h>'] self.headers = ['<PyStencilsField.h>']
self.body = body self.body = body
@property @property
......
...@@ -6,15 +6,14 @@ import sympy as sp ...@@ -6,15 +6,14 @@ import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import DestructuringBindingsForFieldClass, KernelFunction, Node from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func,
reinterpret_cast_func, vector_memory_access) vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import ( from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil)
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
try: try:
...@@ -260,14 +259,15 @@ class CBackend: ...@@ -260,14 +259,15 @@ class CBackend:
result += "else " + false_block result += "else " + false_block
return result return result
def _print_DestructuringBindingsForFieldClass(self, node: Node): def _print_DestructuringBindingsForFieldClass(self, node):
# Define all undefined symbols # Define all undefined symbols
undefined_field_symbols = node.symbols_defined undefined_field_symbols = node.symbols_defined
destructuring_bindings = ["%s = %s.%s%s;" % destructuring_bindings = ["%s %s = %s.%s;" %
(u.name, (u.dtype,
u.name,
u.field_name if hasattr(u, 'field_name') else u.field_names[0], u.field_name if hasattr(u, 'field_name') else u.field_names[0],
DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__], node.CLASS_TO_MEMBER_DICT[u.__class__] %
"" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate)) (() if type(u) == FieldPointerSymbol else (u.coordinate,)))
for u in undefined_field_symbols for u in undefined_field_symbols
] ]
destructuring_bindings.sort() # only for code aesthetics destructuring_bindings.sort() # only for code aesthetics
......
#pragma once
extern "C++" {
#ifdef __CUDA_ARCH__
template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField {
DTYPE_T *data;
DTYPE_T shape[DIMENSION];
DTYPE_T stride[DIMENSION];
};
#else
#include <array>
template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField {
DTYPE_T *data;
std::array<DTYPE_T, DIMENSION> shape;
std::array<DTYPE_T, DIMENSION> stride;
};
#endif
}
import sympy import sympy
import jinja2
import pystencils import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass from pystencils.astnodes import DestructuringBindingsForFieldClass
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
def test_destructuring_field_class(): def test_destructuring_field_class():
...@@ -10,15 +14,39 @@ def test_destructuring_field_class(): ...@@ -10,15 +14,39 @@ def test_destructuring_field_class():
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments) ast = pystencils.create_kernel(normal_assignments, target='gpu')
print(pystencils.show_code(ast)) print(pystencils.show_code(ast))
ast.body = DestructuringBindingsForFieldClass(ast.body) ast.body = DestructuringBindingsForFieldClass(ast.body)
print(pystencils.show_code(ast)) print(pystencils.show_code(ast))
ast.compile()
class DestructuringEmojiClass(DestructuringBindingsForFieldClass):
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "🥶",
FieldShapeSymbol: "😳_%i",
FieldStrideSymbol: "🥵_%i"
}
CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>")
def __init__(self, node):
super().__init__(node)
self.headers = []
def test_destructuring_alternative_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments, target='gpu')
ast.body = DestructuringEmojiClass(ast.body)
print(pystencils.show_code(ast))
def main(): def main():
test_destructuring_field_class() test_destructuring_field_class()
test_destructuring_alternative_field_class()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment