Commit 3bcfac93 authored by Martin Bauer's avatar Martin Bauer
Browse files

PEP8 naming

parent ef924b18
...@@ -2,7 +2,7 @@ from pystencils.field import Field, FieldType, extractCommonSubexpressions ...@@ -2,7 +2,7 @@ from pystencils.field import Field, FieldType, extractCommonSubexpressions
from pystencils.data_types import TypedSymbol from pystencils.data_types import TypedSymbol
from pystencils.slicing import makeSlice from pystencils.slicing import makeSlice
from pystencils.kernelcreation import createKernel, createIndexedKernel from pystencils.kernelcreation import createKernel, createIndexedKernel
from pystencils.display_utils import showCode, toDot from pystencils.display_utils import show_code, to_dot
from pystencils.assignment_collection import AssignmentCollection from pystencils.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator from pystencils.sympyextensions import SymbolCreator
...@@ -11,7 +11,7 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions', ...@@ -11,7 +11,7 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'TypedSymbol', 'TypedSymbol',
'makeSlice', 'makeSlice',
'createKernel', 'createIndexedKernel', 'createKernel', 'createIndexedKernel',
'showCode', 'toDot', 'show_code', 'to_dot',
'AssignmentCollection', 'AssignmentCollection',
'Assignment', 'Assignment',
'SymbolCreator'] 'SymbolCreator']
# -*- coding: utf-8 -*-
from sympy.codegen.ast import Assignment from sympy.codegen.ast import Assignment
from sympy.printing.latex import LatexPrinter from sympy.printing.latex import LatexPrinter
...@@ -11,4 +12,9 @@ def print_assignment_latex(printer, expr): ...@@ -11,4 +12,9 @@ def print_assignment_latex(printer, expr):
return f"{printed_lhs} \leftarrow {printed_rhs}" return f"{printed_lhs} \leftarrow {printed_rhs}"
def assignment_str(assignment):
return f"{assignment.lhs}{assignment.rhs}"
Assignment.__str__ = assignment_str
LatexPrinter._print_Assignment = print_assignment_latex LatexPrinter._print_Assignment = print_assignment_latex
from pystencils.assignment_collection.assignment_collection import AssignmentCollection from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy
from pystencils.assignment_collection.simplifications import sympy_cse, sympy_cse_on_assignment_list, \
apply_to_all_assignments, apply_on_all_subexpressions, subexpression_substitution_in_existing_subexpressions, \
subexpression_substitution_in_main_assignments, add_subexpressions_for_divisions
__all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions']
...@@ -61,7 +61,7 @@ class AssignmentCollection: ...@@ -61,7 +61,7 @@ class AssignmentCollection:
left hand side symbol (which could have been generated) left hand side symbol (which could have been generated)
""" """
if lhs is None: if lhs is None:
lhs = sp.Dummy() lhs = next(self.subexpression_symbol_generator)
eq = Assignment(lhs, rhs) eq = Assignment(lhs, rhs)
self.subexpressions.append(eq) self.subexpressions.append(eq)
if topological_sort: if topological_sort:
...@@ -135,25 +135,6 @@ class AssignmentCollection: ...@@ -135,25 +135,6 @@ class AssignmentCollection:
return handled_symbols return handled_symbols
def get(self, symbols: Sequence[sp.Symbol], from_main_assignments_only=False) -> List[Assignment]:
"""Extracts all assignments that have a left hand side that is contained in the symbols parameter.
Args:
symbols: return assignments that have one of these symbols as left hand side
from_main_assignments_only: search only in main assignments (exclude subexpressions)
"""
if not hasattr(symbols, "__len__"):
symbols = set(symbols)
else:
symbols = set(symbols)
if not from_main_assignments_only:
assignments_to_search = self.all_assignments
else:
assignments_to_search = self.main_assignments
return [assignment for assignment in assignments_to_search if assignment.lhs in symbols]
def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]]=None, module=None): def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]]=None, module=None):
"""Returns a python function to evaluate this equation collection. """Returns a python function to evaluate this equation collection.
...@@ -343,26 +324,25 @@ class AssignmentCollection: ...@@ -343,26 +324,25 @@ class AssignmentCollection:
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments]) return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments])
def __str__(self): def __str__(self):
result = "Subexpressions\n" result = "Subexpressions:\n"
for eq in self.subexpressions: for eq in self.subexpressions:
result += str(eq) + "\n" result += f"\t{eq}\n"
result += "Main Assignments\n" result += "Main Assignments:\n"
for eq in self.main_assignments: for eq in self.main_assignments:
result += str(eq) + "\n" result += f"{eq}\n"
return result return result
class SymbolGen: class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ...""" """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self): def __init__(self, symbol="xi"):
self._ctr = 0 self._ctr = 0
self._symbol = symbol
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
name = f"{self._symbol}_{self._ctr}"
self._ctr += 1 self._ctr += 1
return sp.Symbol("xi_" + str(self._ctr)) return sp.Symbol(name)
def next(self):
return self.__next__()
import sympy as sp import sympy as sp
from typing import Callable, List from typing import Callable, List
from pystencils import Assignment, AssignmentCollection from pystencils.assignment import Assignment
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.sympyextensions import subs_additive from pystencils.sympyextensions import subs_additive
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
ec = AssignmentCollection(assignments, [])
return sympy_cse(ec).all_assignments
def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
"""Searches for common subexpressions inside the equation collection. """Searches for common subexpressions inside the equation collection.
...@@ -32,21 +27,28 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: ...@@ -32,21 +27,28 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
return ac.copy(modified_update_equations, new_subexpressions) return ac.copy(modified_update_equations, new_subexpressions)
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
ec = AssignmentCollection(assignments, [])
return sympy_cse(ec).all_assignments
def apply_to_all_assignments(assignment_collection: AssignmentCollection, def apply_to_all_assignments(assignment_collection: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies sympy expand operation to all equations in collection""" """Applies sympy expand operation to all equations in collection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
return assignment_collection.copy(result) return assignment_collection.copy(result)
def apply_on_all_subexpressions(ac: AssignmentCollection, def apply_on_all_subexpressions(ac: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies the given operation on all subexpressions of the AssignmentCollection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result) return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection: def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
"""Goes through the subexpressions list and replaces the term in the following subexpressions""" """Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = [] result = []
for outerCtr, s in enumerate(ac.subexpressions): for outerCtr, s in enumerate(ac.subexpressions):
new_rhs = s.rhs new_rhs = s.rhs
......
This diff is collapsed.
from .cbackend import generateC from .cbackend import print_c
try: try:
from .dot import dotprint from .dot import print_dot
from .llvm import generateLLVM from .llvm import generateLLVM
except ImportError: except ImportError:
pass pass
import sympy as sp import sympy as sp
from collections import namedtuple
from pystencils.bitoperations import xor, rightShift, leftShift, bitwiseAnd, bitwiseOr from sympy.core import S
from typing import Optional
try: try:
from sympy.utilities.codegen import CCodePrinter
except ImportError:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from collections import namedtuple from pystencils.bitoperations import xor, rightShift, leftShift, bitwiseAnd, bitwiseOr
from sympy.core.mul import _keep_coeff
from sympy.core import S
from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
from pystencils.data_types import createType, PointerType, getTypeOfExpression, VectorType, castFunc from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, castFunc
from pystencils.backends.simd_instruction_sets import selectedInstructionSet from pystencils.backends.simd_instruction_sets import selectedInstructionSet
__all__ = ['print_c']
def generateC(astNode, signatureOnly=False):
""" def print_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str:
Prints the abstract syntax tree as C function """Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
functions.
Args:
ast_node:
signature_only:
use_float_constants:
Returns:
C-like code for the ast node and its descendants
""" """
fieldTypes = set([f.dtype for f in astNode.fieldsAccessed]) if use_float_constants is None:
useFloatConstants = createType("double") not in fieldTypes field_types = set(o.field.dtype for o in ast_node.atoms(ResolvedFieldAccess))
double = create_type('double')
use_float_constants = double not in field_types
vectorIS = selectedInstructionSet['double'] vector_is = selectedInstructionSet['double']
printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly, vectorInstructionSet=vectorIS) printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only,
return printer(astNode) vector_instruction_set=vector_is)
return printer(ast_node)
def getHeaders(astNode): def get_headers(ast_node):
headers = set() headers = set()
if hasattr(astNode, 'headers'): if hasattr(ast_node, 'headers'):
headers.update(astNode.headers) headers.update(ast_node.headers)
elif isinstance(astNode, SympyAssignment): elif isinstance(ast_node, SympyAssignment):
if type(getTypeOfExpression(astNode.rhs)) is VectorType: if type(get_type_of_expression(ast_node.rhs)) is VectorType:
headers.update(selectedInstructionSet['double']['headers']) headers.update(selectedInstructionSet['double']['headers'])
for a in astNode.args: for a in ast_node.args:
if isinstance(a, Node): if isinstance(a, Node):
headers.update(getHeaders(a)) headers.update(get_headers(a))
return headers return headers
...@@ -48,10 +62,11 @@ def getHeaders(astNode): ...@@ -48,10 +62,11 @@ def getHeaders(astNode):
class CustomCppCode(Node): class CustomCppCode(Node):
def __init__(self, code, symbolsRead, symbolsDefined): def __init__(self, code, symbols_read, symbols_defined, parent=None):
super(CustomCppCode, self).__init__(parent=parent)
self._code = "\n" + code self._code = "\n" + code
self._symbolsRead = set(symbolsRead) self._symbolsRead = set(symbols_read)
self._symbolsDefined = set(symbolsDefined) self._symbolsDefined = set(symbols_defined)
self.headers = [] self.headers = []
@property @property
...@@ -63,75 +78,78 @@ class CustomCppCode(Node): ...@@ -63,75 +78,78 @@ class CustomCppCode(Node):
return [] return []
@property @property
def symbolsDefined(self): def symbols_defined(self):
return self._symbolsDefined return self._symbolsDefined
@property @property
def undefinedSymbols(self): def undefined_symbols(self):
return self.symbolsDefined - self._symbolsRead return self.symbols_defined - self._symbolsRead
class PrintNode(CustomCppCode): class PrintNode(CustomCppCode):
def __init__(self, symbolToPrint): # noinspection SpellCheckingInspection
code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name) def __init__(self, symbol_to_print):
super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set()) code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
self.headers.append("<iostream>") self.headers.append("<iostream>")
# ------------------------------------------- Printer ------------------------------------------------------------------ # ------------------------------------------- Printer ------------------------------------------------------------------
class CBackend(object): # noinspection PyPep8Naming
class CBackend:
def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None): def __init__(self, constants_as_floats=False, sympy_printer=None,
if sympyPrinter is None: signature_only=False, vector_instruction_set=None):
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats) if sympy_printer is None:
if vectorInstructionSet is not None: self.sympyPrinter = CustomSympyPrinter(constants_as_floats)
self.sympyPrinter = VectorizedCustomSympyPrinter(vectorInstructionSet, constantsAsFloats) if vector_instruction_set is not None:
self.sympyPrinter = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats)
else: else:
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats) self.sympyPrinter = CustomSympyPrinter(constants_as_floats)
else: else:
self.sympyPrinter = sympyPrinter self.sympyPrinter = sympy_printer
self._vectorInstructionSet = vectorInstructionSet self._vectorInstructionSet = vector_instruction_set
self._indent = " " self._indent = " "
self._signatureOnly = signatureOnly self._signatureOnly = signature_only
def __call__(self, node): def __call__(self, node):
prevIs = VectorType.instructionSet prev_is = VectorType.instructionSet
VectorType.instructionSet = self._vectorInstructionSet VectorType.instructionSet = self._vectorInstructionSet
result = str(self._print(node)) result = str(self._print(node))
VectorType.instructionSet = prevIs VectorType.instructionSet = prev_is
return result return result
def _print(self, node): def _print(self, node):
for cls in type(node).__mro__: for cls in type(node).__mro__:
methodName = "_print_" + cls.__name__ method_name = "_print_" + cls.__name__
if hasattr(self, methodName): if hasattr(self, method_name):
return getattr(self, methodName)(node) return getattr(self, method_name)(node)
raise NotImplementedError("CBackend does not support node of type " + cls.__name__) raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
def _print_KernelFunction(self, node): def _print_KernelFunction(self, node):
functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters] function_arguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments)) func_declaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(function_arguments))
if self._signatureOnly: if self._signatureOnly:
return funcDeclaration return func_declaration
body = self._print(node.body) body = self._print(node.body)
return funcDeclaration + "\n" + body return func_declaration + "\n" + body
def _print_Block(self, node): def _print_Block(self, node):
blockContents = "\n".join([self._print(child) for child in node.args]) block_contents = "\n".join([self._print(child) for child in node.args])
return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True))) return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
def _print_PragmaBlock(self, node): def _print_PragmaBlock(self, node):
return "%s\n%s" % (node.pragmaLine, self._print_Block(node)) return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
def _print_LoopOverCoordinate(self, node): def _print_LoopOverCoordinate(self, node):
counterVar = node.loopCounterName counter_symbol = node.loop_counter_name
start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start)) start = "int %s = %s" % (counter_symbol, self.sympyPrinter.doprint(node.start))
condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop)) condition = "%s < %s" % (counter_symbol, self.sympyPrinter.doprint(node.stop))
update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),) update = "%s += %s" % (counter_symbol, self.sympyPrinter.doprint(node.step),)
loopStr = "for (%s; %s; %s)" % (start, condition, update) loopStr = "for (%s; %s; %s)" % (start, condition, update)
prefix = "\n".join(node.prefixLines) prefix = "\n".join(node.prefixLines)
...@@ -140,12 +158,12 @@ class CBackend(object): ...@@ -140,12 +158,12 @@ class CBackend(object):
return "%s%s\n%s" % (prefix, loopStr, self._print(node.body)) return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
if node.isDeclaration: if node.is_declaration:
dtype = "const " + str(node.lhs.dtype) + " " if node.isConst else str(node.lhs.dtype) + " " data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) return "%s %s = %s;" % (data_type, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
else: else:
lhsType = getTypeOfExpression(node.lhs) lhs_type = get_type_of_expression(node.lhs)
if type(lhsType) is VectorType and node.lhs.func == castFunc: if type(lhs_type) is VectorType and node.lhs.func == castFunc:
return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]), return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]),
self.sympyPrinter.doprint(node.rhs)) + ';' self.sympyPrinter.doprint(node.rhs)) + ';'
else: else:
...@@ -153,31 +171,33 @@ class CBackend(object): ...@@ -153,31 +171,33 @@ class CBackend(object):
def _print_TemporaryMemoryAllocation(self, node): def _print_TemporaryMemoryAllocation(self, node):
return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name), return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name),
node.symbol.dtype.baseType, self.sympyPrinter.doprint(node.size)) node.symbol.dtype.base_type, self.sympyPrinter.doprint(node.size))
def _print_TemporaryMemoryFree(self, node): def _print_TemporaryMemoryFree(self, node):
return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),) return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),)
def _print_CustomCppCode(self, node): @staticmethod
def _print_CustomCppCode(node):
return node.code return node.code
def _print_Conditional(self, node): def _print_Conditional(self, node):
conditionExpr = self.sympyPrinter.doprint(node.conditionExpr) condition_expr = self.sympyPrinter.doprint(node.conditionExpr)
trueBlock = self._print_Block(node.trueBlock) true_block = self._print_Block(node.trueBlock)
result = "if (%s)\n%s " % (conditionExpr, trueBlock) result = "if (%s)\n%s " % (condition_expr, true_block)
if node.falseBlock: if node.falseBlock:
falseBlock = self._print_Block(node.falseBlock) false_block = self._print_Block(node.falseBlock)
result += "else " + falseBlock result += "else " + false_block
return result return result
# ------------------------------------------ Helper function & classes ------------------------------------------------- # ------------------------------------------ Helper function & classes -------------------------------------------------
# noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter): class CustomSympyPrinter(CCodePrinter):
def __init__(self, constantsAsFloats=False): def __init__(self, constants_as_floats=False):
self._constantsAsFloats = constantsAsFloats self._constantsAsFloats = constants_as_floats
super(CustomSympyPrinter, self).__init__() super(CustomSympyPrinter, self).__init__()
def _print_Pow(self, expr): def _print_Pow(self, expr):
...@@ -210,7 +230,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -210,7 +230,7 @@ class CustomSympyPrinter(CCodePrinter):
return res return res
def _print_Function(self, expr): def _print_Function(self, expr):
functionMap = { function_map = {
xor: '^', xor: '^',
rightShift: '>>', rightShift: '>>',
leftShift: '<<', leftShift: '<<',
...@@ -218,33 +238,34 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -218,33 +238,34 @@ class CustomSympyPrinter(CCodePrinter):
bitwiseAnd: '&', bitwiseAnd: '&',
} }