import sympy as sp from sympy.tensor import IndexedBase from pystencils.field import Field from pystencils.data_types import TypedSymbol, createType, castFunc from pystencils.sympyextensions import fastSubs class Node(object): """Base class for all AST nodes""" def __init__(self, parent=None): self.parent = parent def args(self): """Returns all arguments/children of this node""" return [] @property def symbolsDefined(self): """Set of symbols which are defined by this node. """ return set() @property def undefinedSymbols(self): """Symbols which are used but are not defined inside this node""" raise NotImplementedError() def subs(self, *args, **kwargs): """Inplace! substitute, similar to sympys but modifies ast and returns None""" for a in self.args: a.subs(*args, **kwargs) @property def func(self): return self.__class__ def atoms(self, argType): """ Returns a set of all children which are an instance of the given argType """ result = set() for arg in self.args: if isinstance(arg, argType): result.add(arg) result.update(arg.atoms(argType)) return result class Conditional(Node): """Conditional""" def __init__(self, conditionExpr, trueBlock, falseBlock=None): """ Create a new conditional node :param conditionExpr: sympy relational expression :param trueBlock: block which is run if conditional is true :param falseBlock: block which is run if conditional is false, or None if not needed """ assert conditionExpr.is_Boolean or conditionExpr.is_Relational self.conditionExpr = conditionExpr def handleChild(c): if c is None: return None if not isinstance(c, Block): c = Block([c]) c.parent = self return c self.trueBlock = handleChild(trueBlock) self.falseBlock = handleChild(falseBlock) def subs(self, *args, **kwargs): self.trueBlock.subs(*args, **kwargs) if self.falseBlock: self.falseBlock.subs(*args, **kwargs) self.conditionExpr = self.conditionExpr.subs(*args, **kwargs) @property def args(self): result = [self.conditionExpr, self.trueBlock] if self.falseBlock: result.append(self.falseBlock) return result @property def symbolsDefined(self): return set() @property def undefinedSymbols(self): result = self.trueBlock.undefinedSymbols if self.falseBlock: result.update(self.falseBlock.undefinedSymbols) result.update(self.conditionExpr.atoms(sp.Symbol)) return result def __str__(self): return 'if:({!s}) '.format(self.conditionExpr) def __repr__(self): return 'if:({!r}) '.format(self.conditionExpr) class KernelFunction(Node): class Argument: def __init__(self, name, dtype, symbol, kernelFunctionNode): from pystencils.transformations import symbolNameToVariableName self.name = name self.dtype = dtype self.isFieldPtrArgument = False self.isFieldShapeArgument = False self.isFieldStrideArgument = False self.isFieldArgument = False self.fieldName = "" self.coordinate = None self.symbol = symbol if name.startswith(Field.DATA_PREFIX): self.isFieldPtrArgument = True self.isFieldArgument = True self.fieldName = name[len(Field.DATA_PREFIX):] elif name.startswith(Field.SHAPE_PREFIX): self.isFieldShapeArgument = True self.isFieldArgument = True self.fieldName = name[len(Field.SHAPE_PREFIX):] elif name.startswith(Field.STRIDE_PREFIX): self.isFieldStrideArgument = True self.isFieldArgument = True self.fieldName = name[len(Field.STRIDE_PREFIX):] self.field = None if self.isFieldArgument: fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed} self.field = fieldMap[self.fieldName] def __lt__(self, other): def score(l): if l.isFieldPtrArgument: return -4 elif l.isFieldShapeArgument: return -3 elif l.isFieldStrideArgument: return -2 return 0 if score(self) < score(other): return True elif score(self) == score(other): return self.name < other.name else: return False def __repr__(self): return '<{0} {1}>'.format(self.dtype, self.name) def __init__(self, body, ghostLayers=None, functionName="kernel", backend=""): super(KernelFunction, self).__init__() self._body = body body.parent = self self._parameters = None self.functionName = functionName self._body.parent = self self.compile = None self.ghostLayers = ghostLayers # these variables are assumed to be global, so no automatic parameter is generated for them self.globalVariables = set() self.backend = backend @property def symbolsDefined(self): return set() @property def undefinedSymbols(self): return set() @property def parameters(self): self._updateParameters() return self._parameters @property def body(self): return self._body @property def args(self): return [self._body] @property def fieldsAccessed(self): """Set of Field instances: fields which are accessed inside this kernel function""" return set(o.field for o in self.atoms(ResolvedFieldAccess)) def _updateParameters(self): undefinedSymbols = self._body.undefinedSymbols - self.globalVariables self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols] self._parameters.sort() def __str__(self): self._updateParameters() return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters, ("\t" + "\t".join(str(self.body).splitlines(True)))) def __repr__(self): self._updateParameters() return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters) class Block(Node): def __init__(self, listOfNodes): super(Node, self).__init__() self._nodes = listOfNodes self.parent = None for n in self._nodes: n.parent = self @property def args(self): return self._nodes def insertFront(self, node): node.parent = self self._nodes.insert(0, node) def insertBefore(self, newNode, insertBefore): newNode.parent = self idx = self._nodes.index(insertBefore) # move all assignment (definitions to the top) if isinstance(newNode, SympyAssignment) and newNode.isDeclaration: while idx > 0: pn = self._nodes[idx - 1] if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): idx -= 1 else: break self._nodes.insert(idx, newNode) def append(self, node): if isinstance(node, list) or isinstance(node, tuple): for n in node: n.parent = self self._nodes.append(n) else: node.parent = self self._nodes.append(node) def takeChildNodes(self): tmp = self._nodes self._nodes = [] return tmp def replace(self, child, replacements): idx = self._nodes.index(child) del self._nodes[idx] if type(replacements) is list: for e in replacements: e.parent = self self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:] else: replacements.parent = self self._nodes.insert(idx, replacements) @property def symbolsDefined(self): result = set() for a in self.args: result.update(a.symbolsDefined) return result @property def undefinedSymbols(self): result = set() definedSymbols = set() for a in self.args: result.update(a.undefinedSymbols) definedSymbols.update(a.symbolsDefined) return result - definedSymbols def __str__(self): return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes) def __repr__(self): return "Block" class PragmaBlock(Block): def __init__(self, pragmaLine, listOfNodes): super(PragmaBlock, self).__init__(listOfNodes) self.pragmaLine = pragmaLine for n in listOfNodes: n.parent = self def __repr__(self): return self.pragmaLine class LoopOverCoordinate(Node): LOOP_COUNTER_NAME_PREFIX = "ctr" def __init__(self, body, coordinateToLoopOver, start, stop, step=1): self.body = body body.parent = self self.coordinateToLoopOver = coordinateToLoopOver self.start = start self.stop = stop self.step = step self.body.parent = self self.prefixLines = [] def newLoopWithDifferentBody(self, newBody): result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step) result.prefixLines = [l for l in self.prefixLines] return result def subs(self, *args, **kwargs): self.body.subs(*args, **kwargs) if hasattr(self.start, "subs"): self.start = self.start.subs(*args, **kwargs) if hasattr(self.stop, "subs"): self.stop = self.stop.subs(*args, **kwargs) if hasattr(self.step, "subs"): self.step = self.step.subs(*args, **kwargs) @property def args(self): result = [self.body] for e in [self.start, self.stop, self.step]: if hasattr(e, "args"): result.append(e) return result def replace(self, child, replacement): if child == self.body: self.body = replacement elif child == self.start: self.start = replacement elif child == self.step: self.step = replacement elif child == self.stop: self.stop = replacement @property def symbolsDefined(self): return set([self.loopCounterSymbol]) @property def undefinedSymbols(self): result = self.body.undefinedSymbols for possibleSymbol in [self.start, self.stop, self.step]: if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic): result.update(possibleSymbol.atoms(sp.Symbol)) return result - set([self.loopCounterSymbol]) @staticmethod def getLoopCounterName(coordinateToLoopOver): return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver) @property def loopCounterName(self): return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver) @staticmethod def isLoopCounterSymbol(symbol): prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX if not symbol.name.startswith(prefix): return None if symbol.dtype != createType('int'): return None coordinate = int(symbol.name[len(prefix)+1:]) return coordinate @staticmethod def getLoopCounterSymbol(coordinateToLoopOver): return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int') @property def loopCounterSymbol(self): return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver) @property def isOutermostLoop(self): from pystencils.transformations import getNextParentOfType return getNextParentOfType(self, LoopOverCoordinate) is None @property def isInnermostLoop(self): return len(self.atoms(LoopOverCoordinate)) == 0 def __str__(self): return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start, self.loopCounterName, self.stop, self.loopCounterName, self.step, ("\t" + "\t".join(str(self.body).splitlines(True)))) def __repr__(self): return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start, self.loopCounterName, self.stop, self.loopCounterName, self.step) class SympyAssignment(Node): def __init__(self, lhsSymbol, rhsTerm, isConst=True): self._lhsSymbol = lhsSymbol self.rhs = rhsTerm self._isDeclaration = True isCast = self._lhsSymbol.func == castFunc if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast: self._isDeclaration = False self._isConst = isConst @property def lhs(self): return self._lhsSymbol @lhs.setter def lhs(self, newValue): self._lhsSymbol = newValue self._isDeclaration = True isCast = self._lhsSymbol.func == castFunc if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast: self._isDeclaration = False def subs(self, *args, **kwargs): self.lhs = fastSubs(self.lhs, *args, **kwargs) self.rhs = fastSubs(self.rhs, *args, **kwargs) @property def args(self): return [self._lhsSymbol, self.rhs] @property def symbolsDefined(self): if not self._isDeclaration: return set() return set([self._lhsSymbol]) @property def undefinedSymbols(self): result = self.rhs.atoms(sp.Symbol) # Add loop counters if there a field accesses loopCounters = set() for symbol in result: if isinstance(symbol, Field.Access): for i in range(len(symbol.offsets)): loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i)) result.update(loopCounters) result.update(self._lhsSymbol.atoms(sp.Symbol)) return result @property def isDeclaration(self): return self._isDeclaration @property def isConst(self): return self._isConst def replace(self, child, replacement): if child == self.lhs: replacement.parent = self self.lhs = replacement elif child == self.rhs: replacement.parent = self self.rhs = replacement else: raise ValueError('%s is not in args of %s' % (replacement, self.__class__)) def __repr__(self): return repr(self.lhs) + " = " + repr(self.rhs) class ResolvedFieldAccess(sp.Indexed): def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues): if not isinstance(base, IndexedBase): base = IndexedBase(base, shape=(1,)) obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex) obj.field = field obj.offsets = offsets obj.idxCoordinateValues = idxCoordinateValues return obj def _eval_subs(self, old, new): return ResolvedFieldAccess(self.args[0], self.args[1].subs(old, new), self.field, self.offsets, self.idxCoordinateValues) def fastSubs(self, subsDict): if self in subsDict: return subsDict[self] return ResolvedFieldAccess(self.args[0].subs(subsDict), self.args[1].subs(subsDict), self.field, self.offsets, self.idxCoordinateValues) def _hashable_content(self): superClassContents = super(ResolvedFieldAccess, self)._hashable_content() return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field)) @property def typedSymbol(self): return self.base.label def __str__(self): top = super(ResolvedFieldAccess, self).__str__() return "%s (%s)" % (top, self.typedSymbol.dtype) def __getnewargs__(self): return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues class TemporaryMemoryAllocation(Node): def __init__(self, typedSymbol, size): self.symbol = typedSymbol self.size = size @property def symbolsDefined(self): return set([self.symbol]) @property def undefinedSymbols(self): if isinstance(self.size, sp.Basic): return self.size.atoms(sp.Symbol) else: return set() @property def args(self): return [self.symbol] class TemporaryMemoryFree(Node): def __init__(self, typedSymbol): self.symbol = typedSymbol @property def symbolsDefined(self): return set() @property def undefinedSymbols(self): return set() @property def args(self): return []