import sympy as sp from sympy.tensor import IndexedBase from pystencils.field import Field from pystencils.data_types import TypedSymbol, create_type, cast_func from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol from pystencils.sympyextensions import fast_subs from typing import List, Set, Optional, Union, Any, Sequence NodeOrExpr = Union['Node', sp.Expr] class Node: """Base class for all AST nodes.""" def __init__(self, parent: Optional['Node'] = None): self.parent = parent @property def args(self) -> List[NodeOrExpr]: """Returns all arguments/children of this node.""" raise NotImplementedError() @property def symbols_defined(self) -> Set[sp.Symbol]: """Set of symbols which are defined by this node.""" raise NotImplementedError() @property def undefined_symbols(self) -> Set[sp.Symbol]: """Symbols which are used but are not defined inside this node.""" raise NotImplementedError() def subs(self, subs_dict) -> None: """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" for a in self.args: a.subs(subs_dict) @property def func(self): return self.__class__ def atoms(self, arg_type) -> Set[Any]: """Returns a set of all descendants recursively, which are an instance of the given type.""" result = set() for arg in self.args: if isinstance(arg, arg_type): result.add(arg) result.update(arg.atoms(arg_type)) return result class Conditional(Node): """Conditional that maps to a 'if' statement in C/C++. Try to avoid using this node inside of loops, since currently this construction can not be vectorized. Consider using assignments with sympy.Piecewise in this case. Args: condition_expr: sympy relational expression true_block: block which is run if conditional is true false_block: optional block which is run if conditional is false """ def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'], false_block: Optional['Block'] = None) -> None: super(Conditional, self).__init__(parent=None) assert condition_expr.is_Boolean or condition_expr.is_Relational self.condition_expr = condition_expr def handle_child(c): if c is None: return None if not isinstance(c, Block): c = Block([c]) c.parent = self return c self.true_block = handle_child(true_block) self.false_block = handle_child(false_block) def subs(self, subs_dict): self.true_block.subs(subs_dict) if self.false_block: self.false_block.subs(subs_dict) self.condition_expr = self.condition_expr.subs(subs_dict) @property def args(self): result = [self.condition_expr, self.true_block] if self.false_block: result.append(self.false_block) return result @property def symbols_defined(self): return set() @property def undefined_symbols(self): result = self.true_block.undefined_symbols if self.false_block: result.update(self.false_block.undefined_symbols) result.update(self.condition_expr.atoms(sp.Symbol)) return result def __str__(self): return 'if:({!s}) '.format(self.condition_expr) def __repr__(self): return 'if:({!r}) '.format(self.condition_expr) def replace_by_true_block(self): """Replaces the conditional by its True block""" self.parent.replace(self, [self.true_block]) def replace_by_false_block(self): """Replaces the conditional by its False block""" self.parent.replace(self, [self.false_block] if self.false_block else []) class KernelFunction(Node): class Parameter: """Function parameter. Each undefined symbol in a `KernelFunction` node becomes a parameter to the function. Parameters are either symbols introduced by the user that never occur on the left hand side of an Assignment, or are related to fields/arrays passed to the function. A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol defined in pystencils.kernelparameters. If the parameter is related to one or multiple fields, these fields are referenced in the fields property. """ def __init__(self, symbol, fields): self.symbol = symbol # type: TypedSymbol self.fields = fields # type: Sequence[Field] def __repr__(self): return repr(self.symbol) @property def is_field_stride(self): return isinstance(self.symbol, FieldStrideSymbol) @property def is_field_shape(self): return isinstance(self.symbol, FieldShapeSymbol) @property def is_field_pointer(self): return isinstance(self.symbol, FieldPointerSymbol) @property def is_field_parameter(self): return self.is_field_pointer or self.is_field_shape or self.is_field_stride @property def field_name(self): return self.fields[0].name def __init__(self, body, ghost_layers=None, function_name="kernel", backend=""): super(KernelFunction, self).__init__() self._body = body body.parent = self self.function_name = function_name self._body.parent = self self.compile = None self.ghost_layers = ghost_layers # these variables are assumed to be global, so no automatic parameter is generated for them self.global_variables = set() self.backend = backend self.instruction_set = None # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use @property def symbols_defined(self): return set() @property def undefined_symbols(self): return set() @property def body(self): return self._body @body.setter def body(self, value): self._body = value self._body.parent = self @property def args(self): return [self._body] @property def fields_accessed(self) -> Set['ResolvedFieldAccess']: """Set of Field instances: fields which are accessed inside this kernel function""" return set(o.field for o in self.atoms(ResolvedFieldAccess)) def get_parameters(self) -> Sequence['KernelFunction.Parameter']: """Returns list of parameters for this function. This function is expensive, cache the result where possible! """ field_map = {f.name: f for f in self.fields_accessed} def get_fields(symbol): if hasattr(symbol, 'field_name'): return field_map[symbol.field_name], elif hasattr(symbol, 'field_names'): return tuple(field_map[fn] for fn in symbol.field_names) return () argument_symbols = self._body.undefined_symbols - self.global_variables parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols] parameters.sort(key=lambda p: p.symbol.name) return parameters def __str__(self): params = [p.symbol for p in self.get_parameters()] return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params, ("\t" + "\t".join(str(self.body).splitlines(True)))) def __repr__(self): params = [p.symbol for p in self.get_parameters()] return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params) class Block(Node): def __init__(self, nodes: List[Node]): super(Block, self).__init__() self._nodes = nodes self.parent = None for n in self._nodes: n.parent = self @property def args(self): return self._nodes def subs(self, subs_dict) -> None: new_args = [] for a in self.args: if isinstance(a, SympyAssignment) and a.is_declaration and a.rhs in subs_dict.keys(): subs_dict[a.lhs] = subs_dict[a.rhs] else: new_args.append(a) self._nodes = new_args for a in self.args: a.subs(subs_dict) def insert_front(self, node): node.parent = self self._nodes.insert(0, node) def insert_before(self, new_node, insert_before): new_node.parent = self idx = self._nodes.index(insert_before) # move all assignment (definitions to the top) if isinstance(new_node, SympyAssignment) and new_node.is_declaration: while idx > 0: pn = self._nodes[idx - 1] if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): idx -= 1 else: break self._nodes.insert(idx, new_node) def append(self, node): if isinstance(node, list) or isinstance(node, tuple): for n in node: n.parent = self self._nodes.append(n) else: node.parent = self self._nodes.append(node) def take_child_nodes(self): tmp = self._nodes self._nodes = [] return tmp def replace(self, child, replacements): idx = self._nodes.index(child) del self._nodes[idx] if type(replacements) is list: for e in replacements: e.parent = self self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:] else: replacements.parent = self self._nodes.insert(idx, replacements) @property def symbols_defined(self): result = set() for a in self.args: result.update(a.symbols_defined) return result @property def undefined_symbols(self): result = set() defined_symbols = set() for a in self.args: result.update(a.undefined_symbols) defined_symbols.update(a.symbols_defined) return result - defined_symbols def __str__(self): return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes) def __repr__(self): return "Block" class PragmaBlock(Block): def __init__(self, pragma_line, nodes): super(PragmaBlock, self).__init__(nodes) self.pragma_line = pragma_line for n in nodes: n.parent = self def __repr__(self): return self.pragma_line class LoopOverCoordinate(Node): LOOP_COUNTER_NAME_PREFIX = "ctr" BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False): super(LoopOverCoordinate, self).__init__(parent=None) self.body = body body.parent = self self.coordinate_to_loop_over = coordinate_to_loop_over self.start = start self.stop = stop self.step = step self.body.parent = self self.prefix_lines = [] self.is_block_loop = is_block_loop def new_loop_with_different_body(self, new_body): result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, self.step, self.is_block_loop) result.prefix_lines = [l for l in self.prefix_lines] return result def subs(self, subs_dict): self.body.subs(subs_dict) if hasattr(self.start, "subs"): self.start = self.start.subs(subs_dict) if hasattr(self.stop, "subs"): self.stop = self.stop.subs(subs_dict) if hasattr(self.step, "subs"): self.step = self.step.subs(subs_dict) @property def args(self): result = [self.body] for e in [self.start, self.stop, self.step]: if hasattr(e, "args"): result.append(e) return result def replace(self, child, replacement): if child == self.body: self.body = replacement elif child == self.start: self.start = replacement elif child == self.step: self.step = replacement elif child == self.stop: self.stop = replacement @property def symbols_defined(self): return {self.loop_counter_symbol} @property def undefined_symbols(self): result = self.body.undefined_symbols for possible_symbol in [self.start, self.stop, self.step]: if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic): result.update(possible_symbol.atoms(sp.Symbol)) return result - {self.loop_counter_symbol} @staticmethod def get_loop_counter_name(coordinate_to_loop_over): return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over) @staticmethod def get_block_loop_counter_name(coordinate_to_loop_over): return "%s_%s" % (LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over) @property def loop_counter_name(self): if self.is_block_loop: return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over) else: return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over) @staticmethod def is_loop_counter_symbol(symbol): prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX if not symbol.name.startswith(prefix): return None if symbol.dtype != create_type('int'): return None coordinate = int(symbol.name[len(prefix) + 1:]) return coordinate @staticmethod def get_loop_counter_symbol(coordinate_to_loop_over): return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int') @staticmethod def get_block_loop_counter_symbol(coordinate_to_loop_over): return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), 'int') @property def loop_counter_symbol(self): if self.is_block_loop: return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over) else: return self.get_loop_counter_symbol(self.coordinate_to_loop_over) @property def is_outermost_loop(self): from pystencils.transformations import get_next_parent_of_type return get_next_parent_of_type(self, LoopOverCoordinate) is None @property def is_innermost_loop(self): return len(self.atoms(LoopOverCoordinate)) == 0 def __str__(self): return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start, self.loop_counter_name, self.stop, self.loop_counter_name, self.step, ("\t" + "\t".join(str(self.body).splitlines(True)))) def __repr__(self): return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start, self.loop_counter_name, self.stop, self.loop_counter_name, self.step) class SympyAssignment(Node): def __init__(self, lhs_symbol, rhs_expr, is_const=True): super(SympyAssignment, self).__init__(parent=None) self._lhs_symbol = lhs_symbol self.rhs = rhs_expr self._is_const = is_const self._is_declaration = self.__is_declaration() def __is_declaration(self): if isinstance(self._lhs_symbol, cast_func): return False if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)): return False return True @property def lhs(self): return self._lhs_symbol @lhs.setter def lhs(self, new_value): self._lhs_symbol = new_value self._is_declaration = self.__is_declaration() def subs(self, subs_dict): self.lhs = fast_subs(self.lhs, subs_dict) self.rhs = fast_subs(self.rhs, subs_dict) @property def args(self): return [self._lhs_symbol, self.rhs] @property def symbols_defined(self): if not self._is_declaration: return set() return {self._lhs_symbol} @property def undefined_symbols(self): result = self.rhs.atoms(sp.Symbol) # Add loop counters if there a field accesses loop_counters = set() for symbol in result: if isinstance(symbol, Field.Access): for i in range(len(symbol.offsets)): loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) result.update(loop_counters) result.update(self._lhs_symbol.atoms(sp.Symbol)) return result @property def is_declaration(self): return self._is_declaration @property def is_const(self): return self._is_const def replace(self, child, replacement): if child == self.lhs: replacement.parent = self self.lhs = replacement elif child == self.rhs: replacement.parent = self self.rhs = replacement else: raise ValueError('%s is not in args of %s' % (replacement, self.__class__)) def __repr__(self): return repr(self.lhs) + " ← " + repr(self.rhs) def _repr_html_(self): printed_lhs = sp.latex(self.lhs) printed_rhs = sp.latex(self.rhs) return "${printed_lhs} \\leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs) class ResolvedFieldAccess(sp.Indexed): def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values): if not isinstance(base, IndexedBase): base = IndexedBase(base, shape=(1,)) obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index) obj.field = field obj.offsets = offsets obj.idx_coordinate_values = idx_coordinate_values return obj def _eval_subs(self, old, new): return ResolvedFieldAccess(self.args[0], self.args[1].subs(old, new), self.field, self.offsets, self.idx_coordinate_values) def fast_subs(self, substitutions): if self in substitutions: return substitutions[self] return ResolvedFieldAccess(self.args[0].subs(substitutions), self.args[1].subs(substitutions), self.field, self.offsets, self.idx_coordinate_values) def _hashable_content(self): super_class_contents = super(ResolvedFieldAccess, self)._hashable_content() return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field)) @property def typed_symbol(self): return self.base.label def __str__(self): top = super(ResolvedFieldAccess, self).__str__() return "%s (%s)" % (top, self.typed_symbol.dtype) def __getnewargs__(self): return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values class TemporaryMemoryAllocation(Node): """Node for temporary memory buffer allocation. Always allocates aligned memory. Args: typed_symbol: symbol used as pointer (has to be typed) size: number of elements to allocate align_offset: the align_offset's element is aligned """ def __init__(self, typed_symbol: TypedSymbol, size, align_offset): super(TemporaryMemoryAllocation, self).__init__(parent=None) self.symbol = typed_symbol self.size = size self.headers = ['<stdlib.h>'] self._align_offset = align_offset @property def symbols_defined(self): return {self.symbol} @property def undefined_symbols(self): if isinstance(self.size, sp.Basic): return self.size.atoms(sp.Symbol) else: return set() @property def args(self): return [self.symbol] def offset(self, byte_alignment): """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment.""" np_dtype = self.symbol.dtype.base_type.numpy_dtype assert byte_alignment % np_dtype.itemsize == 0 return -self._align_offset % (byte_alignment / np_dtype.itemsize) class TemporaryMemoryFree(Node): def __init__(self, alloc_node): super(TemporaryMemoryFree, self).__init__(parent=None) self.alloc_node = alloc_node @property def symbol(self): return self.alloc_node.symbol def offset(self, byte_alignment): return self.alloc_node.offset(byte_alignment) @property def symbols_defined(self): return set() @property def undefined_symbols(self): return set() @property def args(self): return []