Commit 696eb2d5 authored by Martin Bauer's avatar Martin Bauer
Browse files

Support for conditionals in vectorized loops (by specifying all/any)

parent a2582a17
...@@ -65,7 +65,6 @@ class Conditional(Node): ...@@ -65,7 +65,6 @@ class Conditional(Node):
false_block: Optional['Block'] = None) -> None: false_block: Optional['Block'] = None) -> None:
super(Conditional, self).__init__(parent=None) super(Conditional, self).__init__(parent=None)
assert condition_expr.is_Boolean or condition_expr.is_Relational
self.condition_expr = condition_expr self.condition_expr = condition_expr
def handle_child(c): def handle_child(c):
......
...@@ -4,6 +4,7 @@ from sympy.core import S ...@@ -4,6 +4,7 @@ from sympy.core import S
from typing import Set from typing import Set
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
from pystencils.cpu.vectorization import vec_any, vec_all
from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt
try: try:
...@@ -198,6 +199,9 @@ class CBackend: ...@@ -198,6 +199,9 @@ class CBackend:
return node.get_code(self._dialect, self._vector_instruction_set) return node.get_code(self._dialect, self._vector_instruction_set)
def _print_Conditional(self, node): def _print_Conditional(self, node):
cond_type = get_type_of_expression(node.condition_expr)
if isinstance(cond_type, VectorType):
raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
condition_expr = self.sympy_printer.doprint(node.condition_expr) condition_expr = self.sympy_printer.doprint(node.condition_expr)
true_block = self._print_Block(node.true_block) true_block = self._print_Block(node.true_block)
result = "if (%s)\n%s " % (condition_expr, true_block) result = "if (%s)\n%s " % (condition_expr, true_block)
...@@ -274,6 +278,8 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -274,6 +278,8 @@ class CustomSympyPrinter(CCodePrinter):
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else: else:
return "({})".format(self._print(sp.sqrt(expr.args[0]))) return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
if self._dialect == "cuda": if self._dialect == "cuda":
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
...@@ -328,7 +334,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -328,7 +334,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif expr.func == fast_division: elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr) result = self._scalarFallback('_print_Function', expr)
if not result: if not result:
return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
return result
elif expr.func == fast_sqrt: elif expr.func == fast_sqrt:
return "({})".format(self._print(sp.sqrt(expr.args[0]))) return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif expr.func == fast_inv_sqrt: elif expr.func == fast_inv_sqrt:
...@@ -338,6 +345,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -338,6 +345,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
else: else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
elif isinstance(expr, vec_any):
expr_type = get_type_of_expression(expr.args[0])
if type(expr_type) is not VectorType:
return self._print(expr.args[0])
else:
return self.instruction_set['any'].format(self._print(expr.args[0]))
elif isinstance(expr, vec_all):
expr_type = get_type_of_expression(expr.args[0])
if type(expr_type) is not VectorType:
return self._print(expr.args[0])
else:
return self.instruction_set['all'].format(self._print(expr.args[0]))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_And(self, expr): def _print_And(self, expr):
......
...@@ -98,11 +98,15 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -98,11 +98,15 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result['bool'] = "__m%dd" % (bit_width,) result['bool'] = "__m%dd" % (bit_width,)
result['headers'] = headers[instruction_set] result['headers'] = headers[instruction_set]
result['any'] = "%s_movemask_%s({0}) > 0" % (pre, suf)
result['all'] = "%s_movemask_%s({0}) == 0xF" % (pre, suf)
if instruction_set == 'avx512': if instruction_set == 'avx512':
size = 8 if data_type == 'double' else 16 size = 8 if data_type == 'double' else 16
result['&'] = '_kand_mask%d({0}, {1})' % (size,) result['&'] = '_kand_mask%d({0}, {1})' % (size,)
result['|'] = '_kor_mask%d({0}, {1})' % (size,) result['|'] = '_kor_mask%d({0}, {1})' % (size,)
result['any'] = '!_ktestz_mask%d_u8({0}, {0})' % (size, )
result['all'] = '_kortestc_mask%d_u8({0}, {0})' % (size, )
result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf) result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf)
result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,) result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
result['bool'] = "__mmask%d" % (size,) result['bool'] = "__mmask%d" % (size,)
......
...@@ -12,6 +12,16 @@ from pystencils.transformations import cut_loop, filtered_tree_iteration, replac ...@@ -12,6 +12,16 @@ from pystencils.transformations import cut_loop, filtered_tree_iteration, replac
from pystencils.field import Field from pystencils.field import Field
# noinspection PyPep8Naming
class vec_any(sp.Function):
nargs = (1, )
# noinspection PyPep8Naming
class vec_all(sp.Function):
nargs = (1, )
def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False, assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False,
assume_inner_stride_one: bool = False, assume_sufficient_line_padding: bool = True): assume_inner_stride_one: bool = False, assume_sufficient_line_padding: bool = True):
...@@ -119,7 +129,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -119,7 +129,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
def insert_vector_casts(ast_node): def insert_vector_casts(ast_node):
"""Inserts necessary casts from scalar values to vector values.""" """Inserts necessary casts from scalar values to vector values."""
handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt) handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all)
def visit_expr(expr): def visit_expr(expr):
...@@ -182,6 +192,11 @@ def insert_vector_casts(ast_node): ...@@ -182,6 +192,11 @@ def insert_vector_casts(ast_node):
lhs_type = assignment.lhs.args[1] lhs_type = assignment.lhs.args[1]
if type(lhs_type) is VectorType and type(rhs_type) is not VectorType: if type(lhs_type) is VectorType and type(rhs_type) is not VectorType:
assignment.rhs = cast_func(assignment.rhs, lhs_type) assignment.rhs = cast_func(assignment.rhs, lhs_type)
elif isinstance(arg, ast.Conditional):
arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict,
skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
arg.condition_expr = visit_expr(arg.condition_expr)
visit_node(arg, substitution_dict)
else: else:
visit_node(arg, substitution_dict) visit_node(arg, substitution_dict)
......
import ctypes import ctypes
import sympy as sp import sympy as sp
import numpy as np import numpy as np
try: try:
import llvmlite.ir as ir import llvmlite.ir as ir
except ImportError as e: except ImportError as e:
...@@ -311,6 +312,8 @@ def collate_types(types): ...@@ -311,6 +312,8 @@ def collate_types(types):
@memorycache(maxsize=2048) @memorycache(maxsize=2048)
def get_type_of_expression(expr): def get_type_of_expression(expr):
from pystencils.astnodes import ResolvedFieldAccess from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
expr = sp.sympify(expr) expr = sp.sympify(expr)
if isinstance(expr, sp.Integer): if isinstance(expr, sp.Integer):
return create_type("int") return create_type("int")
...@@ -324,6 +327,8 @@ def get_type_of_expression(expr): ...@@ -324,6 +327,8 @@ def get_type_of_expression(expr):
raise ValueError("All symbols inside this expression have to be typed! ", str(expr)) raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
return expr.args[1] return expr.args[1]
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise: elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args)) collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args)) collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
......
...@@ -6,7 +6,8 @@ from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAss ...@@ -6,7 +6,8 @@ from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAss
from pystencils.cpu.vectorization import vectorize from pystencils.cpu.vectorization import vectorize
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.gpucuda.indexing import indexing_creator_from_params from pystencils.gpucuda.indexing import indexing_creator_from_params
from pystencils.transformations import remove_conditionals_in_staggered_kernel, loop_blocking from pystencils.transformations import remove_conditionals_in_staggered_kernel, loop_blocking, \
move_constants_before_loop
def create_kernel(assignments, target='cpu', data_type="double", iteration_slice=None, ghost_layers=None, def create_kernel(assignments, target='cpu', data_type="double", iteration_slice=None, ghost_layers=None,
...@@ -248,6 +249,7 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar ...@@ -248,6 +249,7 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
if target == 'cpu': if target == 'cpu':
remove_conditionals_in_staggered_kernel(ast) remove_conditionals_in_staggered_kernel(ast)
move_constants_before_loop(ast)
if blocking: if blocking:
loop_blocking(ast, blocking) loop_blocking(ast, blocking)
if cpu_vectorize_info is True: if cpu_vectorize_info is True:
......
...@@ -7,7 +7,6 @@ import hashlib ...@@ -7,7 +7,6 @@ import hashlib
import sympy as sp import sympy as sp
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase from sympy.tensor import IndexedBase
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
...@@ -514,6 +513,13 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -514,6 +513,13 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
assert type(enclosing_block) is ast.Block assert type(enclosing_block) is ast.Block
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast) sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast) sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
elif isinstance(sub_ast, ast.Conditional):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
visit_node(sub_ast.true_block)
if sub_ast.false_block:
visit_node(sub_ast.false_block)
else: else:
for i, a in enumerate(sub_ast.args): for i, a in enumerate(sub_ast.args):
visit_node(a) visit_node(a)
...@@ -545,7 +551,7 @@ def move_constants_before_loop(ast_node): ...@@ -545,7 +551,7 @@ def move_constants_before_loop(ast_node):
last_block_child = prev_element last_block_child = prev_element
if isinstance(element, ast.Conditional): if isinstance(element, ast.Conditional):
critical_symbols = element.condition_expr.atoms(sp.Symbol) break
else: else:
critical_symbols = element.symbols_defined critical_symbols = element.symbols_defined
if node.undefined_symbols.intersection(critical_symbols): if node.undefined_symbols.intersection(critical_symbols):
...@@ -1000,6 +1006,9 @@ def typing_from_sympy_inspection(eqs, default_type="double"): ...@@ -1000,6 +1006,9 @@ def typing_from_sympy_inspection(eqs, default_type="double"):
elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment): elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
continue continue
else: else:
from pystencils.cpu.vectorization import vec_all, vec_any
if isinstance(eq.rhs, vec_all) or isinstance(eq.rhs, vec_any):
result[eq.lhs.name] = "bool"
# problematic case here is when rhs is a symbol: then it is impossible to decide here without # problematic case here is when rhs is a symbol: then it is impossible to decide here without
# further information what type the left hand side is - default fallback is the dict value then # further information what type the left hand side is - default fallback is the dict value then
if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol): if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
......
import pystencils as ps
import sympy as sp
import numpy as np
from pystencils.astnodes import Conditional, Block
from pystencils.cpu.vectorization import vec_all, vec_any
def test_vec_any():
data_arr = np.zeros((15, 15))
data_arr[3:9, 2:7] = 1.0
data = ps.fields("data: double[2D]", data=data_arr)
c = [
ps.Assignment(sp.Symbol("t1"), vec_any(data.center() > 0.0)),
Conditional(vec_any(data.center() > 0.0), Block([
ps.Assignment(data.center(), 2.0)
]))
]
ast = ps.create_kernel(c, target='cpu',
cpu_vectorize_info={'instruction_set': 'avx'})
kernel = ast.compile()
kernel(data=data_arr)
np.testing.assert_equal(data_arr[3:9, 0:8], 2.0)
def test_vec_all():
data_arr = np.zeros((15, 15))
data_arr[3:9, 2:7] = 1.0
data = ps.fields("data: double[2D]", data=data_arr)
c = [
Conditional(vec_all(data.center() > 0.0), Block([
ps.Assignment(data.center(), 2.0)
]))
]
ast = ps.create_kernel(c, target='cpu',
cpu_vectorize_info={'instruction_set': 'avx'})
kernel = ast.compile()
before = data_arr.copy()
kernel(data=data_arr)
np.testing.assert_equal(data_arr, before)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment