Skip to content
Snippets Groups Projects
Commit ea847bc5 authored by Martin Bauer's avatar Martin Bauer
Browse files

Vectorization & Type system overhaul

- first vectorization tests are running
- type system: use memoized getTypeOfExpression
- casts are done using sp.Function('cast')
- C backend adapted for vectorization support
- AST nodes can required optional headers
parent 697d8cdf
Branches
Tags
No related merge requests found
......@@ -6,6 +6,8 @@ from pystencils.types import TypedSymbol, createType, get_type_from_sympy, creat
class ResolvedFieldAccess(sp.Indexed):
def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
if not isinstance(base, IndexedBase):
base = IndexedBase(base, shape=(1,))
obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
obj.field = field
obj.offsets = offsets
......@@ -21,6 +23,14 @@ class ResolvedFieldAccess(sp.Indexed):
superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))
@property
def typedSymbol(self):
return self.base.label
def __str__(self):
top = super(ResolvedFieldAccess, self).__str__()
return "%s (%s)" % (top, self.typedSymbol.dtype)
def __getnewargs__(self):
return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
......
......@@ -4,8 +4,13 @@ try:
except ImportError:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
from pystencils.astnodes import Node
from pystencils.types import createType, PointerType
from collections import namedtuple
from sympy.core.mul import _keep_coeff
from sympy.core import S
from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
from pystencils.types import createType, PointerType, getTypeOfExpression, VectorType, castFunc
from pystencils.backends.simd_instruction_sets import selectedInstructionSet
def generateC(astNode, signatureOnly=False):
......@@ -14,10 +19,26 @@ def generateC(astNode, signatureOnly=False):
"""
fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
useFloatConstants = createType("double") not in fieldTypes
printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly)
vectorIS = selectedInstructionSet['double']
printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly, vectorInstructionSet=vectorIS)
return printer(astNode)
def getHeaders(astNode):
headers = set()
if hasattr(astNode, 'headers'):
headers.update(astNode.headers)
elif isinstance(astNode, SympyAssignment):
if type(getTypeOfExpression(astNode.rhs)) is VectorType:
headers.update(selectedInstructionSet['double']['headers'])
for a in astNode.args:
if isinstance(a, Node):
headers.update(getHeaders(a))
return headers
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------
......@@ -26,6 +47,7 @@ class CustomCppCode(Node):
self._code = "\n" + code
self._symbolsRead = set(symbolsRead)
self._symbolsDefined = set(symbolsDefined)
self.headers = []
@property
def code(self):
......@@ -48,24 +70,33 @@ class PrintNode(CustomCppCode):
def __init__(self, symbolToPrint):
code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())
self.headers.append("<iostream>")
# ------------------------------------------- Printer ------------------------------------------------------------------
class CBackend(object):
def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False):
def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None):
if sympyPrinter is None:
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
if vectorInstructionSet is not None:
self.sympyPrinter = VectorizedCustomSympyPrinter(vectorInstructionSet, constantsAsFloats)
else:
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
else:
self.sympyPrinter = sympyPrinter
self._vectorInstructionSet = vectorInstructionSet
self._indent = " "
self._signatureOnly = signatureOnly
def __call__(self, node):
return str(self._print(node))
prevIs = VectorType.instructionSet
VectorType.instructionSet = self._vectorInstructionSet
result = str(self._print(node))
VectorType.instructionSet = prevIs
return result
def _print(self, node):
for cls in type(node).__mro__:
......@@ -103,13 +134,16 @@ class CBackend(object):
return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
def _print_SympyAssignment(self, node):
dtype = ""
if node.isDeclaration:
if node.isConst:
dtype = "const " + str(node.lhs.dtype) + " "
dtype = "const " + str(node.lhs.dtype) + " " if node.isConst else str(node.lhs.dtype) + " "
return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
else:
lhsType = getTypeOfExpression(node.lhs)
if type(lhsType) is VectorType and node.lhs.func == castFunc:
return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]),
self.sympyPrinter.doprint(node.rhs)) + ';'
else:
dtype = str(node.lhs.dtype) + " "
return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node):
return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
......@@ -177,3 +211,121 @@ class CustomSympyPrinter(CCodePrinter):
else:
return super(CustomSympyPrinter, self)._print_Function(expr)
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instructionSet, constantsAsFloats=False):
super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats)
self.instructionSet = instructionSet
def _print_Function(self, expr):
name = str(expr.func).lower()
if name == 'cast':
arg, dtype = expr.args
if type(dtype) is VectorType:
if type(arg) is ResolvedFieldAccess:
return self.instructionSet['loadU'].format("& " + self._print(arg))
else:
return self.instructionSet['makeVec'].format(self._print(arg))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_Add(self, expr, order=None):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Add(expr, order)
assert self.instructionSet['width'] == exprType.width
summands = []
for term in expr.args:
if term.func == sp.Mul:
sign, t = self._print_Mul(term, insideAdd=True)
else:
t = self._print(term)
sign = 1
summands.append(self.SummandInfo(sign, t))
# Use positive terms first
summands.sort(key=lambda e: e.sign, reverse=True)
# if no positive term exists, prepend a zero
if summands[0].sign == -1:
summands.insert(0, self.SummandInfo(1, "0"))
assert len(summands) >= 2
processed = summands[0].term
for summand in summands[1:]:
func = self.instructionSet['-'] if summand.sign == -1 else self.instructionSet['+']
processed = func.format(processed, summand.term)
return processed
def _print_Mul(self, expr, insideAdd=False):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Mul(expr)
assert self.instructionSet['width'] == exprType.width
c, e = expr.as_coeff_Mul()
if c < 0:
expr = _keep_coeff(-c, e)
sign = -1
else:
sign = 1
a = [] # items in the numerator
b = [] # items that are in the denominator (if any)
# Gather args for numerator/denominator
for item in expr.as_ordered_factors():
if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
if item.exp != -1:
b.append(sp.Pow(item.base, -item.exp, evaluate=False))
else:
b.append(sp.Pow(item.base, -item.exp))
else:
a.append(item)
a = a or [S.One]
a_str = [self._print(x) for x in a]
b_str = [self._print(x) for x in b]
result = a_str[0]
for item in a_str[1:]:
result = self.intrinsics['*'].format(result, item)
if len(b) > 0:
denominator_str = b_str[0]
for item in b_str[1:]:
denominator_str = self.intrinsics['*'].format(denominator_str, item)
result = self.intrinsics['/'].format(result, denominator_str)
if insideAdd:
return sign, result
else:
if sign < 0:
return self.intrinsics['*'].format(self._print(S.NegativeOne), result)
else:
return result
# def _print_Piecewise(self, expr):
# if expr.args[-1].cond != True:
# # We need the last conditional to be a True, otherwise the resulting
# # function may not return a result.
# raise ValueError("All Piecewise expressions must contain an "
# "(expr, True) statement to be used as a default "
# "condition. Without one, the generated "
# "expression may not evaluate to anything under "
# "some condition.")
#
# result = self._print(expr.args[-1][0])
# for trueExpr, condition in reversed(expr.args[:-1]):
# result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition))
# return result
#
# def _print_Relational(self, expr):
# return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs)
#
# def _print_Equality(self, expr):
# return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs))
#
from collections import namedtuple
import sympy as sp
from sympy.core import S
try:
from sympy.utilities.codegen import CCodePrinter
except ImportError:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
from sympy.core.mul import _keep_coeff
from pystencils.backends.cbackend import CustomSympyPrinter
from pystencils.types import getBaseType, createTypeFromString
def getInstructionSetInfoIntel(dataType='double', instructionSet='avx'):
baseNames = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'==': 'cmp[0, 1, _CMP_EQ_UQ ]',
'!=': 'cmp[0, 1, _CMP_NEQ_UQ ]',
'>=': 'cmp[0, 1, _CMP_GE_OQ ]',
'<=': 'cmp[0, 1, _CMP_LE_OQ ]',
'<': 'cmp[0, 1, _CMP_NGE_UQ ]',
'>': 'cmp[0, 1, _CMP_NLE_UQ ]',
'blendv': 'blendv[0, 1, 2]',
'sqrt': 'sqrt[0]',
'makeVec': 'set[0,0,0,0]',
'makeZero': 'setzero[]',
'loadU': 'loadu [0]',
'loadA': 'load [0]',
'storeU': 'storeu[0]',
'storeA': 'store [0]',
}
suffix = {
'double': 'pd',
'float': 'ps',
}
prefix = {
'sse': '_mm',
'avx': '_mm256',
'avx512': '_mm512',
}
width = {
("double", "sse"): 2,
("float", "sse"): 4,
("double", "avx"): 4,
("float", "avx"): 8,
("double", "avx512"): 8,
("float", "avx512"): 16,
}
result = {}
pre = prefix[instructionSet]
suf = suffix[dataType]
for intrinsicId, functionShortcut in baseNames.items():
functionShortcut = functionShortcut.strip()
name = functionShortcut[:functionShortcut.index('[')]
args = functionShortcut[functionShortcut.index('[') + 1: -1]
argString = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
argString += "{" + arg + "},"
else:
argString += arg
argString = argString[:-1] + ")"
result[intrinsicId] = pre + "_" + name + "_" + suf + argString
result['width'] = width[(dataType, instructionSet)]
result['dataTypePrefix'] = {
'double': "_" + pre + 'd',
'float': "_" + pre,
}
return result
class VectorizedCBackend(object):
def __init__(self, astNode, instructionSet='avx'):
fieldTypes = set([getBaseType(f.dtype) for f in astNode.fieldsAccessed])
if len(fieldTypes) != 1:
raise ValueError("Vectorized backend does not support kernels with mixed field types")
fieldType = fieldTypes.pop()
assert fieldType.is_float
dtypeName = str(fieldType)
instructionSetInfo = getInstructionSetInfoIntel(dtypeName, instructionSet)
self.vectorizationWidth = instructionSetInfo['width']
self.sympyVecPrinter = CustomSympyPrinterVectorized(instructionSetInfo)
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats=(dtypeName == 'float'))
self._indent = " "
self._vecTypeName = instructionSetInfo['dataTypePrefix'][dtypeName]
self.dtypeName = dtypeName
def __call__(self, node):
return str(self._print(node))
def _print(self, node):
for cls in type(node).__mro__:
methodName = "_print_" + cls.__name__
if hasattr(self, methodName):
return getattr(self, methodName)(node)
raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
def _print_KernelFunction(self, node):
blockContents = "\n".join([self._print(child) for child in node.body.args])
constantBlock = self.sympyVecPrinter.getConstantsBlock(self._vecTypeName)
body = "{\n%s\n%s\n}" % (constantBlock, self._indent + self._indent.join(blockContents.splitlines(True)))
functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
return funcDeclaration + "\n" + body
def _print_Block(self, node):
blockContents = "\n".join([self._print(child) for child in node.args])
return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)),)
def _print_PragmaBlock(self, node):
return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
def _print_LoopOverCoordinate(self, node):
if node.isInnermostLoop:
iterRange = node.stop - node.start
if isinstance(iterRange, sp.Basic) and not iterRange.is_integer:
raise NotImplementedError("Vectorized backend currently only supports fixed size inner loops")
if iterRange % self.vectorizationWidth != 0 or node.step != 1:
raise NotImplementedError("Vectorized backend only supports loop bounds that are "
"multiples of vectorization width")
step = self.vectorizationWidth
else:
step = node.step
counterVar = node.loopCounterName
start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(step),)
loopStr = "for (%s; %s; %s)" % (start, condition, update)
prefix = "\n".join(node.prefixLines)
if prefix:
prefix += "\n"
return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
def _print_SympyAssignment(self, node):
dtype = ""
if node.isDeclaration:
assert str(getBaseType(node.lhs.dtype)) in (self.dtypeName, 'bool')
if node.lhs.dtype == createTypeFromString(self.dtypeName):
dtypeStr = self._vecTypeName
printer = self.sympyVecPrinter
else:
dtypeStr = str(node.lhs.dtype)
printer = self.sympyPrinter
if node.isConst:
dtype = "const " + dtypeStr + " "
else:
dtype = dtypeStr + " "
else:
printer = self.sympyVecPrinter
return "%s %s = %s;" % (str(dtype), printer.doprint(node.lhs), printer.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node):
return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
node.symbol.dtype, self.sympyPrinter.doprint(node.size))
def _print_TemporaryMemoryFree(self, node):
return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
def _print_CustomCppCode(self, node):
return node.code
class CustomSympyPrinterVectorized(CCodePrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instructionSetInfo):
super(CustomSympyPrinterVectorized, self).__init__()
self.intrinsics = instructionSetInfo
self.constantsDict = {}
def getConstantsBlock(self, vecTypeStr):
result = ""
for value, symbol in self.constantsDict.items():
rhsStr = self.intrinsics['makeVec'].format(self._print(value))
result += "const %s %s = %s;\n" % (vecTypeStr, symbol.name, rhsStr)
return result
def _print_Add(self, expr, order=None):
summands = []
for term in expr.args:
if term.func == sp.Mul:
sign, t = self._print_Mul(term, insideAdd=True)
else:
t = self._print(term)
sign = 1
summands.append(self.SummandInfo(sign, t))
# Use positive terms first
summands.sort(key=lambda e: e.sign, reverse=True)
# if no positive term exists, prepend a zero
if summands[0].sign == -1:
summands.insert(0, self.SummandInfo(1, "0"))
assert len(summands) >= 2
processed = summands[0].term
for summand in summands[1:]:
func = self.intrinsics['-'] if summand.sign == -1 else self.intrinsics['+']
processed = func.format(processed, summand.term)
return processed
def _print_Mul(self, expr, insideAdd=False):
c, e = expr.as_coeff_Mul()
if c < 0:
expr = _keep_coeff(-c, e)
sign = -1
else:
sign = 1
a = [] # items in the numerator
b = [] # items that are in the denominator (if any)
# Gather args for numerator/denominator
for item in expr.as_ordered_factors():
if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
if item.exp != -1:
b.append(sp.Pow(item.base, -item.exp, evaluate=False))
else:
b.append(sp.Pow(item.base, -item.exp))
else:
a.append(item)
a = a or [S.One]
a_str = [self._print(x) for x in a]
b_str = [self._print(x) for x in b]
result = a_str[0]
for item in a_str[1:]:
result = self.intrinsics['*'].format(result, item)
if len(b) > 0:
denominator_str = b_str[0]
for item in b_str[1:]:
denominator_str = self.intrinsics['*'].format(denominator_str, item)
result = self.intrinsics['/'].format(result, denominator_str)
if insideAdd:
return sign, result
else:
if sign < 0:
return self.intrinsics['*'].format(self._print(S.NegativeOne), result)
else:
return result
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))
else:
return super(CustomSympyPrinterVectorized, self)._print_Pow(expr)
def _print_Float(self, expr):
if expr not in self.constantsDict:
self.constantsDict[expr] = sp.Dummy()
symbol = self.constantsDict[expr]
return symbol.name
def _print_Rational(self, expr):
if expr not in self.constantsDict:
self.constantsDict[expr] = sp.Symbol("__value_%d_%d" % (expr.p, expr.q))
symbol = self.constantsDict[expr]
return symbol.name
def _print_Piecewise(self, expr):
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
result = self._print(expr.args[-1][0])
for trueExpr, condition in reversed(expr.args[:-1]):
result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition))
return result
def _print_Relational(self, expr):
return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs)
def _print_Equality(self, expr):
"""Equality operator is not printable in default printer"""
return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs))
def x86VectorInstructionSet(dataType='double', instructionSet='avx'):
baseNames = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'==': 'cmp[0, 1, _CMP_EQ_UQ ]',
'!=': 'cmp[0, 1, _CMP_NEQ_UQ ]',
'>=': 'cmp[0, 1, _CMP_GE_OQ ]',
'<=': 'cmp[0, 1, _CMP_LE_OQ ]',
'<': 'cmp[0, 1, _CMP_NGE_UQ ]',
'>': 'cmp[0, 1, _CMP_NLE_UQ ]',
'blendv': 'blendv[0, 1, 2]',
'sqrt': 'sqrt[0]',
'makeVec': 'set[0,0,0,0]',
'makeZero': 'setzero[]',
'loadU': 'loadu[0]',
'loadA': 'load[0]',
'storeU': 'storeu[0,1]',
'storeA': 'store [0,1]',
}
headers = {
'avx': ['<immintrin.h>'],
'sse': ['<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
}
suffix = {
'double': 'pd',
'float': 'ps',
}
prefix = {
'sse': '_mm',
'avx': '_mm256',
'avx512': '_mm512',
}
width = {
("double", "sse"): 2,
("float", "sse"): 4,
("double", "avx"): 4,
("float", "avx"): 8,
("double", "avx512"): 8,
("float", "avx512"): 16,
}
result = {}
pre = prefix[instructionSet]
suf = suffix[dataType]
for intrinsicId, functionShortcut in baseNames.items():
functionShortcut = functionShortcut.strip()
name = functionShortcut[:functionShortcut.index('[')]
args = functionShortcut[functionShortcut.index('[') + 1: -1]
argString = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
argString += "{" + arg + "},"
else:
argString += arg
argString = argString[:-1] + ")"
result[intrinsicId] = pre + "_" + name + "_" + suf + argString
result['width'] = width[(dataType, instructionSet)]
result['dataTypePrefix'] = {
'double': "_" + pre + 'd',
'float': "_" + pre,
}
bitWidth = result['width'] * 64
result['double'] = "__m%dd" % (bitWidth,)
result['float'] = "__m%d" % (bitWidth,)
result['int'] = "__m%di" % (bitWidth,)
result['headers'] = headers[instructionSet]
return result
selectedInstructionSet = {
'float': x86VectorInstructionSet('float', 'avx'),
'double': x86VectorInstructionSet('double', 'avx'),
}
......@@ -33,7 +33,7 @@ Then 'cl.exe' is used to compile.
where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat'
- **'arch'**: 'x86' or 'x64'
- **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated
- **'restrictQualifier'**: the restrict qualifier is not standardized accross compilers.
- **'restrictQualifier'**: the restrict qualifier is not standardized across compilers.
For Windows compilers the qualifier should be ``__restrict``
......@@ -70,7 +70,7 @@ import glob
import atexit
import shutil
from ctypes import cdll
from pystencils.backends.cbackend import generateC
from pystencils.backends.cbackend import generateC, getHeaders
from collections import OrderedDict, Mapping
from pystencils.transformations import symbolNameToVariableName
from pystencils.types import toCtypes, getBaseType, StructType
......@@ -276,10 +276,13 @@ def compileObjectCacheToSharedLibrary():
atexit.register(compileObjectCacheToSharedLibrary)
def generateCode(ast, includes, restrictQualifier, functionPrefix, targetFile):
def generateCode(ast, restrictQualifier, functionPrefix, targetFile):
headers = getHeaders(ast)
headers.update(['<cmath>', '<cstdint>'])
with open(targetFile, 'w') as sourceFile:
code = generateC(ast)
includes = "\n".join(["#include <%s>" % (includeFile,) for includeFile in includes])
includes = "\n".join(["#include %s" % (includeFile,) for includeFile in headers])
print(includes, file=sourceFile)
print("#define RESTRICT %s" % (restrictQualifier,), file=sourceFile)
print("#define FUNC_PREFIX %s" % (functionPrefix,), file=sourceFile)
......@@ -310,7 +313,7 @@ def compileLinux(ast, codeHashStr, srcFile, libFile):
objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.o')
# Compilation
if not os.path.exists(objectFile):
generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'], '', srcFile)
generateCode(ast, compilerConfig['restrictQualifier'], '', srcFile)
compileCmd = [compilerConfig['command'], '-c'] + compilerConfig['flags'].split()
compileCmd += ['-o', objectFile, srcFile]
runCompileStep(compileCmd)
......@@ -326,7 +329,7 @@ def compileWindows(ast, codeHashStr, srcFile, libFile):
objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.obj')
# Compilation
if not os.path.exists(objectFile):
generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'],
generateCode(ast, compilerConfig['restrictQualifier'],
'__declspec(dllexport)', srcFile)
# /c compiles only, /EHsc turns of exception handling in c code
......
import abc
import sympy as sp
import math
import pycuda.driver as cuda
import pycuda.autoinit
from pystencils.astnodes import Conditional, Block
from pystencils.slicing import normalizeSlice
from pystencils.types import TypedSymbol, createTypeFromString
BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z"))
THREAD_IDX = list(sp.symbols("threadIdx.x threadIdx.y threadIdx.z"))
BLOCK_IDX = [TypedSymbol("blockIdx." + coord, createTypeFromString("int")) for coord in ('x', 'y', 'z')]
THREAD_IDX = [TypedSymbol("threadIdx." + coord, createTypeFromString("int")) for coord in ('x', 'y', 'z')]
class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})):
......
......@@ -7,11 +7,18 @@ from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.field import Field, offsetComponentToDirectionString
from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, createTypeFromString
from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc
from pystencils.slicing import normalizeSlice
import pystencils.astnodes as ast
def filteredTreeIteration(node, nodeType):
for arg in node.args:
if isinstance(arg, nodeType):
yield arg
yield from filteredTreeIteration(arg, nodeType)
def fastSubs(term, subsDict):
"""Similar to sympy subs function.
This version is much faster for big substitution dictionaries than sympy version"""
......@@ -332,9 +339,8 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
coordDict = createCoordinateDict(basePointerInfo[0])
_, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
baseArr = IndexedBase(lastPointer, shape=(1,))
result = ast.ResolvedFieldAccess(baseArr, offset, fieldAccess.field, fieldAccess.offsets, fieldAccess.index)
castFunc = sp.Function("cast")
result = ast.ResolvedFieldAccess(lastPointer, offset, fieldAccess.field, fieldAccess.offsets, fieldAccess.index)
if isinstance(getBaseType(fieldAccess.field.dtype), StructType):
newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0])
result = castFunc(result, newType)
......
......@@ -2,7 +2,11 @@ import ctypes
import sympy as sp
import numpy as np
from sympy.core.cache import cacheit
from pystencils.cache import memorycache
from pystencils.utils import allEqual
castFunc = sp.Function("cast")
class TypedSymbol(sp.Symbol):
......@@ -28,7 +32,7 @@ class TypedSymbol(sp.Symbol):
def _hashable_content(self):
superClassContents = list(super(TypedSymbol, self)._hashable_content())
return tuple(superClassContents + [hash(repr(self._dtype))])
return tuple(superClassContents + [hash(str(self._dtype))])
def __getnewargs__(self):
return self.name, self.dtype
......@@ -52,6 +56,7 @@ def createType(specification):
return StructType(npDataType, const=False)
@memorycache(maxsize=64)
def createTypeFromString(specification):
"""
Creates a new Type object from a c-like string specification
......@@ -131,10 +136,79 @@ toCtypes.map = {
}
def peelOffType(dtype, typeToPeelOff):
while type(dtype) is typeToPeelOff:
dtype = dtype.baseType
return dtype
def collateTypes(types):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
# Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types):
pointerType = None
for t in types:
if type(t) is PointerType:
if pointerType is not None:
raise ValueError("Cannot collate the combination of two pointer types")
pointerType = t
elif type(t) is BasicType:
if not (t.is_int() or t.is_uint()):
raise ValueError("Invalid pointer arithmetic")
else:
raise ValueError("Invalid pointer arithmetic")
return pointerType
# peel of vector types, if at least one vector type occurred the result will also be the vector type
vectorType = [t for t in types if type(t) is VectorType]
if not allEqual(t.width for t in vectorType):
raise ValueError("Collation failed because of vector types with different width")
types = [peelOffType(t, VectorType) for t in types]
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary
resultNumpyType = np.result_type(*(t.numpyDtype for t in types))
result = BasicType(resultNumpyType)
if vectorType:
result = VectorType(result, vectorType[0].width)
return result
@memorycache(maxsize=2048)
def getTypeOfExpression(expr):
if isinstance(expr, TypedSymbol):
from pystencils.astnodes import ResolvedFieldAccess
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return createTypeFromString("int")
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return createTypeFromString("double")
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
raise ValueError("All symbols inside this expression have to be typed!")
elif hasattr(expr, 'func') and expr.func == castFunc:
return expr.args[1]
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
branchResults = [a[0] for a in expr.args]
return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults))
elif isinstance(expr, sp.Indexed):
typedSymbol = expr.base.label
return typedSymbol.dtype
elif isinstance(expr, sp.Expr):
types = tuple(getTypeOfExpression(a) for a in expr.args)
return collateTypes(types)
elif isinstance(expr, sp.boolalg.Boolean):
return createTypeFromString("bool")
raise NotImplementedError("Could not determine type for " + str(expr))
class Type(sp.Basic):
......@@ -239,6 +313,44 @@ class BasicType(Type):
return hash(str(self))
class VectorType(Type):
instructionSet = None
def __init__(self, baseType, width=4):
self._baseType = baseType
self.width = width
@property
def baseType(self):
return self._baseType
@property
def itemSize(self):
return self.width * self.baseType.itemSize
def __eq__(self, other):
if not isinstance(other, VectorType):
return False
else:
return (self.baseType, self.width) == (other.baseType, other.width)
def __str__(self):
if self.instructionSet is None:
return "%s[%d]" % (self.baseType, self.width)
else:
if self.baseType == createTypeFromString("int64"):
return self.instructionSet['int']
elif self.baseType == createTypeFromString("double"):
return self.instructionSet['double']
elif self.baseType == createTypeFromString("float"):
return self.instructionSet['float']
else:
raise NotImplementedError()
def __hash__(self):
return hash(str(self))
class PointerType(Type):
def __init__(self, baseType, const=False, restrict=True):
self._baseType = baseType
......
......@@ -4,3 +4,12 @@ class DotDict(dict):
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def allEqual(iterator):
iterator = iter(iterator)
try:
first = next(iterator)
except StopIteration:
return True
return all(first == rest for rest in iterator)
import sympy as sp
import warnings
from pystencils.transformations import filteredTreeIteration
from pystencils.types import TypedSymbol, VectorType, PointerType, BasicType, getTypeOfExpression, castFunc
import pystencils.astnodes as ast
from pystencils.utils import allEqual
def asVectorType(resolvedFieldAccess, vectorizationWidth):
"""Returns a new ResolvedFieldAccess that has a vector type"""
dtype = resolvedFieldAccess.typedSymbol.dtype
assert type(dtype) is PointerType
basicType = dtype.baseType
assert type(basicType) is BasicType, "Structs are not supported"
newDtype = VectorType(basicType, vectorizationWidth)
newDtype = PointerType(newDtype, dtype.const, dtype.restrict)
newTypedSymbol = TypedSymbol(resolvedFieldAccess.typedSymbol.name, newDtype)
return ast.ResolvedFieldAccess(newTypedSymbol, resolvedFieldAccess.args[1], resolvedFieldAccess.field,
resolvedFieldAccess.offsets, resolvedFieldAccess.idxCoordinateValues)
def vectorize(astNode, vectorWidth=4):
"""
Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if
- loop bounds are constant
- loop range is a multiple of vector width
"""
innerLoops = [n for n in astNode.atoms(ast.LoopOverCoordinate) if n.isInnermostLoop]
for loopNode in innerLoops:
loopRange = loopNode.stop - loopNode.start
# Check restrictions
if isinstance(loopRange, sp.Basic) and not loopRange.is_integer:
warnings.warn("Currently only loops with fixed ranges can be vectorized - skipping loop")
continue
if loopRange % vectorWidth != 0 or loopNode.step != 1:
warnings.warn("Currently only loops with loop bounds that are multiples "
"of vectorization width can be vectorized")
continue
loopNode.step = vectorWidth
# All field accesses depending on loop coordinate are changed to vector type
fieldAccesses = [n for n in loopNode.atoms(ast.ResolvedFieldAccess)]
substitutions = {fa: castFunc(fa, VectorType(BasicType(fa.field.dtype), vectorWidth)) for fa in fieldAccesses}
loopNode.subs(substitutions)
def insertVectorCasts(astNode):
"""
Inserts necessary casts from scalar values to vector values
"""
def visitExpr(expr):
if expr.func in (sp.Add, sp.Mul):
newArgs = [visitExpr(a) for a in expr.args]
argTypes = [getTypeOfExpression(a) for a in newArgs]
if not any(type(t) is VectorType for t in argTypes):
return expr
else:
vectorWidths = [d.width for d in argTypes if type(d) is VectorType]
assert allEqual(vectorWidths), "Incompatible vector type widths"
vectorWidth = vectorWidths[0]
castedArgs = [castFunc(a, VectorType(t, vectorWidth)) if type(t) is not VectorType else a
for a, t in zip(newArgs, argTypes)]
return expr.func(*castedArgs)
elif expr.func == sp.Piecewise:
raise NotImplementedError()
else:
return expr
substitutionDict = {}
for asmt in filteredTreeIteration(astNode, ast.SympyAssignment):
subsExpr = asmt.rhs.subs(substitutionDict)
asmt.rhs = visitExpr(subsExpr)
rhsType = getTypeOfExpression(asmt.rhs)
if isinstance(asmt.lhs, TypedSymbol):
lhsType = asmt.lhs.dtype
if type(rhsType) is VectorType and type(lhsType) is not VectorType:
newLhsType = VectorType(lhsType, rhsType.width)
newLhs = TypedSymbol(asmt.lhs.name, newLhsType)
substitutionDict[asmt.lhs] = newLhs
asmt.lhs = newLhs
elif asmt.lhs.func == castFunc:
lhsType = asmt.lhs.args[1]
if type(lhsType) is VectorType and type(rhsType) is not VectorType:
asmt.rhs = castFunc(asmt.rhs, lhsType)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment