diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 31d49c1d9b1e209afb48f673b127a626a332cbcd..727539373856950a9312966f3fe27859ae17a8f4 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -226,6 +226,20 @@ class KernelFunction(Node): return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params) +class SkipIteration(Node): + @property + def args(self): + return [] + + @property + def symbols_defined(self): + return set() + + @property + def undefined_symbols(self): + return set() + + class Block(Node): def __init__(self, nodes: List[Node]): super(Block, self).__init__() @@ -627,3 +641,8 @@ class TemporaryMemoryFree(Node): @property def args(self): return [] + + +def early_out(condition): + from pystencils.cpu.vectorization import vec_all + return Conditional(vec_all(condition), Block([SkipIteration()])) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index cf96dcd098c8a1b988fc554820752028b6cc7629..0147efabad6e92f49b91cd3fc8fae0711ac0a2d5 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -98,8 +98,7 @@ class PrintNode(CustomCodeNode): # noinspection PyPep8Naming class CBackend: - def __init__(self, sympy_printer=None, - signature_only=False, vector_instruction_set=None, dialect='c'): + def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): if sympy_printer is None: if vector_instruction_set is not None: self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect) @@ -195,6 +194,12 @@ class CBackend: align = 64 return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) + def _print_SkipIteration(self, _): + if self._dialect == 'cuda': + return "return;" + else: + return "continue;" + def _print_CustomCodeNode(self, node): return node.get_code(self._dialect, self._vector_instruction_set) diff --git a/pystencils/placeholder_function.py b/pystencils/placeholder_function.py index 73d8bda78c9d96e52f32083029f36d4990e62b6d..67bae6ceafd2a9334ffabbe8abcabcc1a3c7eaa7 100644 --- a/pystencils/placeholder_function.py +++ b/pystencils/placeholder_function.py @@ -1,6 +1,7 @@ import sympy as sp from typing import List from pystencils.assignment import Assignment +from pystencils.astnodes import Node from pystencils.transformations import generic_visit @@ -54,7 +55,9 @@ def remove_placeholder_functions(expr): subexpressions = [] def visit(e): - if isinstance(e, PlaceholderFunction): + if isinstance(e, Node): + return e + elif isinstance(e, PlaceholderFunction): for se in e.subexpressions: if se.lhs not in {a.lhs for a in subexpressions}: subexpressions.append(se)