Commit 3bcfac93 authored by Martin Bauer's avatar Martin Bauer
Browse files

PEP8 naming

parent ef924b18
......@@ -2,7 +2,7 @@ from pystencils.field import Field, FieldType, extractCommonSubexpressions
from pystencils.data_types import TypedSymbol
from pystencils.slicing import makeSlice
from pystencils.kernelcreation import createKernel, createIndexedKernel
from pystencils.display_utils import showCode, toDot
from pystencils.display_utils import show_code, to_dot
from pystencils.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
......@@ -11,7 +11,7 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'TypedSymbol',
'makeSlice',
'createKernel', 'createIndexedKernel',
'showCode', 'toDot',
'show_code', 'to_dot',
'AssignmentCollection',
'Assignment',
'SymbolCreator']
# -*- coding: utf-8 -*-
from sympy.codegen.ast import Assignment
from sympy.printing.latex import LatexPrinter
......@@ -11,4 +12,9 @@ def print_assignment_latex(printer, expr):
return f"{printed_lhs} \leftarrow {printed_rhs}"
def assignment_str(assignment):
return f"{assignment.lhs}{assignment.rhs}"
Assignment.__str__ = assignment_str
LatexPrinter._print_Assignment = print_assignment_latex
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy
from pystencils.assignment_collection.simplifications import sympy_cse, sympy_cse_on_assignment_list, \
apply_to_all_assignments, apply_on_all_subexpressions, subexpression_substitution_in_existing_subexpressions, \
subexpression_substitution_in_main_assignments, add_subexpressions_for_divisions
__all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions']
......@@ -61,7 +61,7 @@ class AssignmentCollection:
left hand side symbol (which could have been generated)
"""
if lhs is None:
lhs = sp.Dummy()
lhs = next(self.subexpression_symbol_generator)
eq = Assignment(lhs, rhs)
self.subexpressions.append(eq)
if topological_sort:
......@@ -135,25 +135,6 @@ class AssignmentCollection:
return handled_symbols
def get(self, symbols: Sequence[sp.Symbol], from_main_assignments_only=False) -> List[Assignment]:
"""Extracts all assignments that have a left hand side that is contained in the symbols parameter.
Args:
symbols: return assignments that have one of these symbols as left hand side
from_main_assignments_only: search only in main assignments (exclude subexpressions)
"""
if not hasattr(symbols, "__len__"):
symbols = set(symbols)
else:
symbols = set(symbols)
if not from_main_assignments_only:
assignments_to_search = self.all_assignments
else:
assignments_to_search = self.main_assignments
return [assignment for assignment in assignments_to_search if assignment.lhs in symbols]
def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]]=None, module=None):
"""Returns a python function to evaluate this equation collection.
......@@ -343,26 +324,25 @@ class AssignmentCollection:
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments])
def __str__(self):
result = "Subexpressions\n"
result = "Subexpressions:\n"
for eq in self.subexpressions:
result += str(eq) + "\n"
result += "Main Assignments\n"
result += f"\t{eq}\n"
result += "Main Assignments:\n"
for eq in self.main_assignments:
result += str(eq) + "\n"
result += f"{eq}\n"
return result
class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self):
def __init__(self, symbol="xi"):
self._ctr = 0
self._symbol = symbol
def __iter__(self):
return self
def __next__(self):
name = f"{self._symbol}_{self._ctr}"
self._ctr += 1
return sp.Symbol("xi_" + str(self._ctr))
def next(self):
return self.__next__()
return sp.Symbol(name)
import sympy as sp
from typing import Callable, List
from pystencils import Assignment, AssignmentCollection
from pystencils.assignment import Assignment
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.sympyextensions import subs_additive
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
ec = AssignmentCollection(assignments, [])
return sympy_cse(ec).all_assignments
def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
"""Searches for common subexpressions inside the equation collection.
......@@ -32,21 +27,28 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
return ac.copy(modified_update_equations, new_subexpressions)
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
ec = AssignmentCollection(assignments, [])
return sympy_cse(ec).all_assignments
def apply_to_all_assignments(assignment_collection: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies sympy expand operation to all equations in collection"""
"""Applies sympy expand operation to all equations in collection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
return assignment_collection.copy(result)
def apply_on_all_subexpressions(ac: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies the given operation on all subexpressions of the AssignmentCollection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
"""Goes through the subexpressions list and replaces the term in the following subexpressions"""
"""Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = []
for outerCtr, s in enumerate(ac.subexpressions):
new_rhs = s.rhs
......
import sympy as sp
from sympy.tensor import IndexedBase
from pystencils.field import Field
from pystencils.data_types import TypedSymbol, createType, castFunc
from pystencils.data_types import TypedSymbol, create_type, castFunc
from pystencils.sympyextensions import fast_subs
from typing import List, Set, Optional, Union, Any
NodeOrExpr = Union['Node', sp.Expr]
class Node(object):
"""Base class for all AST nodes"""
"""Base class for all AST nodes."""
def __init__(self, parent=None):
def __init__(self, parent: Optional['Node'] = None):
self.parent = parent
def args(self):
"""Returns all arguments/children of this node"""
@property
def args(self) -> List[NodeOrExpr]:
"""Returns all arguments/children of this node."""
return []
@property
def symbolsDefined(self):
"""Set of symbols which are defined by this node. """
def symbols_defined(self) -> Set[sp.Symbol]:
"""Set of symbols which are defined by this node."""
return set()
@property
def undefinedSymbols(self):
"""Symbols which are used but are not defined inside this node"""
def undefined_symbols(self) -> Set[sp.Symbol]:
"""Symbols which are used but are not defined inside this node."""
raise NotImplementedError()
def subs(self, *args, **kwargs):
"""Inplace! substitute, similar to sympys but modifies ast and returns None"""
def subs(self, *args, **kwargs) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args:
a.subs(*args, **kwargs)
......@@ -34,32 +38,35 @@ class Node(object):
def func(self):
return self.__class__
def atoms(self, argType):
"""
Returns a set of all children which are an instance of the given argType
"""
def atoms(self, arg_type) -> Set[Any]:
"""Returns a set of all descendants recursively, which are an instance of the given type."""
result = set()
for arg in self.args:
if isinstance(arg, argType):
if isinstance(arg, arg_type):
result.add(arg)
result.update(arg.atoms(argType))
result.update(arg.atoms(arg_type))
return result
class Conditional(Node):
"""Conditional"""
def __init__(self, conditionExpr, trueBlock, falseBlock=None):
"""
Create a new conditional node
:param conditionExpr: sympy relational expression
:param trueBlock: block which is run if conditional is true
:param falseBlock: block which is run if conditional is false, or None if not needed
"""
assert conditionExpr.is_Boolean or conditionExpr.is_Relational
self.conditionExpr = conditionExpr
def handleChild(c):
"""Conditional that maps to a 'if' statement in C/C++.
Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
Consider using assignments with sympy.Piecewise in this case.
Args:
condition_expr: sympy relational expression
true_block: block which is run if conditional is true
false_block: optional block which is run if conditional is false
"""
def __init__(self, condition_expr: sp.Expr, true_block: 'Block', false_block: Optional['Block'] = None) -> None:
super(Conditional, self).__init__(parent=None)
assert condition_expr.is_Boolean or condition_expr.is_Relational
self.conditionExpr = condition_expr
def handle_child(c):
if c is None:
return None
if not isinstance(c, Block):
......@@ -67,8 +74,8 @@ class Conditional(Node):
c.parent = self
return c
self.trueBlock = handleChild(trueBlock)
self.falseBlock = handleChild(falseBlock)
self.trueBlock = handle_child(true_block)
self.falseBlock = handle_child(false_block)
def subs(self, *args, **kwargs):
self.trueBlock.subs(*args, **kwargs)
......@@ -84,14 +91,14 @@ class Conditional(Node):
return result
@property
def symbolsDefined(self):
def symbols_defined(self):
return set()
@property
def undefinedSymbols(self):
result = self.trueBlock.undefinedSymbols
def undefined_symbols(self):
result = self.trueBlock.undefined_symbols
if self.falseBlock:
result.update(self.falseBlock.undefinedSymbols)
result.update(self.falseBlock.undefined_symbols)
result.update(self.conditionExpr.atoms(sp.Symbol))
return result
......@@ -105,7 +112,7 @@ class Conditional(Node):
class KernelFunction(Node):
class Argument:
def __init__(self, name, dtype, symbol, kernelFunctionNode):
def __init__(self, name, dtype, symbol, kernel_function_node):
from pystencils.transformations import symbolNameToVariableName
self.name = name
self.dtype = dtype
......@@ -132,8 +139,8 @@ class KernelFunction(Node):
self.field = None
if self.isFieldArgument:
fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
self.field = fieldMap[self.fieldName]
field_map = {symbolNameToVariableName(f.name): f for f in kernel_function_node.fields_accessed}
self.field = field_map[self.fieldName]
def __lt__(self, other):
def score(l):
......@@ -155,30 +162,30 @@ class KernelFunction(Node):
def __repr__(self):
return '<{0} {1}>'.format(self.dtype, self.name)
def __init__(self, body, ghostLayers=None, functionName="kernel", backend=""):
def __init__(self, body, ghost_layers=None, function_name="kernel", backend=""):
super(KernelFunction, self).__init__()
self._body = body
body.parent = self
self._parameters = None
self.functionName = functionName
self.functionName = function_name
self._body.parent = self
self.compile = None
self.ghostLayers = ghostLayers
self.ghostLayers = ghost_layers
# these variables are assumed to be global, so no automatic parameter is generated for them
self.globalVariables = set()
self.backend = backend
@property
def symbolsDefined(self):
def symbols_defined(self):
return set()
@property
def undefinedSymbols(self):
def undefined_symbols(self):
return set()
@property
def parameters(self):
self._updateParameters()
self._update_parameters()
return self._parameters
@property
......@@ -190,30 +197,30 @@ class KernelFunction(Node):
return [self._body]
@property
def fieldsAccessed(self):
def fields_accessed(self):
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def _updateParameters(self):
undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]
def _update_parameters(self):
undefined_symbols = self._body.undefined_symbols - self.globalVariables
self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefined_symbols]
self._parameters.sort()
def __str__(self):
self._updateParameters()
self._update_parameters()
return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
("\t" + "\t".join(str(self.body).splitlines(True))))
def __repr__(self):
self._updateParameters()
self._update_parameters()
return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
class Block(Node):
def __init__(self, listOfNodes):
super(Node, self).__init__()
self._nodes = listOfNodes
def __init__(self, nodes: List[Node]):
super(Block, self).__init__()
self._nodes = nodes
self.parent = None
for n in self._nodes:
n.parent = self
......@@ -222,23 +229,23 @@ class Block(Node):
def args(self):
return self._nodes
def insertFront(self, node):
def insert_front(self, node):
node.parent = self
self._nodes.insert(0, node)
def insertBefore(self, newNode, insertBefore):
newNode.parent = self
idx = self._nodes.index(insertBefore)
def insert_before(self, new_node, insert_before):
new_node.parent = self
idx = self._nodes.index(insert_before)
# move all assignment (definitions to the top)
if isinstance(newNode, SympyAssignment) and newNode.isDeclaration:
if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
while idx > 0:
pn = self._nodes[idx - 1]
if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
idx -= 1
else:
break
self._nodes.insert(idx, newNode)
self._nodes.insert(idx, new_node)
def append(self, node):
if isinstance(node, list) or isinstance(node, tuple):
......@@ -249,7 +256,7 @@ class Block(Node):
node.parent = self
self._nodes.append(node)
def takeChildNodes(self):
def take_child_nodes(self):
tmp = self._nodes
self._nodes = []
return tmp
......@@ -266,19 +273,19 @@ class Block(Node):
self._nodes.insert(idx, replacements)
@property
def symbolsDefined(self):
def symbols_defined(self):
result = set()
for a in self.args:
result.update(a.symbolsDefined)
result.update(a.symbols_defined)
return result
@property
def undefinedSymbols(self):
def undefined_symbols(self):
result = set()
defined_symbols = set()
for a in self.args:
result.update(a.undefinedSymbols)
defined_symbols.update(a.symbolsDefined)
result.update(a.undefined_symbols)
defined_symbols.update(a.symbols_defined)
return result - defined_symbols
def __str__(self):
......@@ -289,10 +296,10 @@ class Block(Node):
class PragmaBlock(Block):
def __init__(self, pragmaLine, listOfNodes):
super(PragmaBlock, self).__init__(listOfNodes)
self.pragmaLine = pragmaLine
for n in listOfNodes:
def __init__(self, pragma_line, nodes):
super(PragmaBlock, self).__init__(nodes)
self.pragmaLine = pragma_line
for n in nodes:
n.parent = self
def __repr__(self):
......@@ -302,18 +309,19 @@ class PragmaBlock(Block):
class LoopOverCoordinate(Node):
LOOP_COUNTER_NAME_PREFIX = "ctr"
def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
def __init__(self, body, coordinate_to_loop_over, start, stop, step=1):
super(LoopOverCoordinate, self).__init__(parent=None)
self.body = body
body.parent = self
self.coordinateToLoopOver = coordinateToLoopOver
self.coordinateToLoopOver = coordinate_to_loop_over
self.start = start
self.stop = stop
self.step = step
self.body.parent = self
self.prefixLines = []
def newLoopWithDifferentBody(self, newBody):
result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step)
def new_loop_with_different_body(self, new_body):
result = LoopOverCoordinate(new_body, self.coordinateToLoopOver, self.start, self.stop, self.step)
result.prefixLines = [l for l in self.prefixLines]
return result
......@@ -345,84 +353,85 @@ class LoopOverCoordinate(Node):
self.stop = replacement
@property
def symbolsDefined(self):
return set([self.loopCounterSymbol])
def symbols_defined(self):
return {self.loop_counter_symbol}
@property
def undefinedSymbols(self):
result = self.body.undefinedSymbols
def undefined_symbols(self):
result = self.body.undefined_symbols
for possibleSymbol in [self.start, self.stop, self.step]:
if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
result.update(possibleSymbol.atoms(sp.Symbol))
return result - set([self.loopCounterSymbol])
return result - {self.loop_counter_symbol}
@staticmethod
def getLoopCounterName(coordinateToLoopOver):
return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)
def get_loop_counter_name(coordinate_to_loop_over):
return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)
@property
def loopCounterName(self):
return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)
def loop_counter_name(self):
return LoopOverCoordinate.get_loop_counter_name(self.coordinateToLoopOver)
@staticmethod
def isLoopCounterSymbol(symbol):
def is_loop_counter_symbol(symbol):
prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
if not symbol.name.startswith(prefix):
return None
if symbol.dtype != createType('int'):
if symbol.dtype != create_type('int'):
return None
coordinate = int(symbol.name[len(prefix)+1:])
return coordinate
@staticmethod
def getLoopCounterSymbol(coordinateToLoopOver):
return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
def get_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int')
@property
def loopCounterSymbol(self):
return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
def loop_counter_symbol(self):
return LoopOverCoordinate.get_loop_counter_symbol(self.coordinateToLoopOver)
@property
def isOutermostLoop(self):
def is_outermost_loop(self):
from pystencils.transformations import getNextParentOfType
return getNextParentOfType(self, LoopOverCoordinate) is None
@property
def isInnermostLoop(self):
def is_innermost_loop(self):
return len(self.atoms(LoopOverCoordinate)) == 0
def __str__(self):
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start,
self.loopCounterName, self.stop,
self.loopCounterName, self.step,
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
self.loop_counter_name, self.stop,
self.loop_counter_name, self.step,
("\t" + "\t".join(str(self.body).splitlines(True))))
def __repr__(self):
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start,
self.loopCounterName, self.stop,
self.loopCounterName, self.step)
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
self.loop_counter_name, self.stop,
self.loop_counter_name, self.step)
class SympyAssignment(Node):
def __init__(self, lhsSymbol, rhsTerm, isConst=True):
self._lhsSymbol = lhsSymbol
self.rhs = rhsTerm
def __init__(self, lhs_symbol, rhs_expr, is_const=True):
super(SympyAssignment, self).__init__(parent=None)
self._lhsSymbol = lhs_symbol
self.rhs = rhs_expr
self._isDeclaration = True
isCast = self._lhsSymbol.func == castFunc
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast:
is_cast = self._lhsSymbol.func == castFunc
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or is_cast:
self._isDeclaration = False
self._isConst = isConst
self._isConst = is_const
@property
def lhs(self):
return self._lhsSymbol
@lhs.setter
def lhs(self, newValue):
self._lhsSymbol = newValue
def lhs(self, new_value):
self._lhsSymbol = new_value
self._isDeclaration = True
isCast = self._lhsSymbol.func == castFunc
if isinstance