diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index ba7426a143f243f4a3991d575443e3b6b247f567..271bf23226fd6b5abf29bd22d8d0013d7918d22a 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -192,7 +192,9 @@ class CBackend: def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): if sympy_printer is None: if vector_instruction_set is not None: - self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) + self.vector_sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) + self.scalar_sympy_printer = CustomSympyPrinter() + self.sympy_printer = self.vector_sympy_printer else: self.sympy_printer = CustomSympyPrinter() else: @@ -259,6 +261,12 @@ class CBackend: prefix = "\n".join(node.prefix_lines) if prefix: prefix += "\n" + if self._vector_instruction_set and hasattr(node, 'instruction_set') and node.instruction_set is None: + # the tail loop must not be vectorized + self.sympy_printer = self.scalar_sympy_printer + code = f"{prefix}{loop_str}\n{self._print(node.body)}" + self.sympy_printer = self.vector_sympy_printer + return code return f"{prefix}{loop_str}\n{self._print(node.body)}" def _print_SympyAssignment(self, node): @@ -670,7 +678,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): else: is_boolean = get_type_of_expression(arg) == create_type("bool") is_integer = get_type_of_expression(arg) == create_type("int") or \ - (isinstance(arg, TypedSymbol) and arg.dtype.is_int()) + (isinstance(arg, TypedSymbol) and not isinstance(arg.dtype, VectorType) and arg.dtype.is_int()) instruction = 'makeVecConstBool' if is_boolean else \ 'makeVecConstInt' if is_integer else 'makeVecConst' return self.instruction_set[instruction].format(self._print(arg), **self._kwargs) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index b9fa2819e3f32241becd6113da340fc0c6de157d..b54b78d8583ff86ff256652a2eade14a2e285189 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -126,7 +126,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', vector_width = vector_is['width'] kernel_ast.instruction_set = vector_is - vectorize_rng(kernel_ast, vector_width) strided = 'storeS' in vector_is and 'loadS' in vector_is keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU'] vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, @@ -134,24 +133,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', insert_vector_casts(kernel_ast) -def vectorize_rng(kernel_ast, vector_width): - """Replace scalar result symbols on RNG nodes with vectorial ones""" - from pystencils.rng import RNGBase - subst = {} - - def visit_node(node): - for arg in node.args: - if isinstance(arg, RNGBase): - new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width)) - for s in arg.result_symbols] - subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)}) - arg._symbols_defined = set(new_result_symbols) - else: - visit_node(arg) - visit_node(kernel_ast) - fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase)) - - def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, strided, keep_loop_stop, assume_sufficient_line_padding): """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" @@ -173,6 +154,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)] assert len(loop_nodes) in (0, 1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width + if len(loop_nodes) == 2: + loop_nodes[1].instruction_set = None if len(loop_nodes) == 0: continue loop_node = loop_nodes[0] @@ -225,6 +208,15 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a mask_conditionals(loop_node) + from pystencils.rng import RNGBase + substitutions = {} + for rng in loop_node.atoms(RNGBase): + new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width)) + for s in rng.result_symbols] + substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)}) + rng._symbols_defined = set(new_result_symbols) + fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase)) + def mask_conditionals(loop_body): def visit_node(node, mask): @@ -322,6 +314,9 @@ def insert_vector_casts(ast_node): return expr def visit_node(node, substitution_dict): + if hasattr(node, 'instruction_set') and node.instruction_set is None: + # the tail loop must not be vectorized + return substitution_dict = substitution_dict.copy() for arg in node.args: if isinstance(arg, ast.SympyAssignment):