Skip to content
Snippets Groups Projects
Commit 472f6f6c authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'remove-DestructuringBindingsForFieldClass' into 'master'

Remove DestructuringBindingsForFieldClass

See merge request pycodegen/pystencils!55
parents 32538b37 1f0a4b46
Branches
No related merge requests found
......@@ -676,59 +676,5 @@ def early_out(condition):
return Conditional(vec_all(condition), Block([SkipIteration()]))
class DestructuringBindingsForFieldClass(Node):
"""
Defines all variables needed for describing a field (shape, pointer, strides)
"""
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "data",
FieldShapeSymbol: "shape[%i]",
FieldStrideSymbol: "stride[%i]"
}
CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>"
@property
def fields_accessed(self) -> Set['ResolvedFieldAccess']:
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__()
self.headers = ['<PyStencilsField.h>']
self.body = body
@property
def args(self) -> List[NodeOrExpr]:
"""Returns all arguments/children of this node."""
return set()
@property
def symbols_defined(self) -> Set[sp.Symbol]:
"""Set of symbols which are defined by this node."""
undefined_field_symbols = {s for s in self.body.undefined_symbols
if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))}
return undefined_field_symbols
@property
def undefined_symbols(self) -> Set[sp.Symbol]:
field_map = {f.name: f for f in self.fields_accessed}
undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols)
def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
self.body.subs(subs_dict)
@property
def func(self):
return self.__class__
def atoms(self, arg_type) -> Set[Any]:
return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)}
def get_dummy_symbol(dtype='bool'):
return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype))
......@@ -15,7 +15,6 @@ from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqr
from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
......@@ -276,24 +275,6 @@ class CBackend:
result += "else " + false_block
return result
def _print_DestructuringBindingsForFieldClass(self, node):
# Define all undefined symbols
undefined_field_symbols = node.symbols_defined
destructuring_bindings = ["%s %s = %s.%s;" %
(u.dtype,
u.name,
u.field_name if hasattr(u, 'field_name') else u.field_names[0],
node.CLASS_TO_MEMBER_DICT[u.__class__] %
(() if type(u) == FieldPointerSymbol else (u.coordinate,)))
for u in undefined_field_symbols
]
destructuring_bindings.sort() # only for code aesthetics
return "{\n" + self._indent + \
("\n" + self._indent).join(destructuring_bindings) + \
"\n" + self._indent + \
("\n" + self._indent).join(self._print(node.body).splitlines()) + \
"\n}"
# ------------------------------------------ Helper function & classes -------------------------------------------------
......
import sympy
import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
def test_destructuring_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments, target='gpu')
print(pystencils.show_code(ast))
ast.body = DestructuringBindingsForFieldClass(ast.body)
print(pystencils.show_code(ast))
ast.compile()
class DestructuringEmojiClass(DestructuringBindingsForFieldClass):
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "🥶",
FieldShapeSymbol: "😳_%i",
FieldStrideSymbol: "🥵_%i"
}
CLASS_NAME_TEMPLATE = "🤯<{dtype}, {ndim}>"
def __init__(self, node):
super().__init__(node)
self.headers = []
def test_destructuring_alternative_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments, target='gpu')
ast.body = DestructuringEmojiClass(ast.body)
print(pystencils.show_code(ast))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment