Skip to content
Snippets Groups Projects
Commit 22fd030d authored by Martin Bauer's avatar Martin Bauer
Browse files

Support for early outs

parent 696eb2d5
Branches
Tags
No related merge requests found
......@@ -226,6 +226,20 @@ class KernelFunction(Node):
return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params)
class SkipIteration(Node):
@property
def args(self):
return []
@property
def symbols_defined(self):
return set()
@property
def undefined_symbols(self):
return set()
class Block(Node):
def __init__(self, nodes: List[Node]):
super(Block, self).__init__()
......@@ -627,3 +641,8 @@ class TemporaryMemoryFree(Node):
@property
def args(self):
return []
def early_out(condition):
from pystencils.cpu.vectorization import vec_all
return Conditional(vec_all(condition), Block([SkipIteration()]))
......@@ -98,8 +98,7 @@ class PrintNode(CustomCodeNode):
# noinspection PyPep8Naming
class CBackend:
def __init__(self, sympy_printer=None,
signature_only=False, vector_instruction_set=None, dialect='c'):
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None:
if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
......@@ -195,6 +194,12 @@ class CBackend:
align = 64
return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
def _print_SkipIteration(self, _):
if self._dialect == 'cuda':
return "return;"
else:
return "continue;"
def _print_CustomCodeNode(self, node):
return node.get_code(self._dialect, self._vector_instruction_set)
......
import sympy as sp
from typing import List
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.transformations import generic_visit
......@@ -54,7 +55,9 @@ def remove_placeholder_functions(expr):
subexpressions = []
def visit(e):
if isinstance(e, PlaceholderFunction):
if isinstance(e, Node):
return e
elif isinstance(e, PlaceholderFunction):
for se in e.subexpressions:
if se.lhs not in {a.lhs for a in subexpressions}:
subexpressions.append(se)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment