Commit 3278b6ba authored by Jan Hoenig's avatar Jan Hoenig
Browse files

Added repr for various ast classes.

parent 6fffa85a
......@@ -4,7 +4,7 @@ from pystencils.field import Field
from pystencils.typedsymbol import TypedSymbol
class Node:
class Node(object):
"""Base class for all AST nodes"""
def __init__(self, parent=None):
......@@ -35,6 +35,12 @@ class Node:
return result
def parents(self):
return None
def children(self):
return None
class KernelFunction(Node):
......@@ -62,6 +68,9 @@ class KernelFunction(Node):
self.isFieldArgument = True
self.fieldName = name[len(Field.STRIDE_PREFIX):]
def __repr__(self):
return '<{0} {1}>'.format(self.dtype,
def __init__(self, body, functionName="kernel"):
super(KernelFunction, self).__init__()
self._body = body
......@@ -102,6 +111,13 @@ class KernelFunction(Node):
def children(self):
yield self.body
def __repr__(self):
return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters, self.body)
class Block(Node):
def __init__(self, listOfNodes):
......@@ -156,6 +172,12 @@ class Block(Node):
return result
def children(self):
yield self._nodes
def __repr__(self):
return ''.join('\t{!r}\n'.format(node) for node in self._nodes)
class PragmaBlock(Block):
def __init__(self, pragmaLine, listOfNodes):
......@@ -252,6 +274,13 @@ class LoopOverCoordinate(Node):
def coordinateToLoopOver(self):
return self._coordinateToLoopOver
def children(self):
return self.body
def __repr__(self):
return 'loop:{!s} {!s} in {!s}:{!s}:{!s}\n'.format(self.loopCounterName, self.coordinateToLoopOver, self.start,
self.stop, self.step) + '\t{!r}\n'.format(self.body)
class SympyAssignment(Node):
import as ir
import llvmlite.binding as llvm
import logging.config
from sympy.utilities.codegen import CCodePrinter
from pystencils.ast import Node
from sympy.printing.printer import Printer
from sympy import S
# S is numbers?
def generateLLVM(astNode):
return None
class LLVMPrinter(Printer):
"""Convert expressions to LLVM IR"""
def __init__(self, module, builder, fn, *args, **kwargs):
self.func_arg_map = kwargs.pop("func_arg_map", {})
super(LLVMPrinter, self).__init__(*args, **kwargs)
self.fp_type = ir.DoubleType()
#self.integer = ir.IntType(64)
self.module = module
self.builder = builder
self.fn = fn
self.ext_fn = {} # keep track of wrappers to external functions
self.tmp_var = {}
def _add_tmp_var(self, name, value):
self.tmp_var[name] = value
def _print_Number(self, n, **kwargs):
return ir.Constant(self.fp_type, float(n))
def _print_Integer(self, expr):
return ir.Constant(self.fp_type, float(expr.p))
def _print_Symbol(self, s):
val = self.tmp_var.get(s)
if not val:
# look up parameter with name s
val = self.func_arg_map.get(s)
if not val:
raise LookupError("Symbol not found: %s" % s)
return val
def _print_Pow(self, expr):
base0 = self._print(expr.base)
if expr.exp == S.NegativeOne:
return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0)
if expr.exp == S.Half:
fn = self.ext_fn.get("sqrt")
if not fn:
fn_type = ir.FunctionType(self.fp_type, [self.fp_type])
fn = ir.Function(self.module, fn_type, "sqrt")
self.ext_fn["sqrt"] = fn
return, [base0], "sqrt")
if expr.exp == 2:
return self.builder.fmul(base0, base0)
exp0 = self._print(expr.exp)
fn = self.ext_fn.get("pow")
if not fn:
fn_type = ir.FunctionType(self.fp_type, [self.fp_type, self.fp_type])
fn = ir.Function(self.module, fn_type, "pow")
self.ext_fn["pow"] = fn
return, [base0, exp0], "pow")
def _print_Mul(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
for node in nodes[1:]:
e = self.builder.fmul(e, node)
return e
def _print_Add(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
for node in nodes[1:]:
e = self.builder.fadd(e, node)
return e
# TODO - assumes all called functions take one double precision argument.
# Should have a list of math library functions to validate this.
def _print_Function(self, expr):
name = expr.func.__name__
e0 = self._print(expr.args[0])
fn = self.ext_fn.get(name)
if not fn:
fn_type = ir.FunctionType(self.fp_type, [self.fp_type])
fn = ir.Function(self.module, fn_type, name)
self.ext_fn[name] = fn
return, [e0], name)
def emptyPrinter(self, expr):
raise TypeError("Unsupported type for LLVM JIT conversion: %s"
% type(expr))
class Eval(object):
def __init__(self):
llvm.initialize_native_asmprinter() = llvm.Target.from_default_triple()
def compile(self, module):
llvmmod = llvm.parse_assembly(str(module))
logger.debug('=============Function in IR')
# TODO cpu, features, opt
cpu = llvm.get_host_cpu_name()
features = llvm.get_host_cpu_features()
target_machine =, features=features.flatten(), opt=2)
logger.debug('Machine = ' + str(target_machine.target_data))
with open('gen.ll', 'w') as f:
optimize = True
if optimize:
pmb = llvm.create_pass_manager_builder()
pmb.opt_level = 2
pmb.disable_unit_at_a_time = False
pmb.loop_vectorize = True
pmb.slp_vectorize = True
# TODO possible to pass for functions
pm = llvm.create_module_pass_manager()
with open('gen_opt.ll', 'w') as f:
with llvm.create_mcjit_compiler(llvmmod, target_machine) as ee:
logger.debug('==========Machine code')
with open('gen.S', 'w') as f:
with open('gen.o', 'wb') as f:
# fptr = CFUNCTYPE(c_double, c_double, c_double)(ee.get_function_address('add2'))
# result = fptr(2, 3)
# print(result)
return 0
if __name__ == "__main__":
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
"version" : 1,
"disable_existing_loggers" : false,
"formatters" : {
"simple" :{
"format" : "[%(levelname)s]: %(message)s"
"handlers" : {
"console": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "simple",
"stream": "ext://sys.stdout"
"log_file": {
"class": "logging.FileHandler",
"level": "DEBUG",
"formatter": "simple",
"filename": "gen.log",
"mode" : "w",
"encoding": "utf8"
"loggers" : {
"generator" : {
"level" : "DEBUG",
"handlers" : ["console", "log_file"]
\ No newline at end of file
import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
from pystencils.typedsymbol import TypedSymbol
from pystencils.field import Field
import pystencils.ast as ast
def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=(),
iterationSlice=None, ghostLayers=None):
Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
Loops are created according to the field accesses in the equations.
:param listOfEquations: list of sympy equations, containing accesses to :class:`pystencils.field.Field`.
Defining the update rules of the kernel
:param functionName: name of the generated function - only important if generated code is written out
:param typeForSymbol: a map from symbol name to a C type specifier. If not specified all symbols are assumed to
be of type 'double' except symbols which occur on the left hand side of equations where the
right hand side is a sympy Boolean which are assumed to be 'bool' .
:param splitGroups: Specification on how to split up inner loop into multiple loops. For details see
transformation :func:`pystencils.transformation.splitInnerLoop`
:param iterationSlice: if not None, iteration is done only over this slice of the field
:param ghostLayers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
if None, the number of ghost layers is determined automatically and assumed to be equal for a
all dimensions
:return: :class:`pystencils.ast.KernelFunction` node
if not typeForSymbol:
typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
def typeSymbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(, typeForSymbol[])
raise ValueError("Term has to be field access or symbol")
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
allFields = fieldsRead.union(fieldsWritten)
for field in allFields:
for field in fieldsRead - fieldsWritten:
body = ast.Block(assignments)
code = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, ghostLayers=ghostLayers)
if splitGroups:
typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups]
splitInnerLoop(code, typedSplitGroups)
loopOrder = getOptimalLoopOrdering(allFields)
basePointerInfo = [['spatialInner0'], ['spatialInner1']]
basePointerInfos = { parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields}
resolveFieldAccesses(code, fieldToBasePointerInfo=basePointerInfos)
return code
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment