Skip to content
Snippets Groups Projects
Commit 2e015cf5 authored by Martin Bauer's avatar Martin Bauer
Browse files

Vectorization: Piecewise/blend support

- sympys piecewise defined functions are mapped to blend instructions
- cast function is now a class
- several bugfixes
parent ea847bc5
Branches
Tags
No related merge requests found
import sympy as sp
from sympy.tensor import IndexedBase
from pystencils.field import Field
from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString
from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString, castFunc
class ResolvedFieldAccess(sp.Indexed):
......@@ -383,8 +383,8 @@ class SympyAssignment(Node):
self._lhsSymbol = lhsSymbol
self.rhs = rhsTerm
self._isDeclaration = True
isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast:
isCast = self._lhsSymbol.func == castFunc
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
self._isDeclaration = False
self._isConst = isConst
......@@ -396,7 +396,7 @@ class SympyAssignment(Node):
def lhs(self, newValue):
self._lhsSymbol = newValue
self._isDeclaration = True
isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
isCast = self._lhsSymbol.func == castFunc
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
self._isDeclaration = False
......
......@@ -39,6 +39,8 @@ def getHeaders(astNode):
headers.update(getHeaders(a))
return headers
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------
......@@ -75,6 +77,7 @@ class PrintNode(CustomCppCode):
# ------------------------------------------- Printer ------------------------------------------------------------------
class CBackend(object):
def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None):
......@@ -204,8 +207,7 @@ class CustomSympyPrinter(CCodePrinter):
return res
def _print_Function(self, expr):
name = str(expr.func).lower()
if name == 'cast':
if expr.func == castFunc:
arg, type = expr.args
return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
else:
......@@ -220,8 +222,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
self.instructionSet = instructionSet
def _print_Function(self, expr):
name = str(expr.func).lower()
if name == 'cast':
if expr.func == castFunc:
arg, dtype = expr.args
if type(dtype) is VectorType:
if type(arg) is ResolvedFieldAccess:
......@@ -291,41 +292,57 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = a_str[0]
for item in a_str[1:]:
result = self.intrinsics['*'].format(result, item)
result = self.instructionSet['*'].format(result, item)
if len(b) > 0:
denominator_str = b_str[0]
for item in b_str[1:]:
denominator_str = self.intrinsics['*'].format(denominator_str, item)
result = self.intrinsics['/'].format(result, denominator_str)
denominator_str = self.instructionSet['*'].format(denominator_str, item)
result = self.instructionSet['/'].format(result, denominator_str)
if insideAdd:
return sign, result
else:
if sign < 0:
return self.intrinsics['*'].format(self._print(S.NegativeOne), result)
return self.instructionSet['*'].format(self._print(S.NegativeOne), result)
else:
return result
# def _print_Piecewise(self, expr):
# if expr.args[-1].cond != True:
# # We need the last conditional to be a True, otherwise the resulting
# # function may not return a result.
# raise ValueError("All Piecewise expressions must contain an "
# "(expr, True) statement to be used as a default "
# "condition. Without one, the generated "
# "expression may not evaluate to anything under "
# "some condition.")
#
# result = self._print(expr.args[-1][0])
# for trueExpr, condition in reversed(expr.args[:-1]):
# result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition))
# return result
#
# def _print_Relational(self, expr):
# return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs)
#
# def _print_Equality(self, expr):
# return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs))
#
def _print_Relational(self, expr):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Relational(expr)
assert self.instructionSet['width'] == exprType.width
return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
def _print_Equality(self, expr):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Equality(expr)
assert self.instructionSet['width'] == exprType.width
return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs))
def _print_Piecewise(self, expr):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Piecewise(expr)
assert self.instructionSet['width'] == exprType.width
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
result = self._print(expr.args[-1][0])
for trueExpr, condition in reversed(expr.args[:-1]):
result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition))
return result
......@@ -66,7 +66,7 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'):
if arg in ('0', '1', '2', '3', '4', '5'):
argString += "{" + arg + "},"
else:
argString += arg
argString += arg + ","
argString = argString[:-1] + ")"
result[intrinsicId] = pre + "_" + name + "_" + suf + argString
......@@ -80,6 +80,7 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'):
result['double'] = "__m%dd" % (bitWidth,)
result['float'] = "__m%d" % (bitWidth,)
result['int'] = "__m%di" % (bitWidth,)
result['bool'] = "__m%dd" % (bitWidth,)
result['headers'] = headers[instructionSet]
return result
......
......@@ -302,7 +302,7 @@ def runCompileStep(command):
subprocess.check_output(command, env=compileEnvironment, stderr=subprocess.STDOUT, shell=shell)
except subprocess.CalledProcessError as e:
print(" ".join(command))
print(e.output)
print(e.output.decode('utf8'))
raise e
......
......@@ -6,7 +6,16 @@ from sympy.core.cache import cacheit
from pystencils.cache import memorycache
from pystencils.utils import allEqual
castFunc = sp.Function("cast")
# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
class castFunc(sp.Function, sp.Rel):
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
class TypedSymbol(sp.Symbol):
......@@ -202,11 +211,16 @@ def getTypeOfExpression(expr):
elif isinstance(expr, sp.Indexed):
typedSymbol = expr.base.label
return typedSymbol.dtype
elif isinstance(expr, sp.boolalg.Boolean):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = createTypeFromString("bool")
vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)]
if vecArgs:
result = VectorType(result, width=vecArgs[0].width)
return result
elif isinstance(expr, sp.Expr):
types = tuple(getTypeOfExpression(a) for a in expr.args)
return collateTypes(types)
elif isinstance(expr, sp.boolalg.Boolean):
return createTypeFromString("bool")
raise NotImplementedError("Could not determine type for " + str(expr))
......@@ -344,6 +358,8 @@ class VectorType(Type):
return self.instructionSet['double']
elif self.baseType == createTypeFromString("float"):
return self.instructionSet['float']
elif self.baseType == createTypeFromString("bool"):
return self.instructionSet['bool']
else:
raise NotImplementedError()
......
......@@ -2,26 +2,16 @@ import sympy as sp
import warnings
from pystencils.transformations import filteredTreeIteration
from pystencils.types import TypedSymbol, VectorType, PointerType, BasicType, getTypeOfExpression, castFunc
from pystencils.types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes
import pystencils.astnodes as ast
from pystencils.utils import allEqual
def asVectorType(resolvedFieldAccess, vectorizationWidth):
"""Returns a new ResolvedFieldAccess that has a vector type"""
dtype = resolvedFieldAccess.typedSymbol.dtype
assert type(dtype) is PointerType
basicType = dtype.baseType
assert type(basicType) is BasicType, "Structs are not supported"
newDtype = VectorType(basicType, vectorizationWidth)
newDtype = PointerType(newDtype, dtype.const, dtype.restrict)
newTypedSymbol = TypedSymbol(resolvedFieldAccess.typedSymbol.name, newDtype)
return ast.ResolvedFieldAccess(newTypedSymbol, resolvedFieldAccess.args[1], resolvedFieldAccess.field,
resolvedFieldAccess.offsets, resolvedFieldAccess.idxCoordinateValues)
def vectorize(astNode, vectorWidth=4):
vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth)
insertVectorCasts(astNode)
def vectorize(astNode, vectorWidth=4):
def vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth=4):
"""
Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if
- loop bounds are constant
......@@ -54,20 +44,31 @@ def insertVectorCasts(astNode):
Inserts necessary casts from scalar values to vector values
"""
def visitExpr(expr):
if expr.func in (sp.Add, sp.Mul):
if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc):
newArgs = [visitExpr(a) for a in expr.args]
argTypes = [getTypeOfExpression(a) for a in newArgs]
if not any(type(t) is VectorType for t in argTypes):
return expr
else:
vectorWidths = [d.width for d in argTypes if type(d) is VectorType]
assert allEqual(vectorWidths), "Incompatible vector type widths"
vectorWidth = vectorWidths[0]
castedArgs = [castFunc(a, VectorType(t, vectorWidth)) if type(t) is not VectorType else a
targetType = collateTypes(argTypes)
castedArgs = [castFunc(a, targetType) if t != targetType else a
for a, t in zip(newArgs, argTypes)]
return expr.func(*castedArgs)
elif expr.func == sp.Piecewise:
raise NotImplementedError()
newResults = [visitExpr(a[0]) for a in expr.args]
newConditions = [visitExpr(a[1]) for a in expr.args]
typesOfResults = [getTypeOfExpression(a) for a in newResults]
typesOfConditions = [getTypeOfExpression(a) for a in newConditions]
resultTargetType = getTypeOfExpression(expr)
castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a
for a, t in zip(newResults, typesOfResults)]
conditionTargetType = collateTypes(typesOfConditions)
castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a
for a, t in zip(newConditions, typesOfConditions)]
return sp.Piecewise(*[(r, c) for r, c in zip(castedResults, castedConditions)])
else:
return expr
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment