Commit 3bcfac93 authored by Martin Bauer's avatar Martin Bauer
Browse files

PEP8 naming

parent ef924b18
......@@ -2,7 +2,7 @@ from pystencils.field import Field, FieldType, extractCommonSubexpressions
from pystencils.data_types import TypedSymbol
from pystencils.slicing import makeSlice
from pystencils.kernelcreation import createKernel, createIndexedKernel
from pystencils.display_utils import showCode, toDot
from pystencils.display_utils import show_code, to_dot
from pystencils.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
......@@ -11,7 +11,7 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'TypedSymbol',
'makeSlice',
'createKernel', 'createIndexedKernel',
'showCode', 'toDot',
'show_code', 'to_dot',
'AssignmentCollection',
'Assignment',
'SymbolCreator']
# -*- coding: utf-8 -*-
from sympy.codegen.ast import Assignment
from sympy.printing.latex import LatexPrinter
......@@ -11,4 +12,9 @@ def print_assignment_latex(printer, expr):
return f"{printed_lhs} \leftarrow {printed_rhs}"
def assignment_str(assignment):
return f"{assignment.lhs}{assignment.rhs}"
Assignment.__str__ = assignment_str
LatexPrinter._print_Assignment = print_assignment_latex
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy
from pystencils.assignment_collection.simplifications import sympy_cse, sympy_cse_on_assignment_list, \
apply_to_all_assignments, apply_on_all_subexpressions, subexpression_substitution_in_existing_subexpressions, \
subexpression_substitution_in_main_assignments, add_subexpressions_for_divisions
__all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions']
......@@ -61,7 +61,7 @@ class AssignmentCollection:
left hand side symbol (which could have been generated)
"""
if lhs is None:
lhs = sp.Dummy()
lhs = next(self.subexpression_symbol_generator)
eq = Assignment(lhs, rhs)
self.subexpressions.append(eq)
if topological_sort:
......@@ -135,25 +135,6 @@ class AssignmentCollection:
return handled_symbols
def get(self, symbols: Sequence[sp.Symbol], from_main_assignments_only=False) -> List[Assignment]:
"""Extracts all assignments that have a left hand side that is contained in the symbols parameter.
Args:
symbols: return assignments that have one of these symbols as left hand side
from_main_assignments_only: search only in main assignments (exclude subexpressions)
"""
if not hasattr(symbols, "__len__"):
symbols = set(symbols)
else:
symbols = set(symbols)
if not from_main_assignments_only:
assignments_to_search = self.all_assignments
else:
assignments_to_search = self.main_assignments
return [assignment for assignment in assignments_to_search if assignment.lhs in symbols]
def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]]=None, module=None):
"""Returns a python function to evaluate this equation collection.
......@@ -343,26 +324,25 @@ class AssignmentCollection:
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments])
def __str__(self):
result = "Subexpressions\n"
result = "Subexpressions:\n"
for eq in self.subexpressions:
result += str(eq) + "\n"
result += "Main Assignments\n"
result += f"\t{eq}\n"
result += "Main Assignments:\n"
for eq in self.main_assignments:
result += str(eq) + "\n"
result += f"{eq}\n"
return result
class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self):
def __init__(self, symbol="xi"):
self._ctr = 0
self._symbol = symbol
def __iter__(self):
return self
def __next__(self):
name = f"{self._symbol}_{self._ctr}"
self._ctr += 1
return sp.Symbol("xi_" + str(self._ctr))
def next(self):
return self.__next__()
return sp.Symbol(name)
import sympy as sp
from typing import Callable, List
from pystencils import Assignment, AssignmentCollection
from pystencils.assignment import Assignment
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.sympyextensions import subs_additive
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
ec = AssignmentCollection(assignments, [])
return sympy_cse(ec).all_assignments
def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
"""Searches for common subexpressions inside the equation collection.
......@@ -32,21 +27,28 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
return ac.copy(modified_update_equations, new_subexpressions)
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
ec = AssignmentCollection(assignments, [])
return sympy_cse(ec).all_assignments
def apply_to_all_assignments(assignment_collection: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies sympy expand operation to all equations in collection"""
"""Applies sympy expand operation to all equations in collection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
return assignment_collection.copy(result)
def apply_on_all_subexpressions(ac: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies the given operation on all subexpressions of the AssignmentCollection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
"""Goes through the subexpressions list and replaces the term in the following subexpressions"""
"""Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = []
for outerCtr, s in enumerate(ac.subexpressions):
new_rhs = s.rhs
......
This diff is collapsed.
from .cbackend import generateC
from .cbackend import print_c
try:
from .dot import dotprint
from .dot import print_dot
from .llvm import generateLLVM
except ImportError:
pass
import sympy as sp
from pystencils.bitoperations import xor, rightShift, leftShift, bitwiseAnd, bitwiseOr
from collections import namedtuple
from sympy.core import S
from typing import Optional
try:
from sympy.utilities.codegen import CCodePrinter
except ImportError:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from collections import namedtuple
from sympy.core.mul import _keep_coeff
from sympy.core import S
from pystencils.bitoperations import xor, rightShift, leftShift, bitwiseAnd, bitwiseOr
from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
from pystencils.data_types import createType, PointerType, getTypeOfExpression, VectorType, castFunc
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, castFunc
from pystencils.backends.simd_instruction_sets import selectedInstructionSet
__all__ = ['print_c']
def generateC(astNode, signatureOnly=False):
"""
Prints the abstract syntax tree as C function
def print_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str:
"""Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
functions.
Args:
ast_node:
signature_only:
use_float_constants:
Returns:
C-like code for the ast node and its descendants
"""
fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
useFloatConstants = createType("double") not in fieldTypes
if use_float_constants is None:
field_types = set(o.field.dtype for o in ast_node.atoms(ResolvedFieldAccess))
double = create_type('double')
use_float_constants = double not in field_types
vectorIS = selectedInstructionSet['double']
printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly, vectorInstructionSet=vectorIS)
return printer(astNode)
vector_is = selectedInstructionSet['double']
printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only,
vector_instruction_set=vector_is)
return printer(ast_node)
def getHeaders(astNode):
def get_headers(ast_node):
headers = set()
if hasattr(astNode, 'headers'):
headers.update(astNode.headers)
elif isinstance(astNode, SympyAssignment):
if type(getTypeOfExpression(astNode.rhs)) is VectorType:
if hasattr(ast_node, 'headers'):
headers.update(ast_node.headers)
elif isinstance(ast_node, SympyAssignment):
if type(get_type_of_expression(ast_node.rhs)) is VectorType:
headers.update(selectedInstructionSet['double']['headers'])
for a in astNode.args:
for a in ast_node.args:
if isinstance(a, Node):
headers.update(getHeaders(a))
headers.update(get_headers(a))
return headers
......@@ -48,10 +62,11 @@ def getHeaders(astNode):
class CustomCppCode(Node):
def __init__(self, code, symbolsRead, symbolsDefined):
def __init__(self, code, symbols_read, symbols_defined, parent=None):
super(CustomCppCode, self).__init__(parent=parent)
self._code = "\n" + code
self._symbolsRead = set(symbolsRead)
self._symbolsDefined = set(symbolsDefined)
self._symbolsRead = set(symbols_read)
self._symbolsDefined = set(symbols_defined)
self.headers = []
@property
......@@ -63,75 +78,78 @@ class CustomCppCode(Node):
return []
@property
def symbolsDefined(self):
def symbols_defined(self):
return self._symbolsDefined
@property
def undefinedSymbols(self):
return self.symbolsDefined - self._symbolsRead
def undefined_symbols(self):
return self.symbols_defined - self._symbolsRead
class PrintNode(CustomCppCode):
def __init__(self, symbolToPrint):
code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())
# noinspection SpellCheckingInspection
def __init__(self, symbol_to_print):
code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
self.headers.append("<iostream>")
# ------------------------------------------- Printer ------------------------------------------------------------------
class CBackend(object):
# noinspection PyPep8Naming
class CBackend:
def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None):
if sympyPrinter is None:
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
if vectorInstructionSet is not None:
self.sympyPrinter = VectorizedCustomSympyPrinter(vectorInstructionSet, constantsAsFloats)
def __init__(self, constants_as_floats=False, sympy_printer=None,
signature_only=False, vector_instruction_set=None):
if sympy_printer is None:
self.sympyPrinter = CustomSympyPrinter(constants_as_floats)
if vector_instruction_set is not None:
self.sympyPrinter = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats)
else:
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
self.sympyPrinter = CustomSympyPrinter(constants_as_floats)
else:
self.sympyPrinter = sympyPrinter
self.sympyPrinter = sympy_printer
self._vectorInstructionSet = vectorInstructionSet
self._vectorInstructionSet = vector_instruction_set
self._indent = " "
self._signatureOnly = signatureOnly
self._signatureOnly = signature_only
def __call__(self, node):
prevIs = VectorType.instructionSet
prev_is = VectorType.instructionSet
VectorType.instructionSet = self._vectorInstructionSet
result = str(self._print(node))
VectorType.instructionSet = prevIs
VectorType.instructionSet = prev_is
return result
def _print(self, node):
for cls in type(node).__mro__:
methodName = "_print_" + cls.__name__
if hasattr(self, methodName):
return getattr(self, methodName)(node)
raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
method_name = "_print_" + cls.__name__
if hasattr(self, method_name):
return getattr(self, method_name)(node)
raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
def _print_KernelFunction(self, node):
functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
function_arguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
func_declaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(function_arguments))
if self._signatureOnly:
return funcDeclaration
return func_declaration
body = self._print(node.body)
return funcDeclaration + "\n" + body
return func_declaration + "\n" + body
def _print_Block(self, node):
blockContents = "\n".join([self._print(child) for child in node.args])
return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)))
block_contents = "\n".join([self._print(child) for child in node.args])
return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
def _print_PragmaBlock(self, node):
return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
def _print_LoopOverCoordinate(self, node):
counterVar = node.loopCounterName
start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),)
counter_symbol = node.loop_counter_name
start = "int %s = %s" % (counter_symbol, self.sympyPrinter.doprint(node.start))
condition = "%s < %s" % (counter_symbol, self.sympyPrinter.doprint(node.stop))
update = "%s += %s" % (counter_symbol, self.sympyPrinter.doprint(node.step),)
loopStr = "for (%s; %s; %s)" % (start, condition, update)
prefix = "\n".join(node.prefixLines)
......@@ -140,12 +158,12 @@ class CBackend(object):
return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
def _print_SympyAssignment(self, node):
if node.isDeclaration:
dtype = "const " + str(node.lhs.dtype) + " " if node.isConst else str(node.lhs.dtype) + " "
return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
if node.is_declaration:
data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
return "%s %s = %s;" % (data_type, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
else:
lhsType = getTypeOfExpression(node.lhs)
if type(lhsType) is VectorType and node.lhs.func == castFunc:
lhs_type = get_type_of_expression(node.lhs)
if type(lhs_type) is VectorType and node.lhs.func == castFunc:
return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]),
self.sympyPrinter.doprint(node.rhs)) + ';'
else:
......@@ -153,31 +171,33 @@ class CBackend(object):
def _print_TemporaryMemoryAllocation(self, node):
return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name),
node.symbol.dtype.baseType, self.sympyPrinter.doprint(node.size))
node.symbol.dtype.base_type, self.sympyPrinter.doprint(node.size))
def _print_TemporaryMemoryFree(self, node):
return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),)
def _print_CustomCppCode(self, node):
@staticmethod
def _print_CustomCppCode(node):
return node.code
def _print_Conditional(self, node):
conditionExpr = self.sympyPrinter.doprint(node.conditionExpr)
trueBlock = self._print_Block(node.trueBlock)
result = "if (%s)\n%s " % (conditionExpr, trueBlock)
condition_expr = self.sympyPrinter.doprint(node.conditionExpr)
true_block = self._print_Block(node.trueBlock)
result = "if (%s)\n%s " % (condition_expr, true_block)
if node.falseBlock:
falseBlock = self._print_Block(node.falseBlock)
result += "else " + falseBlock
false_block = self._print_Block(node.falseBlock)
result += "else " + false_block
return result
# ------------------------------------------ Helper function & classes -------------------------------------------------
# noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter):
def __init__(self, constantsAsFloats=False):
self._constantsAsFloats = constantsAsFloats
def __init__(self, constants_as_floats=False):
self._constantsAsFloats = constants_as_floats
super(CustomSympyPrinter, self).__init__()
def _print_Pow(self, expr):
......@@ -210,7 +230,7 @@ class CustomSympyPrinter(CCodePrinter):
return res
def _print_Function(self, expr):
functionMap = {
function_map = {
xor: '^',
rightShift: '>>',
leftShift: '<<',
......@@ -218,33 +238,34 @@ class CustomSympyPrinter(CCodePrinter):
bitwiseAnd: '&',
}
if expr.func == castFunc:
arg, type = expr.args
return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
elif expr.func in functionMap:
return "(%s %s %s)" % (self._print(expr.args[0]), functionMap[expr.func], self._print(expr.args[1]))
arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
elif expr.func in function_map:
return "(%s %s %s)" % (self._print(expr.args[0]), function_map[expr.func], self._print(expr.args[1]))
else:
return super(CustomSympyPrinter, self)._print_Function(expr)
# noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instructionSet, constantsAsFloats=False):
super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats)
self.instructionSet = instructionSet
def __init__(self, instruction_set, constants_as_floats=False):
super(VectorizedCustomSympyPrinter, self).__init__(constants_as_floats)
self.instructionSet = instruction_set
def _scalarFallback(self, funcName, expr, *args, **kwargs):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return getattr(super(VectorizedCustomSympyPrinter, self), funcName)(expr, *args, **kwargs)
def _scalarFallback(self, func_name, expr, *args, **kwargs):
expr_type = get_type_of_expression(expr)
if type(expr_type) is not VectorType:
return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
else:
assert self.instructionSet['width'] == exprType.width
assert self.instructionSet['width'] == expr_type.width
return None
def _print_Function(self, expr):
if expr.func == castFunc:
arg, dtype = expr.args
if type(dtype) is VectorType:
arg, data_type = expr.args
if type(data_type) is VectorType:
if type(arg) is ResolvedFieldAccess:
return self.instructionSet['loadU'].format("& " + self._print(arg))
else:
......@@ -257,10 +278,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if result:
return result
argStrings = [self._print(a) for a in expr.args]
assert len(argStrings) > 0
result = argStrings[0]
for item in argStrings[1:]:
arg_strings = [self._print(a) for a in expr.args]
assert len(arg_strings) > 0
result = arg_strings[0]
for item in arg_strings[1:]:
result = self.instructionSet['&'].format(result, item)
return result
......@@ -269,10 +290,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if result:
return result
argStrings = [self._print(a) for a in expr.args]
assert len(argStrings) > 0
result = argStrings[0]
for item in argStrings[1:]:
arg_strings = [self._print(a) for a in expr.args]
assert len(arg_strings) > 0
result = arg_strings[0]
for item in arg_strings[1:]:
result = self.instructionSet['|'].format(result, item)
return result
......@@ -284,7 +305,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
summands = []
for term in expr.args:
if term.func == sp.Mul:
sign, t = self._print_Mul(term, insideAdd=True)
sign, t = self._print_Mul(term, inside_add=True)
else:
t = self._print(term)
sign = 1
......@@ -318,7 +339,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
else:
raise ValueError("Generic exponential not supported")
def _print_Mul(self, expr, insideAdd=False):
def _print_Mul(self, expr, inside_add=False):
# noinspection PyProtectedMember
from sympy.core.mul import _keep_coeff
result = self._scalarFallback('_print_Mul', expr)
if result:
return result
......@@ -359,7 +383,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
denominator_str = self.instructionSet['*'].format(denominator_str, item)
result = self.instructionSet['/'].format(result, denominator_str)
if insideAdd:
if inside_add:
return sign, result
else:
if sign < 0:
......@@ -384,7 +408,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if result:
return result
if expr.args[-1].cond != True:
if expr.args[-1].cond.args[0] is not sp.sympify(True):
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
......@@ -395,5 +419,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._print(expr.args[-1][0])
for trueExpr, condition in reversed(expr.args[:-1]):
# noinspection SpellCheckingInspection
result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition))
return result
......@@ -3,13 +3,14 @@ from graphviz import Digraph, lang
import graphviz
# noinspection PyPep8Naming
class DotPrinter(Printer):
"""
A printer which converts ast to DOT (graph description language).
"""
def __init__(self, nodeToStrFunction, full, **kwargs):
def __init__(self, node_to_str_function, full, **kwargs):
super(DotPrinter, self).__init__()
self._nodeToStrFunction = nodeToStrFunction
self._nodeToStrFunction = node_to_str_function
self.full = full
self.dot = Digraph(**kwargs)
self.dot.quote_edge = lang.quote
......@@ -33,7 +34,8 @@ class DotPrinter(Printer):
self.dot.edge(str(id(block)), str(id(node)))