Commit 26f855cb authored by Martin Bauer's avatar Martin Bauer
Browse files

Support for setting a staggered field

- new kernel creation function for kernels that write to staggered fields
- simplification of conditionals with integers uses ISL library now
  via islpy bindings
parent f13701df
......@@ -109,6 +109,14 @@ class Conditional(Node):
def __repr__(self):
return 'if:({!r}) '.format(self.condition_expr)
def replace_by_true_block(self):
"""Replaces the conditional by its True block"""
self.parent.replace(self, [self.true_block])
def replace_by_false_block(self):
"""Replaces the conditional by its False block"""
self.parent.replace(self, [self.false_block] if self.false_block else [])
class KernelFunction(Node):
......@@ -2,7 +2,7 @@ from functools import partial
from pystencils.gpucuda.indexing import BlockIndexing
from pystencils.transformations import resolve_field_accesses, type_all_equations, parse_base_pointer_info, \
get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses
get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses, unify_shape_symbols
from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
from pystencils.data_types import TypedSymbol, BasicType, StructType
from pystencils import Field, FieldType
......@@ -22,8 +22,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
num_buffer_accesses = 0
for eq in assignments:
num_buffer_accesses += sum([1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field)])
num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field))
common_shape = get_common_shape(fields_without_buffers)
......@@ -51,6 +50,8 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
assignments = cell_idx_assignments + assignments
block = Block(assignments)
unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)
block = indexing.guard(block, common_shape)
ast = KernelFunction(block, function_name=function_name, ghost_layers=ghost_layers, backend='gpucuda')
......@@ -2,71 +2,62 @@
import sympy as sp
import islpy as isl
from typing import Tuple
import pystencils.astnodes as ast
from pystencils.transformations import parents_of_type
#context = isl.Context()
- find all Condition nodes
- check if they depend on integers only
- create ISL space containing names of all loop symbols (counter and bounds) and all integers in Conditional expression
- build up pre-condition set by iteration over each enclosing loop add ISL constraints
- build up ISL space for condition
- if pre_condition_set.intersect(conditional_set) == pre_condition_set
always use True condition
elif pre_condition_set.intersect(conditional_set).is_empty():
always use False condition
def remove_brackets(s):
return s.replace('[', '').replace(']', '')
def _degrees_of_freedom_as_string(expr):
expr = sp.sympify(expr)
indexed = expr.atoms(sp.Indexed)
symbols = expr.atoms(sp.Symbol)
symbols_without_indexed_base = symbols - {ind.base.args[0] for ind in indexed}
return {remove_brackets(str(s)) for s in symbols_without_indexed_base}
def isl_iteration_set(node: ast.Node):
"""Builds up an ISL set describing the iteration space by analysing the enclosing loops of the given node. """
conditions = []
loop_symbols = set()
degrees_of_freedom = set()
for loop in parents_of_type(node, ast.LoopOverCoordinate):
if loop.step != 1:
raise NotImplementedError("Loops with strides != 1 are not yet supported.")
loop_start_str = str(loop.start).replace('[', '_bracket1_').replace(']', '_bracket2_')
loop_stop_str = str(loop.stop).replace('[', '_bracket1_').replace(']', '_bracket2_')
loop_start_str = remove_brackets(str(loop.start))
loop_stop_str = remove_brackets(str(loop.stop))
ctr_name = loop.loop_counter_name
conditions.append(f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}")
conditions.append(remove_brackets(f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"))
symbol_names = ','.join([ for s in loop_symbols])
symbol_names = ','.join(degrees_of_freedom)
condition_str = ' and '.join(conditions)
set_description = f"{{ [{symbol_names}] : {condition_str} }}"
return loop_symbols, isl.BasicSet(set_description)
for loop in parents_of_type(node, ast.LoopOverCoordinate):
ctr_name = loop.loop_counter_name
lower_constraint = isl.Constraint.ineq_from_names(space, {ctr_name: 1, 1: -loop.start})
upper_constraint = isl.Constraint.ineq_from_names(space, {ctr_name: 1, })
def simplify_conditionals_new(ast_node):
for conditional in ast_node.atoms(ast.Conditional):
if conditional.condition_expr == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
loop_symbols, iteration_set = isl_iteration_set(conditional)
symbols_in_condition = conditional.condition_expr.atoms(sp.Symbol)
if symbols_in_condition.issubset(loop_symbols):
symbol_names = ','.join([ for s in loop_symbols])
condition_str = str(conditional.condition_expr)
condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}")
intersection = iteration_set.intersect(condition_set)
if intersection.is_empty():
[conditional.false_block] if conditional.false_block else [])
elif intersection == iteration_set:
conditional.parent.replace(conditional, [conditional.true_block])
return degrees_of_freedom, isl.BasicSet(set_description)
def simplify_loop_counter_dependent_conditional(conditional):
"""Removes conditionals that depend on the loop counter or iteration limits if they are always true/false."""
dofs_in_condition = _degrees_of_freedom_as_string(conditional.condition_expr)
dofs_in_loops, iteration_set = isl_iteration_set(conditional)
if dofs_in_condition.issubset(dofs_in_loops):
symbol_names = ','.join(dofs_in_loops)
condition_str = remove_brackets(str(conditional.condition_expr))
condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}")
if condition_set.is_empty():
intersection = iteration_set.intersect(condition_set)
if intersection.is_empty():
elif intersection == iteration_set:
......@@ -30,24 +30,3 @@ def create_kernel(assignments, function_name="kernel", type_info=None, split_gro
code = insert_casts(code)
code.compile = partial(make_python_function, code)
return code
def create_indexed_kernel(assignments, index_fields, function_name="kernel", type_info=None,
coordinate_names=('x', 'y', 'z')):
Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
The coordinates are stored in a separated index_field, which is a one dimensional array with struct data type.
This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for
example boundary parameters.
:param assignments: list of update equations or AST nodes
:param index_fields: list of index fields, i.e. 1D fields with struct data type
:param type_info: see documentation of :func:`create_kernel`
:param function_name: see documentation of :func:`create_kernel`
:param coordinate_names: name of the coordinate fields in the struct data type
:return: abstract syntax tree
raise NotImplementedError
import warnings
from collections import defaultdict, OrderedDict
from copy import deepcopy
from types import MappingProxyType
......@@ -19,6 +20,30 @@ def filtered_tree_iteration(node, node_type):
yield from filtered_tree_iteration(arg, node_type)
def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
When creating a kernel with variable array sizes, all passed arrays must have the same size.
This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
of one for each field. For example shape_arr1[0] and shape_arr2[0] must be equal, so they should also be
represented by the same symbol.
body: ast node, for the kernel part where substitutions is made, is modified in-place
common_shape: shape of the field that was chosen
fields: all fields whose shapes should be replaced by common_shape
substitutions = {}
for field in fields:
assert len(field.spatial_shape) == len(common_shape)
if not field.has_fixed_shape:
for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
if shape_component != common_shape_component:
substitutions[shape_component] = common_shape_component
if substitutions:
def get_common_shape(field_set):
"""Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
ValueError is raised"""
......@@ -47,7 +72,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
"""Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
body: list of nodes
body: Block object with inner loop contents
function_name: name of generated C function
iteration_slice: if not None, iteration is done only over this slice of the field
ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
......@@ -68,7 +93,8 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
if loop_order is None:
loop_order = get_optimal_loop_ordering(fields)
shape = get_common_shape(list(fields))
shape = get_common_shape(fields)
unify_shape_symbols(body, common_shape=shape, fields=fields)
if iteration_slice is not None:
iteration_slice = normalize_slice(iteration_slice, shape)
......@@ -580,99 +606,33 @@ def cut_loop(loop_node, cutting_points):
loop_node.parent.replace(loop_node, new_loops)
def is_condition_necessary(condition, pre_condition, symbol):
Determines if a logical condition of a single variable is already contained in a stronger pre_condition
so if from pre_condition follows that condition is always true, then this condition is not necessary
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool=False) -> None:
"""Removes conditionals that are always true/false.
condition: sympy relational of one variable
pre_condition: logical expression that is known to be true
symbol: the single symbol of interest
returns not (pre_condition => condition) where "=>" is logical implication
from sympy.solvers.inequalities import reduce_rational_inequalities
from sympy.logic.boolalg import to_dnf
def normalize_relational(e):
if isinstance(e, sp.Rel):
return e.func(e.lhs - e.rhs, 0)
new_args = [normalize_relational(a) for a in e.args]
return e.func(*new_args) if new_args else e
def to_dnf_list(expr):
result = to_dnf(expr)
if isinstance(result, sp.Or):
return [or_term.args for or_term in result.args]
elif isinstance(result, sp.And):
return [result.args]
return [result]
condition = normalize_relational(condition)
pre_condition = normalize_relational(pre_condition)
a1 = to_dnf_list(pre_condition)
a2 = to_dnf_list(condition)
t1 = reduce_rational_inequalities(to_dnf_list(sp.And(condition, pre_condition)), symbol)
t2 = reduce_rational_inequalities(to_dnf_list(pre_condition), symbol)
return t1 != t2
def simplify_boolean_expression(expr, single_variable_ranges):
"""Simplification of boolean expression using known ranges of variables
The singleVariableRanges parameter is a dict mapping a variable name to a sympy logical expression that
contains only this variable and defines a range for it. For example with a being a symbol
{ a: sp.And(a >=0, a < 10) }
node: ast node, all descendants of this node are simplified
loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
depending on the surrounding loop. For example if the surrounding loop goes from
x=0 to 10 and the condition is x < 0, it is removed.
This analysis needs the integer set library (ISL) islpy, so it is not done by
from sympy.core.relational import Relational
from sympy.logic.boolalg import to_dnf
expr = to_dnf(expr)
def visit(e):
if isinstance(e, Relational):
symbols = e.atoms(sp.Symbol).intersection(single_variable_ranges.keys())
if len(symbols) == 1:
symbol = symbols.pop()
if not is_condition_necessary(e, single_variable_ranges[symbol], symbol):
return sp.true
return e
new_args = [visit(a) for a in e.args]
return e.func(*new_args) if new_args else e
return visit(expr)
def simplify_conditionals(node, loop_conditionals=MappingProxyType({})):
"""Simplifies/Removes conditions inside loops that depend on the loop counter."""
if isinstance(node, ast.LoopOverCoordinate):
ctr_sym = node.loop_counter_symbol
loop_conditionals = loop_conditionals.copy()
loop_conditionals[ctr_sym] = sp.And(ctr_sym >= node.start, ctr_sym < node.stop)
simplify_conditionals(node.body, loop_conditionals)
elif isinstance(node, ast.Conditional):
node.condition_expr = simplify_boolean_expression(node.condition_expr, loop_conditionals)
if node.false_block:
simplify_conditionals(node.false_block, loop_conditionals)
if node.condition_expr == sp.true:
node.parent.replace(node, [node.true_block])
if node.condition_expr == sp.false:
node.parent.replace(node, [node.false_block] if node.false_block else [])
elif isinstance(node, ast.Block):
for a in list(node.args):
simplify_conditionals(a, loop_conditionals)
elif isinstance(node, ast.SympyAssignment):
return node
raise ValueError("Can not handle node", type(node))
def cleanup_blocks(node):
for conditional in node.atoms(ast.Conditional):
conditional.condition_expr = sp.simplify(conditional.condition_expr)
if conditional.condition_expr == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification:
# noinspection PyUnresolvedReferences
from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
except ImportError:
warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")
def cleanup_blocks(node: ast.Node) -> None:
"""Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
if isinstance(node, ast.SympyAssignment):
......@@ -850,9 +810,9 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -
inner_loop = all_inner_loops.pop()
for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
cut_loop(loop, [loop.stop-1])
cut_loop(loop, [loop.stop - 1])
simplify_conditionals(function_node.body, loop_counter_simplification=True)
......@@ -884,8 +844,10 @@ def typing_from_sympy_inspection(eqs, default_type="double"):
def get_next_parent_of_type(node, parent_type):
Traverses the AST nodes parents until a parent of given type was found. If no such parent is found, None is returned
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
parent = node.parent
while parent is not None:
......@@ -896,21 +858,24 @@ def get_next_parent_of_type(node, parent_type):
def parents_of_type(node, parent_type, include_current=False):
"""Similar to get_next_parent_of_type, but as generator"""
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
return None
def get_optimal_loop_ordering(fields):
Determines the optimal loop order for a given set of fields.
If the fields have different memory layout or different sizes an exception is thrown.
:param fields: sequence of fields
:return: list of coordinate ids, where the first list entry should be the outermost loop
fields: sequence of fields
list of coordinate ids, where the first list entry should be the outermost loop
assert len(fields) > 0
ref_field = next(iter(fields))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment