Commit c8b455fe authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: Cleaned up type system

- use data type class consistently instead of strings (in TypedSymbol, Field and jit module)
- new datatype class is based on numpy types with additional specifier information (const and restrict)
- translation between data type class and other modules (numpy, ctypes)
parent c2807e92
import sympy as sp
from sympy.tensor import IndexedBase, Indexed
from sympy.tensor import IndexedBase
from pystencils.field import Field
from pystencils.types import TypedSymbol, DataType, _c_dtype_dict
from pystencils.types import TypedSymbol
class Node(object):
......@@ -262,7 +262,7 @@ class LoopOverCoordinate(Node):
@staticmethod
def getLoopCounterSymbol(coordinateToLoopOver):
return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), DataType('int'))
return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
@property
def loopCounterSymbol(self):
......
from sympy.utilities.codegen import CCodePrinter
from pystencils.astnodes import Node
from pystencils.types import createType
def generateC(astNode):
......@@ -7,7 +8,7 @@ def generateC(astNode):
Prints the abstract syntax tree as C function
"""
fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
useFloatConstants = "double" not in fieldTypes
useFloatConstants = createType("double") not in fieldTypes
printer = CBackend(constantsAsFloats=useFloatConstants)
return printer(astNode)
......
......@@ -62,17 +62,17 @@ compiled into the shared library. Then, the same script can be run from the comp
from __future__ import print_function
import os
import subprocess
from ctypes import cdll, c_double, c_float, sizeof
import shutil
from pystencils.backends.cbackend import generateC
import numpy as np
import hashlib
import json
import platform
import glob
import atexit
import shutil
from ctypes import cdll, sizeof
from pystencils.backends.cbackend import generateC
from collections import OrderedDict, Mapping
from pystencils.transformations import symbolNameToVariableName
from pystencils.types import toCtypes, getBaseType, createType
def makePythonFunction(kernelFunctionNode, argumentDict={}):
......@@ -173,7 +173,7 @@ def readConfig():
defaultCompilerConfig = OrderedDict([
('os', 'linux'),
('command', 'g++'),
('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp'),
('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
('restrictQualifier', '__restrict__')
])
defaultCacheConfig = OrderedDict([
......@@ -240,41 +240,6 @@ def getCacheConfig():
return _config['cache']
def ctypeFromString(typename, includePointers=True):
import ctypes as ct
typename = str(typename).replace("*", " * ")
typeComponents = typename.split()
basicTypeMap = {
'double': ct.c_double,
'float': ct.c_float,
'int': ct.c_int,
'long': ct.c_long,
}
resultType = None
for typeComponent in typeComponents:
typeComponent = typeComponent.strip()
if typeComponent == "const" or typeComponent == "restrict" or typeComponent == "volatile":
continue
if typeComponent in basicTypeMap:
resultType = basicTypeMap[typeComponent]
elif typeComponent == "*" and includePointers:
assert resultType is not None
resultType = ct.POINTER(resultType)
return resultType
def ctypeFromNumpyType(numpyType):
typeMap = {
np.dtype('float64'): c_double,
np.dtype('float32'): c_float,
}
return typeMap[numpyType]
def hashToFunctionName(h):
res = "func_%s" % (h,)
return res.replace('-', 'm')
......@@ -344,7 +309,7 @@ def compileLinux(ast, codeHashStr, srcFile, libFile):
objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.o')
# Compilation
if not os.path.exists(objectFile):
generateCode(ast, ['iostream', 'cmath'], compilerConfig['restrictQualifier'], '', srcFile)
generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'], '', srcFile)
compileCmd = [compilerConfig['command'], '-c'] + compilerConfig['flags'].split()
compileCmd += ['-o', objectFile, srcFile]
runCompileStep(compileCmd)
......@@ -360,7 +325,7 @@ def compileWindows(ast, codeHashStr, srcFile, libFile):
objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.obj')
# Compilation
if not os.path.exists(objectFile):
generateCode(ast, ['iostream', 'cmath'], compilerConfig['restrictQualifier'],
generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'],
'__declspec(dllexport)', srcFile)
# /c compiles only, /EHsc turns of exception handling in c code
......@@ -407,7 +372,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
symbolicField = arg.field
if arg.isFieldPtrArgument:
ctArguments.append(field.ctypes.data_as(ctypeFromString(arg.dtype)))
ctArguments.append(field.ctypes.data_as(toCtypes(arg.dtype)))
if symbolicField.hasFixedShape:
if tuple(int(i) for i in symbolicField.shape) != field.shape:
raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
......@@ -420,11 +385,11 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
if not symbolicField.isIndexField:
arrayShapes.add(field.shape[:symbolicField.spatialDimensions])
elif arg.isFieldShapeArgument:
dataType = ctypeFromString(arg.dtype, includePointers=False)
dataType = toCtypes(getBaseType(arg.dtype))
ctArguments.append(field.ctypes.shape_as(dataType))
elif arg.isFieldStrideArgument:
dataType = ctypeFromString(arg.dtype, includePointers=False)
baseFieldType = ctypeFromNumpyType(field.dtype)
dataType = toCtypes(getBaseType(arg.dtype))
baseFieldType = toCtypes(createType(field.dtype))
strides = field.ctypes.strides_as(dataType)
for i in range(len(field.shape)):
assert strides[i] % sizeof(baseFieldType) == 0
......@@ -437,7 +402,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
param = argumentDict[arg.name]
except KeyError:
raise KeyError("Missing parameter for kernel call " + arg.name)
expectedType = ctypeFromString(arg.dtype)
expectedType = toCtypes(arg.dtype)
ctArguments.append(expectedType(param))
if len(arrayShapes) > 1:
......
import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
from pystencils.types import TypedSymbol, DataType
from pystencils.types import TypedSymbol
from pystencils.field import Field
import pystencils.astnodes as ast
......@@ -37,7 +37,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, DataType(typeForSymbol[term.name]))
return TypedSymbol(term.name, typeForSymbol[term.name])
else:
raise ValueError("Term has to be field access or symbol")
......
......@@ -3,7 +3,7 @@ import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from sympy.tensor import IndexedBase
from pystencils.types import TypedSymbol
from pystencils.types import TypedSymbol, createType
class Field(object):
......@@ -122,7 +122,7 @@ class Field(object):
def __init__(self, fieldName, dtype, layout, shape, strides):
"""Do not use directly. Use static create* methods"""
self._fieldName = fieldName
self._dtype = numpyDataTypeToC(dtype)
self._dtype = createType(dtype)
self._layout = normalizeLayout(layout)
self.shape = shape
self.strides = strides
......@@ -372,17 +372,6 @@ def computeStrides(shape, layout):
return tuple(strides)
def numpyDataTypeToC(dtype):
"""Mapping numpy data types to C data types"""
if dtype == np.float64:
return "double"
elif dtype == np.float32:
return "float"
elif dtype == np.int32:
return "int"
raise NotImplementedError("Cannot convert type " + str(dtype))
def offsetComponentToDirectionString(coordinateId, value):
"""
Translates numerical offset to string notation.
......
......@@ -2,7 +2,7 @@ import sympy as sp
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
desympy_ast, insert_casts
from pystencils.types import TypedSymbol, DataType
from pystencils.types import TypedSymbol
from pystencils.field import Field
import pystencils.astnodes as ast
......@@ -36,7 +36,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, DataType(typeForSymbol[term.name]))
return TypedSymbol(term.name, typeForSymbol[term.name])
else:
raise ValueError("Term has to be field access or symbol")
......
......@@ -6,7 +6,7 @@ from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.field import Field, offsetComponentToDirectionString
from pystencils.types import TypedSymbol, DataType
from pystencils.types import TypedSymbol, createType, PointerType
from pystencils.slicing import normalizeSlice
import pystencils.astnodes as ast
......@@ -109,7 +109,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
Example:
>>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1)
>>> x, y = sp.symbols("x y")
>>> prevPointer = TypedSymbol("ptr", DataType("double"))
>>> prevPointer = TypedSymbol("ptr", "double")
>>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer)
(ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
>>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer)
......@@ -140,7 +140,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
if len(listToHash) > 0:
name += "%0.6X" % (abs(hash(tuple(listToHash))))
newPtr = TypedSymbol(previousPtr.name + name, DataType(previousPtr.dtype))
newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
return newPtr, offset
......@@ -234,12 +234,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
else:
basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]
dtype = DataType(field.dtype)
dtype.alias = False
dtype.ptr = True
if field.name in readOnlyFieldNames:
dtype.const = True
dtype = PointerType(field.dtype, const=field.name in readOnlyFieldNames, restrict=True)
fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(field.name)), dtype)
lastPointer = fieldPtr
......@@ -252,7 +247,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
coordDict[e] = fieldToFixedCoordinates[field.name][e]
else:
ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), DataType('int'))
coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), 'int')
else:
coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
return coordDict
......@@ -433,7 +428,7 @@ def typeAllEquations(eqs, typeForSymbol):
elif isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(symbolNameToVariableName(term.name), DataType(typeForSymbol[term.name]))
return TypedSymbol(symbolNameToVariableName(term.name), typeForSymbol[term.name])
else:
newArgs = [processRhs(arg) for arg in term.args]
return term.func(*newArgs) if newArgs else term
......@@ -446,7 +441,7 @@ def typeAllEquations(eqs, typeForSymbol):
elif isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, DataType(typeForSymbol[term.name]))
return TypedSymbol(term.name, typeForSymbol[term.name])
else:
assert False, "Expected a symbol as left-hand-side"
......@@ -539,9 +534,9 @@ def get_type(node):
# TODO sp.NumberSymbol
elif isinstance(node, sp.Number):
if isinstance(node, sp.Float):
return DataType('double')
return createType('double')
elif isinstance(node, sp.Integer):
return DataType('int')
return createType('int')
else:
raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
else:
......
import ctypes
import sympy as sp
import numpy as np
from sympy.core.cache import cacheit
......@@ -10,7 +12,7 @@ class TypedSymbol(sp.Symbol):
def __new_stage2__(cls, name, dtype):
obj = super(TypedSymbol, cls).__xnew__(cls, name)
obj._dtype = DataType(dtype) if isinstance(dtype, str) else dtype
obj._dtype = createType(dtype)
return obj
__xnew__ = staticmethod(__new_stage2__)
......@@ -29,41 +31,161 @@ class TypedSymbol(sp.Symbol):
return self.name, self.dtype
_c_dtype_dict = {0: 'bool', 1: 'int', 2: 'float', 3: 'double'}
_dtype_dict = {'bool': 0, 'int': 1, 'float': 2, 'double': 3}
class DataType(object):
def __init__(self, dtype):
self.alias = True
self.const = False
self.ptr = False
self.dtype = 0
if isinstance(dtype, str):
for s in dtype.split():
if s == 'const':
self.const = True
elif s == '*':
self.ptr = True
elif s == 'RESTRICT':
self.alias = False
else:
self.dtype = _dtype_dict[s]
elif isinstance(dtype, DataType):
self.__dict__.update(dtype.__dict__)
def createType(specification):
if isinstance(specification, Type):
return specification
elif isinstance(specification, str):
return createTypeFromString(specification)
else:
npDataType = np.dtype(specification)
return BasicType(npDataType, const=False)
def createTypeFromString(specification):
specification = specification.lower().split()
parts = []
current = []
for s in specification:
if s == '*':
parts.append(current)
current = [s]
else:
self.dtype = dtype
current.append(s)
if len(current) > 0:
parts.append(current)
# Parse native part
basePart = parts.pop(0)
const = False
if 'const' in basePart:
const = True
basePart.remove('const')
assert len(basePart) == 1
baseType = BasicType(basePart[0], const)
currentType = baseType
# Parse pointer parts
for part in parts:
restrict = False
const = False
if 'restrict' in part:
restrict = True
part.remove('restrict')
if 'const' in part:
const = True
part.remove("const")
assert len(part) == 1 and part[0] == '*'
currentType = PointerType(currentType, const, restrict)
return currentType
def getBaseType(type):
while type.baseType is not None:
type = type.baseType
return type
def toCtypes(dataType):
if isinstance(dataType, PointerType):
return ctypes.POINTER(toCtypes(dataType.baseType))
else:
return toCtypes.map[dataType.numpyDtype]
toCtypes.map = {
np.dtype(np.int8): ctypes.c_int8,
np.dtype(np.int16): ctypes.c_int16,
np.dtype(np.int32): ctypes.c_int32,
np.dtype(np.int64): ctypes.c_int64,
np.dtype(np.uint8): ctypes.c_uint8,
np.dtype(np.uint16): ctypes.c_uint16,
np.dtype(np.uint32): ctypes.c_uint32,
np.dtype(np.uint64): ctypes.c_uint64,
np.dtype(np.float32): ctypes.c_float,
np.dtype(np.float64): ctypes.c_double,
}
class Type(object):
pass
class BasicType(Type):
@staticmethod
def numpyNameToC(name):
if name == 'float64': return 'double'
elif name == 'float32': return 'float'
elif name.startswith('int'):
width = int(name[len("int"):])
return "int%d_t" % (width,)
elif name.startswith('uint'):
width = int(name[len("int"):])
return "uint%d_t" % (width,)
elif name == 'bool':
return 'bool'
else:
raise NotImplemented("Can map numpy to C name for %s" % (name,))
def __init__(self, dtype, const=False):
self.const = const
self._dtype = np.dtype(dtype)
assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
assert self._dtype.hasobject is False
assert self._dtype.subdtype is None
@property
def baseType(self):
return None
def __repr__(self):
return "{!s} {!s}{!s} {!s}".format("const" if self.const else "", _c_dtype_dict[self.dtype],
"*" if self.ptr else "", "RESTRICT" if not self.alias else "")
@property
def numpyDtype(self):
return self._dtype
def __str__(self):
result = BasicType.numpyNameToC(str(self._dtype))
if self.const:
result += " const"
return result
def __eq__(self, other):
if self.alias == other.alias and self.const == other.const and self.ptr == other.ptr and self.dtype == other.dtype:
return True
if not isinstance(other, BasicType):
return False
else:
return (self.numpyDtype, self.const) == (other.numpyDtype, other.const)
def __hash__(self):
return hash(str(self))
class PointerType(Type):
def __init__(self, baseType, const=False, restrict=True):
self._baseType = baseType
self.const = const
self.restrict = restrict
@property
def alias(self):
return not self.restrict
@property
def baseType(self):
return self._baseType
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
else:
return (self.baseType, self.const, self.restrict) == (other.baseType, other.const, other.restrict)
def __str__(self):
return "%s * %s%s" % (self.baseType, "RESTRICT " if self.restrict else "", "const " if self.const else "")
def __hash__(self):
return hash(str(self))
class StructType(object):
def __init__(self, numpyType):
self._dtype = np.dtype(numpyType)
def get_type_from_sympy(node):
return DataType('int')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment