From ea847bc5e3b820794cdc3fdb2d8affa0c36e8883 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 6 Oct 2017 16:04:25 +0200 Subject: [PATCH] Vectorization & Type system overhaul - first vectorization tests are running - type system: use memoized getTypeOfExpression - casts are done using sp.Function('cast') - C backend adapted for vectorization support - AST nodes can required optional headers --- astnodes.py | 10 + backends/cbackend.py | 174 +++++++++++++++-- backends/cbackend_vectorized.py | 312 ------------------------------ backends/simd_instruction_sets.py | 91 +++++++++ cpu/cpujit.py | 15 +- gpucuda/indexing.py | 6 +- transformations.py | 14 +- types.py | 118 ++++++++++- utils.py | 9 + vectorization.py | 94 +++++++++ 10 files changed, 504 insertions(+), 339 deletions(-) delete mode 100644 backends/cbackend_vectorized.py create mode 100644 backends/simd_instruction_sets.py create mode 100644 vectorization.py diff --git a/astnodes.py b/astnodes.py index 25d4ba4a0..16c1e9829 100644 --- a/astnodes.py +++ b/astnodes.py @@ -6,6 +6,8 @@ from pystencils.types import TypedSymbol, createType, get_type_from_sympy, creat class ResolvedFieldAccess(sp.Indexed): def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues): + if not isinstance(base, IndexedBase): + base = IndexedBase(base, shape=(1,)) obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex) obj.field = field obj.offsets = offsets @@ -21,6 +23,14 @@ class ResolvedFieldAccess(sp.Indexed): superClassContents = super(ResolvedFieldAccess, self)._hashable_content() return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field)) + @property + def typedSymbol(self): + return self.base.label + + def __str__(self): + top = super(ResolvedFieldAccess, self).__str__() + return "%s (%s)" % (top, self.typedSymbol.dtype) + def __getnewargs__(self): return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues diff --git a/backends/cbackend.py b/backends/cbackend.py index c18a0544a..5bbb9b787 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -4,8 +4,13 @@ try: except ImportError: from sympy.printing.ccode import C99CodePrinter as CCodePrinter -from pystencils.astnodes import Node -from pystencils.types import createType, PointerType +from collections import namedtuple +from sympy.core.mul import _keep_coeff +from sympy.core import S + +from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment +from pystencils.types import createType, PointerType, getTypeOfExpression, VectorType, castFunc +from pystencils.backends.simd_instruction_sets import selectedInstructionSet def generateC(astNode, signatureOnly=False): @@ -14,10 +19,26 @@ def generateC(astNode, signatureOnly=False): """ fieldTypes = set([f.dtype for f in astNode.fieldsAccessed]) useFloatConstants = createType("double") not in fieldTypes - printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly) + + vectorIS = selectedInstructionSet['double'] + printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly, vectorInstructionSet=vectorIS) return printer(astNode) +def getHeaders(astNode): + headers = set() + + if hasattr(astNode, 'headers'): + headers.update(astNode.headers) + elif isinstance(astNode, SympyAssignment): + if type(getTypeOfExpression(astNode.rhs)) is VectorType: + headers.update(selectedInstructionSet['double']['headers']) + + for a in astNode.args: + if isinstance(a, Node): + headers.update(getHeaders(a)) + + return headers # --------------------------------------- Backend Specific Nodes ------------------------------------------------------- @@ -26,6 +47,7 @@ class CustomCppCode(Node): self._code = "\n" + code self._symbolsRead = set(symbolsRead) self._symbolsDefined = set(symbolsDefined) + self.headers = [] @property def code(self): @@ -48,24 +70,33 @@ class PrintNode(CustomCppCode): def __init__(self, symbolToPrint): code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name) super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set()) + self.headers.append("<iostream>") # ------------------------------------------- Printer ------------------------------------------------------------------ - class CBackend(object): - def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False): + def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None): if sympyPrinter is None: self.sympyPrinter = CustomSympyPrinter(constantsAsFloats) + if vectorInstructionSet is not None: + self.sympyPrinter = VectorizedCustomSympyPrinter(vectorInstructionSet, constantsAsFloats) + else: + self.sympyPrinter = CustomSympyPrinter(constantsAsFloats) else: self.sympyPrinter = sympyPrinter + self._vectorInstructionSet = vectorInstructionSet self._indent = " " self._signatureOnly = signatureOnly def __call__(self, node): - return str(self._print(node)) + prevIs = VectorType.instructionSet + VectorType.instructionSet = self._vectorInstructionSet + result = str(self._print(node)) + VectorType.instructionSet = prevIs + return result def _print(self, node): for cls in type(node).__mro__: @@ -103,13 +134,16 @@ class CBackend(object): return "%s%s\n%s" % (prefix, loopStr, self._print(node.body)) def _print_SympyAssignment(self, node): - dtype = "" if node.isDeclaration: - if node.isConst: - dtype = "const " + str(node.lhs.dtype) + " " + dtype = "const " + str(node.lhs.dtype) + " " if node.isConst else str(node.lhs.dtype) + " " + return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) + else: + lhsType = getTypeOfExpression(node.lhs) + if type(lhsType) is VectorType and node.lhs.func == castFunc: + return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]), + self.sympyPrinter.doprint(node.rhs)) + ';' else: - dtype = str(node.lhs.dtype) + " " - return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) + return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) def _print_TemporaryMemoryAllocation(self, node): return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol), @@ -177,3 +211,121 @@ class CustomSympyPrinter(CCodePrinter): else: return super(CustomSympyPrinter, self)._print_Function(expr) + +class VectorizedCustomSympyPrinter(CustomSympyPrinter): + SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) + + def __init__(self, instructionSet, constantsAsFloats=False): + super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats) + self.instructionSet = instructionSet + + def _print_Function(self, expr): + name = str(expr.func).lower() + if name == 'cast': + arg, dtype = expr.args + if type(dtype) is VectorType: + if type(arg) is ResolvedFieldAccess: + return self.instructionSet['loadU'].format("& " + self._print(arg)) + else: + return self.instructionSet['makeVec'].format(self._print(arg)) + + return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) + + def _print_Add(self, expr, order=None): + exprType = getTypeOfExpression(expr) + if type(exprType) is not VectorType: + return super(VectorizedCustomSympyPrinter, self)._print_Add(expr, order) + assert self.instructionSet['width'] == exprType.width + + summands = [] + for term in expr.args: + if term.func == sp.Mul: + sign, t = self._print_Mul(term, insideAdd=True) + else: + t = self._print(term) + sign = 1 + summands.append(self.SummandInfo(sign, t)) + # Use positive terms first + summands.sort(key=lambda e: e.sign, reverse=True) + # if no positive term exists, prepend a zero + if summands[0].sign == -1: + summands.insert(0, self.SummandInfo(1, "0")) + + assert len(summands) >= 2 + processed = summands[0].term + for summand in summands[1:]: + func = self.instructionSet['-'] if summand.sign == -1 else self.instructionSet['+'] + processed = func.format(processed, summand.term) + return processed + + def _print_Mul(self, expr, insideAdd=False): + exprType = getTypeOfExpression(expr) + if type(exprType) is not VectorType: + return super(VectorizedCustomSympyPrinter, self)._print_Mul(expr) + assert self.instructionSet['width'] == exprType.width + + c, e = expr.as_coeff_Mul() + if c < 0: + expr = _keep_coeff(-c, e) + sign = -1 + else: + sign = 1 + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + # Gather args for numerator/denominator + for item in expr.as_ordered_factors(): + if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: + if item.exp != -1: + b.append(sp.Pow(item.base, -item.exp, evaluate=False)) + else: + b.append(sp.Pow(item.base, -item.exp)) + else: + a.append(item) + + a = a or [S.One] + + a_str = [self._print(x) for x in a] + b_str = [self._print(x) for x in b] + + result = a_str[0] + for item in a_str[1:]: + result = self.intrinsics['*'].format(result, item) + + if len(b) > 0: + denominator_str = b_str[0] + for item in b_str[1:]: + denominator_str = self.intrinsics['*'].format(denominator_str, item) + result = self.intrinsics['/'].format(result, denominator_str) + + if insideAdd: + return sign, result + else: + if sign < 0: + return self.intrinsics['*'].format(self._print(S.NegativeOne), result) + else: + return result + +# def _print_Piecewise(self, expr): +# if expr.args[-1].cond != True: +# # We need the last conditional to be a True, otherwise the resulting +# # function may not return a result. +# raise ValueError("All Piecewise expressions must contain an " +# "(expr, True) statement to be used as a default " +# "condition. Without one, the generated " +# "expression may not evaluate to anything under " +# "some condition.") +# +# result = self._print(expr.args[-1][0]) +# for trueExpr, condition in reversed(expr.args[:-1]): +# result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition)) +# return result +# +# def _print_Relational(self, expr): +# return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs) +# +# def _print_Equality(self, expr): +# return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs)) +# + diff --git a/backends/cbackend_vectorized.py b/backends/cbackend_vectorized.py deleted file mode 100644 index 62a16b217..000000000 --- a/backends/cbackend_vectorized.py +++ /dev/null @@ -1,312 +0,0 @@ -from collections import namedtuple - -import sympy as sp -from sympy.core import S - -try: - from sympy.utilities.codegen import CCodePrinter -except ImportError: - from sympy.printing.ccode import C99CodePrinter as CCodePrinter - -from sympy.core.mul import _keep_coeff - -from pystencils.backends.cbackend import CustomSympyPrinter -from pystencils.types import getBaseType, createTypeFromString - - -def getInstructionSetInfoIntel(dataType='double', instructionSet='avx'): - baseNames = { - '+': 'add[0, 1]', - '-': 'sub[0, 1]', - '*': 'mul[0, 1]', - '/': 'div[0, 1]', - - '==': 'cmp[0, 1, _CMP_EQ_UQ ]', - '!=': 'cmp[0, 1, _CMP_NEQ_UQ ]', - '>=': 'cmp[0, 1, _CMP_GE_OQ ]', - '<=': 'cmp[0, 1, _CMP_LE_OQ ]', - '<': 'cmp[0, 1, _CMP_NGE_UQ ]', - '>': 'cmp[0, 1, _CMP_NLE_UQ ]', - - 'blendv': 'blendv[0, 1, 2]', - - 'sqrt': 'sqrt[0]', - - 'makeVec': 'set[0,0,0,0]', - 'makeZero': 'setzero[]', - - 'loadU': 'loadu [0]', - 'loadA': 'load [0]', - 'storeU': 'storeu[0]', - 'storeA': 'store [0]', - } - - suffix = { - 'double': 'pd', - 'float': 'ps', - } - prefix = { - 'sse': '_mm', - 'avx': '_mm256', - 'avx512': '_mm512', - } - - width = { - ("double", "sse"): 2, - ("float", "sse"): 4, - ("double", "avx"): 4, - ("float", "avx"): 8, - ("double", "avx512"): 8, - ("float", "avx512"): 16, - } - - result = {} - pre = prefix[instructionSet] - suf = suffix[dataType] - for intrinsicId, functionShortcut in baseNames.items(): - functionShortcut = functionShortcut.strip() - name = functionShortcut[:functionShortcut.index('[')] - args = functionShortcut[functionShortcut.index('[') + 1: -1] - argString = "(" - for arg in args.split(","): - arg = arg.strip() - if not arg: - continue - if arg in ('0', '1', '2', '3', '4', '5'): - argString += "{" + arg + "}," - else: - argString += arg - argString = argString[:-1] + ")" - result[intrinsicId] = pre + "_" + name + "_" + suf + argString - - result['width'] = width[(dataType, instructionSet)] - result['dataTypePrefix'] = { - 'double': "_" + pre + 'd', - 'float': "_" + pre, - } - - return result - - -class VectorizedCBackend(object): - - def __init__(self, astNode, instructionSet='avx'): - fieldTypes = set([getBaseType(f.dtype) for f in astNode.fieldsAccessed]) - if len(fieldTypes) != 1: - raise ValueError("Vectorized backend does not support kernels with mixed field types") - fieldType = fieldTypes.pop() - assert fieldType.is_float - dtypeName = str(fieldType) - - instructionSetInfo = getInstructionSetInfoIntel(dtypeName, instructionSet) - - self.vectorizationWidth = instructionSetInfo['width'] - self.sympyVecPrinter = CustomSympyPrinterVectorized(instructionSetInfo) - self.sympyPrinter = CustomSympyPrinter(constantsAsFloats=(dtypeName == 'float')) - - self._indent = " " - self._vecTypeName = instructionSetInfo['dataTypePrefix'][dtypeName] - self.dtypeName = dtypeName - - def __call__(self, node): - return str(self._print(node)) - - def _print(self, node): - for cls in type(node).__mro__: - methodName = "_print_" + cls.__name__ - if hasattr(self, methodName): - return getattr(self, methodName)(node) - raise NotImplementedError("CBackend does not support node of type " + cls.__name__) - - def _print_KernelFunction(self, node): - blockContents = "\n".join([self._print(child) for child in node.body.args]) - constantBlock = self.sympyVecPrinter.getConstantsBlock(self._vecTypeName) - - body = "{\n%s\n%s\n}" % (constantBlock, self._indent + self._indent.join(blockContents.splitlines(True))) - - functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters] - funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments)) - return funcDeclaration + "\n" + body - - def _print_Block(self, node): - blockContents = "\n".join([self._print(child) for child in node.args]) - return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)),) - - def _print_PragmaBlock(self, node): - return "%s\n%s" % (node.pragmaLine, self._print_Block(node)) - - def _print_LoopOverCoordinate(self, node): - if node.isInnermostLoop: - iterRange = node.stop - node.start - if isinstance(iterRange, sp.Basic) and not iterRange.is_integer: - raise NotImplementedError("Vectorized backend currently only supports fixed size inner loops") - if iterRange % self.vectorizationWidth != 0 or node.step != 1: - raise NotImplementedError("Vectorized backend only supports loop bounds that are " - "multiples of vectorization width") - step = self.vectorizationWidth - else: - step = node.step - - counterVar = node.loopCounterName - start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start)) - condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop)) - update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(step),) - loopStr = "for (%s; %s; %s)" % (start, condition, update) - - prefix = "\n".join(node.prefixLines) - if prefix: - prefix += "\n" - return "%s%s\n%s" % (prefix, loopStr, self._print(node.body)) - - def _print_SympyAssignment(self, node): - dtype = "" - if node.isDeclaration: - assert str(getBaseType(node.lhs.dtype)) in (self.dtypeName, 'bool') - if node.lhs.dtype == createTypeFromString(self.dtypeName): - dtypeStr = self._vecTypeName - printer = self.sympyVecPrinter - else: - dtypeStr = str(node.lhs.dtype) - printer = self.sympyPrinter - - if node.isConst: - dtype = "const " + dtypeStr + " " - else: - dtype = dtypeStr + " " - else: - printer = self.sympyVecPrinter - return "%s %s = %s;" % (str(dtype), printer.doprint(node.lhs), printer.doprint(node.rhs)) - - def _print_TemporaryMemoryAllocation(self, node): - return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol), - node.symbol.dtype, self.sympyPrinter.doprint(node.size)) - - def _print_TemporaryMemoryFree(self, node): - return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),) - - def _print_CustomCppCode(self, node): - return node.code - - -class CustomSympyPrinterVectorized(CCodePrinter): - SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) - - def __init__(self, instructionSetInfo): - super(CustomSympyPrinterVectorized, self).__init__() - self.intrinsics = instructionSetInfo - self.constantsDict = {} - - def getConstantsBlock(self, vecTypeStr): - result = "" - for value, symbol in self.constantsDict.items(): - rhsStr = self.intrinsics['makeVec'].format(self._print(value)) - result += "const %s %s = %s;\n" % (vecTypeStr, symbol.name, rhsStr) - return result - - def _print_Add(self, expr, order=None): - summands = [] - for term in expr.args: - if term.func == sp.Mul: - sign, t = self._print_Mul(term, insideAdd=True) - else: - t = self._print(term) - sign = 1 - summands.append(self.SummandInfo(sign, t)) - # Use positive terms first - summands.sort(key=lambda e: e.sign, reverse=True) - # if no positive term exists, prepend a zero - if summands[0].sign == -1: - summands.insert(0, self.SummandInfo(1, "0")) - - assert len(summands) >= 2 - processed = summands[0].term - for summand in summands[1:]: - func = self.intrinsics['-'] if summand.sign == -1 else self.intrinsics['+'] - processed = func.format(processed, summand.term) - return processed - - def _print_Mul(self, expr, insideAdd=False): - - c, e = expr.as_coeff_Mul() - if c < 0: - expr = _keep_coeff(-c, e) - sign = -1 - else: - sign = 1 - - a = [] # items in the numerator - b = [] # items that are in the denominator (if any) - - # Gather args for numerator/denominator - for item in expr.as_ordered_factors(): - if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: - if item.exp != -1: - b.append(sp.Pow(item.base, -item.exp, evaluate=False)) - else: - b.append(sp.Pow(item.base, -item.exp)) - else: - a.append(item) - - a = a or [S.One] - - a_str = [self._print(x) for x in a] - b_str = [self._print(x) for x in b] - - result = a_str[0] - for item in a_str[1:]: - result = self.intrinsics['*'].format(result, item) - - if len(b) > 0: - denominator_str = b_str[0] - for item in b_str[1:]: - denominator_str = self.intrinsics['*'].format(denominator_str, item) - result = self.intrinsics['/'].format(result, denominator_str) - - if insideAdd: - return sign, result - else: - if sign < 0: - return self.intrinsics['*'].format(self._print(S.NegativeOne), result) - else: - return result - - def _print_Pow(self, expr): - """Don't use std::pow function, for small integer exponents, write as multiplication""" - if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: - return self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) - else: - return super(CustomSympyPrinterVectorized, self)._print_Pow(expr) - - def _print_Float(self, expr): - if expr not in self.constantsDict: - self.constantsDict[expr] = sp.Dummy() - symbol = self.constantsDict[expr] - return symbol.name - - def _print_Rational(self, expr): - if expr not in self.constantsDict: - self.constantsDict[expr] = sp.Symbol("__value_%d_%d" % (expr.p, expr.q)) - symbol = self.constantsDict[expr] - return symbol.name - - def _print_Piecewise(self, expr): - if expr.args[-1].cond != True: - # We need the last conditional to be a True, otherwise the resulting - # function may not return a result. - raise ValueError("All Piecewise expressions must contain an " - "(expr, True) statement to be used as a default " - "condition. Without one, the generated " - "expression may not evaluate to anything under " - "some condition.") - - result = self._print(expr.args[-1][0]) - for trueExpr, condition in reversed(expr.args[:-1]): - result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition)) - return result - - def _print_Relational(self, expr): - return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs) - - def _print_Equality(self, expr): - """Equality operator is not printable in default printer""" - return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs)) diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py new file mode 100644 index 000000000..213b4c481 --- /dev/null +++ b/backends/simd_instruction_sets.py @@ -0,0 +1,91 @@ + + +def x86VectorInstructionSet(dataType='double', instructionSet='avx'): + baseNames = { + '+': 'add[0, 1]', + '-': 'sub[0, 1]', + '*': 'mul[0, 1]', + '/': 'div[0, 1]', + + '==': 'cmp[0, 1, _CMP_EQ_UQ ]', + '!=': 'cmp[0, 1, _CMP_NEQ_UQ ]', + '>=': 'cmp[0, 1, _CMP_GE_OQ ]', + '<=': 'cmp[0, 1, _CMP_LE_OQ ]', + '<': 'cmp[0, 1, _CMP_NGE_UQ ]', + '>': 'cmp[0, 1, _CMP_NLE_UQ ]', + + 'blendv': 'blendv[0, 1, 2]', + + 'sqrt': 'sqrt[0]', + + 'makeVec': 'set[0,0,0,0]', + 'makeZero': 'setzero[]', + + 'loadU': 'loadu[0]', + 'loadA': 'load[0]', + 'storeU': 'storeu[0,1]', + 'storeA': 'store [0,1]', + } + + headers = { + 'avx': ['<immintrin.h>'], + 'sse': ['<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>'] + } + + suffix = { + 'double': 'pd', + 'float': 'ps', + } + prefix = { + 'sse': '_mm', + 'avx': '_mm256', + 'avx512': '_mm512', + } + + width = { + ("double", "sse"): 2, + ("float", "sse"): 4, + ("double", "avx"): 4, + ("float", "avx"): 8, + ("double", "avx512"): 8, + ("float", "avx512"): 16, + } + + result = {} + pre = prefix[instructionSet] + suf = suffix[dataType] + for intrinsicId, functionShortcut in baseNames.items(): + functionShortcut = functionShortcut.strip() + name = functionShortcut[:functionShortcut.index('[')] + args = functionShortcut[functionShortcut.index('[') + 1: -1] + argString = "(" + for arg in args.split(","): + arg = arg.strip() + if not arg: + continue + if arg in ('0', '1', '2', '3', '4', '5'): + argString += "{" + arg + "}," + else: + argString += arg + argString = argString[:-1] + ")" + result[intrinsicId] = pre + "_" + name + "_" + suf + argString + + result['width'] = width[(dataType, instructionSet)] + result['dataTypePrefix'] = { + 'double': "_" + pre + 'd', + 'float': "_" + pre, + } + + bitWidth = result['width'] * 64 + result['double'] = "__m%dd" % (bitWidth,) + result['float'] = "__m%d" % (bitWidth,) + result['int'] = "__m%di" % (bitWidth,) + + result['headers'] = headers[instructionSet] + return result + + +selectedInstructionSet = { + 'float': x86VectorInstructionSet('float', 'avx'), + 'double': x86VectorInstructionSet('double', 'avx'), +} diff --git a/cpu/cpujit.py b/cpu/cpujit.py index f853fa51b..c5919fc6c 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -33,7 +33,7 @@ Then 'cl.exe' is used to compile. where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat' - **'arch'**: 'x86' or 'x64' - **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated -- **'restrictQualifier'**: the restrict qualifier is not standardized accross compilers. +- **'restrictQualifier'**: the restrict qualifier is not standardized across compilers. For Windows compilers the qualifier should be ``__restrict`` @@ -70,7 +70,7 @@ import glob import atexit import shutil from ctypes import cdll -from pystencils.backends.cbackend import generateC +from pystencils.backends.cbackend import generateC, getHeaders from collections import OrderedDict, Mapping from pystencils.transformations import symbolNameToVariableName from pystencils.types import toCtypes, getBaseType, StructType @@ -276,10 +276,13 @@ def compileObjectCacheToSharedLibrary(): atexit.register(compileObjectCacheToSharedLibrary) -def generateCode(ast, includes, restrictQualifier, functionPrefix, targetFile): +def generateCode(ast, restrictQualifier, functionPrefix, targetFile): + headers = getHeaders(ast) + headers.update(['<cmath>', '<cstdint>']) + with open(targetFile, 'w') as sourceFile: code = generateC(ast) - includes = "\n".join(["#include <%s>" % (includeFile,) for includeFile in includes]) + includes = "\n".join(["#include %s" % (includeFile,) for includeFile in headers]) print(includes, file=sourceFile) print("#define RESTRICT %s" % (restrictQualifier,), file=sourceFile) print("#define FUNC_PREFIX %s" % (functionPrefix,), file=sourceFile) @@ -310,7 +313,7 @@ def compileLinux(ast, codeHashStr, srcFile, libFile): objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.o') # Compilation if not os.path.exists(objectFile): - generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'], '', srcFile) + generateCode(ast, compilerConfig['restrictQualifier'], '', srcFile) compileCmd = [compilerConfig['command'], '-c'] + compilerConfig['flags'].split() compileCmd += ['-o', objectFile, srcFile] runCompileStep(compileCmd) @@ -326,7 +329,7 @@ def compileWindows(ast, codeHashStr, srcFile, libFile): objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.obj') # Compilation if not os.path.exists(objectFile): - generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'], + generateCode(ast, compilerConfig['restrictQualifier'], '__declspec(dllexport)', srcFile) # /c compiles only, /EHsc turns of exception handling in c code diff --git a/gpucuda/indexing.py b/gpucuda/indexing.py index 8291ebf4a..9c7f1be6b 100644 --- a/gpucuda/indexing.py +++ b/gpucuda/indexing.py @@ -1,15 +1,15 @@ import abc import sympy as sp -import math import pycuda.driver as cuda import pycuda.autoinit from pystencils.astnodes import Conditional, Block from pystencils.slicing import normalizeSlice +from pystencils.types import TypedSymbol, createTypeFromString -BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z")) -THREAD_IDX = list(sp.symbols("threadIdx.x threadIdx.y threadIdx.z")) +BLOCK_IDX = [TypedSymbol("blockIdx." + coord, createTypeFromString("int")) for coord in ('x', 'y', 'z')] +THREAD_IDX = [TypedSymbol("threadIdx." + coord, createTypeFromString("int")) for coord in ('x', 'y', 'z')] class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})): diff --git a/transformations.py b/transformations.py index 0975e96aa..38c6717b0 100644 --- a/transformations.py +++ b/transformations.py @@ -7,11 +7,18 @@ from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase from pystencils.field import Field, offsetComponentToDirectionString -from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, createTypeFromString +from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc from pystencils.slicing import normalizeSlice import pystencils.astnodes as ast +def filteredTreeIteration(node, nodeType): + for arg in node.args: + if isinstance(arg, nodeType): + yield arg + yield from filteredTreeIteration(arg, nodeType) + + def fastSubs(term, subsDict): """Similar to sympy subs function. This version is much faster for big substitution dictionaries than sympy version""" @@ -332,9 +339,8 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn coordDict = createCoordinateDict(basePointerInfo[0]) _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer) - baseArr = IndexedBase(lastPointer, shape=(1,)) - result = ast.ResolvedFieldAccess(baseArr, offset, fieldAccess.field, fieldAccess.offsets, fieldAccess.index) - castFunc = sp.Function("cast") + result = ast.ResolvedFieldAccess(lastPointer, offset, fieldAccess.field, fieldAccess.offsets, fieldAccess.index) + if isinstance(getBaseType(fieldAccess.field.dtype), StructType): newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0]) result = castFunc(result, newType) diff --git a/types.py b/types.py index 86ad051b3..23c487501 100644 --- a/types.py +++ b/types.py @@ -2,7 +2,11 @@ import ctypes import sympy as sp import numpy as np from sympy.core.cache import cacheit + from pystencils.cache import memorycache +from pystencils.utils import allEqual + +castFunc = sp.Function("cast") class TypedSymbol(sp.Symbol): @@ -28,7 +32,7 @@ class TypedSymbol(sp.Symbol): def _hashable_content(self): superClassContents = list(super(TypedSymbol, self)._hashable_content()) - return tuple(superClassContents + [hash(repr(self._dtype))]) + return tuple(superClassContents + [hash(str(self._dtype))]) def __getnewargs__(self): return self.name, self.dtype @@ -52,6 +56,7 @@ def createType(specification): return StructType(npDataType, const=False) +@memorycache(maxsize=64) def createTypeFromString(specification): """ Creates a new Type object from a c-like string specification @@ -131,10 +136,79 @@ toCtypes.map = { } +def peelOffType(dtype, typeToPeelOff): + while type(dtype) is typeToPeelOff: + dtype = dtype.baseType + return dtype + + +def collateTypes(types): + """ + Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double + Uses the collation rules from numpy. + """ + + # Pointer arithmetic case i.e. pointer + integer is allowed + if any(type(t) is PointerType for t in types): + pointerType = None + for t in types: + if type(t) is PointerType: + if pointerType is not None: + raise ValueError("Cannot collate the combination of two pointer types") + pointerType = t + elif type(t) is BasicType: + if not (t.is_int() or t.is_uint()): + raise ValueError("Invalid pointer arithmetic") + else: + raise ValueError("Invalid pointer arithmetic") + return pointerType + + # peel of vector types, if at least one vector type occurred the result will also be the vector type + vectorType = [t for t in types if type(t) is VectorType] + if not allEqual(t.width for t in vectorType): + raise ValueError("Collation failed because of vector types with different width") + types = [peelOffType(t, VectorType) for t in types] + + # now we should have a list of basic types - struct types are not yet supported + assert all(type(t) is BasicType for t in types) + + # use numpy collation -> create type from numpy type -> and, put vector type around if necessary + resultNumpyType = np.result_type(*(t.numpyDtype for t in types)) + result = BasicType(resultNumpyType) + if vectorType: + result = VectorType(result, vectorType[0].width) + return result + + +@memorycache(maxsize=2048) def getTypeOfExpression(expr): - if isinstance(expr, TypedSymbol): + from pystencils.astnodes import ResolvedFieldAccess + expr = sp.sympify(expr) + if isinstance(expr, sp.Integer): + return createTypeFromString("int") + elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): + return createTypeFromString("double") + elif isinstance(expr, ResolvedFieldAccess): + return expr.field.dtype + elif isinstance(expr, TypedSymbol): return expr.dtype - + elif isinstance(expr, sp.Symbol): + raise ValueError("All symbols inside this expression have to be typed!") + elif hasattr(expr, 'func') and expr.func == castFunc: + return expr.args[1] + elif hasattr(expr, 'func') and expr.func == sp.Piecewise: + branchResults = [a[0] for a in expr.args] + return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults)) + elif isinstance(expr, sp.Indexed): + typedSymbol = expr.base.label + return typedSymbol.dtype + elif isinstance(expr, sp.Expr): + types = tuple(getTypeOfExpression(a) for a in expr.args) + return collateTypes(types) + elif isinstance(expr, sp.boolalg.Boolean): + return createTypeFromString("bool") + + raise NotImplementedError("Could not determine type for " + str(expr)) class Type(sp.Basic): @@ -239,6 +313,44 @@ class BasicType(Type): return hash(str(self)) +class VectorType(Type): + instructionSet = None + + def __init__(self, baseType, width=4): + self._baseType = baseType + self.width = width + + @property + def baseType(self): + return self._baseType + + @property + def itemSize(self): + return self.width * self.baseType.itemSize + + def __eq__(self, other): + if not isinstance(other, VectorType): + return False + else: + return (self.baseType, self.width) == (other.baseType, other.width) + + def __str__(self): + if self.instructionSet is None: + return "%s[%d]" % (self.baseType, self.width) + else: + if self.baseType == createTypeFromString("int64"): + return self.instructionSet['int'] + elif self.baseType == createTypeFromString("double"): + return self.instructionSet['double'] + elif self.baseType == createTypeFromString("float"): + return self.instructionSet['float'] + else: + raise NotImplementedError() + + def __hash__(self): + return hash(str(self)) + + class PointerType(Type): def __init__(self, baseType, const=False, restrict=True): self._baseType = baseType diff --git a/utils.py b/utils.py index a48a963fa..2ea6b7dd9 100644 --- a/utils.py +++ b/utils.py @@ -4,3 +4,12 @@ class DotDict(dict): __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ + + +def allEqual(iterator): + iterator = iter(iterator) + try: + first = next(iterator) + except StopIteration: + return True + return all(first == rest for rest in iterator) diff --git a/vectorization.py b/vectorization.py new file mode 100644 index 000000000..54a9819f7 --- /dev/null +++ b/vectorization.py @@ -0,0 +1,94 @@ +import sympy as sp +import warnings + +from pystencils.transformations import filteredTreeIteration +from pystencils.types import TypedSymbol, VectorType, PointerType, BasicType, getTypeOfExpression, castFunc +import pystencils.astnodes as ast +from pystencils.utils import allEqual + + +def asVectorType(resolvedFieldAccess, vectorizationWidth): + """Returns a new ResolvedFieldAccess that has a vector type""" + dtype = resolvedFieldAccess.typedSymbol.dtype + assert type(dtype) is PointerType + basicType = dtype.baseType + assert type(basicType) is BasicType, "Structs are not supported" + + newDtype = VectorType(basicType, vectorizationWidth) + newDtype = PointerType(newDtype, dtype.const, dtype.restrict) + newTypedSymbol = TypedSymbol(resolvedFieldAccess.typedSymbol.name, newDtype) + return ast.ResolvedFieldAccess(newTypedSymbol, resolvedFieldAccess.args[1], resolvedFieldAccess.field, + resolvedFieldAccess.offsets, resolvedFieldAccess.idxCoordinateValues) + + +def vectorize(astNode, vectorWidth=4): + """ + Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if + - loop bounds are constant + - loop range is a multiple of vector width + """ + innerLoops = [n for n in astNode.atoms(ast.LoopOverCoordinate) if n.isInnermostLoop] + + for loopNode in innerLoops: + loopRange = loopNode.stop - loopNode.start + + # Check restrictions + if isinstance(loopRange, sp.Basic) and not loopRange.is_integer: + warnings.warn("Currently only loops with fixed ranges can be vectorized - skipping loop") + continue + if loopRange % vectorWidth != 0 or loopNode.step != 1: + warnings.warn("Currently only loops with loop bounds that are multiples " + "of vectorization width can be vectorized") + continue + + loopNode.step = vectorWidth + + # All field accesses depending on loop coordinate are changed to vector type + fieldAccesses = [n for n in loopNode.atoms(ast.ResolvedFieldAccess)] + substitutions = {fa: castFunc(fa, VectorType(BasicType(fa.field.dtype), vectorWidth)) for fa in fieldAccesses} + loopNode.subs(substitutions) + + +def insertVectorCasts(astNode): + """ + Inserts necessary casts from scalar values to vector values + """ + def visitExpr(expr): + if expr.func in (sp.Add, sp.Mul): + newArgs = [visitExpr(a) for a in expr.args] + argTypes = [getTypeOfExpression(a) for a in newArgs] + if not any(type(t) is VectorType for t in argTypes): + return expr + else: + vectorWidths = [d.width for d in argTypes if type(d) is VectorType] + assert allEqual(vectorWidths), "Incompatible vector type widths" + vectorWidth = vectorWidths[0] + castedArgs = [castFunc(a, VectorType(t, vectorWidth)) if type(t) is not VectorType else a + for a, t in zip(newArgs, argTypes)] + return expr.func(*castedArgs) + elif expr.func == sp.Piecewise: + raise NotImplementedError() + else: + return expr + + substitutionDict = {} + for asmt in filteredTreeIteration(astNode, ast.SympyAssignment): + subsExpr = asmt.rhs.subs(substitutionDict) + asmt.rhs = visitExpr(subsExpr) + rhsType = getTypeOfExpression(asmt.rhs) + if isinstance(asmt.lhs, TypedSymbol): + lhsType = asmt.lhs.dtype + if type(rhsType) is VectorType and type(lhsType) is not VectorType: + newLhsType = VectorType(lhsType, rhsType.width) + newLhs = TypedSymbol(asmt.lhs.name, newLhsType) + substitutionDict[asmt.lhs] = newLhs + asmt.lhs = newLhs + elif asmt.lhs.func == castFunc: + lhsType = asmt.lhs.args[1] + if type(lhsType) is VectorType and type(rhsType) is not VectorType: + asmt.rhs = castFunc(asmt.rhs, lhsType) + + + + + -- GitLab