Commit edb6fdcd authored by Jan Hoenig's avatar Jan Hoenig
Browse files

More fixes on DataType transition

Move my llvm demo notebook in the correct folder
parent 815edd12
......@@ -2,7 +2,7 @@ import sympy as sp
import textwrap as textwrap
from sympy.tensor import IndexedBase, Indexed
from pystencils.field import Field
from pystencils.types import TypedSymbol
from pystencils.types import TypedSymbol, DataType
class Node(object):
......@@ -266,7 +266,7 @@ class LoopOverCoordinate(Node):
@staticmethod
def getLoopCounterSymbol(coordinateToLoopOver):
return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), "int")
return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), DataType('int'))
@property
def loopCounterSymbol(self):
......
from .llvm import generateLLVM
from .cbackend import generateC, generateCUDA
......@@ -78,6 +78,7 @@ class LLVMPrinter(Printer):
def _print_Mul(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
print(nodes)
for node in nodes[1:]:
e = self.builder.fmul(e, node)
return e
......@@ -120,10 +121,12 @@ class LLVMPrinter(Printer):
def _print_LoopOverCoordinate(self, loop):
with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step),
loop.loopCounterName, loop.loopCounterSymbol.name) as i:
self._add_tmp_var(loop.loopCounterSymbol, i)
self._print(loop.body)
def _print_SympyAssignment(self, loop):
pass
def _print_SympyAssignment(self, assignment):
expr = self._print(assignment.rhs)
# Should have a list of math library functions to validate this.
......
import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
from pystencils.types import TypedSymbol
from pystencils.types import TypedSymbol, DataType
from pystencils.field import Field
import pystencils.ast as ast
......@@ -37,7 +37,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, typeForSymbol[term.name])
return TypedSymbol(term.name, DataType(typeForSymbol[term.name]))
else:
raise ValueError("Term has to be field access or symbol")
......
from pystencils.cpu.kernelcreation import createKernel
from .kernelcreation import createKernel
import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
from pystencils.types import TypedSymbol
from pystencils.types import TypedSymbol, DataType
from pystencils.field import Field
import pystencils.ast as ast
......@@ -35,7 +35,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, typeForSymbol[term.name])
return TypedSymbol(term.name, DataType(typeForSymbol[term.name]))
else:
raise ValueError("Term has to be field access or symbol")
......
......@@ -98,7 +98,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
Example:
>>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1)
>>> x, y = sp.symbols("x y")
>>> prevPointer = TypedSymbol("ptr", "double")
>>> prevPointer = TypedSymbol("ptr", DataType("double"))
>>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer)
(ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
>>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer)
......@@ -129,7 +129,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
if len(listToHash) > 0:
name += "%0.6X" % (abs(hash(tuple(listToHash))))
newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
newPtr = TypedSymbol(previousPtr.name + name, DataType(previousPtr.dtype))
return newPtr, offset
......@@ -238,7 +238,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
coordDict[e] = fieldToFixedCoordinates[field.name][e]
else:
ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), "int")
coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), DataType('int'))
else:
coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
return coordDict
......@@ -420,7 +420,7 @@ def typeAllEquations(eqs, typeForSymbol):
elif isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(symbolNameToVariableName(term.name), typeForSymbol[term.name])
return TypedSymbol(symbolNameToVariableName(term.name), DataType(typeForSymbol[term.name]))
else:
newArgs = [processRhs(arg) for arg in term.args]
return term.func(*newArgs) if newArgs else term
......@@ -433,7 +433,7 @@ def typeAllEquations(eqs, typeForSymbol):
elif isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, typeForSymbol[term.name])
return TypedSymbol(term.name, DataType(typeForSymbol[term.name]))
else:
assert False, "Expected a symbol as left-hand-side"
......
......@@ -10,7 +10,7 @@ class TypedSymbol(sp.Symbol):
def __new_stage2__(cls, name, dtype):
obj = super(TypedSymbol, cls).__xnew__(cls, name)
obj._dtype = dtype
obj._dtype = DataType(dtype) if isinstance(dtype, str) else dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
......@@ -29,8 +29,8 @@ class TypedSymbol(sp.Symbol):
return self.name, self.dtype
_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float'}
_dtype_dict = {'int': 0, 'double': 1, 'float': 2}
_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float', 3: 'bool'}
_dtype_dict = {'int': 0, 'double': 1, 'float': 2, 'bool': 3}
class DataType(object):
......@@ -38,11 +38,28 @@ class DataType(object):
self.alias = True
self.const = False
self.ptr = False
self.dtype = 0
if isinstance(dtype, str):
self.dtype = _dtype_dict[dtype]
for s in dtype.split():
if s == 'const':
self.const = True
elif s == '*':
self.ptr = True
elif s == '__restrict__':
self.alias = False
else:
self.dtype = _dtype_dict[s]
elif isinstance(dtype, DataType):
self.__dict__.update(dtype.__dict__)
else:
self.dtype = dtype
def __repr__(self):
return "{!s} {!s}{!s} {!s}".format("const" if self.const else "", _c_dtype_dict[self.dtype],
"*" if self.ptr else "", "__restrict__" if not self.alias else "")
def __eq__(self, other):
if self.alias == other.alias and self.const == other.const and self.ptr == other.ptr and self.dtype == other.dtype:
return True
else:
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment