diff --git a/__init__.py b/__init__.py index d4455a766682f032373ada14e5c9024b71d4a383..4d4dacb55c0f236e360656afbd7b6b3abcd16733 100644 --- a/__init__.py +++ b/__init__.py @@ -2,7 +2,7 @@ from pystencils.field import Field, FieldType, extractCommonSubexpressions from pystencils.data_types import TypedSymbol from pystencils.slicing import makeSlice from pystencils.kernelcreation import createKernel, createIndexedKernel -from pystencils.display_utils import showCode, toDot +from pystencils.display_utils import show_code, to_dot from pystencils.assignment_collection import AssignmentCollection from pystencils.assignment import Assignment from pystencils.sympyextensions import SymbolCreator @@ -11,7 +11,7 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions', 'TypedSymbol', 'makeSlice', 'createKernel', 'createIndexedKernel', - 'showCode', 'toDot', + 'show_code', 'to_dot', 'AssignmentCollection', 'Assignment', 'SymbolCreator'] diff --git a/assignment.py b/assignment.py index 135a5d1264c095a8a6955e9a75d29b62e20048d8..251ff6966af05372abdf953f7e69291a2c151148 100644 --- a/assignment.py +++ b/assignment.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from sympy.codegen.ast import Assignment from sympy.printing.latex import LatexPrinter @@ -11,4 +12,9 @@ def print_assignment_latex(printer, expr): return f"{printed_lhs} \leftarrow {printed_rhs}" +def assignment_str(assignment): + return f"{assignment.lhs} ↠{assignment.rhs}" + + +Assignment.__str__ = assignment_str LatexPrinter._print_Assignment = print_assignment_latex diff --git a/assignment_collection/__init__.py b/assignment_collection/__init__.py index a71a7d05cec3af128062e31ebbdd1c246159ef51..7648bbab8624019ab4e7de218576c1212673ee31 100644 --- a/assignment_collection/__init__.py +++ b/assignment_collection/__init__.py @@ -1,2 +1,10 @@ from pystencils.assignment_collection.assignment_collection import AssignmentCollection from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy +from pystencils.assignment_collection.simplifications import sympy_cse, sympy_cse_on_assignment_list, \ + apply_to_all_assignments, apply_on_all_subexpressions, subexpression_substitution_in_existing_subexpressions, \ + subexpression_substitution_in_main_assignments, add_subexpressions_for_divisions + +__all__ = ['AssignmentCollection', 'SimplificationStrategy', + 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', + 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', + 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions'] diff --git a/assignment_collection/assignment_collection.py b/assignment_collection/assignment_collection.py index 4c66773589775dc276abfa89479b97644b8c84f3..2907cb5fea3f77ebf971ea706096da6fe38dcbd9 100644 --- a/assignment_collection/assignment_collection.py +++ b/assignment_collection/assignment_collection.py @@ -61,7 +61,7 @@ class AssignmentCollection: left hand side symbol (which could have been generated) """ if lhs is None: - lhs = sp.Dummy() + lhs = next(self.subexpression_symbol_generator) eq = Assignment(lhs, rhs) self.subexpressions.append(eq) if topological_sort: @@ -135,25 +135,6 @@ class AssignmentCollection: return handled_symbols - def get(self, symbols: Sequence[sp.Symbol], from_main_assignments_only=False) -> List[Assignment]: - """Extracts all assignments that have a left hand side that is contained in the symbols parameter. - - Args: - symbols: return assignments that have one of these symbols as left hand side - from_main_assignments_only: search only in main assignments (exclude subexpressions) - """ - if not hasattr(symbols, "__len__"): - symbols = set(symbols) - else: - symbols = set(symbols) - - if not from_main_assignments_only: - assignments_to_search = self.all_assignments - else: - assignments_to_search = self.main_assignments - - return [assignment for assignment in assignments_to_search if assignment.lhs in symbols] - def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]]=None, module=None): """Returns a python function to evaluate this equation collection. @@ -343,26 +324,25 @@ class AssignmentCollection: return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments]) def __str__(self): - result = "Subexpressions\n" + result = "Subexpressions:\n" for eq in self.subexpressions: - result += str(eq) + "\n" - result += "Main Assignments\n" + result += f"\t{eq}\n" + result += "Main Assignments:\n" for eq in self.main_assignments: - result += str(eq) + "\n" + result += f"{eq}\n" return result class SymbolGen: """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" - def __init__(self): + def __init__(self, symbol="xi"): self._ctr = 0 + self._symbol = symbol def __iter__(self): return self def __next__(self): + name = f"{self._symbol}_{self._ctr}" self._ctr += 1 - return sp.Symbol("xi_" + str(self._ctr)) - - def next(self): - return self.__next__() + return sp.Symbol(name) diff --git a/assignment_collection/simplifications.py b/assignment_collection/simplifications.py index b7707bc998b95885069efa92cfe2b2125404ec1f..f8b2b23adb0809bfade757bf1d94b1cf4fd0426c 100644 --- a/assignment_collection/simplifications.py +++ b/assignment_collection/simplifications.py @@ -1,15 +1,10 @@ import sympy as sp from typing import Callable, List -from pystencils import Assignment, AssignmentCollection +from pystencils.assignment import Assignment +from pystencils.assignment_collection.assignment_collection import AssignmentCollection from pystencils.sympyextensions import subs_additive -def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: - """Extracts common subexpressions from a list of assignments.""" - ec = AssignmentCollection(assignments, []) - return sympy_cse(ec).all_assignments - - def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: """Searches for common subexpressions inside the equation collection. @@ -32,21 +27,28 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: return ac.copy(modified_update_equations, new_subexpressions) +def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: + """Extracts common subexpressions from a list of assignments.""" + ec = AssignmentCollection(assignments, []) + return sympy_cse(ec).all_assignments + + def apply_to_all_assignments(assignment_collection: AssignmentCollection, operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: - """Applies sympy expand operation to all equations in collection""" + """Applies sympy expand operation to all equations in collection.""" result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] return assignment_collection.copy(result) def apply_on_all_subexpressions(ac: AssignmentCollection, operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: + """Applies the given operation on all subexpressions of the AssignmentCollection.""" result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] return ac.copy(ac.main_assignments, result) def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection: - """Goes through the subexpressions list and replaces the term in the following subexpressions""" + """Goes through the subexpressions list and replaces the term in the following subexpressions.""" result = [] for outerCtr, s in enumerate(ac.subexpressions): new_rhs = s.rhs diff --git a/astnodes.py b/astnodes.py index 58780ac4ea1d745af1986cb1a56bd357cb87199c..0a10b8645ba3699fded7e424496e3db61a1e2fd2 100644 --- a/astnodes.py +++ b/astnodes.py @@ -1,32 +1,36 @@ import sympy as sp from sympy.tensor import IndexedBase from pystencils.field import Field -from pystencils.data_types import TypedSymbol, createType, castFunc +from pystencils.data_types import TypedSymbol, create_type, castFunc from pystencils.sympyextensions import fast_subs +from typing import List, Set, Optional, Union, Any + +NodeOrExpr = Union['Node', sp.Expr] class Node(object): - """Base class for all AST nodes""" + """Base class for all AST nodes.""" - def __init__(self, parent=None): + def __init__(self, parent: Optional['Node'] = None): self.parent = parent - def args(self): - """Returns all arguments/children of this node""" + @property + def args(self) -> List[NodeOrExpr]: + """Returns all arguments/children of this node.""" return [] @property - def symbolsDefined(self): - """Set of symbols which are defined by this node. """ + def symbols_defined(self) -> Set[sp.Symbol]: + """Set of symbols which are defined by this node.""" return set() @property - def undefinedSymbols(self): - """Symbols which are used but are not defined inside this node""" + def undefined_symbols(self) -> Set[sp.Symbol]: + """Symbols which are used but are not defined inside this node.""" raise NotImplementedError() - def subs(self, *args, **kwargs): - """Inplace! substitute, similar to sympys but modifies ast and returns None""" + def subs(self, *args, **kwargs) -> None: + """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" for a in self.args: a.subs(*args, **kwargs) @@ -34,32 +38,35 @@ class Node(object): def func(self): return self.__class__ - def atoms(self, argType): - """ - Returns a set of all children which are an instance of the given argType - """ + def atoms(self, arg_type) -> Set[Any]: + """Returns a set of all descendants recursively, which are an instance of the given type.""" result = set() for arg in self.args: - if isinstance(arg, argType): + if isinstance(arg, arg_type): result.add(arg) - result.update(arg.atoms(argType)) + result.update(arg.atoms(arg_type)) return result class Conditional(Node): - """Conditional""" - def __init__(self, conditionExpr, trueBlock, falseBlock=None): - """ - Create a new conditional node - - :param conditionExpr: sympy relational expression - :param trueBlock: block which is run if conditional is true - :param falseBlock: block which is run if conditional is false, or None if not needed - """ - assert conditionExpr.is_Boolean or conditionExpr.is_Relational - self.conditionExpr = conditionExpr - - def handleChild(c): + """Conditional that maps to a 'if' statement in C/C++. + + Try to avoid using this node inside of loops, since currently this construction can not be vectorized. + Consider using assignments with sympy.Piecewise in this case. + + Args: + condition_expr: sympy relational expression + true_block: block which is run if conditional is true + false_block: optional block which is run if conditional is false + """ + + def __init__(self, condition_expr: sp.Expr, true_block: 'Block', false_block: Optional['Block'] = None) -> None: + super(Conditional, self).__init__(parent=None) + + assert condition_expr.is_Boolean or condition_expr.is_Relational + self.conditionExpr = condition_expr + + def handle_child(c): if c is None: return None if not isinstance(c, Block): @@ -67,8 +74,8 @@ class Conditional(Node): c.parent = self return c - self.trueBlock = handleChild(trueBlock) - self.falseBlock = handleChild(falseBlock) + self.trueBlock = handle_child(true_block) + self.falseBlock = handle_child(false_block) def subs(self, *args, **kwargs): self.trueBlock.subs(*args, **kwargs) @@ -84,14 +91,14 @@ class Conditional(Node): return result @property - def symbolsDefined(self): + def symbols_defined(self): return set() @property - def undefinedSymbols(self): - result = self.trueBlock.undefinedSymbols + def undefined_symbols(self): + result = self.trueBlock.undefined_symbols if self.falseBlock: - result.update(self.falseBlock.undefinedSymbols) + result.update(self.falseBlock.undefined_symbols) result.update(self.conditionExpr.atoms(sp.Symbol)) return result @@ -105,7 +112,7 @@ class Conditional(Node): class KernelFunction(Node): class Argument: - def __init__(self, name, dtype, symbol, kernelFunctionNode): + def __init__(self, name, dtype, symbol, kernel_function_node): from pystencils.transformations import symbolNameToVariableName self.name = name self.dtype = dtype @@ -132,8 +139,8 @@ class KernelFunction(Node): self.field = None if self.isFieldArgument: - fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed} - self.field = fieldMap[self.fieldName] + field_map = {symbolNameToVariableName(f.name): f for f in kernel_function_node.fields_accessed} + self.field = field_map[self.fieldName] def __lt__(self, other): def score(l): @@ -155,30 +162,30 @@ class KernelFunction(Node): def __repr__(self): return '<{0} {1}>'.format(self.dtype, self.name) - def __init__(self, body, ghostLayers=None, functionName="kernel", backend=""): + def __init__(self, body, ghost_layers=None, function_name="kernel", backend=""): super(KernelFunction, self).__init__() self._body = body body.parent = self self._parameters = None - self.functionName = functionName + self.functionName = function_name self._body.parent = self self.compile = None - self.ghostLayers = ghostLayers + self.ghostLayers = ghost_layers # these variables are assumed to be global, so no automatic parameter is generated for them self.globalVariables = set() self.backend = backend @property - def symbolsDefined(self): + def symbols_defined(self): return set() @property - def undefinedSymbols(self): + def undefined_symbols(self): return set() @property def parameters(self): - self._updateParameters() + self._update_parameters() return self._parameters @property @@ -190,30 +197,30 @@ class KernelFunction(Node): return [self._body] @property - def fieldsAccessed(self): + def fields_accessed(self): """Set of Field instances: fields which are accessed inside this kernel function""" return set(o.field for o in self.atoms(ResolvedFieldAccess)) - def _updateParameters(self): - undefinedSymbols = self._body.undefinedSymbols - self.globalVariables - self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols] + def _update_parameters(self): + undefined_symbols = self._body.undefined_symbols - self.globalVariables + self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefined_symbols] self._parameters.sort() def __str__(self): - self._updateParameters() + self._update_parameters() return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters, ("\t" + "\t".join(str(self.body).splitlines(True)))) def __repr__(self): - self._updateParameters() + self._update_parameters() return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters) class Block(Node): - def __init__(self, listOfNodes): - super(Node, self).__init__() - self._nodes = listOfNodes + def __init__(self, nodes: List[Node]): + super(Block, self).__init__() + self._nodes = nodes self.parent = None for n in self._nodes: n.parent = self @@ -222,23 +229,23 @@ class Block(Node): def args(self): return self._nodes - def insertFront(self, node): + def insert_front(self, node): node.parent = self self._nodes.insert(0, node) - def insertBefore(self, newNode, insertBefore): - newNode.parent = self - idx = self._nodes.index(insertBefore) + def insert_before(self, new_node, insert_before): + new_node.parent = self + idx = self._nodes.index(insert_before) # move all assignment (definitions to the top) - if isinstance(newNode, SympyAssignment) and newNode.isDeclaration: + if isinstance(new_node, SympyAssignment) and new_node.is_declaration: while idx > 0: pn = self._nodes[idx - 1] if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): idx -= 1 else: break - self._nodes.insert(idx, newNode) + self._nodes.insert(idx, new_node) def append(self, node): if isinstance(node, list) or isinstance(node, tuple): @@ -249,7 +256,7 @@ class Block(Node): node.parent = self self._nodes.append(node) - def takeChildNodes(self): + def take_child_nodes(self): tmp = self._nodes self._nodes = [] return tmp @@ -266,19 +273,19 @@ class Block(Node): self._nodes.insert(idx, replacements) @property - def symbolsDefined(self): + def symbols_defined(self): result = set() for a in self.args: - result.update(a.symbolsDefined) + result.update(a.symbols_defined) return result @property - def undefinedSymbols(self): + def undefined_symbols(self): result = set() defined_symbols = set() for a in self.args: - result.update(a.undefinedSymbols) - defined_symbols.update(a.symbolsDefined) + result.update(a.undefined_symbols) + defined_symbols.update(a.symbols_defined) return result - defined_symbols def __str__(self): @@ -289,10 +296,10 @@ class Block(Node): class PragmaBlock(Block): - def __init__(self, pragmaLine, listOfNodes): - super(PragmaBlock, self).__init__(listOfNodes) - self.pragmaLine = pragmaLine - for n in listOfNodes: + def __init__(self, pragma_line, nodes): + super(PragmaBlock, self).__init__(nodes) + self.pragmaLine = pragma_line + for n in nodes: n.parent = self def __repr__(self): @@ -302,18 +309,19 @@ class PragmaBlock(Block): class LoopOverCoordinate(Node): LOOP_COUNTER_NAME_PREFIX = "ctr" - def __init__(self, body, coordinateToLoopOver, start, stop, step=1): + def __init__(self, body, coordinate_to_loop_over, start, stop, step=1): + super(LoopOverCoordinate, self).__init__(parent=None) self.body = body body.parent = self - self.coordinateToLoopOver = coordinateToLoopOver + self.coordinateToLoopOver = coordinate_to_loop_over self.start = start self.stop = stop self.step = step self.body.parent = self self.prefixLines = [] - def newLoopWithDifferentBody(self, newBody): - result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step) + def new_loop_with_different_body(self, new_body): + result = LoopOverCoordinate(new_body, self.coordinateToLoopOver, self.start, self.stop, self.step) result.prefixLines = [l for l in self.prefixLines] return result @@ -345,84 +353,85 @@ class LoopOverCoordinate(Node): self.stop = replacement @property - def symbolsDefined(self): - return set([self.loopCounterSymbol]) + def symbols_defined(self): + return {self.loop_counter_symbol} @property - def undefinedSymbols(self): - result = self.body.undefinedSymbols + def undefined_symbols(self): + result = self.body.undefined_symbols for possibleSymbol in [self.start, self.stop, self.step]: if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic): result.update(possibleSymbol.atoms(sp.Symbol)) - return result - set([self.loopCounterSymbol]) + return result - {self.loop_counter_symbol} @staticmethod - def getLoopCounterName(coordinateToLoopOver): - return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver) + def get_loop_counter_name(coordinate_to_loop_over): + return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over) @property - def loopCounterName(self): - return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver) + def loop_counter_name(self): + return LoopOverCoordinate.get_loop_counter_name(self.coordinateToLoopOver) @staticmethod - def isLoopCounterSymbol(symbol): + def is_loop_counter_symbol(symbol): prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX if not symbol.name.startswith(prefix): return None - if symbol.dtype != createType('int'): + if symbol.dtype != create_type('int'): return None coordinate = int(symbol.name[len(prefix)+1:]) return coordinate @staticmethod - def getLoopCounterSymbol(coordinateToLoopOver): - return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int') + def get_loop_counter_symbol(coordinate_to_loop_over): + return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int') @property - def loopCounterSymbol(self): - return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver) + def loop_counter_symbol(self): + return LoopOverCoordinate.get_loop_counter_symbol(self.coordinateToLoopOver) @property - def isOutermostLoop(self): + def is_outermost_loop(self): from pystencils.transformations import getNextParentOfType return getNextParentOfType(self, LoopOverCoordinate) is None @property - def isInnermostLoop(self): + def is_innermost_loop(self): return len(self.atoms(LoopOverCoordinate)) == 0 def __str__(self): - return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start, - self.loopCounterName, self.stop, - self.loopCounterName, self.step, + return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start, + self.loop_counter_name, self.stop, + self.loop_counter_name, self.step, ("\t" + "\t".join(str(self.body).splitlines(True)))) def __repr__(self): - return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start, - self.loopCounterName, self.stop, - self.loopCounterName, self.step) + return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start, + self.loop_counter_name, self.stop, + self.loop_counter_name, self.step) class SympyAssignment(Node): - def __init__(self, lhsSymbol, rhsTerm, isConst=True): - self._lhsSymbol = lhsSymbol - self.rhs = rhsTerm + def __init__(self, lhs_symbol, rhs_expr, is_const=True): + super(SympyAssignment, self).__init__(parent=None) + self._lhsSymbol = lhs_symbol + self.rhs = rhs_expr self._isDeclaration = True - isCast = self._lhsSymbol.func == castFunc - if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast: + is_cast = self._lhsSymbol.func == castFunc + if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or is_cast: self._isDeclaration = False - self._isConst = isConst + self._isConst = is_const @property def lhs(self): return self._lhsSymbol @lhs.setter - def lhs(self, newValue): - self._lhsSymbol = newValue + def lhs(self, new_value): + self._lhsSymbol = new_value self._isDeclaration = True - isCast = self._lhsSymbol.func == castFunc - if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast: + is_cast = self._lhsSymbol.func == castFunc + if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or is_cast: self._isDeclaration = False def subs(self, *args, **kwargs): @@ -434,30 +443,30 @@ class SympyAssignment(Node): return [self._lhsSymbol, self.rhs] @property - def symbolsDefined(self): + def symbols_defined(self): if not self._isDeclaration: return set() - return set([self._lhsSymbol]) + return {self._lhsSymbol} @property - def undefinedSymbols(self): + def undefined_symbols(self): result = self.rhs.atoms(sp.Symbol) # Add loop counters if there a field accesses - loopCounters = set() + loop_counters = set() for symbol in result: if isinstance(symbol, Field.Access): for i in range(len(symbol.offsets)): - loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i)) - result.update(loopCounters) + loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) + result.update(loop_counters) result.update(self._lhsSymbol.atoms(sp.Symbol)) return result @property - def isDeclaration(self): + def is_declaration(self): return self._isDeclaration @property - def isConst(self): + def is_const(self): return self._isConst def replace(self, child, replacement): @@ -480,13 +489,13 @@ class SympyAssignment(Node): class ResolvedFieldAccess(sp.Indexed): - def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues): + def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values): if not isinstance(base, IndexedBase): base = IndexedBase(base, shape=(1,)) - obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex) + obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index) obj.field = field obj.offsets = offsets - obj.idxCoordinateValues = idxCoordinateValues + obj.idxCoordinateValues = idx_coordinate_values return obj def _eval_subs(self, old, new): @@ -502,32 +511,33 @@ class ResolvedFieldAccess(sp.Indexed): self.field, self.offsets, self.idxCoordinateValues) def _hashable_content(self): - superClassContents = super(ResolvedFieldAccess, self)._hashable_content() - return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field)) + super_class_contents = super(ResolvedFieldAccess, self)._hashable_content() + return super_class_contents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field)) @property - def typedSymbol(self): + def typed_symbol(self): return self.base.label def __str__(self): top = super(ResolvedFieldAccess, self).__str__() - return "%s (%s)" % (top, self.typedSymbol.dtype) + return "%s (%s)" % (top, self.typed_symbol.dtype) def __getnewargs__(self): return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues class TemporaryMemoryAllocation(Node): - def __init__(self, typedSymbol, size): - self.symbol = typedSymbol + def __init__(self, typed_symbol, size): + super(TemporaryMemoryAllocation, self).__init__(parent=None) + self.symbol = typed_symbol self.size = size @property - def symbolsDefined(self): - return set([self.symbol]) + def symbols_defined(self): + return {self.symbol} @property - def undefinedSymbols(self): + def undefined_symbols(self): if isinstance(self.size, sp.Basic): return self.size.atoms(sp.Symbol) else: @@ -539,18 +549,18 @@ class TemporaryMemoryAllocation(Node): class TemporaryMemoryFree(Node): - def __init__(self, typedSymbol): - self.symbol = typedSymbol + def __init__(self, typed_symbol): + super(TemporaryMemoryFree, self).__init__(parent=None) + self.symbol = typed_symbol @property - def symbolsDefined(self): + def symbols_defined(self): return set() @property - def undefinedSymbols(self): + def undefined_symbols(self): return set() @property def args(self): return [] - diff --git a/backends/__init__.py b/backends/__init__.py index 690fa2d0bc172dfa33500db504a3d85db6426a67..e66688a98b80c4dbad5492411a62de513920531c 100644 --- a/backends/__init__.py +++ b/backends/__init__.py @@ -1,7 +1,7 @@ -from .cbackend import generateC +from .cbackend import print_c try: - from .dot import dotprint + from .dot import print_dot from .llvm import generateLLVM except ImportError: pass diff --git a/backends/cbackend.py b/backends/cbackend.py index 7710c65ee6a0d09bc2d28106a9ef10a1d707f6df..43d585cae4a98b3419c83b22548c248b7f0a8541 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -1,45 +1,59 @@ import sympy as sp - -from pystencils.bitoperations import xor, rightShift, leftShift, bitwiseAnd, bitwiseOr +from collections import namedtuple +from sympy.core import S +from typing import Optional try: - from sympy.utilities.codegen import CCodePrinter -except ImportError: from sympy.printing.ccode import C99CodePrinter as CCodePrinter +except ImportError: + from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 -from collections import namedtuple -from sympy.core.mul import _keep_coeff -from sympy.core import S - +from pystencils.bitoperations import xor, rightShift, leftShift, bitwiseAnd, bitwiseOr from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment -from pystencils.data_types import createType, PointerType, getTypeOfExpression, VectorType, castFunc +from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, castFunc from pystencils.backends.simd_instruction_sets import selectedInstructionSet +__all__ = ['print_c'] -def generateC(astNode, signatureOnly=False): - """ - Prints the abstract syntax tree as C function + +def print_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str: + """Prints an abstract syntax tree node as C or CUDA code. + + This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded + in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel + functions. + + Args: + ast_node: + signature_only: + use_float_constants: + + Returns: + C-like code for the ast node and its descendants """ - fieldTypes = set([f.dtype for f in astNode.fieldsAccessed]) - useFloatConstants = createType("double") not in fieldTypes + if use_float_constants is None: + field_types = set(o.field.dtype for o in ast_node.atoms(ResolvedFieldAccess)) + double = create_type('double') + use_float_constants = double not in field_types - vectorIS = selectedInstructionSet['double'] - printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly, vectorInstructionSet=vectorIS) - return printer(astNode) + vector_is = selectedInstructionSet['double'] + printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only, + vector_instruction_set=vector_is) + return printer(ast_node) -def getHeaders(astNode): +def get_headers(ast_node): headers = set() - if hasattr(astNode, 'headers'): - headers.update(astNode.headers) - elif isinstance(astNode, SympyAssignment): - if type(getTypeOfExpression(astNode.rhs)) is VectorType: + if hasattr(ast_node, 'headers'): + headers.update(ast_node.headers) + elif isinstance(ast_node, SympyAssignment): + if type(get_type_of_expression(ast_node.rhs)) is VectorType: headers.update(selectedInstructionSet['double']['headers']) - for a in astNode.args: + for a in ast_node.args: if isinstance(a, Node): - headers.update(getHeaders(a)) + headers.update(get_headers(a)) return headers @@ -48,10 +62,11 @@ def getHeaders(astNode): class CustomCppCode(Node): - def __init__(self, code, symbolsRead, symbolsDefined): + def __init__(self, code, symbols_read, symbols_defined, parent=None): + super(CustomCppCode, self).__init__(parent=parent) self._code = "\n" + code - self._symbolsRead = set(symbolsRead) - self._symbolsDefined = set(symbolsDefined) + self._symbolsRead = set(symbols_read) + self._symbolsDefined = set(symbols_defined) self.headers = [] @property @@ -63,75 +78,78 @@ class CustomCppCode(Node): return [] @property - def symbolsDefined(self): + def symbols_defined(self): return self._symbolsDefined @property - def undefinedSymbols(self): - return self.symbolsDefined - self._symbolsRead + def undefined_symbols(self): + return self.symbols_defined - self._symbolsRead class PrintNode(CustomCppCode): - def __init__(self, symbolToPrint): - code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name) - super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set()) + # noinspection SpellCheckingInspection + def __init__(self, symbol_to_print): + code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name) + super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set()) self.headers.append("<iostream>") # ------------------------------------------- Printer ------------------------------------------------------------------ -class CBackend(object): +# noinspection PyPep8Naming +class CBackend: - def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None): - if sympyPrinter is None: - self.sympyPrinter = CustomSympyPrinter(constantsAsFloats) - if vectorInstructionSet is not None: - self.sympyPrinter = VectorizedCustomSympyPrinter(vectorInstructionSet, constantsAsFloats) + def __init__(self, constants_as_floats=False, sympy_printer=None, + signature_only=False, vector_instruction_set=None): + if sympy_printer is None: + self.sympyPrinter = CustomSympyPrinter(constants_as_floats) + if vector_instruction_set is not None: + self.sympyPrinter = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats) else: - self.sympyPrinter = CustomSympyPrinter(constantsAsFloats) + self.sympyPrinter = CustomSympyPrinter(constants_as_floats) else: - self.sympyPrinter = sympyPrinter + self.sympyPrinter = sympy_printer - self._vectorInstructionSet = vectorInstructionSet + self._vectorInstructionSet = vector_instruction_set self._indent = " " - self._signatureOnly = signatureOnly + self._signatureOnly = signature_only def __call__(self, node): - prevIs = VectorType.instructionSet + prev_is = VectorType.instructionSet VectorType.instructionSet = self._vectorInstructionSet result = str(self._print(node)) - VectorType.instructionSet = prevIs + VectorType.instructionSet = prev_is return result def _print(self, node): for cls in type(node).__mro__: - methodName = "_print_" + cls.__name__ - if hasattr(self, methodName): - return getattr(self, methodName)(node) - raise NotImplementedError("CBackend does not support node of type " + cls.__name__) + method_name = "_print_" + cls.__name__ + if hasattr(self, method_name): + return getattr(self, method_name)(node) + raise NotImplementedError("CBackend does not support node of type " + str(type(node))) def _print_KernelFunction(self, node): - functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters] - funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments)) + function_arguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters] + func_declaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(function_arguments)) if self._signatureOnly: - return funcDeclaration + return func_declaration body = self._print(node.body) - return funcDeclaration + "\n" + body + return func_declaration + "\n" + body def _print_Block(self, node): - blockContents = "\n".join([self._print(child) for child in node.args]) - return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True))) + block_contents = "\n".join([self._print(child) for child in node.args]) + return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True))) def _print_PragmaBlock(self, node): return "%s\n%s" % (node.pragmaLine, self._print_Block(node)) def _print_LoopOverCoordinate(self, node): - counterVar = node.loopCounterName - start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start)) - condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop)) - update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),) + counter_symbol = node.loop_counter_name + start = "int %s = %s" % (counter_symbol, self.sympyPrinter.doprint(node.start)) + condition = "%s < %s" % (counter_symbol, self.sympyPrinter.doprint(node.stop)) + update = "%s += %s" % (counter_symbol, self.sympyPrinter.doprint(node.step),) loopStr = "for (%s; %s; %s)" % (start, condition, update) prefix = "\n".join(node.prefixLines) @@ -140,12 +158,12 @@ class CBackend(object): return "%s%s\n%s" % (prefix, loopStr, self._print(node.body)) def _print_SympyAssignment(self, node): - if node.isDeclaration: - dtype = "const " + str(node.lhs.dtype) + " " if node.isConst else str(node.lhs.dtype) + " " - return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) + if node.is_declaration: + data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " " + return "%s %s = %s;" % (data_type, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) else: - lhsType = getTypeOfExpression(node.lhs) - if type(lhsType) is VectorType and node.lhs.func == castFunc: + lhs_type = get_type_of_expression(node.lhs) + if type(lhs_type) is VectorType and node.lhs.func == castFunc: return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]), self.sympyPrinter.doprint(node.rhs)) + ';' else: @@ -153,31 +171,33 @@ class CBackend(object): def _print_TemporaryMemoryAllocation(self, node): return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name), - node.symbol.dtype.baseType, self.sympyPrinter.doprint(node.size)) + node.symbol.dtype.base_type, self.sympyPrinter.doprint(node.size)) def _print_TemporaryMemoryFree(self, node): return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),) - def _print_CustomCppCode(self, node): + @staticmethod + def _print_CustomCppCode(node): return node.code def _print_Conditional(self, node): - conditionExpr = self.sympyPrinter.doprint(node.conditionExpr) - trueBlock = self._print_Block(node.trueBlock) - result = "if (%s)\n%s " % (conditionExpr, trueBlock) + condition_expr = self.sympyPrinter.doprint(node.conditionExpr) + true_block = self._print_Block(node.trueBlock) + result = "if (%s)\n%s " % (condition_expr, true_block) if node.falseBlock: - falseBlock = self._print_Block(node.falseBlock) - result += "else " + falseBlock + false_block = self._print_Block(node.falseBlock) + result += "else " + false_block return result # ------------------------------------------ Helper function & classes ------------------------------------------------- +# noinspection PyPep8Naming class CustomSympyPrinter(CCodePrinter): - def __init__(self, constantsAsFloats=False): - self._constantsAsFloats = constantsAsFloats + def __init__(self, constants_as_floats=False): + self._constantsAsFloats = constants_as_floats super(CustomSympyPrinter, self).__init__() def _print_Pow(self, expr): @@ -210,7 +230,7 @@ class CustomSympyPrinter(CCodePrinter): return res def _print_Function(self, expr): - functionMap = { + function_map = { xor: '^', rightShift: '>>', leftShift: '<<', @@ -218,33 +238,34 @@ class CustomSympyPrinter(CCodePrinter): bitwiseAnd: '&', } if expr.func == castFunc: - arg, type = expr.args - return "*((%s)(& %s))" % (PointerType(type), self._print(arg)) - elif expr.func in functionMap: - return "(%s %s %s)" % (self._print(expr.args[0]), functionMap[expr.func], self._print(expr.args[1])) + arg, data_type = expr.args + return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg)) + elif expr.func in function_map: + return "(%s %s %s)" % (self._print(expr.args[0]), function_map[expr.func], self._print(expr.args[1])) else: return super(CustomSympyPrinter, self)._print_Function(expr) +# noinspection PyPep8Naming class VectorizedCustomSympyPrinter(CustomSympyPrinter): SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) - def __init__(self, instructionSet, constantsAsFloats=False): - super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats) - self.instructionSet = instructionSet + def __init__(self, instruction_set, constants_as_floats=False): + super(VectorizedCustomSympyPrinter, self).__init__(constants_as_floats) + self.instructionSet = instruction_set - def _scalarFallback(self, funcName, expr, *args, **kwargs): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return getattr(super(VectorizedCustomSympyPrinter, self), funcName)(expr, *args, **kwargs) + def _scalarFallback(self, func_name, expr, *args, **kwargs): + expr_type = get_type_of_expression(expr) + if type(expr_type) is not VectorType: + return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs) else: - assert self.instructionSet['width'] == exprType.width + assert self.instructionSet['width'] == expr_type.width return None def _print_Function(self, expr): if expr.func == castFunc: - arg, dtype = expr.args - if type(dtype) is VectorType: + arg, data_type = expr.args + if type(data_type) is VectorType: if type(arg) is ResolvedFieldAccess: return self.instructionSet['loadU'].format("& " + self._print(arg)) else: @@ -257,10 +278,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if result: return result - argStrings = [self._print(a) for a in expr.args] - assert len(argStrings) > 0 - result = argStrings[0] - for item in argStrings[1:]: + arg_strings = [self._print(a) for a in expr.args] + assert len(arg_strings) > 0 + result = arg_strings[0] + for item in arg_strings[1:]: result = self.instructionSet['&'].format(result, item) return result @@ -269,10 +290,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if result: return result - argStrings = [self._print(a) for a in expr.args] - assert len(argStrings) > 0 - result = argStrings[0] - for item in argStrings[1:]: + arg_strings = [self._print(a) for a in expr.args] + assert len(arg_strings) > 0 + result = arg_strings[0] + for item in arg_strings[1:]: result = self.instructionSet['|'].format(result, item) return result @@ -284,7 +305,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): summands = [] for term in expr.args: if term.func == sp.Mul: - sign, t = self._print_Mul(term, insideAdd=True) + sign, t = self._print_Mul(term, inside_add=True) else: t = self._print(term) sign = 1 @@ -318,7 +339,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): else: raise ValueError("Generic exponential not supported") - def _print_Mul(self, expr, insideAdd=False): + def _print_Mul(self, expr, inside_add=False): + # noinspection PyProtectedMember + from sympy.core.mul import _keep_coeff + result = self._scalarFallback('_print_Mul', expr) if result: return result @@ -359,7 +383,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): denominator_str = self.instructionSet['*'].format(denominator_str, item) result = self.instructionSet['/'].format(result, denominator_str) - if insideAdd: + if inside_add: return sign, result else: if sign < 0: @@ -384,7 +408,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if result: return result - if expr.args[-1].cond != True: + if expr.args[-1].cond.args[0] is not sp.sympify(True): # We need the last conditional to be a True, otherwise the resulting # function may not return a result. raise ValueError("All Piecewise expressions must contain an " @@ -395,5 +419,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._print(expr.args[-1][0]) for trueExpr, condition in reversed(expr.args[:-1]): + # noinspection SpellCheckingInspection result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition)) return result diff --git a/backends/dot.py b/backends/dot.py index e78ac1bb0c550be5af7537edbf098f4aa086bf0f..fa1c7067d18656584208541d5de343725967e147 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -3,13 +3,14 @@ from graphviz import Digraph, lang import graphviz +# noinspection PyPep8Naming class DotPrinter(Printer): """ A printer which converts ast to DOT (graph description language). """ - def __init__(self, nodeToStrFunction, full, **kwargs): + def __init__(self, node_to_str_function, full, **kwargs): super(DotPrinter, self).__init__() - self._nodeToStrFunction = nodeToStrFunction + self._nodeToStrFunction = node_to_str_function self.full = full self.dot = Digraph(**kwargs) self.dot.quote_edge = lang.quote @@ -33,7 +34,8 @@ class DotPrinter(Printer): self.dot.edge(str(id(block)), str(id(node))) def _print_SympyAssignment(self, assignment): - self.dot.node(str(id(assignment)), style='filled', fillcolor='#56db7f', label=self._nodeToStrFunction(assignment)) + self.dot.node(str(id(assignment)), style='filled', fillcolor='#56db7f', + label=self._nodeToStrFunction(assignment)) if self.full: for node in assignment.args: self._print(node) @@ -48,7 +50,7 @@ class DotPrinter(Printer): self._print(expr.falseBlock) self.dot.edge(str(id(expr)), str(id(expr.falseBlock))) - def emptyPrinter(self, expr): + def empty_printer(self, expr): if self.full: self.dot.node(str(id(expr)), label=self._nodeToStrFunction(expr)) for node in expr.args: @@ -56,7 +58,7 @@ class DotPrinter(Printer): for node in expr.args: self.dot.edge(str(id(expr)), str(id(node))) else: - raise NotImplementedError('Dotprinter cannot print', type(expr), expr) + raise NotImplementedError('DotPrinter cannot print', type(expr), expr) def doprint(self, expr): self._print(expr) @@ -68,7 +70,7 @@ def __shortened(node): if isinstance(node, LoopOverCoordinate): return "Loop over dim %d" % (node.coordinateToLoopOver,) elif isinstance(node, KernelFunction): - params = [f.name for f in node.fieldsAccessed] + params = [f.name for f in node.fields_accessed] params += [p.name for p in node.parameters if not p.isFieldArgument] return "Func: %s (%s)" % (node.functionName, ",".join(params)) elif isinstance(node, SympyAssignment): @@ -81,11 +83,11 @@ def __shortened(node): raise NotImplementedError("Cannot handle node type %s" % (type(node),)) -def dotprint(node, view=False, short=False, full=False, **kwargs): +def print_dot(node, view=False, short=False, full=False, **kwargs): """ Returns a string which can be used to generate a DOT-graph :param node: The ast which should be generated - :param view: Boolen, if rendering of the image directly should occur. + :param view: Boolean, if rendering of the image directly should occur. :param short: Uses the __shortened output :param full: Prints the whole tree with type information :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph @@ -95,7 +97,8 @@ def dotprint(node, view=False, short=False, full=False, **kwargs): if short: node_to_str_function = __shortened elif full: - node_to_str_function = lambda expr: repr(type(expr)) + repr(expr) + def node_to_str_function(expr): + return repr(type(expr)) + repr(expr) printer = DotPrinter(node_to_str_function, full, **kwargs) dot = printer.doprint(node) if view: diff --git a/backends/logging.json b/backends/logging.json deleted file mode 100644 index 42617b3878e3fefc3b298e35e857d560e37447b5..0000000000000000000000000000000000000000 --- a/backends/logging.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "version" : 1, - "disable_existing_loggers" : false, - "formatters" : { - "simple" :{ - "format" : "[%(levelname)s]: %(message)s" - } - }, - "handlers" : { - "console": { - "class": "logging.StreamHandler", - "level": "INFO", - "formatter": "simple", - "stream": "ext://sys.stdout" - }, - "log_file": { - "class": "logging.FileHandler", - "level": "DEBUG", - "formatter": "simple", - "filename": "gen.log", - "mode" : "w", - "encoding": "utf8" - } - }, - "loggers" : { - "generator" : { - "level" : "DEBUG", - "handlers" : ["console", "log_file"] - } - } -} \ No newline at end of file diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py index 6760837c8eb91ce1e9bb8f3cc43ac03ada7d8a6e..0949214545ccd0e76756e4882db147624ff8a371 100644 --- a/backends/simd_instruction_sets.py +++ b/backends/simd_instruction_sets.py @@ -1,7 +1,8 @@ -def x86VectorInstructionSet(dataType='double', instructionSet='avx'): - baseNames = { +# noinspection SpellCheckingInspection +def x86_vector_instruction_set(data_type='double', instruction_set='avx'): + base_names = { '+': 'add[0, 1]', '-': 'sub[0, 1]', '*': 'mul[0, 1]', @@ -53,41 +54,41 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'): } result = {} - pre = prefix[instructionSet] - suf = suffix[dataType] - for intrinsicId, functionShortcut in baseNames.items(): - functionShortcut = functionShortcut.strip() - name = functionShortcut[:functionShortcut.index('[')] - args = functionShortcut[functionShortcut.index('[') + 1: -1] - argString = "(" + pre = prefix[instruction_set] + suf = suffix[data_type] + for intrinsicId, function_shortcut in base_names.items(): + function_shortcut = function_shortcut.strip() + name = function_shortcut[:function_shortcut.index('[')] + args = function_shortcut[function_shortcut.index('[') + 1: -1] + arg_string = "(" for arg in args.split(","): arg = arg.strip() if not arg: continue if arg in ('0', '1', '2', '3', '4', '5'): - argString += "{" + arg + "}," + arg_string += "{" + arg + "}," else: - argString += arg + "," - argString = argString[:-1] + ")" - result[intrinsicId] = pre + "_" + name + "_" + suf + argString + arg_string += arg + "," + arg_string = arg_string[:-1] + ")" + result[intrinsicId] = pre + "_" + name + "_" + suf + arg_string - result['width'] = width[(dataType, instructionSet)] + result['width'] = width[(data_type, instruction_set)] result['dataTypePrefix'] = { 'double': "_" + pre + 'd', 'float': "_" + pre, } - bitWidth = result['width'] * 64 - result['double'] = "__m%dd" % (bitWidth,) - result['float'] = "__m%d" % (bitWidth,) - result['int'] = "__m%di" % (bitWidth,) - result['bool'] = "__m%dd" % (bitWidth,) + bit_width = result['width'] * 64 + result['double'] = "__m%dd" % (bit_width,) + result['float'] = "__m%d" % (bit_width,) + result['int'] = "__m%di" % (bit_width,) + result['bool'] = "__m%dd" % (bit_width,) - result['headers'] = headers[instructionSet] + result['headers'] = headers[instruction_set] return result selectedInstructionSet = { - 'float': x86VectorInstructionSet('float', 'avx'), - 'double': x86VectorInstructionSet('double', 'avx'), + 'float': x86_vector_instruction_set('float', 'avx'), + 'double': x86_vector_instruction_set('double', 'avx'), } diff --git a/boundaries/boundaryhandling.py b/boundaries/boundaryhandling.py index 76999a549ca25639401b03c667ce7ca951430e1c..aa39d35a4e4e57af0f778aa49068216c369a8085 100644 --- a/boundaries/boundaryhandling.py +++ b/boundaries/boundaryhandling.py @@ -5,7 +5,7 @@ from pystencils import Field, TypedSymbol, createIndexedKernel from pystencils.backends.cbackend import CustomCppCode from pystencils.boundaries.createindexlist import numpyDataTypeForBoundaryObject, createBoundaryIndexArray from pystencils.cache import memorycache -from pystencils.data_types import createType +from pystencils.data_types import create_type class FlagInterface: @@ -350,19 +350,19 @@ class BoundaryOffsetInfo(CustomCppCode): code += "const int %s [] = { %s };\n" % (self.INV_DIR_SYMBOL.name, ", ".join(invDirs)) offsetSymbols = BoundaryOffsetInfo._offsetSymbols(dim) - super(BoundaryOffsetInfo, self).__init__(code, symbolsRead=set(), - symbolsDefined=set(offsetSymbols + [self.INV_DIR_SYMBOL])) + super(BoundaryOffsetInfo, self).__init__(code, symbols_read=set(), + symbols_defined=set(offsetSymbols + [self.INV_DIR_SYMBOL])) @staticmethod def _offsetSymbols(dim): - return [TypedSymbol("c_%d" % (d,), createType(np.int64)) for d in range(dim)] + return [TypedSymbol("c_%d" % (d,), create_type(np.int64)) for d in range(dim)] INV_DIR_SYMBOL = TypedSymbol("invDir", "int") def createBoundaryKernel(field, indexField, stencil, boundaryFunctor, target='cpu', openMP=True): elements = [BoundaryOffsetInfo(stencil)] - indexArrDtype = indexField.dtype.numpyDtype + indexArrDtype = indexField.dtype.numpy_dtype dirSymbol = TypedSymbol("dir", indexArrDtype.fields['dir'][0]) elements += [Assignment(dirSymbol, indexField[0]('dir'))] elements += boundaryFunctor(field, directionSymbol=dirSymbol, indexField=indexField) diff --git a/boundaries/createindexlist.py b/boundaries/createindexlist.py index ee46b0b426b96e44e372abd066f8379240483ae6..69ddb4c1b14f91628fe23d3fa8a4f8ef68d67ac2 100644 --- a/boundaries/createindexlist.py +++ b/boundaries/createindexlist.py @@ -22,7 +22,7 @@ def numpyDataTypeForBoundaryObject(boundaryObject, dim): coordinateNames = boundaryIndexArrayCoordinateNames[:dim] return np.dtype([(name, np.int32) for name in coordinateNames] + [(directionMemberName, np.int32)] + - [(i[0], i[1].numpyDtype) for i in boundaryObject.additionalData], align=True) + [(i[0], i[1].numpy_dtype) for i in boundaryObject.additionalData], align=True) def _createBoundaryIndexListPython(flagFieldArr, nrOfGhostLayers, boundaryMask, fluidMask, stencil): diff --git a/boundaries/inkernel.py b/boundaries/inkernel.py index 777e7090a50673c616f617a24349b33053aba62b..22f8d3e3771bb50a2dbee61c0f6e9d493acf1262 100644 --- a/boundaries/inkernel.py +++ b/boundaries/inkernel.py @@ -2,7 +2,7 @@ import sympy as sp from pystencils import Field, TypedSymbol from pystencils.bitoperations import bitwiseAnd from pystencils.boundaries.boundaryhandling import FlagInterface -from pystencils.data_types import createType +from pystencils.data_types import create_type def addNeumannBoundary(eqs, fields, flagField, boundaryFlag="neumannFlag", inverseFlag=False): @@ -21,7 +21,7 @@ def addNeumannBoundary(eqs, fields, flagField, boundaryFlag="neumannFlag", inver fields = set(fields) if type(boundaryFlag) is str: - boundaryFlag = TypedSymbol(boundaryFlag, dtype=createType(FlagInterface.FLAG_DTYPE)) + boundaryFlag = TypedSymbol(boundaryFlag, dtype=create_type(FlagInterface.FLAG_DTYPE)) substitutions = {} for eq in eqs: diff --git a/cpu/__init__.py b/cpu/__init__.py index cec36fd9298f5645d1669edac101c24207cfd655..00f0f9e7235061d21282396f96c32f7db26d718d 100644 --- a/cpu/__init__.py +++ b/cpu/__init__.py @@ -1,3 +1,3 @@ from pystencils.cpu.kernelcreation import createKernel, createIndexedKernel, addOpenMP from pystencils.cpu.cpujit import makePythonFunction -from pystencils.backends.cbackend import generateC +from pystencils.backends.cbackend import print_c diff --git a/cpu/cpujit.py b/cpu/cpujit.py index fa04eda775b2d0aeba8050ec717c710057f6b672..894621c0a4a9842503413311ec2ee5cf7b5782e7 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -71,10 +71,10 @@ import shutil import numpy as np from appdirs import user_config_dir, user_cache_dir from ctypes import cdll -from pystencils.backends.cbackend import generateC, getHeaders +from pystencils.backends.cbackend import print_c, get_headers from collections import OrderedDict, Mapping from pystencils.transformations import symbolNameToVariableName -from pystencils.data_types import toCtypes, getBaseType, StructType +from pystencils.data_types import to_ctypes, get_base_type, StructType from pystencils.field import FieldType @@ -272,11 +272,11 @@ atexit.register(compileObjectCacheToSharedLibrary) def generateCode(ast, restrictQualifier, functionPrefix, targetFile): - headers = getHeaders(ast) + headers = get_headers(ast) headers.update(['<cmath>', '<cstdint>']) with open(targetFile, 'w') as sourceFile: - code = generateC(ast) + code = print_c(ast) includes = "\n".join(["#include %s" % (includeFile,) for includeFile in headers]) print(includes, file=sourceFile) print("#define RESTRICT %s" % (restrictQualifier,), file=sourceFile) @@ -339,7 +339,7 @@ def compileWindows(ast, codeHashStr, srcFile, libFile): def compileAndLoad(ast): cacheConfig = getCacheConfig() - codeHashStr = hashlib.sha256(generateC(ast).encode()).hexdigest() + codeHashStr = hashlib.sha256(print_c(ast).encode()).hexdigest() ast.functionName = hashToFunctionName(codeHashStr) srcFile = os.path.join(cacheConfig['objectCache'], codeHashStr + ".cpp") @@ -373,7 +373,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict): symbolicField = arg.field if arg.isFieldPtrArgument: - ctArguments.append(fieldArr.ctypes.data_as(toCtypes(arg.dtype))) + ctArguments.append(fieldArr.ctypes.data_as(to_ctypes(arg.dtype))) if symbolicField.hasFixedShape: symbolicFieldShape = tuple(int(i) for i in symbolicField.shape) if isinstance(symbolicField.dtype, StructType): @@ -395,10 +395,10 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict): arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) elif arg.isFieldShapeArgument: - dataType = toCtypes(getBaseType(arg.dtype)) + dataType = to_ctypes(get_base_type(arg.dtype)) ctArguments.append(fieldArr.ctypes.shape_as(dataType)) elif arg.isFieldStrideArgument: - dataType = toCtypes(getBaseType(arg.dtype)) + dataType = to_ctypes(get_base_type(arg.dtype)) strides = fieldArr.ctypes.strides_as(dataType) for i in range(len(fieldArr.shape)): assert strides[i] % fieldArr.itemsize == 0 @@ -411,7 +411,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict): param = argumentDict[arg.name] except KeyError: raise KeyError("Missing parameter for kernel call " + arg.name) - expectedType = toCtypes(arg.dtype) + expectedType = to_ctypes(arg.dtype) ctArguments.append(expectedType(param)) if len(arrayShapes) > 1: diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index c4c2833b49aa1c6282621164630e0074220ae6bf..ce37a08876e4a6a359dff01d4a01c3a3858995c6 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -4,7 +4,7 @@ from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, Kern from pystencils.transformations import resolveBufferAccesses, resolveFieldAccesses, makeLoopOverDomain, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \ substituteArrayAccessesWithConstants -from pystencils.data_types import TypedSymbol, BasicType, StructType, createType +from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type from pystencils.field import Field, FieldType import pystencils.astnodes as ast from pystencils.cpu.cpujit import makePythonFunction @@ -38,45 +38,45 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double', return term elif isinstance(term, sp.Symbol): if not hasattr(typeForSymbol, '__getitem__'): - return TypedSymbol(term.name, createType(typeForSymbol)) + return TypedSymbol(term.name, create_type(typeForSymbol)) else: return TypedSymbol(term.name, typeForSymbol[term.name]) else: raise ValueError("Term has to be field access or symbol") - fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) - allFields = fieldsRead.union(fieldsWritten) - readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) + fields_read, fields_written, assignments = typeAllEquations(listOfEquations, typeForSymbol) + all_fields = fields_read.union(fields_written) + read_only_fields = set([f.name for f in fields_read - fields_written]) - buffers = set([f for f in allFields if FieldType.isBuffer(f)]) - fieldsWithoutBuffers = allFields - buffers + buffers = set([f for f in all_fields if FieldType.isBuffer(f)]) + fields_without_buffers = all_fields - buffers body = ast.Block(assignments) - loopOrder = getOptimalLoopOrdering(fieldsWithoutBuffers) - code, loopStrides, loopVars = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, - ghostLayers=ghostLayers, loopOrder=loopOrder) + loop_order = getOptimalLoopOrdering(fields_without_buffers) + code, loop_strides, loop_vars = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, + ghostLayers=ghostLayers, loopOrder=loop_order) code.target = 'cpu' if splitGroups: - typedSplitGroups = [[type_symbol(s) for s in splitGroup] for splitGroup in splitGroups] - splitInnerLoop(code, typedSplitGroups) + typed_split_groups = [[type_symbol(s) for s in splitGroup] for splitGroup in splitGroups] + splitInnerLoop(code, typed_split_groups) - basePointerInfo = [['spatialInner0'], ['spatialInner1']] if len(loopOrder) >= 2 else [['spatialInner0']] - basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) - for field in fieldsWithoutBuffers} + base_pointer_info = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']] + base_pointer_infos = {field.name: parseBasePointerInfo(base_pointer_info, loop_order, field) + for field in fields_without_buffers} - bufferBasePointerInfos = {field.name: parseBasePointerInfo([['spatialInner0']], [0], field) for field in buffers} - basePointerInfos.update(bufferBasePointerInfos) + buffer_base_pointer_infos = {field.name: parseBasePointerInfo([['spatialInner0']], [0], field) for field in buffers} + base_pointer_infos.update(buffer_base_pointer_infos) - baseBufferIndex = loopVars[0] + base_buffer_index = loop_vars[0] stride = 1 - for idx, var in enumerate(loopVars[1:]): - curStride = loopStrides[idx] - stride *= int(curStride) if isinstance(curStride, float) else curStride - baseBufferIndex += var * stride + for idx, var in enumerate(loop_vars[1:]): + cur_stride = loop_strides[idx] + stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride + base_buffer_index += var * stride - resolveBufferAccesses(code, baseBufferIndex, readOnlyFields) - resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) + resolveBufferAccesses(code, base_buffer_index, read_only_fields) + resolveFieldAccesses(code, read_only_fields, field_to_base_pointer_info=base_pointer_infos) substituteArrayAccessesWithConstants(code) moveConstantsBeforeLoop(code) code.compile = partial(makePythonFunction, code) @@ -118,9 +118,9 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ for indexField in indexFields: assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype" dataType = indexField.dtype - if dataType.hasElement(name): + if dataType.has_element(name): rhs = indexField[0](name) - lhs = TypedSymbol(name, BasicType(dataType.getElementType(name))) + lhs = TypedSymbol(name, BasicType(dataType.get_element_type(name))) return SympyAssignment(lhs, rhs) raise ValueError("Index %s not found in any of the passed index fields" % (name,)) @@ -130,16 +130,16 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ # make 1D loop over index fields loopBody = Block([]) - loopNode = LoopOverCoordinate(loopBody, coordinateToLoopOver=0, start=0, stop=indexFields[0].shape[0]) + loopNode = LoopOverCoordinate(loopBody, coordinate_to_loop_over=0, start=0, stop=indexFields[0].shape[0]) for assignment in assignments: loopBody.append(assignment) functionBody = Block([loopNode]) - ast = KernelFunction(functionBody, backend="cpu", functionName=functionName) + ast = KernelFunction(functionBody, backend="cpu", function_name=functionName) fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields} - resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping) + resolveFieldAccesses(ast, set(['indexField']), field_to_fixed_coordinates=fixedCoordinateMapping) substituteArrayAccessesWithConstants(ast) moveConstantsBeforeLoop(ast) ast.compile = partial(makePythonFunction, ast) @@ -160,10 +160,10 @@ def addOpenMP(astNode, schedule="static", numThreads=True): assert type(astNode) is ast.KernelFunction body = astNode.body threadsClause = "" if numThreads and isinstance(numThreads,bool) else " num_threads(%s)" % (numThreads,) - wrapperBlock = ast.PragmaBlock('#pragma omp parallel' + threadsClause, body.takeChildNodes()) + wrapperBlock = ast.PragmaBlock('#pragma omp parallel' + threadsClause, body.take_child_nodes()) body.append(wrapperBlock) - outerLoops = [l for l in body.atoms(ast.LoopOverCoordinate) if l.isOutermostLoop] + outerLoops = [l for l in body.atoms(ast.LoopOverCoordinate) if l.is_outermost_loop] assert outerLoops, "No outer loop found" assert len(outerLoops) <= 1, "More than one outer loop found. Which one should be parallelized?" loopToParallelize = outerLoops[0] diff --git a/data_types.py b/data_types.py index 2e0b5e0e7e67e964a571df6329b543a381b0cc23..c92875374da01b9de08ee31e9aa7b7e0d399d6ce 100644 --- a/data_types.py +++ b/data_types.py @@ -44,7 +44,7 @@ class TypedSymbol(sp.Symbol): def __new_stage2__(cls, name, dtype): obj = super(TypedSymbol, cls).__xnew__(cls, name) try: - obj._dtype = createType(dtype) + obj._dtype = create_type(dtype) except TypeError: # on error keep the string obj._dtype = dtype @@ -58,14 +58,14 @@ class TypedSymbol(sp.Symbol): return self._dtype def _hashable_content(self): - superClassContents = list(super(TypedSymbol, self)._hashable_content()) - return tuple(superClassContents + [hash(self._dtype)]) + super_class_contents = list(super(TypedSymbol, self)._hashable_content()) + return tuple(super_class_contents + [hash(self._dtype)]) def __getnewargs__(self): return self.name, self.dtype -def createType(specification): +def create_type(specification): """ Create a subclass of Type according to a string or an object of subclass Type :param specification: Type object, or a string @@ -74,15 +74,15 @@ def createType(specification): if isinstance(specification, Type): return specification else: - npDataType = np.dtype(specification) - if npDataType.fields is None: - return BasicType(npDataType, const=False) + numpy_dtype = np.dtype(specification) + if numpy_dtype.fields is None: + return BasicType(numpy_dtype, const=False) else: - return StructType(npDataType, const=False) + return StructType(numpy_dtype, const=False) @memorycache(maxsize=64) -def createCompositeTypeFromString(specification): +def create_composite_type_from_string(specification): """ Creates a new Type object from a c-like string specification :param specification: Specification string @@ -100,16 +100,16 @@ def createCompositeTypeFromString(specification): if len(current) > 0: parts.append(current) # Parse native part - basePart = parts.pop(0) + base_part = parts.pop(0) const = False - if 'const' in basePart: + if 'const' in base_part: const = True - basePart.remove('const') - assert len(basePart) == 1 - if basePart[0][-1] == "*": - basePart[0] = basePart[0][:-1] + base_part.remove('const') + assert len(base_part) == 1 + if base_part[0][-1] == "*": + base_part[0] = base_part[0][:-1] parts.append('*') - currentType = BasicType(np.dtype(basePart[0]), const) + current_type = BasicType(np.dtype(base_part[0]), const) # Parse pointer parts for part in parts: restrict = False @@ -121,30 +121,31 @@ def createCompositeTypeFromString(specification): const = True part.remove("const") assert len(part) == 1 and part[0] == '*' - currentType = PointerType(currentType, const, restrict) - return currentType + current_type = PointerType(current_type, const, restrict) + return current_type -def getBaseType(type): - while type.baseType is not None: - type = type.baseType - return type +def get_base_type(data_type): + while data_type.base_type is not None: + data_type = data_type.base_type + return data_type -def toCtypes(dataType): +def to_ctypes(data_type): """ Transforms a given Type into ctypes - :param dataType: Subclass of Type + :param data_type: Subclass of Type :return: ctypes type object """ - if isinstance(dataType, PointerType): - return ctypes.POINTER(toCtypes(dataType.baseType)) - elif isinstance(dataType, StructType): + if isinstance(data_type, PointerType): + return ctypes.POINTER(to_ctypes(data_type.base_type)) + elif isinstance(data_type, StructType): return ctypes.POINTER(ctypes.c_uint8) else: - return toCtypes.map[dataType.numpyDtype] + return to_ctypes.map[data_type.numpy_dtype] -toCtypes.map = { + +to_ctypes.map = { np.dtype(np.int8): ctypes.c_int8, np.dtype(np.int16): ctypes.c_int16, np.dtype(np.int32): ctypes.c_int32, @@ -199,9 +200,10 @@ def to_llvm_type(data_type): if not ir: raise _ir_importerror if isinstance(data_type, PointerType): - return to_llvm_type(data_type.baseType).as_pointer() + return to_llvm_type(data_type.base_type).as_pointer() else: - return to_llvm_type.map[data_type.numpyDtype] + return to_llvm_type.map[data_type.numpy_dtype] + if ir: to_llvm_type.map = { @@ -220,13 +222,13 @@ if ir: } -def peelOffType(dtype, typeToPeelOff): - while type(dtype) is typeToPeelOff: - dtype = dtype.baseType +def peel_off_type(dtype, type_to_peel_off): + while type(dtype) is type_to_peel_off: + dtype = dtype.base_type return dtype -def collateTypes(types): +def collate_types(types): """ Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Uses the collation rules from numpy. @@ -234,44 +236,44 @@ def collateTypes(types): # Pointer arithmetic case i.e. pointer + integer is allowed if any(type(t) is PointerType for t in types): - pointerType = None + pointer_type = None for t in types: if type(t) is PointerType: - if pointerType is not None: + if pointer_type is not None: raise ValueError("Cannot collate the combination of two pointer types") - pointerType = t + pointer_type = t elif type(t) is BasicType: if not (t.is_int() or t.is_uint()): raise ValueError("Invalid pointer arithmetic") else: raise ValueError("Invalid pointer arithmetic") - return pointerType + return pointer_type # peel of vector types, if at least one vector type occurred the result will also be the vector type - vectorType = [t for t in types if type(t) is VectorType] - if not allEqual(t.width for t in vectorType): + vector_type = [t for t in types if type(t) is VectorType] + if not allEqual(t.width for t in vector_type): raise ValueError("Collation failed because of vector types with different width") - types = [peelOffType(t, VectorType) for t in types] + types = [peel_off_type(t, VectorType) for t in types] # now we should have a list of basic types - struct types are not yet supported assert all(type(t) is BasicType for t in types) # use numpy collation -> create type from numpy type -> and, put vector type around if necessary - resultNumpyType = np.result_type(*(t.numpyDtype for t in types)) - result = BasicType(resultNumpyType) - if vectorType: - result = VectorType(result, vectorType[0].width) + result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) + result = BasicType(result_numpy_type) + if vector_type: + result = VectorType(result, vector_type[0].width) return result @memorycache(maxsize=2048) -def getTypeOfExpression(expr): +def get_type_of_expression(expr): from pystencils.astnodes import ResolvedFieldAccess expr = sp.sympify(expr) if isinstance(expr, sp.Integer): - return createType("int") + return create_type("int") elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): - return createType("double") + return create_type("double") elif isinstance(expr, ResolvedFieldAccess): return expr.field.dtype elif isinstance(expr, TypedSymbol): @@ -281,24 +283,24 @@ def getTypeOfExpression(expr): elif hasattr(expr, 'func') and expr.func == castFunc: return expr.args[1] elif hasattr(expr, 'func') and expr.func == sp.Piecewise: - collatedResultType = collateTypes(tuple(getTypeOfExpression(a[0]) for a in expr.args)) - collatedConditionType = collateTypes(tuple(getTypeOfExpression(a[1]) for a in expr.args)) - if type(collatedConditionType) is VectorType and type(collatedResultType) is not VectorType: - collatedResultType = VectorType(collatedResultType, width=collatedConditionType.width) - return collatedResultType + collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args)) + collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args)) + if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType: + collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width) + return collated_result_type elif isinstance(expr, sp.Indexed): - typedSymbol = expr.base.label - return typedSymbol.dtype.baseType + typed_symbol = expr.base.label + return typed_symbol.dtype.base_type elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction): # if any arg is of vector type return a vector boolean, else return a normal scalar boolean - result = createType("bool") - vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)] - if vecArgs: - result = VectorType(result, width=vecArgs[0].width) + result = create_type("bool") + vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)] + if vec_args: + result = VectorType(result, width=vec_args[0].width) return result elif isinstance(expr, sp.Expr): - types = tuple(getTypeOfExpression(a) for a in expr.args) - return collateTypes(types) + types = tuple(get_type_of_expression(a) for a in expr.args) + return collate_types(types) raise NotImplementedError("Could not determine type for", expr, type(expr)) @@ -311,7 +313,7 @@ class Type(sp.Basic): # Needed for sorting the types inside an expression if isinstance(self, BasicType): if isinstance(other, BasicType): - return self.numpyDtype > other.numpyDtype # TODO const + return self.numpy_dtype > other.numpy_dtype # TODO const elif isinstance(other, PointerType): return False else: # isinstance(other, StructType): @@ -320,7 +322,7 @@ class Type(sp.Basic): if isinstance(other, BasicType): return True elif isinstance(other, PointerType): - return self.baseType > other.baseType # TODO const, restrict + return self.base_type > other.base_type # TODO const, restrict else: # isinstance(other, StructType): raise NotImplementedError("Struct type comparison is not yet implemented") elif isinstance(self, StructType): @@ -331,13 +333,10 @@ class Type(sp.Basic): def _sympystr(self, *args, **kwargs): return str(self) - def _sympystr(self, *args, **kwargs): - return str(self) - class BasicType(Type): @staticmethod - def numpyNameToC(name): + def numpy_name_to_c(name): if name == 'float64': return 'double' elif name == 'float32': @@ -356,7 +355,7 @@ class BasicType(Type): def __init__(self, dtype, const=False): self.const = const if isinstance(dtype, Type): - self._dtype = dtype.numpyDtype + self._dtype = dtype.numpy_dtype else: self._dtype = np.dtype(dtype) assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type" @@ -364,41 +363,41 @@ class BasicType(Type): assert self._dtype.subdtype is None def __getnewargs__(self): - return self.numpyDtype, self.const + return self.numpy_dtype, self.const @property - def baseType(self): + def base_type(self): return None @property - def numpyDtype(self): + def numpy_dtype(self): return self._dtype @property - def itemSize(self): + def item_size(self): return 1 def is_int(self): - return self.numpyDtype in np.sctypes['int'] + return self.numpy_dtype in np.sctypes['int'] def is_float(self): - return self.numpyDtype in np.sctypes['float'] + return self.numpy_dtype in np.sctypes['float'] def is_uint(self): - return self.numpyDtype in np.sctypes['uint'] + return self.numpy_dtype in np.sctypes['uint'] - def is_comlex(self): - return self.numpyDtype in np.sctypes['complex'] + def is_complex(self): + return self.numpy_dtype in np.sctypes['complex'] def is_other(self): - return self.numpyDtype in np.sctypes['others'] + return self.numpy_dtype in np.sctypes['others'] @property - def baseName(self): - return BasicType.numpyNameToC(str(self._dtype)) + def base_name(self): + return BasicType.numpy_name_to_c(str(self._dtype)) def __str__(self): - result = BasicType.numpyNameToC(str(self._dtype)) + result = BasicType.numpy_name_to_c(str(self._dtype)) if self.const: result += " const" return result @@ -410,7 +409,7 @@ class BasicType(Type): if not isinstance(other, BasicType): return False else: - return (self.numpyDtype, self.const) == (other.numpyDtype, other.const) + return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) def __hash__(self): return hash(str(self)) @@ -419,115 +418,115 @@ class BasicType(Type): class VectorType(Type): instructionSet = None - def __init__(self, baseType, width=4): - self._baseType = baseType + def __init__(self, base_type, width=4): + self._base_type = base_type self.width = width @property - def baseType(self): - return self._baseType + def base_type(self): + return self._base_type @property - def itemSize(self): - return self.width * self.baseType.itemSize + def item_size(self): + return self.width * self.base_type.item_size def __eq__(self, other): if not isinstance(other, VectorType): return False else: - return (self.baseType, self.width) == (other.baseType, other.width) + return (self.base_type, self.width) == (other.base_type, other.width) def __str__(self): if self.instructionSet is None: - return "%s[%d]" % (self.baseType, self.width) + return "%s[%d]" % (self.base_type, self.width) else: - if self.baseType == createType("int64"): + if self.base_type == create_type("int64"): return self.instructionSet['int'] - elif self.baseType == createType("float64"): + elif self.base_type == create_type("float64"): return self.instructionSet['double'] - elif self.baseType == createType("float32"): + elif self.base_type == create_type("float32"): return self.instructionSet['float'] - elif self.baseType == createType("bool"): + elif self.base_type == create_type("bool"): return self.instructionSet['bool'] else: raise NotImplementedError() def __hash__(self): - return hash((self.baseType, self.width)) + return hash((self.base_type, self.width)) class PointerType(Type): - def __init__(self, baseType, const=False, restrict=True): - self._baseType = baseType + def __init__(self, base_type, const=False, restrict=True): + self._base_type = base_type self.const = const self.restrict = restrict def __getnewargs__(self): - return self.baseType, self.const, self.restrict + return self.base_type, self.const, self.restrict @property def alias(self): return not self.restrict @property - def baseType(self): - return self._baseType + def base_type(self): + return self._base_type @property - def itemSize(self): - return self.baseType.itemSize + def item_size(self): + return self.base_type.item_size def __eq__(self, other): if not isinstance(other, PointerType): return False else: - return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict) + return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) def __str__(self): - return "%s *%s%s" % (self.baseType, " RESTRICT " if self.restrict else "", " const " if self.const else "") + return "%s *%s%s" % (self.base_type, " RESTRICT " if self.restrict else "", " const " if self.const else "") def __repr__(self): return str(self) def __hash__(self): - return hash((self._baseType, self.const, self.restrict)) + return hash((self._base_type, self.const, self.restrict)) class StructType(object): - def __init__(self, numpyType, const=False): + def __init__(self, numpy_type, const=False): self.const = const - self._dtype = np.dtype(numpyType) + self._dtype = np.dtype(numpy_type) def __getnewargs__(self): - return self.numpyDtype, self.const + return self.numpy_dtype, self.const @property - def baseType(self): + def base_type(self): return None @property - def numpyDtype(self): + def numpy_dtype(self): return self._dtype @property - def itemSize(self): - return self.numpyDtype.itemsize + def item_size(self): + return self.numpy_dtype.itemsize - def getElementOffset(self, elementName): - return self.numpyDtype.fields[elementName][1] + def get_element_offset(self, element_name): + return self.numpy_dtype.fields[element_name][1] - def getElementType(self, elementName): - npElementType = self.numpyDtype.fields[elementName][0] - return BasicType(npElementType, self.const) + def get_element_type(self, element_name): + np_element_type = self.numpy_dtype.fields[element_name][0] + return BasicType(np_element_type, self.const) - def hasElement(self, elementName): - return elementName in self.numpyDtype.fields + def has_element(self, element_name): + return element_name in self.numpy_dtype.fields def __eq__(self, other): if not isinstance(other, StructType): return False else: - return (self.numpyDtype, self.const) == (other.numpyDtype, other.const) + return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) def __str__(self): # structs are handled byte-wise @@ -540,7 +539,7 @@ class StructType(object): return str(self) def __hash__(self): - return hash((self.numpyDtype, self.const)) + return hash((self.numpy_dtype, self.const)) # TODO this should not work at all!!! def __gt__(self, other): @@ -566,11 +565,11 @@ def get_type_from_sympy(node): raise TypeError(node, 'is not a sp.Number') if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber): - return createType('double'), float(node) + return create_type('double'), float(node) elif isinstance(node, sp.Integer): - return createType('int'), int(node) + return create_type('int'), int(node) elif isinstance(node, sp.Rational): # TODO is it always float? - return createType('double'), float(node.p/node.q) + return create_type('double'), float(node.p / node.q) else: raise TypeError(node, ' is not a supported type (yet)!') diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py index 758b074f8d7f8aaf5df887e6f1c1b972cd9f8b66..4b91b10ffad37bb2333125e3dd8def4bb0cd17d0 100644 --- a/datahandling/serial_datahandling.py +++ b/datahandling/serial_datahandling.py @@ -275,7 +275,7 @@ class SerialDataHandling(DataHandling): resultFunctors.append(getPeriodicBoundaryFunctor(filteredStencil, self._domainSize, indexDimensions=self.fields[name].indexDimensions, indexDimShape=self._fieldInformation[name]['fSize'], - dtype=self.fields[name].dtype.numpyDtype, + dtype=self.fields[name].dtype.numpy_dtype, ghostLayers=gls)) if target == 'cpu': diff --git a/display_utils.py b/display_utils.py index 841529ec9982589f68cf274c02848a23e27a76ee..30341e2345ec99a721600551385941445cdc9be2 100644 --- a/display_utils.py +++ b/display_utils.py @@ -1,44 +1,54 @@ +import sympy as sp +from typing import Any, Dict, Optional +from pystencils.astnodes import KernelFunction -def toDot(expr, graphStyle={}): + +def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None): """Show a sympy or pystencils AST as dot graph""" from pystencils.astnodes import Node import graphviz + graph_style = {} if graph_style is None else graph_style + if isinstance(expr, Node): - from pystencils.backends.dot import dotprint - return graphviz.Source(dotprint(expr, short=True, graph_attr=graphStyle)) + from pystencils.backends.dot import print_dot + return graphviz.Source(print_dot(expr, short=True, graph_attr=graph_style)) else: from sympy.printing.dot import dotprint - return graphviz.Source(dotprint(expr, graph_attr=graphStyle)) + return graphviz.Source(dotprint(expr, graph_attr=graph_style)) -def highlightCpp(code): - """Highlight the given C/C++ source code with Pygments""" +def highlight_cpp(code: str): + """Highlight the given C/C++ source code with pygments.""" from IPython.display import HTML, display from pygments import highlight + # noinspection PyUnresolvedReferences from pygments.formatters import HtmlFormatter + # noinspection PyUnresolvedReferences from pygments.lexers import CppLexer - display(HTML(""" - <style> - {pygments_css} - </style> - """.format(pygments_css=HtmlFormatter().get_style_defs('.highlight')))) + css = HtmlFormatter().get_style_defs('.highlight') + css_tag = f"<style>{css}</style>" + display(HTML(css_tag)) return HTML(highlight(code, CppLexer(), HtmlFormatter())) -def showCode(ast): - from pystencils.cpu import generateC +def show_code(ast: KernelFunction): + """Returns an object to display C code. + + Can either be displayed as HTML in Jupyter notebooks or printed as normal string. + """ + from pystencils.cpu import print_c class CodeDisplay: - def __init__(self, astInput): - self.ast = astInput + def __init__(self, ast_input): + self.ast = ast_input def _repr_html_(self): - return highlightCpp(generateC(self.ast)).__html__() + return highlight_cpp(print_c(self.ast)).__html__() def __str__(self): - return generateC(self.ast) + return print_c(self.ast) def __repr__(self): - return generateC(self.ast) + return print_c(self.ast) return CodeDisplay(ast) diff --git a/field.py b/field.py index 60ac7dd32e7e0b225dbc8b7c905417d7fe13655f..6085a314f9c576bc99b1ab040f22955eca7b240b 100644 --- a/field.py +++ b/field.py @@ -7,7 +7,7 @@ from sympy.tensor import IndexedBase from pystencils.assignment import Assignment from pystencils.alignedarray import aligned_empty -from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType +from pystencils.data_types import TypedSymbol, create_type, create_composite_type_from_string, StructType from pystencils.sympyextensions import is_integer_sequence @@ -186,7 +186,7 @@ class Field(object): self._fieldName = fieldName assert isinstance(fieldType, FieldType) self.fieldType = fieldType - self._dtype = createType(dtype) + self._dtype = create_type(dtype) self._layout = normalizeLayout(layout) self.shape = shape self.strides = strides @@ -300,8 +300,8 @@ class Field(object): PREFIX = "f" STRIDE_PREFIX = PREFIX + "stride_" SHAPE_PREFIX = PREFIX + "shape_" - STRIDE_DTYPE = createCompositeTypeFromString("const int *") - SHAPE_DTYPE = createCompositeTypeFromString("const int *") + STRIDE_DTYPE = create_composite_type_from_string("const int *") + SHAPE_DTYPE = create_composite_type_from_string("const int *") DATA_PREFIX = PREFIX + "d_" class Access(sp.Symbol): diff --git a/gpucuda/cudajit.py b/gpucuda/cudajit.py index 2cfafb0ee2910ac80ff01f53298bdde4d263a0d1..c9eb9a45bf1f81fd0a338a1fc1e98f523710217f 100644 --- a/gpucuda/cudajit.py +++ b/gpucuda/cudajit.py @@ -1,7 +1,7 @@ import numpy as np -from pystencils.backends.cbackend import generateC +from pystencils.backends.cbackend import print_c from pystencils.transformations import symbolNameToVariableName -from pystencils.data_types import StructType, getBaseType +from pystencils.data_types import StructType, get_base_type from pystencils.field import FieldType @@ -22,7 +22,7 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}): code = "#include <cstdint>\n" code += "#define FUNC_PREFIX __global__\n" code += "#define RESTRICT __restrict__\n\n" - code += str(generateC(kernelFunctionNode)) + code += str(print_c(kernelFunctionNode)) mod = SourceModule(code, options=["-w", "-std=c++11"]) func = mod.get_function(kernelFunctionNode.functionName) @@ -68,24 +68,24 @@ def _buildNumpyArgumentList(parameters, argumentDict): field = argumentDict[arg.fieldName] if arg.isFieldPtrArgument: actualType = field.dtype - expectedType = arg.dtype.baseType.numpyDtype + expectedType = arg.dtype.base_type.numpy_dtype if expectedType != actualType: raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." % (arg.fieldName, expectedType, actualType)) result.append(field) elif arg.isFieldStrideArgument: - dtype = getBaseType(arg.dtype).numpyDtype + dtype = get_base_type(arg.dtype).numpy_dtype strideArr = np.array(field.strides, dtype=dtype) // field.dtype.itemsize result.append(cuda.In(strideArr)) elif arg.isFieldShapeArgument: - dtype = getBaseType(arg.dtype).numpyDtype + dtype = get_base_type(arg.dtype).numpy_dtype shapeArr = np.array(field.shape, dtype=dtype) result.append(cuda.In(shapeArr)) else: assert False else: param = argumentDict[arg.name] - expectedType = arg.dtype.numpyDtype + expectedType = arg.dtype.numpy_dtype result.append(expectedType.type(param)) assert len(result) == len(parameters) return result diff --git a/gpucuda/indexing.py b/gpucuda/indexing.py index aa05078e884b325a7c59a42aedb2d96745c3629d..e699c7657573478d94592f6be656d1b304662eab 100644 --- a/gpucuda/indexing.py +++ b/gpucuda/indexing.py @@ -4,15 +4,15 @@ import sympy as sp from pystencils.astnodes import Conditional, Block from pystencils.slicing import normalizeSlice -from pystencils.data_types import TypedSymbol, createType +from pystencils.data_types import TypedSymbol, create_type from functools import partial AUTO_BLOCKSIZE_LIMITING = True -BLOCK_IDX = [TypedSymbol("blockIdx." + coord, createType("int")) for coord in ('x', 'y', 'z')] -THREAD_IDX = [TypedSymbol("threadIdx." + coord, createType("int")) for coord in ('x', 'y', 'z')] -BLOCK_DIM = [TypedSymbol("blockDim." + coord, createType("int")) for coord in ('x', 'y', 'z')] -GRID_DIM = [TypedSymbol("gridDim." + coord, createType("int")) for coord in ('x', 'y', 'z')] +BLOCK_IDX = [TypedSymbol("blockIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')] +THREAD_IDX = [TypedSymbol("threadIdx." + coord, create_type("int")) for coord in ('x', 'y', 'z')] +BLOCK_DIM = [TypedSymbol("blockDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')] +GRID_DIM = [TypedSymbol("gridDim." + coord, create_type("int")) for coord in ('x', 'y', 'z')] class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})): diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index 2c25243d3f7fd61bc930294e99d820b2f3e64a8c..2ef07df5108d6afdd90fa036383474c7fcd06857 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -45,7 +45,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, block = Block(assignments) block = indexing.guard(block, commonShape) - ast = KernelFunction(block, functionName=functionName, ghostLayers=ghostLayers, backend='gpucuda') + ast = KernelFunction(block, function_name=functionName, ghost_layers=ghostLayers, backend='gpucuda') ast.globalVariables.update(indexing.indexVariables) coordMapping = indexing.coordinates @@ -64,8 +64,8 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, baseBufferIndex += var * stride resolveBufferAccesses(ast, baseBufferIndex, readOnlyFields) - resolveFieldAccesses(ast, readOnlyFields, fieldToBasePointerInfo=basePointerInfos, - fieldToFixedCoordinates=coordMapping) + resolveFieldAccesses(ast, readOnlyFields, field_to_base_pointer_info=basePointerInfos, + field_to_fixed_coordinates=coordMapping) substituteArrayAccessesWithConstants(ast) @@ -74,10 +74,10 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, # If loop counter symbols have been explicitly used in the update equations (e.g. for built in periodicity), # they are defined here - undefinedLoopCounters = {LoopOverCoordinate.isLoopCounterSymbol(s): s for s in ast.body.undefinedSymbols - if LoopOverCoordinate.isLoopCounterSymbol(s) is not None} + undefinedLoopCounters = {LoopOverCoordinate.is_loop_counter_symbol(s): s for s in ast.body.undefined_symbols + if LoopOverCoordinate.is_loop_counter_symbol(s) is not None} for i, loopCounter in undefinedLoopCounters.items(): - ast.body.insertFront(SympyAssignment(loopCounter, indexing.coordinates[i])) + ast.body.insert_front(SympyAssignment(loopCounter, indexing.coordinates[i])) ast.indexing = indexing ast.compile = partial(makePythonFunction, ast) @@ -104,9 +104,9 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel" for indexField in indexFields: assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype" dataType = indexField.dtype - if dataType.hasElement(name): + if dataType.has_element(name): rhs = indexField[0](name) - lhs = TypedSymbol(name, BasicType(dataType.getElementType(name))) + lhs = TypedSymbol(name, BasicType(dataType.get_element_type(name))) return SympyAssignment(lhs, rhs) raise ValueError("Index %s not found in any of the passed index fields" % (name,)) @@ -118,7 +118,7 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel" functionBody = Block(coordinateSymbolAssignments + assignments) functionBody = indexing.guard(functionBody, getCommonShape(indexFields)) - ast = KernelFunction(functionBody, functionName=functionName, backend='gpucuda') + ast = KernelFunction(functionBody, function_name=functionName, backend='gpucuda') ast.globalVariables.update(indexing.indexVariables) coordMapping = indexing.coordinates @@ -127,8 +127,8 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel" coordMapping = {f.name: coordMapping for f in indexFields} coordMapping.update({f.name: coordinateTypedSymbols for f in nonIndexFields}) - resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping, - fieldToBasePointerInfo=basePointerInfos) + resolveFieldAccesses(ast, readOnlyFields, field_to_fixed_coordinates=coordMapping, + field_to_base_pointer_info=basePointerInfos) substituteArrayAccessesWithConstants(ast) # add the function which determines #blocks and #threads as additional member to KernelFunction node diff --git a/kerncraft_coupling/generate_benchmark.py b/kerncraft_coupling/generate_benchmark.py index 8bff2b0c36b08952cf20a5a041af5414c1616535..ac7339936d38cd0b51e1c3c5dc43ab95e9c58483 100644 --- a/kerncraft_coupling/generate_benchmark.py +++ b/kerncraft_coupling/generate_benchmark.py @@ -1,7 +1,7 @@ from jinja2 import Template -from pystencils.cpu import generateC +from pystencils.cpu import print_c from pystencils.sympyextensions import prod -from pystencils.data_types import getBaseType +from pystencils.data_types import get_base_type benchmarkTemplate = Template(""" #include "kerncraft.h" @@ -85,7 +85,7 @@ int main(int argc, char **argv) def generateBenchmark(ast, likwid=False): - accessedFields = {f.name: f for f in ast.fieldsAccessed} + accessedFields = {f.name: f for f in ast.fields_accessed} constants = [] fields = [] callParameters = [] @@ -96,13 +96,13 @@ def generateBenchmark(ast, likwid=False): else: assert p.isFieldPtrArgument, "Benchmark implemented only for kernels with fixed loop size" field = accessedFields[p.fieldName] - dtype = str(getBaseType(p.dtype)) + dtype = str(get_base_type(p.dtype)) fields.append((p.fieldName, dtype, prod(field.shape))) callParameters.append(p.fieldName) args = { 'likwid': likwid, - 'kernelCode': generateC(ast), + 'kernelCode': print_c(ast), 'kernelName': ast.functionName, 'fields': fields, 'constants': constants, diff --git a/kerncraft_coupling/kerncraft_interface.py b/kerncraft_coupling/kerncraft_interface.py index e44bd1b9b4e2bbb91a3d0b47a18d171f555df0c0..a583869b2a6fa6c4c7534484c18b53ac8d61c633 100644 --- a/kerncraft_coupling/kerncraft_interface.py +++ b/kerncraft_coupling/kerncraft_interface.py @@ -30,7 +30,7 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): self.temporaryDir = TemporaryDirectory() # Loops - innerLoops = [l for l in ast.atoms(LoopOverCoordinate) if l.isInnermostLoop] + innerLoops = [l for l in ast.atoms(LoopOverCoordinate) if l.is_innermost_loop] if len(innerLoops) == 0: raise ValueError("No loop found in pystencils AST") elif len(innerLoops) > 1: @@ -42,7 +42,7 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): curNode = innerLoop while curNode is not None: if isinstance(curNode, LoopOverCoordinate): - loopCounterSym = curNode.loopCounterSymbol + loopCounterSym = curNode.loop_counter_symbol loopInfo = (loopCounterSym.name, curNode.start, curNode.stop, curNode.step) self._loop_stack.append(loopInfo) curNode = curNode.parent @@ -55,7 +55,7 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): reads, writes = searchResolvedFieldAccessesInAst(innerLoop) for accesses, targetDict in [(reads, self.sources), (writes, self.destinations)]: for fa in accesses: - coord = [sp.Symbol(LoopOverCoordinate.getLoopCounterName(i), positive=True, integer=True) + off + coord = [sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i), positive=True, integer=True) + off for i, off in enumerate(fa.offsets)] coord += list(fa.idxCoordinateValues) layout = getLayoutFromStrides(fa.field.strides) @@ -63,7 +63,7 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): targetDict[fa.field.name].append(permutedCoord) # Variables (arrays) - fieldsAccessed = ast.fieldsAccessed + fieldsAccessed = ast.fields_accessed for field in fieldsAccessed: layout = getLayoutFromStrides(field.strides) permutedShape = list(field.shape[i] for i in layout) diff --git a/kernelcreation.py b/kernelcreation.py index c9599c2bd2dd35393ca709549c96b3842d120d33..c5d6b988d3b9705eb20254fc4f6e67832fee0153 100644 --- a/kernelcreation.py +++ b/kernelcreation.py @@ -49,7 +49,7 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None import pystencils.backends.simd_instruction_sets as vec from pystencils.vectorization import vectorize vecParams = cpuVectorizeInfo - vec.selectedInstructionSet = vec.x86VectorInstructionSet(instructionSet=vecParams[0], dataType=vecParams[1]) + vec.selectedInstructionSet = vec.x86_vector_instruction_set(instruction_set=vecParams[0], data_type=vecParams[1]) vectorize(ast) return ast elif target == 'llvm': diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index aa690b957d687c2f2dcd695ca2396ef3d29f2550..c47ed21eb220f6ada761a84ddd68e4b76b8dc15f 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -71,9 +71,9 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ for indexField in indexFields: assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype" dataType = indexField.dtype - if dataType.hasElement(name): + if dataType.has_element(name): rhs = indexField[0](name) - lhs = TypedSymbol(name, BasicType(dataType.getElementType(name))) + lhs = TypedSymbol(name, BasicType(dataType.get_element_type(name))) return SympyAssignment(lhs, rhs) raise ValueError("Index %s not found in any of the passed index fields" % (name,)) @@ -83,7 +83,7 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ # make 1D loop over index fields loopBody = Block([]) - loopNode = LoopOverCoordinate(loopBody, coordinateToLoopOver=0, start=0, stop=indexFields[0].shape[0]) + loopNode = LoopOverCoordinate(loopBody, coordinate_to_loop_over=0, start=0, stop=indexFields[0].shape[0]) for assignment in assignments: loopBody.append(assignment) @@ -92,7 +92,7 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ ast = KernelFunction(functionBody, None, functionName, backend='llvm') fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields} - resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping) + resolveFieldAccesses(ast, set(['indexField']), field_to_fixed_coordinates=fixedCoordinateMapping) moveConstantsBeforeLoop(ast) desympy_ast(ast) diff --git a/llvm/llvm.py b/llvm/llvm.py index 08ad220bb0a07627392e42d70484c697f7941a79..4a0d8d97fb6acc7f5d3ef6eb0482e9e847205db0 100644 --- a/llvm/llvm.py +++ b/llvm/llvm.py @@ -6,8 +6,8 @@ from sympy import S # S is numbers? from pystencils.llvm.control_flow import Loop -from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression, collateTypes, \ - createCompositeTypeFromString +from pystencils.data_types import create_type, to_llvm_type, get_type_of_expression, collate_types, \ + create_composite_type_from_string from sympy import Indexed from pystencils.assignment import Assignment @@ -48,9 +48,9 @@ class LLVMPrinter(Printer): del self.tmp_var[name] def _print_Number(self, n): - if getTypeOfExpression(n) == createType("int"): + if get_type_of_expression(n) == create_type("int"): return ir.Constant(self.integer, int(n)) - elif getTypeOfExpression(n) == createType("double"): + elif get_type_of_expression(n) == create_type("double"): return ir.Constant(self.fp_type, float(n)) else: raise NotImplementedError("Numbers can only have int and double", n) @@ -100,7 +100,7 @@ class LLVMPrinter(Printer): def _print_Mul(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] - if getTypeOfExpression(expr) == createType('double'): + if get_type_of_expression(expr) == create_type('double'): mul = self.builder.fmul else: # int TODO unsigned/signed mul = self.builder.mul @@ -111,7 +111,7 @@ class LLVMPrinter(Printer): def _print_Add(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] - if getTypeOfExpression(expr) == createType('double'): + if get_type_of_expression(expr) == create_type('double'): add = self.builder.fadd else: # int TODO unsigned/signed add = self.builder.add @@ -152,7 +152,7 @@ class LLVMPrinter(Printer): return self._comparison('==', expr) def _comparison(self, cmpop, expr): - if collateTypes([getTypeOfExpression(arg) for arg in expr.args]) == createType('double'): + if collate_types([get_type_of_expression(arg) for arg in expr.args]) == create_type('double'): comparison = self.builder.fcmp_unordered else: comparison = self.builder.icmp_signed @@ -189,10 +189,10 @@ class LLVMPrinter(Printer): def _print_LoopOverCoordinate(self, loop): with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step), - loop.loopCounterName, loop.loopCounterSymbol.name) as i: - self._add_tmp_var(loop.loopCounterSymbol, i) + loop.loop_counter_name, loop.loop_counter_symbol.name) as i: + self._add_tmp_var(loop.loop_counter_symbol, i) self._print(loop.body) - self._remove_tmp_var(loop.loopCounterSymbol) + self._remove_tmp_var(loop.loop_counter_symbol) def _print_SympyAssignment(self, assignment): expr = self._print(assignment.rhs) @@ -207,30 +207,30 @@ class LLVMPrinter(Printer): def _print_castFunc(self, conversion): node = self._print(conversion.args[0]) - to_dtype = getTypeOfExpression(conversion) - from_dtype = getTypeOfExpression(conversion.args[0]) + to_dtype = get_type_of_expression(conversion) + from_dtype = get_type_of_expression(conversion.args[0]) # (From, to) decision = { - (createCompositeTypeFromString("int"), createCompositeTypeFromString("double")): functools.partial( + (create_composite_type_from_string("int"), create_composite_type_from_string("double")): functools.partial( self.builder.sitofp, node, self.fp_type), - (createCompositeTypeFromString("double"), createCompositeTypeFromString("int")): functools.partial( + (create_composite_type_from_string("double"), create_composite_type_from_string("int")): functools.partial( self.builder.fptosi, node, self.integer), - (createCompositeTypeFromString("double *"), createCompositeTypeFromString("int")): functools.partial( + (create_composite_type_from_string("double *"), create_composite_type_from_string("int")): functools.partial( self.builder.ptrtoint, node, self.integer), - (createCompositeTypeFromString("int"), createCompositeTypeFromString("double *")): functools.partial(self.builder.inttoptr, node, - self.fp_pointer), - (createCompositeTypeFromString("double * restrict"), createCompositeTypeFromString("int")): functools.partial( + (create_composite_type_from_string("int"), create_composite_type_from_string("double *")): functools.partial(self.builder.inttoptr, node, + self.fp_pointer), + (create_composite_type_from_string("double * restrict"), create_composite_type_from_string("int")): functools.partial( self.builder.ptrtoint, node, self.integer), - (createCompositeTypeFromString("int"), - createCompositeTypeFromString("double * restrict")): functools.partial(self.builder.inttoptr, node, - self.fp_pointer), - (createCompositeTypeFromString("double * restrict const"), - createCompositeTypeFromString("int")): functools.partial(self.builder.ptrtoint, node, - self.integer), - (createCompositeTypeFromString("int"), - createCompositeTypeFromString("double * restrict const")): functools.partial(self.builder.inttoptr, node, - self.fp_pointer), + (create_composite_type_from_string("int"), + create_composite_type_from_string("double * restrict")): functools.partial(self.builder.inttoptr, node, + self.fp_pointer), + (create_composite_type_from_string("double * restrict const"), + create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, + self.integer), + (create_composite_type_from_string("int"), + create_composite_type_from_string("double * restrict const")): functools.partial(self.builder.inttoptr, node, + self.fp_pointer), } # TODO float, TEST: const, restrict # TODO bitcast, addrspacecast @@ -285,7 +285,7 @@ class LLVMPrinter(Printer): self.builder.branch(after_block) self.builder.position_at_end(falseBlock) - phi = self.builder.phi(to_llvm_type(getTypeOfExpression(piece))) + phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece))) for (val, block) in phiData: phi.add_incoming(val, block) return phi diff --git a/llvm/llvmjit.py b/llvm/llvmjit.py index bd848671a39ed6e1b566b4195779c0dfbec7e06d..7625eec5786b1b0e6127d0265ab7f92c031f65c7 100644 --- a/llvm/llvmjit.py +++ b/llvm/llvmjit.py @@ -5,8 +5,8 @@ import ctypes as ct import subprocess import shutil -from pystencils.data_types import createCompositeTypeFromString -from ..data_types import toCtypes, ctypes_from_llvm +from pystencils.data_types import create_composite_type_from_string +from ..data_types import to_ctypes, ctypes_from_llvm from .llvm import generateLLVM from ..cpu.cpujit import buildCTypeArgumentList, makePythonFunctionIncompleteParams @@ -128,7 +128,7 @@ class Jit(object): if not function.is_declaration: return_type = None if function.ftype.return_type != ir.VoidType(): - return_type = toCtypes(createCompositeTypeFromString(str(function.ftype.return_type))) + return_type = to_ctypes(create_composite_type_from_string(str(function.ftype.return_type))) args = [ctypes_from_llvm(arg) for arg in function.ftype.args] function_address = self.ee.get_function_address(function.name) fptr[function.name] = ct.CFUNCTYPE(return_type, *args)(function_address) diff --git a/sympyextensions.py b/sympyextensions.py index 8bcfae4cff34ccfaecc5a6b635b1e9951cba5e09..c59df3c160a65bd7bc84292d771f888c0435e651 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -7,7 +7,7 @@ import sympy as sp from sympy.functions import Abs from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple -from pystencils.data_types import getTypeOfExpression, getBaseType +from pystencils.data_types import get_type_of_expression, get_base_type from pystencils.assignment import Assignment T = TypeVar('T') @@ -448,7 +448,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], if only_type is None: return True try: - base_type = getBaseType(getTypeOfExpression(e)) + base_type = get_base_type(get_type_of_expression(e)) except ValueError: return False if only_type == 'int' and (base_type.is_int() or base_type.is_uint()): diff --git a/transformations/stage2.py b/transformations/stage2.py index f46f09b0b547873a7a9b28b8c2bc10034e53d70d..dd539c3601fbf03998cb2e15f5ded962b3770858 100644 --- a/transformations/stage2.py +++ b/transformations/stage2.py @@ -1,8 +1,5 @@ -from operator import attrgetter - import sympy as sp - -from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, getTypeOfExpression, collateTypes, castFunc, pointerArithmeticFunc +from pystencils.data_types import PointerType, get_type_of_expression, collate_types, castFunc, pointerArithmeticFunc import pystencils.astnodes as ast @@ -21,7 +18,7 @@ def insertCasts(node): """ casted_args = [] for arg, dataType in zippedArgsTypes: - if dataType.numpyDtype != target.numpyDtype: # ignoring const + if dataType.numpy_dtype != target.numpy_dtype: # ignoring const casted_args.append(castFunc(arg, target)) else: casted_args.append(arg) @@ -54,9 +51,9 @@ def insertCasts(node): # TODO indexed, LoopOverCoordinate if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge): # TODO optimize pow, don't cast integer on double - types = [getTypeOfExpression(arg) for arg in args] + types = [get_type_of_expression(arg) for arg in args] assert len(types) > 0 - target = collateTypes(types) + target = collate_types(types) zipped = list(zip(args, types)) if target.func is PointerType: assert node.func is sp.Add @@ -66,11 +63,11 @@ def insertCasts(node): elif node.func is ast.SympyAssignment: lhs = args[0] rhs = args[1] - target = getTypeOfExpression(lhs) + target = get_type_of_expression(lhs) if target.func is PointerType: return node.func(*args) # TODO fix, not complete else: - return node.func(lhs, *cast([(rhs, getTypeOfExpression(rhs))], target)) + return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target)) elif node.func is ast.ResolvedFieldAccess: return node elif node.func is ast.Block: @@ -83,86 +80,10 @@ def insertCasts(node): return node elif node.func is sp.Piecewise: exprs = [expr for (expr, _) in args] - types = [getTypeOfExpression(expr) for expr in exprs] - target = collateTypes(types) + types = [get_type_of_expression(expr) for expr in exprs] + target = collate_types(types) zipped = list(zip(exprs, types)) casted_exprs = cast(zipped, target) args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_exprs)] return node.func(*args) - - -def insert_casts(node): - """ - Inserts casts and dtype where needed - :param node: ast which should be traversed - :return: node - """ - def conversion(args): - target = args[0] - if isinstance(target.dtype, PointerType): - # Pointer arithmetic - for arg in args[1:]: - # Check validness - if not arg.dtype.is_int() and not arg.dtype.is_uint(): - raise ValueError("Impossible pointer arithmetic", target, arg) - pointer = ast.PointerArithmetic(ast.Add(args[1:]), target) - return [pointer] - - else: - for i in range(len(args)): - if args[i].dtype.numpyDtype != target.dtype.numpyDtype: # TODO ignoring const -> valid behavior? - args[i] = ast.Conversion(args[i], createType(target.dtype), node) - return args - - for arg in node.args: - insert_casts(arg) - if isinstance(node, ast.Indexed): - # TODO need to do something here? - pass - elif isinstance(node, ast.Expr): - args = sorted((arg for arg in node.args), key=attrgetter('dtype')) - target = args[0] - node.args = conversion(args) - node.dtype = target.dtype - elif isinstance(node, ast.SympyAssignment): - if node.lhs.dtype != node.rhs.dtype: - node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype)) - elif isinstance(node, ast.LoopOverCoordinate): - pass - return node - - -#def desympy_ast(node): -# """ -# Remove Sympy Expressions, which have more then one argument. -# This is necessary for further changes in the tree. -# :param node: ast which should be traversed. Only node's children will be modified. -# :return: (modified) node -# """ -# if node.args is None: -# return node -# for i in range(len(node.args)): -# arg = node.args[i] -# if isinstance(arg, sp.Add): -# node.replace(arg, ast.Add(arg.args, node)) -# elif isinstance(arg, sp.Number): -# node.replace(arg, ast.Number(arg, node)) -# elif isinstance(arg, sp.Mul): -# node.replace(arg, ast.Mul(arg.args, node)) -# elif isinstance(arg, sp.Pow): -# node.replace(arg, ast.Pow(arg, node)) -# elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed): -# node.replace(arg, ast.Indexed(arg.args, arg.base, node)) -# elif isinstance(arg, sp.tensor.IndexedBase): -# node.replace(arg, arg.target) -# elif isinstance(arg, sp.Function): -# node.replace(arg, ast.Function(arg.func, arg.args, node)) -# #elif isinstance(arg, sp.containers.Tuple): -# # -# else: -# #print('Not transforming:', type(arg), arg) -# pass -# for arg in node.args: -# desympy_ast(arg) -# return node diff --git a/transformations/transformations.py b/transformations/transformations.py index ae5b08bd6ce59b2262f24cfbd334e14e8f7da9d5..0fb6d0164d1178e3a4312e9d127e7697af05f7b5 100644 --- a/transformations/transformations.py +++ b/transformations/transformations.py @@ -9,7 +9,7 @@ from sympy.tensor import IndexedBase from pystencils.assignment import Assignment from pystencils.field import Field, FieldType, offsetComponentToDirectionString -from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc +from pystencils.data_types import TypedSymbol, create_type, PointerType, StructType, get_base_type, castFunc from pystencils.slicing import normalizeSlice import pystencils.astnodes as ast @@ -94,7 +94,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None lastLoop = newLoop currentBody = ast.Block([lastLoop]) loopStrides.append(getLoopStride(begin, end, 1)) - loopVars.append(newLoop.loopCounterSymbol) + loopVars.append(newLoop.loop_counter_symbol) else: sliceComponent = iterationSlice[loopCoordinate] if type(sliceComponent) is slice: @@ -103,14 +103,14 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None lastLoop = newLoop currentBody = ast.Block([lastLoop]) loopStrides.append(getLoopStride(sc.start, sc.stop, sc.step)) - loopVars.append(newLoop.loopCounterSymbol) + loopVars.append(newLoop.loop_counter_symbol) else: - assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate), + assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loopCoordinate), sp.sympify(sliceComponent)) - currentBody.insertFront(assignment) + currentBody.insert_front(assignment) loopVars = [numBufferAccesses * var for var in loopVars] - astNode = ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName, backend='cpu') + astNode = ast.KernelFunction(currentBody, ghost_layers=ghostLayers, function_name=functionName, backend='cpu') return (astNode, loopStrides, loopVars) @@ -256,17 +256,17 @@ def substituteArrayAccessesWithConstants(astNode): for indexedExpr in indexedExprs: base, idx = indexedExpr.args typedSymbol = base.args[0] - baseType = deepcopy(getBaseType(typedSymbol.dtype)) + baseType = deepcopy(get_base_type(typedSymbol.dtype)) baseType.const = False constantReplacingIndexed = TypedSymbol(typedSymbol.name + str(idx), baseType) constantsDefinitions.append(ast.SympyAssignment(constantReplacingIndexed, indexedExpr)) constantSubstitutions[indexedExpr] = constantReplacingIndexed constantsDefinitions.sort(key=lambda e: e.lhs.name) - alreadyDefined = parentBlock.symbolsDefined + alreadyDefined = parentBlock.symbols_defined for newAssignment in constantsDefinitions: if newAssignment.lhs not in alreadyDefined: - parentBlock.insertBefore(newAssignment, astNode) + parentBlock.insert_before(newAssignment, astNode) return expr.subs(constantSubstitutions) @@ -330,95 +330,95 @@ def resolveBufferAccesses(astNode, baseBufferIndex, readOnlyFieldNames=set()): return visitNode(astNode) -def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}): +def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), field_to_base_pointer_info={}, field_to_fixed_coordinates={}): """ Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing :param astNode: the AST root :param readOnlyFieldNames: set of field names which are considered read-only - :param fieldToBasePointerInfo: a list of tuples indicating which intermediate base pointers should be created + :param field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created for details see :func:`parseBasePointerInfo` - :param fieldToFixedCoordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop + :param field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop counters to index the field these symbols are used as coordinates :return: transformed AST """ - fieldToBasePointerInfo = OrderedDict(sorted(fieldToBasePointerInfo.items(), key=lambda pair: pair[0])) - fieldToFixedCoordinates = OrderedDict(sorted(fieldToFixedCoordinates.items(), key=lambda pair: pair[0])) + field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0])) + field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) - def visitSympyExpr(expr, enclosingBlock, sympyAssignment): + def visit_sympy_expr(expr, enclosing_block, sympy_assignment): if isinstance(expr, Field.Access): - fieldAccess = expr - field = fieldAccess.field + field_access = expr + field = field_access.field - if field.name in fieldToBasePointerInfo: - basePointerInfo = fieldToBasePointerInfo[field.name] + if field.name in field_to_base_pointer_info: + base_pointer_info = field_to_base_pointer_info[field.name] else: - basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))] + base_pointer_info = [list(range(field.indexDimensions + field.spatialDimensions))] dtype = PointerType(field.dtype, const=field.name in readOnlyFieldNames, restrict=True) - fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(field.name)), dtype) + field_ptr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(field.name)), dtype) - def createCoordinateDict(group): - coordDict = {} + def create_coordinate_dict(group): + coord_dict = {} for e in group: if e < field.spatialDimensions: - if field.name in fieldToFixedCoordinates: - coordDict[e] = fieldToFixedCoordinates[field.name][e] + if field.name in field_to_fixed_coordinates: + coord_dict[e] = field_to_fixed_coordinates[field.name][e] else: - ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX - coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), 'int') - coordDict[e] *= field.dtype.itemSize + ctr_name = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX + coord_dict[e] = TypedSymbol("%s_%d" % (ctr_name, e), 'int') + coord_dict[e] *= field.dtype.item_size else: if isinstance(field.dtype, StructType): assert field.indexDimensions == 1 - accessedFieldName = fieldAccess.index[0] - assert isinstance(accessedFieldName, str) - coordDict[e] = field.dtype.getElementOffset(accessedFieldName) + accessed_field_name = field_access.index[0] + assert isinstance(accessed_field_name, str) + coord_dict[e] = field.dtype.get_element_offset(accessed_field_name) else: - coordDict[e] = fieldAccess.index[e - field.spatialDimensions] + coord_dict[e] = field_access.index[e - field.spatialDimensions] - return coordDict + return coord_dict - lastPointer = fieldPtr + last_pointer = field_ptr - for group in reversed(basePointerInfo[1:]): - coordDict = createCoordinateDict(group) - newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer) - if newPtr not in enclosingBlock.symbolsDefined: - newAssignment = ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False) - enclosingBlock.insertBefore(newAssignment, sympyAssignment) - lastPointer = newPtr + for group in reversed(base_pointer_info[1:]): + coord_dict = create_coordinate_dict(group) + new_ptr, offset = createIntermediateBasePointer(field_access, coord_dict, last_pointer) + if new_ptr not in enclosing_block.symbols_defined: + new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False) + enclosing_block.insert_before(new_assignment, sympy_assignment) + last_pointer = new_ptr - coordDict = createCoordinateDict(basePointerInfo[0]) + coord_dict = create_coordinate_dict(base_pointer_info[0]) - _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer) - result = ast.ResolvedFieldAccess(lastPointer, offset, fieldAccess.field, - fieldAccess.offsets, fieldAccess.index) + _, offset = createIntermediateBasePointer(field_access, coord_dict, last_pointer) + result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field, + field_access.offsets, field_access.index) - if isinstance(getBaseType(fieldAccess.field.dtype), StructType): - newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0]) - result = castFunc(result, newType) + if isinstance(get_base_type(field_access.field.dtype), StructType): + new_type = field_access.field.dtype.get_element_type(field_access.index[0]) + result = castFunc(result, new_type) - return visitSympyExpr(result, enclosingBlock, sympyAssignment) + return visit_sympy_expr(result, enclosing_block, sympy_assignment) else: if isinstance(expr, ast.ResolvedFieldAccess): return expr - newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args] + new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args] kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} - return expr.func(*newArgs, **kwargs) if newArgs else expr - - def visitNode(subAst): - if isinstance(subAst, ast.SympyAssignment): - enclosingBlock = subAst.parent - assert type(enclosingBlock) is ast.Block - subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst) - subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst) + return expr.func(*new_args, **kwargs) if new_args else expr + + def visit_node(sub_ast): + if isinstance(sub_ast, ast.SympyAssignment): + enclosing_block = sub_ast.parent + assert type(enclosing_block) is ast.Block + sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast) + sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast) else: - for i, a in enumerate(subAst.args): - visitNode(a) + for i, a in enumerate(sub_ast.args): + visit_node(a) - return visitNode(astNode) + return visit_node(astNode) def moveConstantsBeforeLoop(astNode): @@ -450,8 +450,8 @@ def moveConstantsBeforeLoop(astNode): if isinstance(element, ast.Conditional): criticalSymbols = element.conditionExpr.atoms(sp.Symbol) else: - criticalSymbols = element.symbolsDefined - if node.undefinedSymbols.intersection(criticalSymbols): + criticalSymbols = element.symbols_defined + if node.undefined_symbols.intersection(criticalSymbols): break prevElement = element element = element.parent @@ -475,7 +475,7 @@ def moveConstantsBeforeLoop(astNode): allBlocks = [] getBlocks(astNode, allBlocks) for block in allBlocks: - children = block.takeChildNodes() + children = block.take_child_nodes() for child in children: if not isinstance(child, ast.SympyAssignment): block.append(child) @@ -486,7 +486,7 @@ def moveConstantsBeforeLoop(astNode): else: existingAssignment = checkIfAssignmentAlreadyInBlock(child, target) if not existingAssignment: - target.insertBefore(child, childToInsertBefore) + target.insert_before(child, childToInsertBefore) else: assert existingAssignment.rhs == child.rhs, "Symbol with same name exists already" @@ -502,11 +502,11 @@ def splitInnerLoop(astNode, symbolGroups): :return: transformed AST """ allLoops = astNode.atoms(ast.LoopOverCoordinate) - innerLoop = [l for l in allLoops if l.isInnermostLoop] + innerLoop = [l for l in allLoops if l.is_innermost_loop] assert len(innerLoop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?" innerLoop = innerLoop[0] assert type(innerLoop.body) is ast.Block - outerLoop = [l for l in allLoops if l.isOutermostLoop] + outerLoop = [l for l in allLoops if l.is_outermost_loop] assert len(outerLoop) == 1, "Error in AST, multiple outermost loops." outerLoop = outerLoop[0] @@ -533,7 +533,7 @@ def splitInnerLoop(astNode, symbolGroups): if type(symbol) is not Field.Access: assert type(symbol) is TypedSymbol newTs = TypedSymbol(symbol.name, PointerType(symbol.dtype)) - symbolsWithTemporaryArray[symbol] = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol] + symbolsWithTemporaryArray[symbol] = IndexedBase(newTs, shape=(1,))[innerLoop.loop_counter_symbol] assignmentGroup = [] for assignment in innerLoop.body.args: @@ -542,18 +542,18 @@ def splitInnerLoop(astNode, symbolGroups): if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup: assert type(assignment.lhs) is TypedSymbol newTs = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) - newLhs = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol] + newLhs = IndexedBase(newTs, shape=(1,))[innerLoop.loop_counter_symbol] else: newLhs = assignment.lhs assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs)) assignmentGroups.append(assignmentGroup) - newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups] + newLoops = [innerLoop.new_loop_with_different_body(ast.Block(group)) for group in assignmentGroups] innerLoop.parent.replace(innerLoop, ast.Block(newLoops)) for tmpArray in symbolsWithTemporaryArray: tmpArrayPointer = TypedSymbol(tmpArray.name, PointerType(tmpArray.dtype)) - outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArrayPointer, innerLoop.stop)) + outerLoop.parent.insert_front(ast.TemporaryMemoryAllocation(tmpArrayPointer, innerLoop.stop)) outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArrayPointer)) @@ -568,7 +568,7 @@ def cutLoop(loopNode, cuttingPoints): for newEnd in cuttingPoints: if newEnd - newStart == 1: newBody = deepcopy(loopNode.body) - newBody.subs({loopNode.loopCounterSymbol: newStart}) + newBody.subs({loopNode.loop_counter_symbol: newStart}) newLoops.append(newBody) else: newLoop = ast.LoopOverCoordinate(deepcopy(loopNode.body), loopNode.coordinateToLoopOver, @@ -634,7 +634,7 @@ def simplifyBooleanExpression(expr, singleVariableRanges): def simplifyConditionals(node, loopConditionals={}): """Simplifies/Removes conditions inside loops that depend on the loop counter.""" if isinstance(node, ast.LoopOverCoordinate): - ctrSym = node.loopCounterSymbol + ctrSym = node.loop_counter_symbol loopConditionals[ctrSym] = sp.And(ctrSym >= node.start, ctrSym < node.stop) simplifyConditionals(node.body) del loopConditionals[ctrSym] @@ -729,7 +729,7 @@ def typeAllEquations(eqs, typeForSymbol): elif isinstance(object, ast.Conditional): falseBlock = None if object.falseBlock is None else visit(object.falseBlock) return ast.Conditional(processRhs(object.conditionExpr), - trueBlock=visit(object.trueBlock), falseBlock=falseBlock) + true_block=visit(object.trueBlock), false_block=falseBlock) elif isinstance(object, ast.Block): return ast.Block([visit(e) for e in object.args]) else: @@ -818,9 +818,9 @@ def get_type(node): # TODO sp.NumberSymbol elif isinstance(node, sp.Number): if isinstance(node, sp.Float): - return createType('double') + return create_type('double') elif isinstance(node, sp.Integer): - return createType('int') + return create_type('int') else: raise NotImplemented('Not yet supported: %s %s' % (node, type(node))) else: diff --git a/vectorization.py b/vectorization.py index 8e1377c04a62c808cbfcef54bd7a7bd8cc8595af..6cb9501a6d00c38b9c2e902968c6c2bd143b8c7d 100644 --- a/vectorization.py +++ b/vectorization.py @@ -1,114 +1,112 @@ import sympy as sp import warnings - from pystencils.sympyextensions import fast_subs from pystencils.transformations import filteredTreeIteration -from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \ - PointerType +from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, castFunc, collate_types, PointerType import pystencils.astnodes as ast -def vectorize(astNode, vectorWidth=4): - vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth) - insertVectorCasts(astNode) +def vectorize(ast_node, vector_width=4): + vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width) + insert_vector_casts(ast_node) -def vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth=4): +def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4): """ Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if - - loop bounds are constant - - loop range is a multiple of vector width + * loop bounds are constant + * loop range is a multiple of vector width """ - innerLoops = [n for n in astNode.atoms(ast.LoopOverCoordinate) if n.isInnermostLoop] + inner_loops = [n for n in ast_node.atoms(ast.LoopOverCoordinate) if n.is_innermost_loop] - for loopNode in innerLoops: - loopRange = loopNode.stop - loopNode.start + for loopNode in inner_loops: + loop_range = loopNode.stop - loopNode.start # Check restrictions - if isinstance(loopRange, sp.Basic) and not loopRange.is_integer: + if isinstance(loop_range, sp.Expr) and not loop_range.is_number: warnings.warn("Currently only loops with fixed ranges can be vectorized - skipping loop") continue - if loopRange % vectorWidth != 0 or loopNode.step != 1: + if loop_range % vector_width != 0 or loopNode.step != 1: warnings.warn("Currently only loops with loop bounds that are multiples " - "of vectorization width can be vectorized") + "of vectorization width can be vectorized - skipping loop") continue # Find all array accesses (indexed) that depend on the loop counter as offset - loopCounterSymbol = ast.LoopOverCoordinate.getLoopCounterSymbol(loopNode.coordinateToLoopOver) + loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loopNode.coordinateToLoopOver) substitutions = {} successful = True for indexed in loopNode.atoms(sp.Indexed): base, index = indexed.args - if loopCounterSymbol in index.atoms(sp.Symbol): - loopCounterIsOffset = loopCounterSymbol not in (index - loopCounterSymbol).atoms() - if not loopCounterIsOffset: + if loop_counter_symbol in index.atoms(sp.Symbol): + loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms() + if not loop_counter_is_offset: successful = False break - typedSymbol = base.label - assert type(typedSymbol.dtype) is PointerType, "Type of access is " + str(typedSymbol.dtype) + ", " + str(indexed) - substitutions[indexed] = castFunc(indexed, VectorType(typedSymbol.dtype.baseType, vectorWidth)) + typed_symbol = base.label + assert type(typed_symbol.dtype) is PointerType, f"Type of access is {typed_symbol.dtype}, {indexed}" + substitutions[indexed] = castFunc(indexed, VectorType(typed_symbol.dtype.base_type, vector_width)) if not successful: warnings.warn("Could not vectorize loop because of non-consecutive memory access") continue - loopNode.step = vectorWidth + loopNode.step = vector_width loopNode.subs(substitutions) -def insertVectorCasts(astNode): - """ - Inserts necessary casts from scalar values to vector values - """ - def visitExpr(expr): +def insert_vector_casts(ast_node): + """Inserts necessary casts from scalar values to vector values.""" + + def visit_expr(expr): if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc) or \ isinstance(expr, sp.boolalg.BooleanFunction): - newArgs = [visitExpr(a) for a in expr.args] - argTypes = [getTypeOfExpression(a) for a in newArgs] - if not any(type(t) is VectorType for t in argTypes): + new_args = [visit_expr(a) for a in expr.args] + arg_types = [get_type_of_expression(a) for a in new_args] + if not any(type(t) is VectorType for t in arg_types): return expr else: - targetType = collateTypes(argTypes) - castedArgs = [castFunc(a, targetType) if t != targetType else a - for a, t in zip(newArgs, argTypes)] - return expr.func(*castedArgs) + target_type = collate_types(arg_types) + casted_args = [castFunc(a, target_type) if t != target_type else a + for a, t in zip(new_args, arg_types)] + return expr.func(*casted_args) elif expr.func is sp.Pow: - newArg = visitExpr(expr.args[0]) - return sp.Pow(newArg, expr.args[1]) + new_arg = visit_expr(expr.args[0]) + return expr.func(new_arg, expr.args[1]) elif expr.func == sp.Piecewise: - newResults = [visitExpr(a[0]) for a in expr.args] - newConditions = [visitExpr(a[1]) for a in expr.args] - typesOfResults = [getTypeOfExpression(a) for a in newResults] - typesOfConditions = [getTypeOfExpression(a) for a in newConditions] + new_results = [visit_expr(a[0]) for a in expr.args] + new_conditions = [visit_expr(a[1]) for a in expr.args] + types_of_results = [get_type_of_expression(a) for a in new_results] + types_of_conditions = [get_type_of_expression(a) for a in new_conditions] - resultTargetType = getTypeOfExpression(expr) - conditionTargetType = collateTypes(typesOfConditions) - if type(conditionTargetType) is VectorType and type(resultTargetType) is not VectorType: - resultTargetType = VectorType(resultTargetType, width=conditionTargetType.width) + result_target_type = get_type_of_expression(expr) + condition_target_type = collate_types(types_of_conditions) + if type(condition_target_type) is VectorType and type(result_target_type) is not VectorType: + result_target_type = VectorType(result_target_type, width=condition_target_type.width) - castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a - for a, t in zip(newResults, typesOfResults)] + casted_results = [castFunc(a, result_target_type) if t != result_target_type else a + for a, t in zip(new_results, types_of_results)] - castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a - for a, t in zip(newConditions, typesOfConditions)] + casted_conditions = [castFunc(a, condition_target_type) + if t != condition_target_type and a is not True else a + for a, t in zip(new_conditions, types_of_conditions)] - return sp.Piecewise(*[(r, c) for r, c in zip(castedResults, castedConditions)]) + return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) else: return expr - substitutionDict = {} - for asmt in filteredTreeIteration(astNode, ast.SympyAssignment): - subsExpr = fast_subs(asmt.rhs, substitutionDict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) - asmt.rhs = visitExpr(subsExpr) - rhsType = getTypeOfExpression(asmt.rhs) - if isinstance(asmt.lhs, TypedSymbol): - lhsType = asmt.lhs.dtype - if type(rhsType) is VectorType and type(lhsType) is not VectorType: - newLhsType = VectorType(lhsType, rhsType.width) - newLhs = TypedSymbol(asmt.lhs.name, newLhsType) - substitutionDict[asmt.lhs] = newLhs - asmt.lhs = newLhs - elif asmt.lhs.func == castFunc: - lhsType = asmt.lhs.args[1] - if type(lhsType) is VectorType and type(rhsType) is not VectorType: - asmt.rhs = castFunc(asmt.rhs, lhsType) + substitution_dict = {} + for assignment in filteredTreeIteration(ast_node, ast.SympyAssignment): + subs_expr = fast_subs(assignment.rhs, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) + assignment.rhs = visit_expr(subs_expr) + rhs_type = get_type_of_expression(assignment.rhs) + if isinstance(assignment.lhs, TypedSymbol): + lhs_type = assignment.lhs.dtype + if type(rhs_type) is VectorType and type(lhs_type) is not VectorType: + new_lhs_type = VectorType(lhs_type, rhs_type.width) + new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) + substitution_dict[assignment.lhs] = new_lhs + assignment.lhs = new_lhs + elif assignment.lhs.func == castFunc: + lhs_type = assignment.lhs.args[1] + if type(lhs_type) is VectorType and type(rhs_type) is not VectorType: + assignment.rhs = castFunc(assignment.rhs, lhs_type)