diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index 018100eb24263b2ae951fad1fefdad5520d307ea..ee424b92c2fa801e272db7e9d0e037ae68481fb5 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -437,7 +437,6 @@ def run_compile_step(command): config_env = compiler_config['env'] if 'env' in compiler_config else {} compile_environment = os.environ.copy() compile_environment.update(config_env) - try: shell = True if compiler_config['os'].lower() == 'windows' else False subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 9556494576df91c2f614ea241ad44baf2e6f9f41..bfd90ffc4cbfb8d6cd6ac713695ff0911edf574a 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -143,7 +143,9 @@ def mask_conditionals(loop_body): def visit_node(node, mask): if isinstance(node, ast.Conditional): - true_mask = sp.And(node.condition_expr, mask) + cond = node.condition_expr + cond = cond if loop_body.loop_counter_symbol in cond.atoms(sp.Symbol) else True + true_mask = sp.And(cond, mask) visit_node(node.true_block, true_mask) if node.false_block: false_mask = sp.And(sp.Not(node.condition_expr), mask) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index a5e876c974e372fdeeb946a9ed2fb1cf1ba08c50..6fb15be18f5fd229b9e0a3646d01b2cefa5e2390 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -27,6 +27,9 @@ def typed_symbols(names, dtype, *args): else: return TypedSymbol(str(symbols), dtype) +def type_all_numbers(expr, dtype): + substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)} + return expr.subs(substitutions) def matrix_symbols(names, dtype, rows, cols): if isinstance(names, str): diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index ce9f1c9fecaafb62a2f34789fdec5e7a02c537ef..bf4074ed9287d5c9c77ef4776c62ceb1e4ffff4b 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -345,7 +345,8 @@ class AssignmentCollection: return result def __repr__(self): - return "Assignment Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments]) + return "Assignment Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments + if isinstance(eq, Assignment)]) def __str__(self): result = "Subexpressions:\n" diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 1c244be0537b401587b76c348f89ce64c04cdee2..c790995a31b9a3906427f40a0b4d8795d7469647 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -659,7 +659,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): outer_loop = outer_loop[0] symbols_with_temporary_array = OrderedDict() - assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args) + assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args if hasattr(a, 'lhs')) assignment_groups = [] for symbol_group in symbol_groups: @@ -690,13 +690,10 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): if assignment.lhs in symbols_resolved: new_rhs = assignment.rhs.subs( symbols_with_temporary_array.items()) - if not isinstance(assignment.lhs, AbstractField.AbstractAccess - ) and assignment.lhs in symbol_group: + if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group: assert type(assignment.lhs) is TypedSymbol - new_ts = TypedSymbol(assignment.lhs.name, - PointerType(assignment.lhs.dtype)) - new_lhs = sp.IndexedBase( - new_ts, shape=(1, ))[inner_loop.loop_counter_symbol] + new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) + new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol] else: new_lhs = assignment.lhs assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs)) @@ -805,13 +802,14 @@ class KernelConstraintsCheck: """ FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) - def __init__(self, type_for_symbol, check_independence_condition): + def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True): self._type_for_symbol = type_for_symbol self.scopes = NestedScopes() self._field_writes = defaultdict(set) self.fields_read = set() self.check_independence_condition = check_independence_condition + self.check_double_write_condition = check_double_write_condition def process_assignment(self, assignment): # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 @@ -891,19 +889,15 @@ class KernelConstraintsCheck: if isinstance(lhs, AbstractField.AbstractAccess): fai = self.FieldAndIndex(lhs.field, lhs.index) self._field_writes[fai].add(lhs.offsets) - #if len(self._field_writes[fai]) > 1: - # raise ValueError( - # "Field {} is written at two different locations".format( - # lhs.field.name)) + if self.check_double_write_condition and len(self._field_writes[fai]) > 1: + raise ValueError( + "Field {} is written at two different locations".format( + lhs.field.name)) elif isinstance(lhs, sp.Symbol): if self.scopes.is_defined_locally(lhs): - raise ValueError( - "Assignments not in SSA form, multiple assignments to {}". - format(lhs.name)) + raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name)) if lhs in self.scopes.free_parameters: - raise ValueError( - "Symbol {} is written, after it has been read".format( - lhs.name)) + raise ValueError("Symbol {} is written, after it has been read".format(lhs.name)) self.scopes.define_symbol(lhs) def _update_accesses_rhs(self, rhs): @@ -947,12 +941,16 @@ def add_types(eqs, type_for_symbol, check_independence_condition): return check.process_assignment(obj) elif isinstance(obj, ast.Conditional): check.scopes.push() + # Disable double write check inside conditionals + # would be triggered by e.g. in-kernel boundaries + check.check_double_write_condition = False false_block = None if obj.false_block is None else visit( obj.false_block) result = ast.Conditional(check.process_expression( obj.condition_expr, type_constants=False), true_block=visit(obj.true_block), false_block=false_block) + check.check_double_write_condition = True check.scopes.pop() return result elif isinstance(obj, ast.Block):