diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 0ae5e36400b7d05dd57c11874657096d36181078..ab0aa8a4333ccdc3a19c379d3e8d361514fed93d 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -233,11 +233,17 @@ class CBackend: self.sympy_printer.doprint(node.rhs)) else: lhs_type = get_type_of_expression(node.lhs) + printed_mask = "" if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): - arg, data_type, aligned, nontemporal = node.lhs.args + arg, data_type, aligned, nontemporal, mask = node.lhs.args instr = 'storeU' if aligned: instr = 'stream' if nontemporal else 'storeA' + if mask != True: + instr = 'maskStore' if aligned else 'maskStoreU' + printed_mask = self.sympy_printer.doprint(mask) + if self._vector_instruction_set['dataTypePrefix']['double'] == '__mm256d': + printed_mask = "_mm256_castpd_si256({})".format(printed_mask) rhs_type = get_type_of_expression(node.rhs) if type(rhs_type) is not VectorType: @@ -246,7 +252,8 @@ class CBackend: rhs = node.rhs return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]), - self.sympy_printer.doprint(rhs)) + ';' + self.sympy_printer.doprint(rhs), + printed_mask) + ';' else: return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) @@ -450,7 +457,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): def _print_Function(self, expr): if isinstance(expr, vector_memory_access): - arg, data_type, aligned, _ = expr.args + arg, data_type, aligned, _, mask = expr.args instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format("& " + self._print(arg)) elif isinstance(expr, cast_func): diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index 4415b9ab66e22af1bd671dd581d66ed7e159d855..be8523e90913d7f62eb908ace8cbdf0189c31fbf 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -32,7 +32,24 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): 'storeU': 'storeu[0,1]', 'storeA': 'store[0,1]', 'stream': 'stream[0,1]', + 'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]', + 'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]' } + if instruction_set == 'avx512': + base_names.update({ + 'maskStore': 'mask_store[0, 2, 1]', + 'maskStoreU': 'mask_storeu[0, 2, 1]', + 'maskLoad': 'mask_load[2, 1, 0]', + 'maskLoadU': 'mask_loadu[2, 1, 0]' + }) + if instruction_set == 'avx': + base_names.update({ + 'maskStore': 'maskstore[0, 2, 1]', + 'maskStoreU': 'maskstore[0, 2, 1]', + 'maskLoad': 'maskload[0, 1]', + 'maskLoadU': 'maskloadu[0, 1]' + }) + for comparison_op, constant in comparisons.items(): base_names[comparison_op] = 'cmp[0, 1, %s]' % (constant,) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index d2f206722972cdc62cdc9f398012cf579b27a96d..9556494576df91c2f614ea241ad44baf2e6f9f41 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -123,7 +123,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a nontemporal = False if hasattr(indexed, 'field'): nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields) - substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal) + substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True) if not successful: warnings.warn("Could not vectorize loop because of non-consecutive memory access") continue @@ -136,6 +136,30 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter}, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access)) + mask_conditionals(loop_node) + + +def mask_conditionals(loop_body): + + def visit_node(node, mask): + if isinstance(node, ast.Conditional): + true_mask = sp.And(node.condition_expr, mask) + visit_node(node.true_block, true_mask) + if node.false_block: + false_mask = sp.And(sp.Not(node.condition_expr), mask) + visit_node(node, false_mask) + node.condition_expr = vec_any(node.condition_expr) + elif isinstance(node, ast.SympyAssignment): + if mask is not True: + s = {ma: vector_memory_access(ma.args[0], ma.args[1], ma.args[2], ma.args[3], sp.And(mask, ma.args[4])) + for ma in node.atoms(vector_memory_access)} + node.subs(s) + else: + for arg in node.args: + visit_node(arg, mask) + + visit_node(loop_body, mask=True) + def insert_vector_casts(ast_node): """Inserts necessary casts from scalar values to vector values.""" @@ -143,8 +167,10 @@ def insert_vector_casts(ast_node): handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all) def visit_expr(expr): - - if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access): + if isinstance(expr, vector_memory_access): + return vector_memory_access(expr.args[0], expr.args[1], expr.args[2], expr.args[3], + visit_expr(expr.args[4])) + elif isinstance(expr, cast_func): return expr elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): new_args = [visit_expr(a) for a in expr.args] @@ -199,10 +225,12 @@ def insert_vector_casts(ast_node): new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) substitution_dict[assignment.lhs] = new_lhs assignment.lhs = new_lhs - elif isinstance(assignment.lhs.func, cast_func): - lhs_type = assignment.lhs.args[1] - if type(lhs_type) is VectorType and type(rhs_type) is not VectorType: - assignment.rhs = cast_func(assignment.rhs, lhs_type) + elif isinstance(assignment.lhs, vector_memory_access): + assignment.lhs = visit_expr(assignment.lhs) + #elif isinstance(assignment.lhs, cast_func): # TODO check if necessary + # lhs_type = assignment.lhs.args[1] + # if type(lhs_type) is VectorType and type(rhs_type) is not VectorType: + # assignment.rhs = cast_func(assignment.rhs, lhs_type) elif isinstance(arg, ast.Conditional): arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 20eb94d6b83474bfab46dc491a05b3b1ed2191d9..a5e876c974e372fdeeb946a9ed2fb1cf1ba08c50 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -190,7 +190,8 @@ class boolean_cast_func(cast_func, Boolean): # noinspection PyPep8Naming class vector_memory_access(cast_func): - nargs = (4,) + # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none) + nargs = (5,) # noinspection PyPep8Naming diff --git a/pystencils/transformations.py b/pystencils/transformations.py index f43b5bdc6e400fe004e84e2acab188f88a2562ca..1c244be0537b401587b76c348f89ce64c04cdee2 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -891,10 +891,10 @@ class KernelConstraintsCheck: if isinstance(lhs, AbstractField.AbstractAccess): fai = self.FieldAndIndex(lhs.field, lhs.index) self._field_writes[fai].add(lhs.offsets) - if len(self._field_writes[fai]) > 1: - raise ValueError( - "Field {} is written at two different locations".format( - lhs.field.name)) + #if len(self._field_writes[fai]) > 1: + # raise ValueError( + # "Field {} is written at two different locations".format( + # lhs.field.name)) elif isinstance(lhs, sp.Symbol): if self.scopes.is_defined_locally(lhs): raise ValueError(