From 866e9fc0019e733191c37533ae46ae0f3d366c29 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Mon, 14 May 2018 15:51:38 +0200 Subject: [PATCH] Fixes in vectorization to also support float kernels --- backends/cbackend.py | 25 +++++++++++++-------- backends/simd_instruction_sets.py | 36 ++++++++++++++++++------------- cpu/vectorization.py | 6 +++--- data_types.py | 7 +++++- kernelcreation.py | 2 +- llvm/llvm.py | 7 ++++++ transformations.py | 26 ++++++++++++++++------ 7 files changed, 73 insertions(+), 36 deletions(-) diff --git a/backends/cbackend.py b/backends/cbackend.py index 188679d47..76eb3ee74 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -213,6 +213,7 @@ class CustomSympyPrinter(CCodePrinter): def __init__(self, constants_as_floats=False): self._constantsAsFloats = constants_as_floats super(CustomSympyPrinter, self).__init__() + self._float_type = create_type("float32") def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" @@ -224,8 +225,6 @@ class CustomSympyPrinter(CCodePrinter): def _print_Rational(self, expr): """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" res = str(expr.evalf().num) - if self._constantsAsFloats: - res += "f" return res def _print_Equality(self, expr): @@ -237,12 +236,6 @@ class CustomSympyPrinter(CCodePrinter): result = super(CustomSympyPrinter, self)._print_Piecewise(expr) return result.replace("\n", "") - def _print_Float(self, expr): - res = str(expr) - if self._constantsAsFloats: - res += "f" - return res - def _print_Function(self, expr): function_map = { bitwise_xor: '^', @@ -255,7 +248,10 @@ class CustomSympyPrinter(CCodePrinter): return expr.to_c(self._print) if expr.func == cast_func: arg, data_type = expr.args - return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg)) + if isinstance(arg, sp.Number): + return self._typed_number(arg, data_type) + else: + return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg)) elif expr.func == modulo_floor: assert all(get_type_of_expression(e).is_int() for e in expr.args) return "({dtype})({0} / {1}) * {1}".format(*expr.args, dtype=get_type_of_expression(expr.args[0])) @@ -264,6 +260,17 @@ class CustomSympyPrinter(CCodePrinter): else: return super(CustomSympyPrinter, self)._print_Function(expr) + def _typed_number(self, number, dtype): + res = self._print(number) + if dtype.is_float: + if dtype == self._float_type: + if '.' not in res: + res += ".0f" + else: + res += "f" + return res + else: + return res # noinspection PyPep8Naming class VectorizedCustomSympyPrinter(CustomSympyPrinter): diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py index 518e6a59a..28e1cc8b2 100644 --- a/backends/simd_instruction_sets.py +++ b/backends/simd_instruction_sets.py @@ -20,7 +20,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): 'sqrt': 'sqrt[0]', - 'makeVec': 'set[0,0,0,0]', + 'makeVec': 'set[]', 'makeZero': 'setzero[]', 'loadU': 'loadu[0]', @@ -31,6 +31,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): } headers = { + 'avx512': ['<immintrin.h>'], 'avx': ['<immintrin.h>'], 'sse': ['<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>'] } @@ -54,32 +55,37 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ("float", "avx512"): 16, } - result = {} + result = { + 'width': width[(data_type, instruction_set)], + } pre = prefix[instruction_set] suf = suffix[data_type] for intrinsic_id, function_shortcut in base_names.items(): function_shortcut = function_shortcut.strip() name = function_shortcut[:function_shortcut.index('[')] - args = function_shortcut[function_shortcut.index('[') + 1: -1] - arg_string = "(" - for arg in args.split(","): - arg = arg.strip() - if not arg: - continue - if arg in ('0', '1', '2', '3', '4', '5'): - arg_string += "{" + arg + "}," - else: - arg_string += arg + "," - arg_string = arg_string[:-1] + ")" + + if intrinsic_id == 'makeVec': + arg_string = "({})".format(",".join(["{0}"] * result['width'])) + else: + args = function_shortcut[function_shortcut.index('[') + 1: -1] + arg_string = "(" + for arg in args.split(","): + arg = arg.strip() + if not arg: + continue + if arg in ('0', '1', '2', '3', '4', '5'): + arg_string += "{" + arg + "}," + else: + arg_string += arg + "," + arg_string = arg_string[:-1] + ")" result[intrinsic_id] = pre + "_" + name + "_" + suf + arg_string - result['width'] = width[(data_type, instruction_set)] result['dataTypePrefix'] = { 'double': "_" + pre + 'd', 'float': "_" + pre, } - bit_width = result['width'] * 64 + bit_width = result['width'] * (64 if data_type == 'double' else 32) result['double'] = "__m%dd" % (bit_width,) result['float'] = "__m%d" % (bit_width,) result['int'] = "__m%di" % (bit_width,) diff --git a/cpu/vectorization.py b/cpu/vectorization.py index 3745bb31c..a8542d0af 100644 --- a/cpu/vectorization.py +++ b/cpu/vectorization.py @@ -13,13 +13,13 @@ from pystencils.transformations import cut_loop, filtered_tree_iteration from pystencils.field import Field -def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx', +def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False): """Explicit vectorization using SIMD vectorization via intrinsics. Args: kernel_ast: abstract syntax tree (KernelFunction node) - vector_instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512') + instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512') assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are used. If true, some of the loads are assumed to be from aligned memory addresses. For example if x is the fastest coordinate, the access to center can be fetched via an @@ -42,7 +42,7 @@ def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx float_size = field_float_dtypes.pop().numpy_dtype.itemsize assert float_size in (8, 4) vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float', - instruction_set=vector_instruction_set) + instruction_set=instruction_set) vector_width = vector_is['width'] kernel_ast.instruction_set = vector_is diff --git a/data_types.py b/data_types.py index 318a24752..d3dad765d 100644 --- a/data_types.py +++ b/data_types.py @@ -289,7 +289,10 @@ def get_type_of_expression(expr): from pystencils.astnodes import ResolvedFieldAccess expr = sp.sympify(expr) if isinstance(expr, sp.Integer): - return create_type("int") + if expr == 1 or expr == -1: + return create_type("int16") + else: + return create_type("int") elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): return create_type("double") elif isinstance(expr, ResolvedFieldAccess): @@ -316,6 +319,8 @@ def get_type_of_expression(expr): if vec_args: result = VectorType(result, width=vec_args[0].width) return result + elif isinstance(expr, sp.Pow): + return get_type_of_expression(expr.args[0]) elif isinstance(expr, sp.Expr): types = tuple(get_type_of_expression(a) for a in expr.args) return collate_types(types) diff --git a/kernelcreation.py b/kernelcreation.py index f79865f2e..0a948f195 100644 --- a/kernelcreation.py +++ b/kernelcreation.py @@ -73,7 +73,7 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice add_openmp(ast, num_threads=cpu_openmp) if cpu_vectorize_info: if cpu_vectorize_info is True: - vectorize(ast, vector_instruction_set='avx', assume_aligned=False, nontemporal=None) + vectorize(ast, instruction_set='avx', assume_aligned=False, nontemporal=None) elif isinstance(cpu_vectorize_info, dict): vectorize(ast, **cpu_vectorize_info) else: diff --git a/llvm/llvm.py b/llvm/llvm.py index b6a0f5895..1b165a2f9 100644 --- a/llvm/llvm.py +++ b/llvm/llvm.py @@ -205,10 +205,17 @@ class LLVMPrinter(Printer): node = self._print(conversion.args[0]) to_dtype = get_type_of_expression(conversion) from_dtype = get_type_of_expression(conversion.args[0]) + if from_dtype == to_dtype: + return self._print(conversion.args[0]) + # (From, to) decision = { + (create_composite_type_from_string("int16"), + create_composite_type_from_string("int64")): lambda: ir.Constant(self.integer, node), (create_composite_type_from_string("int"), create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type), + (create_composite_type_from_string("int16"), + create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type), (create_composite_type_from_string("double"), create_composite_type_from_string("int")): functools.partial(self.builder.fptosi, node, self.integer), (create_composite_type_from_string("double *"), diff --git a/transformations.py b/transformations.py index b2c6f5e27..ab7d98ab5 100644 --- a/transformations.py +++ b/transformations.py @@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase from pystencils.assignment import Assignment from pystencils.field import Field, FieldType from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \ - pointer_arithmetic_func, get_type_of_expression, collate_types + pointer_arithmetic_func, get_type_of_expression, collate_types, create_type from pystencils.slicing import normalize_slice import pystencils.astnodes as ast @@ -716,9 +716,18 @@ class KernelConstraintsCheck: return rhs elif isinstance(rhs, sp.Symbol): return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name]) - else: - new_args = [self.process_expression(arg) for arg in rhs.args] + elif isinstance(rhs, sp.Number): + return cast_func(rhs, create_type(self._type_for_symbol['_constant'])) + elif isinstance(rhs, sp.Mul): + new_args = [self.process_expression(arg) if arg not in (-1, 1) else arg for arg in rhs.args] return rhs.func(*new_args) if new_args else rhs + else: + if isinstance(rhs, sp.Pow): + # don't process exponents -> they should remain integers + return sp.Pow(self.process_expression(rhs.args[0]), rhs.args[1]) + else: + new_args = [self.process_expression(arg) for arg in rhs.args] + return rhs.func(*new_args) if new_args else rhs @property def fields_written(self): @@ -800,10 +809,13 @@ def add_types(eqs, type_for_symbol, check_independence_condition): def insert_casts(node): - """Checks the types and inserts casts and pointer arithmetic where necessary + """Checks the types and inserts casts and pointer arithmetic where necessary. - :param node: the head node of the ast - :return: modified ast + Args: + node: the head node of the ast + + Returns: + modified AST """ def cast(zipped_args_types, target_dtype): """ @@ -839,7 +851,7 @@ def insert_casts(node): new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args return pointer_arithmetic_func(pointer, new_args) - if isinstance(node, sp.AtomicExpr): + if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func): return node args = [] for arg in node.args: -- GitLab