From 6fa41f7ccc8ae5f29d3e1dc4708e897019ffca7e Mon Sep 17 00:00:00 2001 From: Michael Kuron <mkuron@icp.uni-stuttgart.de> Date: Thu, 27 May 2021 15:24:26 +0000 Subject: [PATCH] Fix RNG vectorization for LB --- pystencils/backends/cbackend.py | 12 ++++++++++-- pystencils/cpu/vectorization.py | 33 ++++++++++++++------------------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index ba7426a14..271bf2322 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -192,7 +192,9 @@ class CBackend: def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): if sympy_printer is None: if vector_instruction_set is not None: - self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) + self.vector_sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) + self.scalar_sympy_printer = CustomSympyPrinter() + self.sympy_printer = self.vector_sympy_printer else: self.sympy_printer = CustomSympyPrinter() else: @@ -259,6 +261,12 @@ class CBackend: prefix = "\n".join(node.prefix_lines) if prefix: prefix += "\n" + if self._vector_instruction_set and hasattr(node, 'instruction_set') and node.instruction_set is None: + # the tail loop must not be vectorized + self.sympy_printer = self.scalar_sympy_printer + code = f"{prefix}{loop_str}\n{self._print(node.body)}" + self.sympy_printer = self.vector_sympy_printer + return code return f"{prefix}{loop_str}\n{self._print(node.body)}" def _print_SympyAssignment(self, node): @@ -670,7 +678,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): else: is_boolean = get_type_of_expression(arg) == create_type("bool") is_integer = get_type_of_expression(arg) == create_type("int") or \ - (isinstance(arg, TypedSymbol) and arg.dtype.is_int()) + (isinstance(arg, TypedSymbol) and not isinstance(arg.dtype, VectorType) and arg.dtype.is_int()) instruction = 'makeVecConstBool' if is_boolean else \ 'makeVecConstInt' if is_integer else 'makeVecConst' return self.instruction_set[instruction].format(self._print(arg), **self._kwargs) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index b9fa2819e..b54b78d85 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -126,7 +126,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', vector_width = vector_is['width'] kernel_ast.instruction_set = vector_is - vectorize_rng(kernel_ast, vector_width) strided = 'storeS' in vector_is and 'loadS' in vector_is keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU'] vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, @@ -134,24 +133,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', insert_vector_casts(kernel_ast) -def vectorize_rng(kernel_ast, vector_width): - """Replace scalar result symbols on RNG nodes with vectorial ones""" - from pystencils.rng import RNGBase - subst = {} - - def visit_node(node): - for arg in node.args: - if isinstance(arg, RNGBase): - new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width)) - for s in arg.result_symbols] - subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)}) - arg._symbols_defined = set(new_result_symbols) - else: - visit_node(arg) - visit_node(kernel_ast) - fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase)) - - def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, strided, keep_loop_stop, assume_sufficient_line_padding): """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" @@ -173,6 +154,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)] assert len(loop_nodes) in (0, 1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width + if len(loop_nodes) == 2: + loop_nodes[1].instruction_set = None if len(loop_nodes) == 0: continue loop_node = loop_nodes[0] @@ -225,6 +208,15 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a mask_conditionals(loop_node) + from pystencils.rng import RNGBase + substitutions = {} + for rng in loop_node.atoms(RNGBase): + new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width)) + for s in rng.result_symbols] + substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)}) + rng._symbols_defined = set(new_result_symbols) + fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase)) + def mask_conditionals(loop_body): def visit_node(node, mask): @@ -322,6 +314,9 @@ def insert_vector_casts(ast_node): return expr def visit_node(node, substitution_dict): + if hasattr(node, 'instruction_set') and node.instruction_set is None: + # the tail loop must not be vectorized + return substitution_dict = substitution_dict.copy() for arg in node.args: if isinstance(arg, ast.SympyAssignment): -- GitLab