Commit e09f0646 authored by Stephan Seitz's avatar Stephan Seitz Committed by Martin Bauer

Fixup for DestructuringBindingsForFieldClass

- rename header Field.h is not a unique name in waLBerla context
- add PyStencilsField.h
- bindings were lacking data type
parent 3b291b02
Pipeline #16588 canceled with stage
in 37 seconds
......@@ -670,10 +670,10 @@ class DestructuringBindingsForFieldClass(Node):
"""
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "data",
FieldShapeSymbol: "shape",
FieldStrideSymbol: "stride"
FieldShapeSymbol: "shape[%i]",
FieldStrideSymbol: "stride[%i]"
}
CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>")
CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>")
@property
def fields_accessed(self) -> Set['ResolvedFieldAccess']:
......@@ -682,7 +682,7 @@ class DestructuringBindingsForFieldClass(Node):
def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__()
self.headers = ['<Field.h>']
self.headers = ['<PyStencilsField.h>']
self.body = body
@property
......
......@@ -6,15 +6,14 @@ import sympy as sp
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import DestructuringBindingsForFieldClass, KernelFunction, Node
from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
reinterpret_cast_func, vector_memory_access)
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func,
vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
try:
......@@ -260,14 +259,15 @@ class CBackend:
result += "else " + false_block
return result
def _print_DestructuringBindingsForFieldClass(self, node: Node):
def _print_DestructuringBindingsForFieldClass(self, node):
# Define all undefined symbols
undefined_field_symbols = node.symbols_defined
destructuring_bindings = ["%s = %s.%s%s;" %
(u.name,
destructuring_bindings = ["%s %s = %s.%s;" %
(u.dtype,
u.name,
u.field_name if hasattr(u, 'field_name') else u.field_names[0],
DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__],
"" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate))
node.CLASS_TO_MEMBER_DICT[u.__class__] %
(() if type(u) == FieldPointerSymbol else (u.coordinate,)))
for u in undefined_field_symbols
]
destructuring_bindings.sort() # only for code aesthetics
......
#pragma once
extern "C++" {
#ifdef __CUDA_ARCH__
template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField {
DTYPE_T *data;
DTYPE_T shape[DIMENSION];
DTYPE_T stride[DIMENSION];
};
#else
#include <array>
template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField {
DTYPE_T *data;
std::array<DTYPE_T, DIMENSION> shape;
std::array<DTYPE_T, DIMENSION> stride;
};
#endif
}
import sympy
import jinja2
import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
def test_destructuring_field_class():
......@@ -10,15 +14,39 @@ def test_destructuring_field_class():
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments)
ast = pystencils.create_kernel(normal_assignments, target='gpu')
print(pystencils.show_code(ast))
ast.body = DestructuringBindingsForFieldClass(ast.body)
print(pystencils.show_code(ast))
ast.compile()
class DestructuringEmojiClass(DestructuringBindingsForFieldClass):
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "🥶",
FieldShapeSymbol: "😳_%i",
FieldStrideSymbol: "🥵_%i"
}
CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>")
def __init__(self, node):
super().__init__(node)
self.headers = []
def test_destructuring_alternative_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments, target='gpu')
ast.body = DestructuringEmojiClass(ast.body)
print(pystencils.show_code(ast))
def main():
test_destructuring_field_class()
test_destructuring_alternative_field_class()
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment