diff --git a/ast.py b/ast.py index 04ab30cf94c4bcbdaeb683d02ce2b6e82963a5ae..b4a32a000ed20c80cdf95eb6d3c951df0a055a32 100644 --- a/ast.py +++ b/ast.py @@ -2,7 +2,7 @@ import sympy as sp import textwrap as textwrap from sympy.tensor import IndexedBase, Indexed from pystencils.field import Field -from pystencils.types import TypedSymbol +from pystencils.types import TypedSymbol, DataType class Node(object): @@ -266,7 +266,7 @@ class LoopOverCoordinate(Node): @staticmethod def getLoopCounterSymbol(coordinateToLoopOver): - return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), "int") + return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), DataType('int')) @property def loopCounterSymbol(self): diff --git a/backends/__init__.py b/backends/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..b4f7b6786b4061c7174fbad0bff8bf9a1deb55a6 100644 --- a/backends/__init__.py +++ b/backends/__init__.py @@ -0,0 +1,2 @@ +from .llvm import generateLLVM +from .cbackend import generateC, generateCUDA diff --git a/backends/llvm.py b/backends/llvm.py index 4a45844f58575c0325a9374750915c4ee24e7ca3..5b4fcd6c5b0d0a8e93c0eb090cea8fc485f7bddc 100644 --- a/backends/llvm.py +++ b/backends/llvm.py @@ -78,6 +78,7 @@ class LLVMPrinter(Printer): def _print_Mul(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] + print(nodes) for node in nodes[1:]: e = self.builder.fmul(e, node) return e @@ -120,10 +121,12 @@ class LLVMPrinter(Printer): def _print_LoopOverCoordinate(self, loop): with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step), loop.loopCounterName, loop.loopCounterSymbol.name) as i: + self._add_tmp_var(loop.loopCounterSymbol, i) self._print(loop.body) - def _print_SympyAssignment(self, loop): - pass + def _print_SympyAssignment(self, assignment): + expr = self._print(assignment.rhs) + # Should have a list of math library functions to validate this. diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index e8e722c13940ad157b138ddc31e83074800873a5..13fb0d785d0e064bcd1db84534cd850135021d38 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -1,7 +1,7 @@ import sympy as sp from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop -from pystencils.types import TypedSymbol +from pystencils.types import TypedSymbol, DataType from pystencils.field import Field import pystencils.ast as ast @@ -37,7 +37,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): return term elif isinstance(term, sp.Symbol): - return TypedSymbol(term.name, typeForSymbol[term.name]) + return TypedSymbol(term.name, DataType(typeForSymbol[term.name])) else: raise ValueError("Term has to be field access or symbol") diff --git a/llvm/__init__.py b/llvm/__init__.py index 681ec00a7a3e268fa55252521da2a3d1c7a610f8..da5dfa39db26286f274c84e10aa38dd635b75465 100644 --- a/llvm/__init__.py +++ b/llvm/__init__.py @@ -1 +1 @@ -from pystencils.cpu.kernelcreation import createKernel +from .kernelcreation import createKernel diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 54d4ed0ce5d197d1c160092a07de4285a56efdfb..e0957fc8d551661deddd3153010783f62a5f62a3 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -1,7 +1,7 @@ import sympy as sp from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop -from pystencils.types import TypedSymbol +from pystencils.types import TypedSymbol, DataType from pystencils.field import Field import pystencils.ast as ast @@ -35,7 +35,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl if isinstance(term, Field.Access) or isinstance(term, TypedSymbol): return term elif isinstance(term, sp.Symbol): - return TypedSymbol(term.name, typeForSymbol[term.name]) + return TypedSymbol(term.name, DataType(typeForSymbol[term.name])) else: raise ValueError("Term has to be field access or symbol") diff --git a/transformations.py b/transformations.py index f3f3675f00d02a2fba3c8067dc4b8a15d1abba19..5f38bc384fb24b194ed924ba85c27be2574d1613 100644 --- a/transformations.py +++ b/transformations.py @@ -98,7 +98,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): Example: >>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1) >>> x, y = sp.symbols("x y") - >>> prevPointer = TypedSymbol("ptr", "double") + >>> prevPointer = TypedSymbol("ptr", DataType("double")) >>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer) (ptr_E, x*fstride_myfield[0] + fstride_myfield[0]) >>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer) @@ -129,7 +129,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): if len(listToHash) > 0: name += "%0.6X" % (abs(hash(tuple(listToHash)))) - newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype) + newPtr = TypedSymbol(previousPtr.name + name, DataType(previousPtr.dtype)) return newPtr, offset @@ -238,7 +238,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn coordDict[e] = fieldToFixedCoordinates[field.name][e] else: ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX - coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), "int") + coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), DataType('int')) else: coordDict[e] = fieldAccess.index[e-field.spatialDimensions] return coordDict @@ -420,7 +420,7 @@ def typeAllEquations(eqs, typeForSymbol): elif isinstance(term, TypedSymbol): return term elif isinstance(term, sp.Symbol): - return TypedSymbol(symbolNameToVariableName(term.name), typeForSymbol[term.name]) + return TypedSymbol(symbolNameToVariableName(term.name), DataType(typeForSymbol[term.name])) else: newArgs = [processRhs(arg) for arg in term.args] return term.func(*newArgs) if newArgs else term @@ -433,7 +433,7 @@ def typeAllEquations(eqs, typeForSymbol): elif isinstance(term, TypedSymbol): return term elif isinstance(term, sp.Symbol): - return TypedSymbol(term.name, typeForSymbol[term.name]) + return TypedSymbol(term.name, DataType(typeForSymbol[term.name])) else: assert False, "Expected a symbol as left-hand-side" diff --git a/types.py b/types.py index 2da1e3fc45c891d175d860b54d115434cddd4e7a..8d964c8ec75a884cbb4f1f37f91d88f67eb1c184 100644 --- a/types.py +++ b/types.py @@ -10,7 +10,7 @@ class TypedSymbol(sp.Symbol): def __new_stage2__(cls, name, dtype): obj = super(TypedSymbol, cls).__xnew__(cls, name) - obj._dtype = dtype + obj._dtype = DataType(dtype) if isinstance(dtype, str) else dtype return obj __xnew__ = staticmethod(__new_stage2__) @@ -29,8 +29,8 @@ class TypedSymbol(sp.Symbol): return self.name, self.dtype -_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float'} -_dtype_dict = {'int': 0, 'double': 1, 'float': 2} +_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float', 3: 'bool'} +_dtype_dict = {'int': 0, 'double': 1, 'float': 2, 'bool': 3} class DataType(object): @@ -38,11 +38,28 @@ class DataType(object): self.alias = True self.const = False self.ptr = False + self.dtype = 0 if isinstance(dtype, str): - self.dtype = _dtype_dict[dtype] + for s in dtype.split(): + if s == 'const': + self.const = True + elif s == '*': + self.ptr = True + elif s == '__restrict__': + self.alias = False + else: + self.dtype = _dtype_dict[s] + elif isinstance(dtype, DataType): + self.__dict__.update(dtype.__dict__) else: self.dtype = dtype def __repr__(self): return "{!s} {!s}{!s} {!s}".format("const" if self.const else "", _c_dtype_dict[self.dtype], "*" if self.ptr else "", "__restrict__" if not self.alias else "") + + def __eq__(self, other): + if self.alias == other.alias and self.const == other.const and self.ptr == other.ptr and self.dtype == other.dtype: + return True + else: + return False