diff --git a/astnodes.py b/astnodes.py index 9a36744db7145b45e941decda8b1c90f436cbe03..91f2e494f6aefdbf48649859cfd30b55fc9a85c0 100644 --- a/astnodes.py +++ b/astnodes.py @@ -1,7 +1,7 @@ import sympy as sp from sympy.tensor import IndexedBase from pystencils.field import Field -from pystencils.types import TypedSymbol +from pystencils.types import TypedSymbol, createType, get_type_from_sympy class Node(object): @@ -294,7 +294,7 @@ class SympyAssignment(Node): self._lhsSymbol = lhsSymbol self.rhs = rhsTerm self._isDeclaration = True - isCast = str(self._lhsSymbol.func).lower() == 'cast' + isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast: self._isDeclaration = False self._isConst = isConst @@ -307,7 +307,7 @@ class SympyAssignment(Node): def lhs(self, newValue): self._lhsSymbol = newValue self._isDeclaration = True - isCast = str(self._lhsSymbol.func).lower() == 'cast' + isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast: self._isDeclaration = False @@ -344,7 +344,8 @@ class SympyAssignment(Node): def replace(self, child, replacement): if child == self.lhs: - self.lhs = child + replacement.parent = self + self.lhs = replacement elif child == self.rhs: replacement.parent = self self.rhs = replacement @@ -394,8 +395,6 @@ class TemporaryMemoryFree(Node): # TODO implement defined & undefinedSymbols - - class Conversion(Node): def __init__(self, child, dtype, parent=None): super(Conversion, self).__init__(parent) @@ -422,9 +421,9 @@ class Conversion(Node): raise set() def __repr__(self): - return '(%s)' % (self.dtype,) + repr(self.args) + return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args) -# TODO everything which is not Atomic expression: Pow) +# TODO Pow _expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'} @@ -480,15 +479,33 @@ class Pow(Expr): class Indexed(Expr): + def __init__(self, args, base, parent=None): + super(Indexed, self).__init__(args, parent) + self.base = base + # Get dtype from label, and unpointer it + self.dtype = createType(base.label.dtype.baseType) + def __repr__(self): return '%s[%s]' % (self.args[0], self.args[1]) -class Number(Node): +class PointerArithmetic(Expr): + def __init__(self, args, pointer, parent=None): + super(PointerArithmetic, self).__init__([args] + [pointer], parent) + self.pointer = pointer + self.offset = args + self.dtype = pointer.dtype + + def __repr__(self): + return '*(%s + %s)' % (self.pointer, self.args) + + +class Number(Node, sp.AtomicExpr): def __init__(self, number, parent=None): super(Number, self).__init__(parent) - self._args = None - self.dtype = dtype + + self.dtype, self.value = get_type_from_sympy(number) + self._args = tuple() @property def args(self): @@ -506,6 +523,12 @@ class Number(Node): raise set() def __repr__(self): - return '(%s)' % (self.dtype,) + repr(self.args) + return repr(self.value) + + def __float__(self): + return float(self.value) + + def __int__(self): + return int(self.value) diff --git a/backends/__init__.py b/backends/__init__.py index 3b23aca7af30894625ab210ec25a9eeb6d14d50b..bd2b52ecba0ee4f95806bf29e6e28c23407a288f 100644 --- a/backends/__init__.py +++ b/backends/__init__.py @@ -1,6 +1,3 @@ -try: - from .llvm import generateLLVM -except ImportError: - pass - +from .llvm import generateLLVM from .cbackend import generateC +from .dot import dotprint diff --git a/backends/dot.py b/backends/dot.py index b90ffc40a8434f16b8beac04495b5731d78d73c7..aac41d02de940ae80fe1ddb43a563eec8aba4e5f 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -6,9 +6,10 @@ class DotPrinter(Printer): """ A printer which converts ast to DOT (graph description language). """ - def __init__(self, nodeToStrFunction, **kwargs): + def __init__(self, nodeToStrFunction, full, **kwargs): super(DotPrinter, self).__init__() self._nodeToStrFunction = nodeToStrFunction + self.full = full self.dot = Digraph(**kwargs) self.dot.quote_edge = lang.quote @@ -30,6 +31,21 @@ class DotPrinter(Printer): def _print_SympyAssignment(self, assignment): self.dot.node(self._nodeToStrFunction(assignment)) + if self.full: + for node in assignment.args: + self._print(node) + for node in assignment.args: + self.dot.edge(self._nodeToStrFunction(assignment), self._nodeToStrFunction(node)) + + def emptyPrinter(self, expr): + if self.full: + self.dot.node(self._nodeToStrFunction(expr)) + for node in expr.args: + self._print(node) + for node in expr.args: + self.dot.edge(self._nodeToStrFunction(expr), self._nodeToStrFunction(node)) + else: + raise NotImplemented('Dotprinter cannot print', expr) def doprint(self, expr): self._print(expr) @@ -48,17 +64,20 @@ def __shortened(node): return "Assignment: " + repr(node.lhs) -def dotprint(ast, view=False, short=False, **kwargs): +def dotprint(node, view=False, short=False, full=False, **kwargs): """ Returns a string which can be used to generate a DOT-graph - :param ast: The ast which should be generated + :param node: The ast which should be generated :param view: Boolen, if rendering of the image directly should occur. + :param short: Uses the __shortened output + :param full: Prints the whole tree with type information :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph :return: string in DOT format """ nodeToStrFunction = __shortened if short else repr - printer = DotPrinter(nodeToStrFunction, **kwargs) - dot = printer.doprint(ast) + nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr) if full else nodeToStrFunction + printer = DotPrinter(nodeToStrFunction, full, **kwargs) + dot = printer.doprint(node) if view: printer.dot.render(view=view) return dot @@ -80,4 +99,4 @@ if __name__ == "__main__": from pystencils.cpu import createKernel ast = createKernel([updateRule]) - print(dotprint(ast, short=True)) \ No newline at end of file + print(dotprint(ast, short=True)) diff --git a/backends/llvm.py b/backends/llvm.py index fe11e77a46291a336e9c984ed741b9641de2c577..3a7e3f1940383062bfdd83f0e1bf4b8e65a9e2ca 100644 --- a/backends/llvm.py +++ b/backends/llvm.py @@ -1,18 +1,19 @@ import llvmlite.ir as ir +import functools from sympy.printing.printer import Printer from sympy import S # S is numbers? from pystencils.llvm.control_flow import Loop +from ..types import createType +from ..astnodes import Indexed -def generateLLVM(ast_node): +def generateLLVM(ast_node, module=ir.Module(), builder=ir.IRBuilder()): """ Prints the ast as llvm code """ - module = ir.Module() - builder = ir.IRBuilder() printer = LLVMPrinter(module, builder) return printer._print(ast_node) @@ -25,6 +26,7 @@ class LLVMPrinter(Printer): self.fp_type = ir.DoubleType() self.fp_pointer = self.fp_type.as_pointer() self.integer = ir.IntType(64) + self.integer_pointer = self.integer.as_pointer() self.void = ir.VoidType() self.module = module self.builder = builder @@ -35,11 +37,16 @@ class LLVMPrinter(Printer): def _add_tmp_var(self, name, value): self.tmp_var[name] = value - def _print_Number(self, n, **kwargs): - return ir.Constant(self.fp_type, float(n)) + def _print_Number(self, n): + if n.dtype == createType("int"): + return ir.Constant(self.integer, int(n)) + elif n.dtype == createType("double"): + return ir.Constant(self.fp_type, float(n)) + else: + raise NotImplementedError("Numbers can only have int and double", n) def _print_Float(self, expr): - return ir.Constant(self.fp_type, float(expr.p)) + return ir.Constant(self.fp_type, expr.p) def _print_Integer(self, expr): return ir.Constant(self.integer, expr.p) @@ -81,24 +88,31 @@ class LLVMPrinter(Printer): def _print_Mul(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] + if expr.dtype == createType('double'): + mul = self.builder.fmul + else: # int TODO others? + mul = self.builder.mul for node in nodes[1:]: - e = self.builder.fmul(e, node) + e = mul(e, node) return e def _print_Add(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] + if expr.dtype == createType('double'): + add = self.builder.fadd + else: # int TODO others? + add = self.builder.add for node in nodes[1:]: - print(e, node) - e = self.builder.fadd(e, node) + e = add(e, node) return e def _print_KernelFunction(self, function): return_type = self.void - # TODO argument in their own call? + # TODO argument in their own call? -> nope parameter_type = [] for parameter in function.parameters: - # TODO what bout ptr shape and stride argument? + # TODO what about ptr shape and stride argument? if parameter.isFieldArgument: parameter_type.append(self.fp_pointer) else: @@ -118,6 +132,7 @@ class LLVMPrinter(Printer): block = fn.append_basic_block(name="entry") self.builder = ir.IRBuilder(block) self._print(function.body) + self.builder.ret_void() self.fn = fn return fn @@ -129,16 +144,57 @@ class LLVMPrinter(Printer): with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step), loop.loopCounterName, loop.loopCounterSymbol.name) as i: self._add_tmp_var(loop.loopCounterSymbol, i) + # TODO remove tmp var self._print(loop.body) def _print_SympyAssignment(self, assignment): expr = self._print(assignment.rhs) - - - - # Should have a list of math library functions to validate this. - - # TODO delete this -> NO this should be a function call + lhs = assignment.lhs + if isinstance(lhs, Indexed): + ptr = self._print(lhs.base.label) + index = self._print(lhs.args[1]) + gep = self.builder.gep(ptr, [index]) + return self.builder.store(expr, gep) + self.func_arg_map[assignment.lhs.name] = expr + return expr + + def _print_Conversion(self, conversion): + node = self._print(conversion.args[0]) + to_dtype = conversion.dtype + from_dtype = conversion.args[0].dtype + # (From, to) + decision = { + (createType("int"), createType("double")): functools.partial(self.builder.sitofp, node, self.fp_type), + (createType("double"), createType("int")): functools.partial(self.builder.fptosi, node, self.integer), + (createType("double *"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), + (createType("int"), createType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), + (createType("double * restrict"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), + (createType("int"), createType("double * restrict")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), + (createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), + (createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), + } + # TODO float, const, restrict + # TODO bitcast, addrspacecast + # print([x for x in decision.keys()]) + # print("Types:") + # print([(type(x), type(y)) for (x, y) in decision.keys()]) + # print("Cast:") + # print((from_dtype, to_dtype)) + return decision[(from_dtype, to_dtype)]() + + def _print_Indexed(self, indexed): + ptr = self._print(indexed.base.label) + index = self._print(indexed.args[1]) + gep = self.builder.gep(ptr, [index]) + return self.builder.load(gep, name=indexed.base.label.name) + + def _print_PointerArithmetic(self, pointer): + ptr = self._print(pointer.pointer) + index = self._print(pointer.offset) + return self.builder.gep(ptr, [index]) + + # Should have a list of math library functions to validate this. + # TODO function calls def _print_Function(self, expr): name = expr.func.__name__ e0 = self._print(expr.args[0]) @@ -150,5 +206,5 @@ class LLVMPrinter(Printer): return self.builder.call(fn, [e0], name) def emptyPrinter(self, expr): - raise TypeError("Unsupported type for LLVM JIT conversion: %s" - % type(expr)) + raise TypeError("Unsupported type for LLVM JIT conversion: %s %s" + % (type(expr), expr)) diff --git a/llvm/__init__.py b/llvm/__init__.py index da5dfa39db26286f274c84e10aa38dd635b75465..f34532f8bea8ee95dd295cde4f521fb8902cb3d0 100644 --- a/llvm/__init__.py +++ b/llvm/__init__.py @@ -1 +1,2 @@ from .kernelcreation import createKernel +from .jit import compileLLVM \ No newline at end of file diff --git a/llvm/jit.py b/llvm/jit.py index 8e6fdb56f947ce6a98ba93aad5f8bb2ef9006014..aa53eb285859dee7807f4d839e6191ec26a108fc 100644 --- a/llvm/jit.py +++ b/llvm/jit.py @@ -1,71 +1,81 @@ +import llvmlite.ir as ir import llvmlite.binding as llvm -import logging.config +from ..types import toCtypes, createType +import ctypes as ct -class Eval(object): + +def compileLLVM(module): + jit = Jit() + jit.parse(module) + jit.optimize() + jit.compile() + return jit + + +class Jit(object): def __init__(self): llvm.initialize() llvm.initialize_all_targets() llvm.initialize_native_target() llvm.initialize_native_asmprinter() + + self.module = None + self.llvmmod = None self.target = llvm.Target.from_default_triple() + self.cpu = llvm.get_host_cpu_name() + self.cpu_features = llvm.get_host_cpu_features() + self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(), opt=2) + self.ee = None + self.fptr = None - def compile(self, module): - logger.debug('=============Preparse') - logger.debug(str(module)) + def parse(self, module): + self.module = module llvmmod = llvm.parse_assembly(str(module)) llvmmod.verify() - logger.debug('=============Function in IR') - logger.debug(str(llvmmod)) - # TODO cpu, features, opt - cpu = llvm.get_host_cpu_name() - features = llvm.get_host_cpu_features() - logger.debug('=======Things') - logger.debug(cpu) - logger.debug(features.flatten()) - target_machine = self.target.create_target_machine(cpu=cpu, features=features.flatten(), opt=2) - - logger.debug('Machine = ' + str(target_machine.target_data)) + self.llvmmod = llvmmod - with open('gen.ll', 'w') as f: - f.write(str(llvmmod)) - optimize = True - if optimize: - pmb = llvm.create_pass_manager_builder() - pmb.opt_level = 2 - pmb.disable_unit_at_a_time = False - pmb.loop_vectorize = True - pmb.slp_vectorize = True - # TODO possible to pass for functions - pm = llvm.create_module_pass_manager() - pm.add_instruction_combining_pass() - pm.add_function_attrs_pass() - pm.add_constant_merge_pass() - pm.add_licm_pass() - pmb.populate(pm) - pm.run(llvmmod) - logger.debug("==========Opt") - logger.debug(str(llvmmod)) - with open('gen_opt.ll', 'w') as f: - f.write(str(llvmmod)) + def write_ll(self, file): + with open(file, 'w') as f: + f.write(str(self.llvmmod)) - with llvm.create_mcjit_compiler(llvmmod, target_machine) as ee: - ee.finalize_object() + def optimize(self): + pmb = llvm.create_pass_manager_builder() + pmb.opt_level = 2 + pmb.disable_unit_at_a_time = False + pmb.loop_vectorize = True + pmb.slp_vectorize = True + # TODO possible to pass for functions + pm = llvm.create_module_pass_manager() + pm.add_instruction_combining_pass() + pm.add_function_attrs_pass() + pm.add_constant_merge_pass() + pm.add_licm_pass() + pmb.populate(pm) + pm.run(self.llvmmod) - logger.debug('==========Machine code') - logger.debug(target_machine.emit_assembly(llvmmod)) - with open('gen.S', 'w') as f: - f.write(target_machine.emit_assembly(llvmmod)) - with open('gen.o', 'wb') as f: - f.write(target_machine.emit_object(llvmmod)) + def compile(self, assembly_file=None, object_file=None): + ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine) + ee.finalize_object() - # fptr = CFUNCTYPE(c_double, c_double, c_double)(ee.get_function_address('add2')) - # result = fptr(2, 3) - # print(result) - return 0 + if assembly_file is not None: + with open(assembly_file, 'w') as f: + f.write(self.target_machine.emit_assembly(self.llvmmod)) + if object_file is not None: + with open(object_file, 'wb') as f: + f.write(self.target_machine.emit_object(self.llvmmod)) + fptr = {} + for function in self.module.functions: + if not function.is_declaration: + return_type = None + if function.ftype.return_type != ir.VoidType(): + return_type = toCtypes(createType(str(function.ftype.return_type))) + args = [toCtypes(createType(str(arg))) for arg in function.ftype.args] + function_address = ee.get_function_address(function.name) + fptr[function.name] = ct.CFUNCTYPE(return_type, *args)(function_address) + self.ee = ee + self.fptr = fptr -if __name__ == "__main__": - logger = logging.getLogger(__name__) -else: - logger = logging.getLogger(__name__) + def __call__(self, function, *args, **kwargs): + self.fptr[function](*args, **kwargs) diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 75bac76cc3c57534e51b0ffb6a2bd6d128edcb2a..5473b9ce7ed5034de38fb748affd176a2a17b05f 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -60,9 +60,11 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) moveConstantsBeforeLoop(code) + print('Ast:') + print(code) desympy_ast(code) + print('Desympied ast:') + print(code) insert_casts(code) - - - return code \ No newline at end of file + return code diff --git a/transformations.py b/transformations.py index 870904b54740532068f8084952ba273767dbf41d..73b684a3c3d91b8c7dc6af51d0cdb60cf15fc4a9 100644 --- a/transformations.py +++ b/transformations.py @@ -566,23 +566,45 @@ def get_type(node): def insert_casts(node): """ - Inserts casts where needed + Inserts casts and dtype where needed :param node: ast which should be traversed :return: node """ - def add_conversion(node, dtype): - return node + def conversion(args): + target = args[0] + if isinstance(target.dtype, PointerType): + # Pointer arithmetic + for arg in args[1:]: + # Check validness + if not arg.dtype.is_int() and not arg.dtype.is_uint(): + raise ValueError("Impossible pointer arithmetic", target, arg) + pointer = ast.PointerArithmetic(ast.Add(args[1:]), target) + return [pointer] + + else: + for i in range(len(args)): + if args[i].dtype != target.dtype: + args[i] = ast.Conversion(args[i], target.dtype, node) + return args for arg in node.args: insert_casts(arg) if isinstance(node, ast.Indexed): + #TODO revmove this pass elif isinstance(node, ast.Expr): - args = sorted((arg.dtype for arg in node.args), key=attrgetter('ptr', 'dtype')) + #print(node, node.args) + #print([type(arg) for arg in node.args]) + #print([arg.dtype for arg in node.args]) + args = sorted((arg for arg in node.args), key=attrgetter('dtype')) target = args[0] - for i in range(len(args)): - args[i] = add_conversion(args[i], target.dtype) - node.args = args + node.args = conversion(args) + node.dtype = target.dtype + #print(node.dtype) + #print(node) + elif isinstance(node, ast.SympyAssignment): + if node.lhs.dtype != node.rhs.dtype: + node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype)) elif isinstance(node, ast.LoopOverCoordinate): pass return node @@ -595,16 +617,45 @@ def desympy_ast(node): :param node: ast which should be traversed. Only node's children will be modified. :return: (modified) node """ + if node.args is None: + return node for i in range(len(node.args)): arg = node.args[i] if isinstance(arg, sp.Add): node.replace(arg, ast.Add(arg.args, node)) + elif isinstance(arg, sp.Number): + node.replace(arg, ast.Number(arg, node)) elif isinstance(arg, sp.Mul): node.replace(arg, ast.Mul(arg.args, node)) elif isinstance(arg, sp.Pow): node.replace(arg, ast.Pow(arg.args, node)) - elif isinstance(arg, sp.tensor.Indexed): - node.replace(arg, ast.Indexed(arg.args, node)) + elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed): + node.replace(arg, ast.Indexed(arg.args, arg.base, node)) + elif isinstance(arg, sp.tensor.IndexedBase): + node.replace(arg, arg.label) + #elif isinstance(arg, sp.containers.Tuple): + # + else: + #print('Not transforming:', type(arg), arg) + pass for arg in node.args: desympy_ast(arg) return node + + +def check_dtype(node): + if isinstance(node, ast.KernelFunction): + pass + elif isinstance(node, ast.Block): + pass + elif isinstance(node, ast.LoopOverCoordinate): + pass + elif isinstance(node, ast.SympyAssignment): + pass + else: + #print(node) + #print(node.dtype) + pass + for arg in node.args: + check_dtype(arg) + diff --git a/types.py b/types.py index 251688ce873fff3a8b0eaf5edc300ad8b7bf4e77..61f2fc63aa7555a530dd3c19fead33a24ba63e78 100644 --- a/types.py +++ b/types.py @@ -1,6 +1,7 @@ import ctypes import sympy as sp import numpy as np +# import llvmlite.ir as ir from sympy.core.cache import cacheit @@ -50,6 +51,11 @@ class TypedSymbol(sp.Symbol): def createType(specification): + """ + Create a subclass of Type according to a string or an object of subclass Type + :param specification: Type object, or a string + :return: Type object, or a new Type object parsed from the string + """ if isinstance(specification, Type): return specification elif isinstance(specification, str): @@ -63,6 +69,11 @@ def createType(specification): def createTypeFromString(specification): + """ + Creates a new Type object from a c-like string specification + :param specification: Specification string + :return: Type object + """ specification = specification.lower().split() parts = [] current = [] @@ -74,16 +85,17 @@ def createTypeFromString(specification): current.append(s) if len(current) > 0: parts.append(current) - - # Parse native part + # Parse native part basePart = parts.pop(0) const = False if 'const' in basePart: const = True basePart.remove('const') assert len(basePart) == 1 + if basePart[0][-1] == "*": + basePart[0] = basePart[0][:-1] + parts.append('*') baseType = BasicType(basePart[0], const) - currentType = baseType # Parse pointer parts for part in parts: @@ -107,6 +119,11 @@ def getBaseType(type): def toCtypes(dataType): + """ + Transforms a given Type into ctypes + :param dataType: Subclass of Type + :return: ctypes type object + """ if isinstance(dataType, PointerType): return ctypes.POINTER(toCtypes(dataType.baseType)) elif isinstance(dataType, StructType): @@ -114,6 +131,7 @@ def toCtypes(dataType): else: return toCtypes.map[dataType.numpyDtype] + toCtypes.map = { np.dtype(np.int8): ctypes.c_int8, np.dtype(np.int16): ctypes.c_int16, @@ -130,16 +148,66 @@ toCtypes.map = { } +#def to_llvmlite_type(data_type): +# """ +# Transforms a given type into ctypes +# :param data_type: Subclass of Type +# :return: llvmlite type object +# """ +# if isinstance(data_type, PointerType): +# return to_llvmlite_type.map[data_type.baseType].as_pointer() +# else: +# return to_llvmlite_type.map[data_type.numpyDType] +# +#to_llvmlite_type.map = { +# np.dtype(np.int8): ir.IntType(8), +# np.dtype(np.int16): ir.IntType(16), +# np.dtype(np.int32): ir.IntType(32), +# np.dtype(np.int64): ir.IntType(64), +# +# # TODO llvmlite doesn't seem to differentiate between Int types +# np.dtype(np.uint8): ir.IntType(8), +# np.dtype(np.uint16): ir.IntType(16), +# np.dtype(np.uint32): ir.IntType(32), +# np.dtype(np.uint64): ir.IntType(64), +# +# np.dtype(np.float32): ir.FloatType(), +# np.dtype(np.float64): ir.DoubleType(), +# # TODO const, restrict, void +#} + + class Type(sp.Basic): def __new__(cls, *args, **kwargs): return sp.Basic.__new__(cls) + def __lt__(self, other): + # Needed for sorting the types inside an expression + if isinstance(self, BasicType): + if isinstance(other, BasicType): + return self.numpyDtype < other.numpyDtype # TODO const + if isinstance(other, PointerType): + return False + if isinstance(other, StructType): + raise NotImplementedError("Struct type comparison is not yet implemented") + if isinstance(self, PointerType): + if isinstance(other, BasicType): + return True + if isinstance(other, PointerType): + return self.baseType < other.baseType # TODO const, restrict + if isinstance(other, StructType): + raise NotImplementedError("Struct type comparison is not yet implemented") + if isinstance(self, StructType): + raise NotImplementedError("Struct type comparison is not yet implemented") + class BasicType(Type): @staticmethod def numpyNameToC(name): - if name == 'float64': return 'double' - elif name == 'float32': return 'float' + if name == 'float64': + return 'double' + elif name == 'float32': + return 'float' elif name.startswith('int'): width = int(name[len("int"):]) return "int%d_t" % (width,) @@ -176,7 +244,22 @@ class BasicType(Type): def itemSize(self): return 1 - def __str__(self): + def is_int(self): + return self.numpyDtype in np.sctypes['int'] + + def is_float(self): + return self.numpyDtype in np.sctypes['float'] + + def is_uint(self): + return self.numpyDtype in np.sctypes['uint'] + + def is_comlex(self): + return self.numpyDtype in np.sctypes['complex'] + + def is_other(self): + return self.numpyDtype in np.sctypes['others'] + + def __repr__(self): result = BasicType.numpyNameToC(str(self._dtype)) if self.const: result += " const" @@ -219,8 +302,8 @@ class PointerType(Type): else: return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict) - def __str__(self): - return "%s *%s%s" % (self.baseType, " RESTRICT" if self.restrict else "", " const" if self.const else "") + def __repr__(self): + return "%s * %s%s" % (self.baseType, "RESTRICT " if self.restrict else "", "const " if self.const else "") def __hash__(self): return hash(str(self)) @@ -271,3 +354,35 @@ class StructType(object): def __hash__(self): return hash((self.numpyDtype, self.const)) + + # TODO this should not work at all!!! + def __gt__(self, other): + if self.ptr and not other.ptr: + return True + if self.dtype > other.dtype: + return True + + +def get_type_from_sympy(node): + """ + Creates a Type object from a Sympy object + :param node: Sympy object + :return: Type object + """ + # Rational, NumberSymbol? + # Zero, One, NegativeOne )= Integer + # Half )= Rational + # NAN, Infinity, Negative Inifinity, + # Exp1, Imaginary Unit, Pi, EulerGamma, Catalan, Golden Ratio + # Pow, Mul, Add, Mod, Relational + if not isinstance(node, sp.Number): + raise TypeError(node, 'is not a sp.Number') + + if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber): + return createType('double'), float(node) + elif isinstance(node, sp.Integer): + return createType('int'), int(node) + elif isinstance(node, sp.Rational): + raise NotImplementedError('Rationals are not supported yet') + else: + raise TypeError(node, ' is not a supported type (yet)!')