Commit 6fa41f7c authored by Michael Kuron's avatar Michael Kuron Committed by Markus Holzer
Browse files

Fix RNG vectorization for LB

parent 584b4255
......@@ -192,7 +192,9 @@ class CBackend:
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None:
if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
self.vector_sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
self.scalar_sympy_printer = CustomSympyPrinter()
self.sympy_printer = self.vector_sympy_printer
else:
self.sympy_printer = CustomSympyPrinter()
else:
......@@ -259,6 +261,12 @@ class CBackend:
prefix = "\n".join(node.prefix_lines)
if prefix:
prefix += "\n"
if self._vector_instruction_set and hasattr(node, 'instruction_set') and node.instruction_set is None:
# the tail loop must not be vectorized
self.sympy_printer = self.scalar_sympy_printer
code = f"{prefix}{loop_str}\n{self._print(node.body)}"
self.sympy_printer = self.vector_sympy_printer
return code
return f"{prefix}{loop_str}\n{self._print(node.body)}"
def _print_SympyAssignment(self, node):
......@@ -670,7 +678,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
else:
is_boolean = get_type_of_expression(arg) == create_type("bool")
is_integer = get_type_of_expression(arg) == create_type("int") or \
(isinstance(arg, TypedSymbol) and arg.dtype.is_int())
(isinstance(arg, TypedSymbol) and not isinstance(arg.dtype, VectorType) and arg.dtype.is_int())
instruction = 'makeVecConstBool' if is_boolean else \
'makeVecConstInt' if is_integer else 'makeVecConst'
return self.instruction_set[instruction].format(self._print(arg), **self._kwargs)
......
......@@ -126,7 +126,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
vector_width = vector_is['width']
kernel_ast.instruction_set = vector_is
vectorize_rng(kernel_ast, vector_width)
strided = 'storeS' in vector_is and 'loadS' in vector_is
keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU']
vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
......@@ -134,24 +133,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
insert_vector_casts(kernel_ast)
def vectorize_rng(kernel_ast, vector_width):
"""Replace scalar result symbols on RNG nodes with vectorial ones"""
from pystencils.rng import RNGBase
subst = {}
def visit_node(node):
for arg in node.args:
if isinstance(arg, RNGBase):
new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
for s in arg.result_symbols]
subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)})
arg._symbols_defined = set(new_result_symbols)
else:
visit_node(arg)
visit_node(kernel_ast)
fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase))
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
strided, keep_loop_stop, assume_sufficient_line_padding):
"""Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
......@@ -173,6 +154,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)]
assert len(loop_nodes) in (0, 1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width
if len(loop_nodes) == 2:
loop_nodes[1].instruction_set = None
if len(loop_nodes) == 0:
continue
loop_node = loop_nodes[0]
......@@ -225,6 +208,15 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
mask_conditionals(loop_node)
from pystencils.rng import RNGBase
substitutions = {}
for rng in loop_node.atoms(RNGBase):
new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
for s in rng.result_symbols]
substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
rng._symbols_defined = set(new_result_symbols)
fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))
def mask_conditionals(loop_body):
def visit_node(node, mask):
......@@ -322,6 +314,9 @@ def insert_vector_casts(ast_node):
return expr
def visit_node(node, substitution_dict):
if hasattr(node, 'instruction_set') and node.instruction_set is None:
# the tail loop must not be vectorized
return
substitution_dict = substitution_dict.copy()
for arg in node.args:
if isinstance(arg, ast.SympyAssignment):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment