From 22fd030d220cacefb7e46a7c08b2a5b0dbbb3d43 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 29 Mar 2019 18:14:42 +0100
Subject: [PATCH] Support for early outs

---
 pystencils/astnodes.py             | 19 +++++++++++++++++++
 pystencils/backends/cbackend.py    |  9 +++++++--
 pystencils/placeholder_function.py |  5 ++++-
 3 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 31d49c1d9..727539373 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -226,6 +226,20 @@ class KernelFunction(Node):
         return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params)
 
 
+class SkipIteration(Node):
+    @property
+    def args(self):
+        return []
+
+    @property
+    def symbols_defined(self):
+        return set()
+
+    @property
+    def undefined_symbols(self):
+        return set()
+
+
 class Block(Node):
     def __init__(self, nodes: List[Node]):
         super(Block, self).__init__()
@@ -627,3 +641,8 @@ class TemporaryMemoryFree(Node):
     @property
     def args(self):
         return []
+
+
+def early_out(condition):
+    from pystencils.cpu.vectorization import vec_all
+    return Conditional(vec_all(condition), Block([SkipIteration()]))
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index cf96dcd09..0147efaba 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -98,8 +98,7 @@ class PrintNode(CustomCodeNode):
 # noinspection PyPep8Naming
 class CBackend:
 
-    def __init__(self, sympy_printer=None,
-                 signature_only=False, vector_instruction_set=None, dialect='c'):
+    def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
         if sympy_printer is None:
             if vector_instruction_set is not None:
                 self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
@@ -195,6 +194,12 @@ class CBackend:
         align = 64
         return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
 
+    def _print_SkipIteration(self, _):
+        if self._dialect == 'cuda':
+            return "return;"
+        else:
+            return "continue;"
+
     def _print_CustomCodeNode(self, node):
         return node.get_code(self._dialect, self._vector_instruction_set)
 
diff --git a/pystencils/placeholder_function.py b/pystencils/placeholder_function.py
index 73d8bda78..67bae6cea 100644
--- a/pystencils/placeholder_function.py
+++ b/pystencils/placeholder_function.py
@@ -1,6 +1,7 @@
 import sympy as sp
 from typing import List
 from pystencils.assignment import Assignment
+from pystencils.astnodes import Node
 from pystencils.transformations import generic_visit
 
 
@@ -54,7 +55,9 @@ def remove_placeholder_functions(expr):
     subexpressions = []
 
     def visit(e):
-        if isinstance(e, PlaceholderFunction):
+        if isinstance(e, Node):
+            return e
+        elif isinstance(e, PlaceholderFunction):
             for se in e.subexpressions:
                 if se.lhs not in {a.lhs for a in subexpressions}:
                     subexpressions.append(se)
-- 
GitLab