From 696eb2d5cdadb31787f19d92ccc5d42d2dbf6e6f Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 29 Mar 2019 17:51:58 +0100 Subject: [PATCH] Support for conditionals in vectorized loops (by specifying all/any) --- pystencils/astnodes.py | 1 - pystencils/backends/cbackend.py | 22 +++++++++- pystencils/backends/simd_instruction_sets.py | 4 ++ pystencils/cpu/vectorization.py | 17 +++++++- pystencils/data_types.py | 5 +++ pystencils/kernelcreation.py | 4 +- pystencils/transformations.py | 13 +++++- pystencils_tests/test_conditional_vec.py | 43 ++++++++++++++++++++ 8 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 pystencils_tests/test_conditional_vec.py diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 992cbf2..31d49c1 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -65,7 +65,6 @@ class Conditional(Node): false_block: Optional['Block'] = None) -> None: super(Conditional, self).__init__(parent=None) - assert condition_expr.is_Boolean or condition_expr.is_Relational self.condition_expr = condition_expr def handle_child(c): diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index f0787bc..cf96dcd 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -4,6 +4,7 @@ from sympy.core import S from typing import Set from sympy.printing.ccode import C89CodePrinter +from pystencils.cpu.vectorization import vec_any, vec_all from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt try: @@ -198,6 +199,9 @@ class CBackend: return node.get_code(self._dialect, self._vector_instruction_set) def _print_Conditional(self, node): + cond_type = get_type_of_expression(node.condition_expr) + if isinstance(cond_type, VectorType): + raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") condition_expr = self.sympy_printer.doprint(node.condition_expr) true_block = self._print_Block(node.true_block) result = "if (%s)\n%s " % (condition_expr, true_block) @@ -274,6 +278,8 @@ class CustomSympyPrinter(CCodePrinter): return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) else: return "({})".format(self._print(sp.sqrt(expr.args[0]))) + elif isinstance(expr, vec_any) or isinstance(expr, vec_all): + return self._print(expr.args[0]) elif isinstance(expr, fast_inv_sqrt): if self._dialect == "cuda": return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) @@ -328,7 +334,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) if not result: - return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) + result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) + return result elif expr.func == fast_sqrt: return "({})".format(self._print(sp.sqrt(expr.args[0]))) elif expr.func == fast_inv_sqrt: @@ -338,6 +345,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) else: return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) + elif isinstance(expr, vec_any): + expr_type = get_type_of_expression(expr.args[0]) + if type(expr_type) is not VectorType: + return self._print(expr.args[0]) + else: + return self.instruction_set['any'].format(self._print(expr.args[0])) + elif isinstance(expr, vec_all): + expr_type = get_type_of_expression(expr.args[0]) + if type(expr_type) is not VectorType: + return self._print(expr.args[0]) + else: + return self.instruction_set['all'].format(self._print(expr.args[0])) + return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) def _print_And(self, expr): diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index 0220750..f72b0fe 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -98,11 +98,15 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): result['bool'] = "__m%dd" % (bit_width,) result['headers'] = headers[instruction_set] + result['any'] = "%s_movemask_%s({0}) > 0" % (pre, suf) + result['all'] = "%s_movemask_%s({0}) == 0xF" % (pre, suf) if instruction_set == 'avx512': size = 8 if data_type == 'double' else 16 result['&'] = '_kand_mask%d({0}, {1})' % (size,) result['|'] = '_kor_mask%d({0}, {1})' % (size,) + result['any'] = '!_ktestz_mask%d_u8({0}, {0})' % (size, ) + result['all'] = '_kortestc_mask%d_u8({0}, {0})' % (size, ) result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf) result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,) result['bool'] = "__mmask%d" % (size,) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 6a55b69..2790fe3 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -12,6 +12,16 @@ from pystencils.transformations import cut_loop, filtered_tree_iteration, replac from pystencils.field import Field +# noinspection PyPep8Naming +class vec_any(sp.Function): + nargs = (1, ) + + +# noinspection PyPep8Naming +class vec_all(sp.Function): + nargs = (1, ) + + def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False, assume_inner_stride_one: bool = False, assume_sufficient_line_padding: bool = True): @@ -119,7 +129,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a def insert_vector_casts(ast_node): """Inserts necessary casts from scalar values to vector values.""" - handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt) + handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all) def visit_expr(expr): @@ -182,6 +192,11 @@ def insert_vector_casts(ast_node): lhs_type = assignment.lhs.args[1] if type(lhs_type) is VectorType and type(rhs_type) is not VectorType: assignment.rhs = cast_func(assignment.rhs, lhs_type) + elif isinstance(arg, ast.Conditional): + arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict, + skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) + arg.condition_expr = visit_expr(arg.condition_expr) + visit_node(arg, substitution_dict) else: visit_node(arg, substitution_dict) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 56d19d4..7bdc9d3 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -1,6 +1,7 @@ import ctypes import sympy as sp import numpy as np + try: import llvmlite.ir as ir except ImportError as e: @@ -311,6 +312,8 @@ def collate_types(types): @memorycache(maxsize=2048) def get_type_of_expression(expr): from pystencils.astnodes import ResolvedFieldAccess + from pystencils.cpu.vectorization import vec_all, vec_any + expr = sp.sympify(expr) if isinstance(expr, sp.Integer): return create_type("int") @@ -324,6 +327,8 @@ def get_type_of_expression(expr): raise ValueError("All symbols inside this expression have to be typed! ", str(expr)) elif isinstance(expr, cast_func): return expr.args[1] + elif isinstance(expr, vec_any) or isinstance(expr, vec_all): + return create_type("bool") elif hasattr(expr, 'func') and expr.func == sp.Piecewise: collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args)) collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args)) diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index 5044049..024ee4a 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -6,7 +6,8 @@ from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAss from pystencils.cpu.vectorization import vectorize from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.gpucuda.indexing import indexing_creator_from_params -from pystencils.transformations import remove_conditionals_in_staggered_kernel, loop_blocking +from pystencils.transformations import remove_conditionals_in_staggered_kernel, loop_blocking, \ + move_constants_before_loop def create_kernel(assignments, target='cpu', data_type="double", iteration_slice=None, ghost_layers=None, @@ -248,6 +249,7 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar if target == 'cpu': remove_conditionals_in_staggered_kernel(ast) + move_constants_before_loop(ast) if blocking: loop_blocking(ast, blocking) if cpu_vectorize_info is True: diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 6fc9f83..6411d6a 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -7,7 +7,6 @@ import hashlib import sympy as sp from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase - from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.assignment import Assignment from pystencils.field import Field, FieldType @@ -514,6 +513,13 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), assert type(enclosing_block) is ast.Block sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast) sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast) + elif isinstance(sub_ast, ast.Conditional): + enclosing_block = sub_ast.parent + assert type(enclosing_block) is ast.Block + sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast) + visit_node(sub_ast.true_block) + if sub_ast.false_block: + visit_node(sub_ast.false_block) else: for i, a in enumerate(sub_ast.args): visit_node(a) @@ -545,7 +551,7 @@ def move_constants_before_loop(ast_node): last_block_child = prev_element if isinstance(element, ast.Conditional): - critical_symbols = element.condition_expr.atoms(sp.Symbol) + break else: critical_symbols = element.symbols_defined if node.undefined_symbols.intersection(critical_symbols): @@ -1000,6 +1006,9 @@ def typing_from_sympy_inspection(eqs, default_type="double"): elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment): continue else: + from pystencils.cpu.vectorization import vec_all, vec_any + if isinstance(eq.rhs, vec_all) or isinstance(eq.rhs, vec_any): + result[eq.lhs.name] = "bool" # problematic case here is when rhs is a symbol: then it is impossible to decide here without # further information what type the left hand side is - default fallback is the dict value then if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol): diff --git a/pystencils_tests/test_conditional_vec.py b/pystencils_tests/test_conditional_vec.py new file mode 100644 index 0000000..c1d11f1 --- /dev/null +++ b/pystencils_tests/test_conditional_vec.py @@ -0,0 +1,43 @@ +import pystencils as ps +import sympy as sp +import numpy as np +from pystencils.astnodes import Conditional, Block +from pystencils.cpu.vectorization import vec_all, vec_any + + +def test_vec_any(): + data_arr = np.zeros((15, 15)) + + data_arr[3:9, 2:7] = 1.0 + data = ps.fields("data: double[2D]", data=data_arr) + + c = [ + ps.Assignment(sp.Symbol("t1"), vec_any(data.center() > 0.0)), + Conditional(vec_any(data.center() > 0.0), Block([ + ps.Assignment(data.center(), 2.0) + ])) + ] + ast = ps.create_kernel(c, target='cpu', + cpu_vectorize_info={'instruction_set': 'avx'}) + kernel = ast.compile() + kernel(data=data_arr) + np.testing.assert_equal(data_arr[3:9, 0:8], 2.0) + + +def test_vec_all(): + data_arr = np.zeros((15, 15)) + + data_arr[3:9, 2:7] = 1.0 + data = ps.fields("data: double[2D]", data=data_arr) + + c = [ + Conditional(vec_all(data.center() > 0.0), Block([ + ps.Assignment(data.center(), 2.0) + ])) + ] + ast = ps.create_kernel(c, target='cpu', + cpu_vectorize_info={'instruction_set': 'avx'}) + kernel = ast.compile() + before = data_arr.copy() + kernel(data=data_arr) + np.testing.assert_equal(data_arr, before) -- GitLab