From 939241f2804b4886ae89f9fb8cde4cd877b1386e Mon Sep 17 00:00:00 2001 From: markus holzer <markus.holzer@fau.de> Date: Wed, 26 Jan 2022 15:04:19 +0100 Subject: [PATCH] Fixing vectorisation --- pystencils/backends/cbackend.py | 253 +++++++++----------- pystencils/backends/x86_instruction_sets.py | 8 +- pystencils/cpu/vectorization.py | 38 ++- pystencils/fast_approximation.py | 1 + pystencils/typing/leaf_typing.py | 7 + pystencils/typing/utilities.py | 24 +- pystencils/utils.py | 15 +- pystencils_tests/test_vectorization.py | 41 +++- 8 files changed, 211 insertions(+), 176 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 06eda124e..b631f0a8a 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -13,11 +13,12 @@ from sympy.functions.elementary.hyperbolic import HyperbolicFunction from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize -from pystencils.data_types import ( - PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, - reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol) +from pystencils.typing import ( + PointerType, VectorType, CastFunc, create_type, get_type_of_expression, + ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol) from pystencils.enums import Backend from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt +from pystencils.functions import DivFunc, AddressOf from pystencils.integer_functions import ( bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) @@ -32,8 +33,6 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy HEADER_REGEX = re.compile(r'^[<"].*[">]$') -KERNCRAFT_NO_TERNARY_MODE = False - def generate_c(ast_node: Node, signature_only: bool = False, @@ -221,7 +220,7 @@ class CBackend: return getattr(self, method_name)(node) raise NotImplementedError(f"{self.__class__.__name__} does not support node of type {node.__class__.__name__}") - def _print_Type(self, node): + def _print_AbstractType(self, node): return str(node) def _print_KernelFunction(self, node): @@ -276,9 +275,9 @@ class CBackend: self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) else: - lhs_type = get_type_of_expression(node.lhs) + lhs_type = get_type_of_expression(node.lhs) # TOOD: this should have been typed printed_mask = "" - if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): + if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc): arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args instr = 'storeU' if aligned: @@ -291,20 +290,20 @@ class CBackend: self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs), '{1}', '{2}', **self._kwargs), **self._kwargs) printed_mask = self.sympy_printer.doprint(mask) - if data_type.base_type.base_name == 'double': + if data_type.base_type.c_name == 'double': if self._vector_instruction_set['double'] == '__m256d': printed_mask = f"_mm256_castpd_si256({printed_mask})" elif self._vector_instruction_set['double'] == '__m128d': printed_mask = f"_mm_castpd_si128({printed_mask})" - elif data_type.base_type.base_name == 'float': + elif data_type.base_type.c_name == 'float': if self._vector_instruction_set['float'] == '__m256': printed_mask = f"_mm256_castps_si256({printed_mask})" elif self._vector_instruction_set['float'] == '__m128': printed_mask = f"_mm_castps_si128({printed_mask})" - rhs_type = get_type_of_expression(node.rhs) + rhs_type = get_type_of_expression(node.rhs) # TOOD: vector only??? if type(rhs_type) is not VectorType: - rhs = cast_func(node.rhs, VectorType(rhs_type)) + rhs = CastFunc(node.rhs, VectorType(rhs_type)) else: rhs = node.rhs @@ -324,7 +323,7 @@ class CBackend: if stride == 1: offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1}) size = sp.Mul(*node.lhs.args[0].field.spatial_shape) - element_size = 8 if data_type.base_type.base_name == 'double' else 4 + element_size = 8 if data_type.base_type.c_name == 'double' else 4 size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}" pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \ self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n' @@ -418,7 +417,7 @@ class CBackend: return self._print_Block(node.true_block) elif type(node.condition_expr) is BooleanFalse: return self._print_Block(node.false_block) - cond_type = get_type_of_expression(node.condition_expr) + cond_type = get_type_of_expression(node.condition_expr) # TODO: Could be vector or bool? if isinstance(cond_type, VectorType): raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") condition_expr = self.sympy_printer.doprint(node.condition_expr) @@ -438,19 +437,15 @@ class CustomSympyPrinter(CCodePrinter): def __init__(self): super(CustomSympyPrinter, self).__init__() - self._float_type = create_type("float32") def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" if not expr.free_symbols: - return self._typed_number(expr.evalf(17), get_type_of_expression(expr.base)) + raise NotImplementedError("This pow should be simplified already?") + # return self._typed_number(expr.evalf(), get_type_of_expression(expr.base)) + return super(CustomSympyPrinter, self)._print_Pow(expr) - if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: - return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" - elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: - return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" - else: - return super(CustomSympyPrinter, self)._print_Pow(expr) + # TODO don't print ones in sp.Mul def _print_Rational(self, expr): """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" @@ -485,15 +480,15 @@ class CustomSympyPrinter(CCodePrinter): } if hasattr(expr, 'to_c'): return expr.to_c(self._print) - if isinstance(expr, reinterpret_cast_func): + if isinstance(expr, ReinterpretCastFunc): arg, data_type = expr.args return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))" - elif isinstance(expr, address_of): + elif isinstance(expr, AddressOf): assert len(expr.args) == 1, "address_of must only have one argument" return f"&({self._print(expr.args[0])})" - elif isinstance(expr, cast_func): + elif isinstance(expr, CastFunc): arg, data_type = expr.args - if isinstance(arg, sp.Number) and arg.is_finite: + if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): return self._typed_number(arg, data_type) elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \ and data_type == BasicType('float32'): @@ -519,8 +514,6 @@ class CustomSympyPrinter(CCodePrinter): return f"({self._print(1 / sp.sqrt(expr.args[0]))})" elif isinstance(expr, sp.Abs): return f"abs({self._print(expr.args[0])})" - elif isinstance(expr, sp.Max): - return self._print(expr) elif isinstance(expr, sp.Mod): if expr.args[0].is_integer and expr.args[1].is_integer: return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})" @@ -532,6 +525,8 @@ class CustomSympyPrinter(CCodePrinter): return f"(1 << ({self._print(expr.args[0])}))" elif expr.func == int_div: return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))" + elif expr.func == DivFunc: + return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))' else: name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__ arg_str = ', '.join(self._print(a) for a in expr.args) @@ -554,52 +549,6 @@ class CustomSympyPrinter(CCodePrinter): else: return res - def _print_Sum(self, expr): - template = """[&]() {{ - {dtype} sum = ({dtype}) 0; - for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{ - sum += {expr}; - }} - return sum; -}}()""" - var = expr.limits[0][0] - start = expr.limits[0][1] - end = expr.limits[0][2] - code = template.format( - dtype=get_type_of_expression(expr.args[0]), - iterator_dtype='int', - var=self._print(var), - start=self._print(start), - end=self._print(end), - expr=self._print(expr.function), - increment=str(1), - condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' - ) - return code - - def _print_Product(self, expr): - template = """[&]() {{ - {dtype} product = ({dtype}) 1; - for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{ - product *= {expr}; - }} - return product; -}}()""" - var = expr.limits[0][0] - start = expr.limits[0][1] - end = expr.limits[0][2] - code = template.format( - dtype=get_type_of_expression(expr.args[0]), - iterator_dtype='int', - var=self._print(var), - start=self._print(start), - end=self._print(end), - expr=self._print(expr.function), - increment=str(1), - condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' - ) - return code - def _print_ConditionalFieldAccess(self, node): return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True))) @@ -623,27 +572,6 @@ class CustomSympyPrinter(CCodePrinter): return f"(({a} < {b}) ? {a} : {b})" return inner_print_min(expr.args) - def _print_re(self, expr): - return f"real({self._print(expr.args[0])})" - - def _print_im(self, expr): - return f"imag({self._print(expr.args[0])})" - - def _print_ImaginaryUnit(self, expr): - return "complex<double>{0,1}" - - def _print_TypedImaginaryUnit(self, expr): - if expr.dtype.numpy_dtype == np.complex64: - return "complex<float>{0,1}" - elif expr.dtype.numpy_dtype == np.complex128: - return "complex<double>{0,1}" - else: - raise NotImplementedError( - "only complex64 and complex128 supported") - - def _print_Complex(self, expr): - return self._typed_number(expr, np.complex64) - # noinspection PyPep8Naming class VectorizedCustomSympyPrinter(CustomSympyPrinter): @@ -662,40 +590,91 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return None def _print_Abs(self, expr): - if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access): + if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess): return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs) return super()._print_Abs(expr) + def _typed_vectorized_number(self, expr, data_type): + basic_data_type = data_type.base_type + number = self._typed_number(expr, basic_data_type) + instruction = 'makeVecConst' + if basic_data_type.is_bool(): + instruction = 'makeVecConstBool' + # TODO is int, or sint, or uint? + elif basic_data_type.is_int(): + instruction = 'makeVecConstInt' + return self.instruction_set[instruction].format(number, **self._kwargs) + + def _typed_vectorized_symbol(self, expr, data_type): + if not isinstance(expr, TypedSymbol): + raise ValueError(f'{expr} is not a TypeSymbol. It is {expr.type=}') + basic_data_type = data_type.base_type + symbol = self._print(expr) + if basic_data_type != expr.dtype: + symbol = f'(({basic_data_type.data_type})({symbol}))' + + instruction = 'makeVecConst' + if basic_data_type.is_bool(): + instruction = 'makeVecConstBool' + # TODO is int, or sint, or uint? + elif basic_data_type.is_int(): + instruction = 'makeVecConstInt' + return self.instruction_set[instruction].format(symbol, **self._kwargs) + + def _print_CastFunc(self, expr): + arg, data_type = expr.args + if type(data_type) is VectorType: + # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func + assert not isinstance(arg, VectorMemoryAccess) # TODO Is this true for our new type system? + if isinstance(arg, sp.Tuple): + is_boolean = get_type_of_expression(arg[0]) == create_type("bool") + is_integer = get_type_of_expression(arg[0]) == create_type("int") + printed_args = [self._print(a) for a in arg] + instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec' + if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set: + increments = np.array(arg)[1:] - np.array(arg)[:-1] + if len(set(increments)) == 1: + return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0], + **self._kwargs) + return self.instruction_set[instruction].format(*printed_args, **self._kwargs) + else: + if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): + return self._typed_vectorized_number(arg, data_type) + elif isinstance(arg, TypedSymbol): + return self._typed_vectorized_symbol(arg, data_type) + elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \ + and data_type == BasicType('float32'): + raise NotImplementedError('Vectorizer is not tested for trigonometric functions yet') + # known = self.known_functions[arg.__class__.__name__.lower()] + # code = self._print(arg) + # return code.replace(known, f"{known}f") + elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'): + raise NotImplementedError('Vectorizer cannot print casted aka. not double pow') + # known = ['sqrt', 'cbrt', 'pow'] + # code = self._print(arg) + # for k in known: + # if k in code: + # return code.replace(k, f'{k}f') + # raise ValueError(f"{code} doesn't give {known=} function back.") + else: + raise NotImplementedError('Vectorizer cannot cast between different datatypes') + # to_type = self.instruction_set['suffix'][data_type.base_type.c_name] + # from_type = self.instruction_set['suffix'][get_type_of_expression(arg).base_type.c_name] + # return self.instruction_set['cast'].format(from_type, to_type, self._print(arg)) + else: + return self._scalarFallback('_print_Function', expr) + # raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.') + def _print_Function(self, expr): - if isinstance(expr, vector_memory_access): + if isinstance(expr, VectorMemoryAccess): arg, data_type, aligned, _, mask, stride = expr.args if stride != 1: return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs) instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format(f"& {self._print(arg)}", **self._kwargs) - elif isinstance(expr, cast_func): - arg, data_type = expr.args - if type(data_type) is VectorType: - # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func - assert not isinstance(arg, vector_memory_access) - if isinstance(arg, sp.Tuple): - is_boolean = get_type_of_expression(arg[0]) == create_type("bool") - is_integer = get_type_of_expression(arg[0]) == create_type("int") - printed_args = [self._print(a) for a in arg] - instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec' - if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set: - increments = np.array(arg)[1:] - np.array(arg)[:-1] - if len(set(increments)) == 1: - return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0], - **self._kwargs) - return self.instruction_set[instruction].format(*printed_args, **self._kwargs) - else: - is_boolean = get_type_of_expression(arg) == create_type("bool") - is_integer = get_type_of_expression(arg) == create_type("int") or \ - (isinstance(arg, TypedSymbol) and not isinstance(arg.dtype, VectorType) and arg.dtype.is_int()) - instruction = 'makeVecConstBool' if is_boolean else \ - 'makeVecConstInt' if is_integer else 'makeVecConst' - return self.instruction_set[instruction].format(self._print(arg), **self._kwargs) + elif expr.func == DivFunc: + return self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend), + **self._kwargs) elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) if not result: @@ -761,12 +740,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization suffix = "" - if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) + if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]): - dtype = set([e.dtype for e in args if type(e) is cast_func]) + dtype = set([e.dtype for e in args if type(e) is CastFunc]) assert len(dtype) == 1 dtype = dtype.pop() - args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e + args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e for e in args] suffix = "int" @@ -798,19 +777,24 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs) - if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: - return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" - elif expr.exp == -1: + if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number: + exp = expr.exp.args[0] + else: + exp = expr.exp + + if exp.is_integer and exp.is_number and 0 < exp < 8: + return "(" + self._print(sp.Mul(*[expr.base] * exp, evaluate=False)) + ")" + elif exp == -1: one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs) return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs) - elif expr.exp == 0.5: + elif exp == 0.5: return self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs) - elif expr.exp == -0.5: + elif exp == -0.5: root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs) return self.instruction_set['/'].format(one, root, **self._kwargs) - elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: + elif exp.is_integer and exp.is_number and - 8 < exp < 0: return self.instruction_set['/'].format(one, - self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)), + self._print(sp.Mul(*[expr.base] * (-exp), evaluate=False)), **self._kwargs) else: raise ValueError("Generic exponential not supported: " + str(expr)) @@ -894,12 +878,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._print(expr.args[-1][0]) for true_expr, condition in reversed(expr.args[:-1]): - if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): - if not KERNCRAFT_NO_TERNARY_MODE: - result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), - result, **self._kwargs) - else: - print("Warning - skipping ternary op") + if isinstance(condition, CastFunc) and get_type_of_expression(condition.args[0]) == create_type("bool"): + result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), + result, **self._kwargs) else: # noinspection SpellCheckingInspection result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition), diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py index f72b48266..7653c7c69 100644 --- a/pystencils/backends/x86_instruction_sets.py +++ b/pystencils/backends/x86_instruction_sets.py @@ -51,7 +51,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): 'makeVecConstBool': 'set[]', 'makeVecInt': 'set[]', 'makeVecConstInt': 'set[]', - + 'loadU': 'loadu[0]', 'loadA': 'load[0]', 'storeU': 'storeu[0,1]', @@ -93,7 +93,6 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ("float", "avx512"): 16, ("int", "avx512"): 16, } - result = { 'width': width[(data_type, instruction_set)], 'intwidth': width[('int', instruction_set)], @@ -114,11 +113,6 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else '' result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string - result['dataTypePrefix'] = { - 'double': "_" + pre + 'd', - 'float': "_" + pre, - } - bit_width = result['width'] * (64 if data_type == 'double' else 32) result['double'] = f"__m{bit_width}d" result['float'] = f"__m{bit_width}" diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index a161d5879..4069a2485 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -3,13 +3,14 @@ from typing import Container, Union import numpy as np import sympy as sp -from sympy.logic.boolalg import BooleanFunction +from sympy.logic.boolalg import BooleanFunction, BooleanAtom import pystencils.astnodes as ast from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set -from pystencils.typing import ( - PointerType, TypedSymbol, VectorType, CastFunc, collate_types, get_type_of_expression, VectorMemoryAccess) +from pystencils.typing import ( BasicType, PointerType, TypedSymbol, VectorType, CastFunc, collate_types, + get_type_of_expression, VectorMemoryAccess) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt +from pystencils.functions import DivFunc from pystencils.field import Field from pystencils.integer_functions import modulo_ceil, modulo_floor from pystencils.sympyextensions import fast_subs @@ -121,6 +122,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', "to differently typed floating point fields") float_size = field_float_dtypes.pop().numpy_dtype.itemsize assert float_size in (8, 4) + # TODO: future work allow mixed precision fields default_float_type = 'double' if float_size == 8 else 'float' vector_is = get_vector_instruction_set(default_float_type, instruction_set=instruction_set) vector_width = vector_is['width'] @@ -129,12 +131,14 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', strided = 'storeS' in vector_is and 'loadS' in vector_is keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU'] vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, - strided, keep_loop_stop, assume_sufficient_line_padding) - insert_vector_casts(kernel_ast, default_float_type) + strided, keep_loop_stop, assume_sufficient_line_padding, + default_float_type) + # is in vectorize_inner_loops_and_adapt_load_stores.. insert_vector_casts(kernel_ast, default_float_type) def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, - strided, keep_loop_stop, assume_sufficient_line_padding): + strided, keep_loop_stop, assume_sufficient_line_padding, + default_float_type): """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) inner_loops = [n for n in all_loops if n.is_innermost_loop] @@ -157,6 +161,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a if len(loop_nodes) == 0: continue loop_node = loop_nodes[0] + # TODO loop_node is the vectorized one # Find all array accesses (indexed) that depend on the loop counter as offset loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over) @@ -214,6 +219,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)}) rng._symbols_defined = set(new_result_symbols) fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase)) + insert_vector_casts(loop_node, default_float_type) def mask_conditionals(loop_body): @@ -245,13 +251,18 @@ def mask_conditionals(loop_body): def insert_vector_casts(ast_node, default_float_type='double'): """Inserts necessary casts from scalar values to vector values.""" - handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all) + handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc, + sp.UnevaluatedExpr) - def visit_expr(expr, default_type='double'): + def visit_expr(expr, default_type='double'): # TODO get rid of default_type if isinstance(expr, VectorMemoryAccess): return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:]) elif isinstance(expr, CastFunc): - return expr + cast_type = expr.args[1] + arg = visit_expr(expr.args[0]) + assert cast_type in [BasicType('float32'), BasicType('float64')],\ + f'Vectorization cannot vectorize type {cast_type}' + return expr.func(arg, VectorType(cast_type)) elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: new_arg = visit_expr(expr.args[0], default_type) base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \ @@ -307,14 +318,21 @@ def insert_vector_casts(ast_node, default_float_type='double'): for a, t in zip(new_conditions, types_of_conditions)] return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) - else: + elif isinstance(expr, (sp.Number, TypedSymbol, BooleanAtom)): return expr + else: + # TODO better error string + raise NotImplementedError(f'Should I raise or should I return now? {type(expr)} {expr}') def visit_node(node, substitution_dict, default_type='double'): substitution_dict = substitution_dict.copy() for arg in node.args: if isinstance(arg, ast.SympyAssignment): + # TODO only if not remainder loop (? if no VectorAccess then remainder loop) assignment = arg + # If there is a remainder loop we do not vectorise it, thus lhs will indicate this + # if isinstance(assignment.lhs, ast.ResolvedFieldAccess): + # continue subs_expr = fast_subs(assignment.rhs, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) assignment.rhs = visit_expr(subs_expr, default_type) diff --git a/pystencils/fast_approximation.py b/pystencils/fast_approximation.py index 9eee41a96..65f85a71a 100644 --- a/pystencils/fast_approximation.py +++ b/pystencils/fast_approximation.py @@ -9,6 +9,7 @@ from pystencils.assignment import Assignment # noinspection PyPep8Naming class fast_division(sp.Function): + # TODO how is this fast? The printer prints a normal division??? nargs = (2,) diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index 20f92eabd..aa23de65d 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -21,6 +21,7 @@ from pystencils.typing.types import BasicType, create_type, PointerType from pystencils.typing.utilities import get_type_of_expression, collate_types from pystencils.typing.cast_functions import CastFunc, BooleanCastFunc from pystencils.typing.typed_sympy import TypedSymbol +from pystencils.fast_approximation import fast_sqrt, fast_division, fast_inv_sqrt from pystencils.utils import ContextVar @@ -215,6 +216,12 @@ class TypeAdder: return new_func, collated_type else: return CastFunc(new_func, collated_type), collated_type + elif isinstance(expr, (fast_sqrt, fast_division, fast_inv_sqrt)): + args_types = [self.figure_out_type(arg) for arg in expr.args] + collated_type = BasicType('float32') + new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] + new_func = expr.func(*new_args) if new_args else expr + return CastFunc(new_func, collated_type), collated_type elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)): args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index 1cc62c168..4f67435bb 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -12,6 +12,7 @@ from pystencils.cache import memorycache_if_hashable from pystencils.typing.types import BasicType, VectorType, PointerType, create_type from pystencils.typing.cast_functions import CastFunc, PointerArithmeticFunc from pystencils.typing.typed_sympy import TypedSymbol +from pystencils.utils import all_equal def typed_symbols(names, dtype, *args): @@ -33,14 +34,6 @@ def get_base_type(data_type): return data_type -def peel_off_type(dtype, type_to_peel_off): - # TODO: WTF is this??? DOCS!!! - # TODO: used only once.... can be a lambda there - while type(dtype) is type_to_peel_off: - dtype = dtype.base_type - return dtype - - ############################# This is basically our type system ######################################################## def result_type(*args: np.dtype): @@ -83,18 +76,25 @@ def collate_types(types: Sequence[Union[BasicType, VectorType]]): # # peel of vector types, if at least one vector type occurred the result will also be the vector type vector_type = [t for t in types if isinstance(t, VectorType)] - # if not all_equal(t.width for t in vector_type): - # raise ValueError("Collation failed because of vector types with different width") + if not all_equal(t.width for t in vector_type): + raise ValueError("Collation failed because of vector types with different width") + + # TODO: check if this is needed + # def peel_off_type(dtype, type_to_peel_off): + # while type(dtype) is type_to_peel_off: + # dtype = dtype.base_type + # return dtype # types = [peel_off_type(t, VectorType) for t in types] + types = [t.base_type if isinstance(t, VectorType) else t for t in types] + # now we should have a list of basic types - struct types are not yet supported assert all(type(t) is BasicType for t in types) result_numpy_type = result_type(*(t.numpy_dtype for t in types)) result = BasicType(result_numpy_type) if vector_type: - raise NotImplementedError("Vector type not implemented at the moment") - # result = VectorType(result, vector_type[0].width) + result = VectorType(result, vector_type[0].width) return result diff --git a/pystencils/utils.py b/pystencils/utils.py index dc8d35ee6..22d61d0ba 100644 --- a/pystencils/utils.py +++ b/pystencils/utils.py @@ -1,5 +1,6 @@ import os import itertools +from itertools import groupby from collections import Counter from contextlib import contextmanager from tempfile import NamedTemporaryFile @@ -23,13 +24,13 @@ class DotDict(dict): self[key] = value -def all_equal(iterator): - iterator = iter(iterator) - try: - first = next(iterator) - except StopIteration: - return True - return all(first == rest for rest in iterator) +def all_equal(iterable): + """ + Returns ``True`` if all the elements are equal to each other. + Copied from: more-itertools 8.12.0 + """ + g = groupby(iterable) + return next(g, True) and not next(g, False) def recursive_dict_update(d, u): diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index 478022d32..55070e547 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -1,10 +1,12 @@ import numpy as np +import pytest + import pystencils.config import sympy as sp import pystencils as ps -from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets +from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set from pystencils.cpu.vectorization import vectorize from pystencils.fast_approximation import insert_fast_sqrts, insert_fast_divisions from pystencils.enums import Target @@ -13,10 +15,22 @@ from pystencils.transformations import replace_inner_stride_with_one supported_instruction_sets = get_supported_instruction_sets() if supported_instruction_sets: instruction_set = supported_instruction_sets[-1] + instructions = get_vector_instruction_set(instruction_set=instruction_set) else: instruction_set = None +# CI: +# FAILED pystencils_tests/test_vectorization.py::test_vectorised_pow - NotImple... +# FAILED pystencils_tests/test_vectorization.py::test_inplace_update - NotImple... +# FAILED pystencils_tests/test_vectorization.py::test_vectorised_fast_approximations +# test_issue40 + +# Jan: +# test_vectorised_pow +# test_issue40 + +# TODO: Skip tests if no instruction set is available and check all codes if they are really vectorised ! def test_vector_type_propagation(instruction_set=instruction_set): a, b, c, d, e = sp.symbols("a b c d e") arr = np.ones((2 ** 2 + 2, 2 ** 3 + 2)) @@ -30,6 +44,8 @@ def test_vector_type_propagation(instruction_set=instruction_set): ast = ps.create_kernel(update_rule) vectorize(ast, instruction_set=instruction_set) + # ps.show_code(ast) + func = ast.compile() dst = np.zeros_like(arr) func(g=dst, f=arr) @@ -64,6 +80,8 @@ def test_aligned_and_nt_stores(instruction_set=instruction_set, openmp=False): assert ast.instruction_set[instruction].split('{')[0] in ps.get_code_str(ast) kernel = ast.compile() + # ps.show_code(ast) + dh.run_kernel(kernel) np.testing.assert_equal(np.sum(dh.cpu_arrays['f']), np.prod(domain_size)) @@ -114,6 +132,10 @@ def test_vectorization_fixed_size(instruction_set=instruction_set): ast = ps.create_kernel(update_rule) vectorize(ast, instruction_set=instruction_set) + code = ps.get_code_str(ast) + add_instruction = instructions["+"][:instructions["+"].find("(")] + assert add_instruction in code + print(code) func = ast.compile() dst = np.zeros_like(arr) @@ -167,7 +189,9 @@ def test_piecewise2(instruction_set=instruction_set): g[0, 0] @= s.result ast = ps.create_kernel(test_kernel) + # ps.show_code(ast) vectorize(ast, instruction_set=instruction_set) + # ps.show_code(ast) func = ast.compile() func(f=arr, g=arr) np.testing.assert_equal(arr, np.ones_like(arr)) @@ -183,7 +207,9 @@ def test_piecewise3(instruction_set=instruction_set): g[0, 0] @= 1.0 / (s.b + s.k) if f[0, 0] > 0.0 else 1.0 ast = ps.create_kernel(test_kernel) + ps.show_code(ast) vectorize(ast, instruction_set=instruction_set) + ps.show_code(ast) ast.compile() @@ -262,6 +288,7 @@ def test_vectorised_pow(instruction_set=instruction_set): def test_vectorised_fast_approximations(instruction_set=instruction_set): + # fast_approximations are a gpu thing arr = np.zeros((24, 24)) f, g = ps.fields(f=arr, g=arr) @@ -269,18 +296,24 @@ def test_vectorised_fast_approximations(instruction_set=instruction_set): assignment = ps.Assignment(g[0, 0], insert_fast_sqrts(expr)) ast = ps.create_kernel(assignment) vectorize(ast, instruction_set=instruction_set) - ast.compile() + + with pytest.raises(Exception): + ast.compile() expr = f[0, 0] / f[1, 0] assignment = ps.Assignment(g[0, 0], insert_fast_divisions(expr)) ast = ps.create_kernel(assignment) vectorize(ast, instruction_set=instruction_set) - ast.compile() + + with pytest.raises(Exception): + ast.compile() assignment = ps.Assignment(sp.Symbol("tmp"), 3 / sp.sqrt(f[0, 0] + f[1, 0])) ast = ps.create_kernel(insert_fast_sqrts(assignment)) vectorize(ast, instruction_set=instruction_set) - ast.compile() + + with pytest.raises(Exception): + ast.compile() def test_issue40(*_): -- GitLab