From a822ffc940f27e766d5426acbc9c78ee83dcd6a7 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 11 Oct 2019 08:09:07 +0200 Subject: [PATCH] Better boolean support in sympy printer - bugfix: loop counter of vectorized loop now correctly stored as SIMD vector with entries i, i+1, i+2, ... - basis for in-kernel boundary handling --- pystencils/astnodes.py | 18 ++++++++++++++++-- pystencils/backends/cbackend.py | 15 +++++++++++---- pystencils/backends/simd_instruction_sets.py | 19 ++++++++++++++++++- pystencils/cpu/vectorization.py | 13 +++++++++---- pystencils/sympyextensions.py | 2 +- pystencils/transformations.py | 7 ++++++- 6 files changed, 61 insertions(+), 13 deletions(-) diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 47f1fd7d1..3f67248f6 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -293,6 +293,10 @@ class Block(Node): for a in self.args: a.subs(subs_dict) + def fast_subs(self, subs_dict, skip=None): + self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes] + return self + def insert_front(self, node): if isinstance(node, collections.abc.Iterable): node = list(node) @@ -408,6 +412,16 @@ class LoopOverCoordinate(Node): if hasattr(self.step, "subs"): self.step = self.step.subs(subs_dict) + def fast_subs(self, subs_dict, skip=None): + self.body = fast_subs(self.body, subs_dict, skip) + if isinstance(self.start, sp.Basic): + self.start = fast_subs(self.start, subs_dict, skip) + if isinstance(self.stop, sp.Basic): + self.stop = fast_subs(self.stop, subs_dict, skip) + if isinstance(self.step, sp.Basic): + self.step = fast_subs(self.step, subs_dict, skip) + return self + @property def args(self): result = [self.body] @@ -538,7 +552,7 @@ class SympyAssignment(Node): @property def args(self): - return [self._lhs_symbol, self.rhs] + return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)] @property def symbols_defined(self): @@ -603,7 +617,7 @@ class ResolvedFieldAccess(sp.Indexed): self.args[1].subs(old, new), self.field, self.offsets, self.idx_coordinate_values) - def fast_subs(self, substitutions): + def fast_subs(self, substitutions, skip=None): if self in substitutions: return substitutions[self] return ResolvedFieldAccess(self.args[0].subs(substitutions), diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 556065251..0ae5e3640 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -5,7 +5,6 @@ import numpy as np import sympy as sp from sympy.core import S from sympy.printing.ccode import C89CodePrinter - from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import ( @@ -457,7 +456,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): elif isinstance(expr, cast_func): arg, data_type = expr.args if type(data_type) is VectorType: - return self.instruction_set['makeVec'].format(self._print(arg)) + if isinstance(arg, sp.Tuple): + is_boolean = get_type_of_expression(arg[0]) == create_type("bool") + printed_args = [self._print(a) for a in arg] + instruction = 'makeVecBool' if is_boolean else 'makeVec' + return self.instruction_set[instruction].format(*printed_args) + else: + is_boolean = get_type_of_expression(arg) == create_type("bool") + instruction = 'makeVecConstBool' if is_boolean else 'makeVecConst' + return self.instruction_set[instruction].format(self._print(arg)) elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) if not result: @@ -542,12 +549,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if result: return result - one = self.instruction_set['makeVec'].format(1.0) + one = self.instruction_set['makeVecConst'].format(1.0) if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" elif expr.exp == -1: - one = self.instruction_set['makeVec'].format(1.0) + one = self.instruction_set['makeVecConst'].format(1.0) return self.instruction_set['/'].format(one, self._print(expr.base)) elif expr.exp == 0.5: return self.instruction_set['sqrt'].format(self._print(expr.base)) diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index f72b0fec3..4415b9ab6 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -21,7 +21,10 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): 'sqrt': 'sqrt[0]', + 'makeVecConst': 'set[]', 'makeVec': 'set[]', + 'makeVecBool': 'set[]', + 'makeVecConstBool': 'set[]', 'makeZero': 'setzero[]', 'loadU': 'loadu[0]', @@ -68,8 +71,17 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): function_shortcut = function_shortcut.strip() name = function_shortcut[:function_shortcut.index('[')] - if intrinsic_id == 'makeVec': + if intrinsic_id == 'makeVecConst': arg_string = "({})".format(",".join(["{0}"] * result['width'])) + elif intrinsic_id == 'makeVec': + params = ["{" + str(i) + "}" for i in reversed(range(result['width']))] + arg_string = "({})".format(",".join(params)) + elif intrinsic_id == 'makeVecBool': + params = ["(({{{i}}} ? -1.0 : 0.0)".format(i=i) for i in reversed(range(result['width']))] + arg_string = "({})".format(",".join(params)) + elif intrinsic_id == 'makeVecConstBool': + params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])] + arg_string = "({})".format(",".join(params)) else: args = function_shortcut[function_shortcut.index('[') + 1: -1] arg_string = "(" @@ -111,6 +123,11 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,) result['bool'] = "__mmask%d" % (size,) + params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)]) + result['makeVecBool'] = "__mmask8(({}) )".format(params) + params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)]) + result['makeVecConstBool'] = "__mmask8(({}) )".format(params) + if instruction_set == 'avx' and data_type == 'float': result['rsqrt'] = "_mm256_rsqrt_ps({0})" diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 6bf3a26de..d2f206722 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -18,12 +18,12 @@ from pystencils.transformations import ( # noinspection PyPep8Naming class vec_any(sp.Function): - nargs = (1, ) + nargs = (1,) # noinspection PyPep8Naming class vec_all(sp.Function): - nargs = (1, ) + nargs = (1,) def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', @@ -53,7 +53,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', """ if instruction_set is None: return - + all_fields = kernel_ast.fields_accessed if nontemporal is None or nontemporal is False: nontemporal = {} @@ -101,7 +101,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a if len(loop_nodes) == 0: continue loop_node = loop_nodes[0] - + # Find all array accesses (indexed) that depend on the loop counter as offset loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over) substitutions = {} @@ -130,6 +130,11 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a loop_node.step = vector_width loop_node.subs(substitutions) + vector_loop_counter = cast_func(tuple(loop_counter_symbol + i for i in range(vector_width)), + VectorType(loop_counter_symbol.dtype, vector_width)) + + fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter}, + skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access)) def insert_vector_casts(ast_node): diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index afdf0fde3..7d25f49c7 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -160,7 +160,7 @@ def fast_subs(expression: T, substitutions: Dict, if skip and skip(expr): return expr if hasattr(expr, "fast_subs"): - return expr.fast_subs(substitutions) + return expr.fast_subs(substitutions, skip) if expr in substitutions: return substitutions[expr] if not hasattr(expr, 'args'): diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 1bfb0511a..f43b5bdc6 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -481,6 +481,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), if isinstance(field.dtype, StructType): assert field.index_dimensions == 1 accessed_field_name = field_access.index[0] + if isinstance(accessed_field_name, sp.Symbol): + accessed_field_name = accessed_field_name.name assert isinstance(accessed_field_name, str) coordinates[e] = field.dtype.get_element_offset(accessed_field_name) else: @@ -504,7 +506,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), field_access.offsets, field_access.index) if isinstance(get_base_type(field_access.field.dtype), StructType): - new_type = field_access.field.dtype.get_element_type(field_access.index[0]) + accessed_field_name = field_access.index[0] + if isinstance(accessed_field_name, sp.Symbol): + accessed_field_name = accessed_field_name.name + new_type = field_access.field.dtype.get_element_type(accessed_field_name) result = reinterpret_cast_func(result, new_type) return visit_sympy_expr(result, enclosing_block, sympy_assignment) -- GitLab