Commit 22fd030d authored by Martin Bauer's avatar Martin Bauer
Browse files

Support for early outs

parent 696eb2d5
...@@ -226,6 +226,20 @@ class KernelFunction(Node): ...@@ -226,6 +226,20 @@ class KernelFunction(Node):
return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params) return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params)
class SkipIteration(Node):
def args(self):
return []
def symbols_defined(self):
return set()
def undefined_symbols(self):
return set()
class Block(Node): class Block(Node):
def __init__(self, nodes: List[Node]): def __init__(self, nodes: List[Node]):
super(Block, self).__init__() super(Block, self).__init__()
...@@ -627,3 +641,8 @@ class TemporaryMemoryFree(Node): ...@@ -627,3 +641,8 @@ class TemporaryMemoryFree(Node):
@property @property
def args(self): def args(self):
return [] return []
def early_out(condition):
from pystencils.cpu.vectorization import vec_all
return Conditional(vec_all(condition), Block([SkipIteration()]))
...@@ -98,8 +98,7 @@ class PrintNode(CustomCodeNode): ...@@ -98,8 +98,7 @@ class PrintNode(CustomCodeNode):
# noinspection PyPep8Naming # noinspection PyPep8Naming
class CBackend: class CBackend:
def __init__(self, sympy_printer=None, def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None: if sympy_printer is None:
if vector_instruction_set is not None: if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect) self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
...@@ -195,6 +194,12 @@ class CBackend: ...@@ -195,6 +194,12 @@ class CBackend:
align = 64 align = 64
return "free(%s - %d);" % (self.sympy_printer.doprint(, node.offset(align)) return "free(%s - %d);" % (self.sympy_printer.doprint(, node.offset(align))
def _print_SkipIteration(self, _):
if self._dialect == 'cuda':
return "return;"
return "continue;"
def _print_CustomCodeNode(self, node): def _print_CustomCodeNode(self, node):
return node.get_code(self._dialect, self._vector_instruction_set) return node.get_code(self._dialect, self._vector_instruction_set)
import sympy as sp import sympy as sp
from typing import List from typing import List
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.transformations import generic_visit from pystencils.transformations import generic_visit
...@@ -54,7 +55,9 @@ def remove_placeholder_functions(expr): ...@@ -54,7 +55,9 @@ def remove_placeholder_functions(expr):
subexpressions = [] subexpressions = []
def visit(e): def visit(e):
if isinstance(e, PlaceholderFunction): if isinstance(e, Node):
return e
elif isinstance(e, PlaceholderFunction):
for se in e.subexpressions: for se in e.subexpressions:
if se.lhs not in {a.lhs for a in subexpressions}: if se.lhs not in {a.lhs for a in subexpressions}:
subexpressions.append(se) subexpressions.append(se)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment