Commit 866e9fc0 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes in vectorization to also support float kernels

parent 501b2d7e
......@@ -213,6 +213,7 @@ class CustomSympyPrinter(CCodePrinter):
def __init__(self, constants_as_floats=False):
self._constantsAsFloats = constants_as_floats
super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
......@@ -224,8 +225,6 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Rational(self, expr):
"""Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
res = str(expr.evalf().num)
if self._constantsAsFloats:
res += "f"
return res
def _print_Equality(self, expr):
......@@ -237,12 +236,6 @@ class CustomSympyPrinter(CCodePrinter):
result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
return result.replace("\n", "")
def _print_Float(self, expr):
res = str(expr)
if self._constantsAsFloats:
res += "f"
return res
def _print_Function(self, expr):
function_map = {
bitwise_xor: '^',
......@@ -255,7 +248,10 @@ class CustomSympyPrinter(CCodePrinter):
return expr.to_c(self._print)
if expr.func == cast_func:
arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
if isinstance(arg, sp.Number):
return self._typed_number(arg, data_type)
else:
return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
elif expr.func == modulo_floor:
assert all(get_type_of_expression(e).is_int() for e in expr.args)
return "({dtype})({0} / {1}) * {1}".format(*expr.args, dtype=get_type_of_expression(expr.args[0]))
......@@ -264,6 +260,17 @@ class CustomSympyPrinter(CCodePrinter):
else:
return super(CustomSympyPrinter, self)._print_Function(expr)
def _typed_number(self, number, dtype):
res = self._print(number)
if dtype.is_float:
if dtype == self._float_type:
if '.' not in res:
res += ".0f"
else:
res += "f"
return res
else:
return res
# noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
......
......@@ -20,7 +20,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
'sqrt': 'sqrt[0]',
'makeVec': 'set[0,0,0,0]',
'makeVec': 'set[]',
'makeZero': 'setzero[]',
'loadU': 'loadu[0]',
......@@ -31,6 +31,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
}
headers = {
'avx512': ['<immintrin.h>'],
'avx': ['<immintrin.h>'],
'sse': ['<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
}
......@@ -54,32 +55,37 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
("float", "avx512"): 16,
}
result = {}
result = {
'width': width[(data_type, instruction_set)],
}
pre = prefix[instruction_set]
suf = suffix[data_type]
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
if intrinsic_id == 'makeVec':
arg_string = "({})".format(",".join(["{0}"] * result['width']))
else:
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
result[intrinsic_id] = pre + "_" + name + "_" + suf + arg_string
result['width'] = width[(data_type, instruction_set)]
result['dataTypePrefix'] = {
'double': "_" + pre + 'd',
'float': "_" + pre,
}
bit_width = result['width'] * 64
bit_width = result['width'] * (64 if data_type == 'double' else 32)
result['double'] = "__m%dd" % (bit_width,)
result['float'] = "__m%d" % (bit_width,)
result['int'] = "__m%di" % (bit_width,)
......
......@@ -13,13 +13,13 @@ from pystencils.transformations import cut_loop, filtered_tree_iteration
from pystencils.field import Field
def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx',
def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False):
"""Explicit vectorization using SIMD vectorization via intrinsics.
Args:
kernel_ast: abstract syntax tree (KernelFunction node)
vector_instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are
used. If true, some of the loads are assumed to be from aligned memory addresses.
For example if x is the fastest coordinate, the access to center can be fetched via an
......@@ -42,7 +42,7 @@ def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx
float_size = field_float_dtypes.pop().numpy_dtype.itemsize
assert float_size in (8, 4)
vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
instruction_set=vector_instruction_set)
instruction_set=instruction_set)
vector_width = vector_is['width']
kernel_ast.instruction_set = vector_is
......
......@@ -289,7 +289,10 @@ def get_type_of_expression(expr):
from pystencils.astnodes import ResolvedFieldAccess
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type("int")
if expr == 1 or expr == -1:
return create_type("int16")
else:
return create_type("int")
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type("double")
elif isinstance(expr, ResolvedFieldAccess):
......@@ -316,6 +319,8 @@ def get_type_of_expression(expr):
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
return get_type_of_expression(expr.args[0])
elif isinstance(expr, sp.Expr):
types = tuple(get_type_of_expression(a) for a in expr.args)
return collate_types(types)
......
......@@ -73,7 +73,7 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
add_openmp(ast, num_threads=cpu_openmp)
if cpu_vectorize_info:
if cpu_vectorize_info is True:
vectorize(ast, vector_instruction_set='avx', assume_aligned=False, nontemporal=None)
vectorize(ast, instruction_set='avx', assume_aligned=False, nontemporal=None)
elif isinstance(cpu_vectorize_info, dict):
vectorize(ast, **cpu_vectorize_info)
else:
......
......@@ -205,10 +205,17 @@ class LLVMPrinter(Printer):
node = self._print(conversion.args[0])
to_dtype = get_type_of_expression(conversion)
from_dtype = get_type_of_expression(conversion.args[0])
if from_dtype == to_dtype:
return self._print(conversion.args[0])
# (From, to)
decision = {
(create_composite_type_from_string("int16"),
create_composite_type_from_string("int64")): lambda: ir.Constant(self.integer, node),
(create_composite_type_from_string("int"),
create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
(create_composite_type_from_string("int16"),
create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
(create_composite_type_from_string("double"),
create_composite_type_from_string("int")): functools.partial(self.builder.fptosi, node, self.integer),
(create_composite_type_from_string("double *"),
......
......@@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment
from pystencils.field import Field, FieldType
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
pointer_arithmetic_func, get_type_of_expression, collate_types
pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
from pystencils.slicing import normalize_slice
import pystencils.astnodes as ast
......@@ -716,9 +716,18 @@ class KernelConstraintsCheck:
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
else:
new_args = [self.process_expression(arg) for arg in rhs.args]
elif isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
elif isinstance(rhs, sp.Mul):
new_args = [self.process_expression(arg) if arg not in (-1, 1) else arg for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
else:
if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers
return sp.Pow(self.process_expression(rhs.args[0]), rhs.args[1])
else:
new_args = [self.process_expression(arg) for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
@property
def fields_written(self):
......@@ -800,10 +809,13 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
def insert_casts(node):
"""Checks the types and inserts casts and pointer arithmetic where necessary
"""Checks the types and inserts casts and pointer arithmetic where necessary.
:param node: the head node of the ast
:return: modified ast
Args:
node: the head node of the ast
Returns:
modified AST
"""
def cast(zipped_args_types, target_dtype):
"""
......@@ -839,7 +851,7 @@ def insert_casts(node):
new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
return pointer_arithmetic_func(pointer, new_args)
if isinstance(node, sp.AtomicExpr):
if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
return node
args = []
for arg in node.args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment