Commit a822ffc9 authored by Martin Bauer's avatar Martin Bauer
Browse files

Better boolean support in sympy printer

- bugfix: loop counter of vectorized loop now correctly stored as SIMD vector with entries i, i+1, i+2, ...
- basis for in-kernel boundary handling
parent 609c4b08
......@@ -293,6 +293,10 @@ class Block(Node):
for a in self.args:
a.subs(subs_dict)
def fast_subs(self, subs_dict, skip=None):
self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
return self
def insert_front(self, node):
if isinstance(node, collections.abc.Iterable):
node = list(node)
......@@ -408,6 +412,16 @@ class LoopOverCoordinate(Node):
if hasattr(self.step, "subs"):
self.step = self.step.subs(subs_dict)
def fast_subs(self, subs_dict, skip=None):
self.body = fast_subs(self.body, subs_dict, skip)
if isinstance(self.start, sp.Basic):
self.start = fast_subs(self.start, subs_dict, skip)
if isinstance(self.stop, sp.Basic):
self.stop = fast_subs(self.stop, subs_dict, skip)
if isinstance(self.step, sp.Basic):
self.step = fast_subs(self.step, subs_dict, skip)
return self
@property
def args(self):
result = [self.body]
......@@ -538,7 +552,7 @@ class SympyAssignment(Node):
@property
def args(self):
return [self._lhs_symbol, self.rhs]
return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
@property
def symbols_defined(self):
......@@ -603,7 +617,7 @@ class ResolvedFieldAccess(sp.Indexed):
self.args[1].subs(old, new),
self.field, self.offsets, self.idx_coordinate_values)
def fast_subs(self, substitutions):
def fast_subs(self, substitutions, skip=None):
if self in substitutions:
return substitutions[self]
return ResolvedFieldAccess(self.args[0].subs(substitutions),
......
......@@ -5,7 +5,6 @@ import numpy as np
import sympy as sp
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (
......@@ -457,7 +456,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif isinstance(expr, cast_func):
arg, data_type = expr.args
if type(data_type) is VectorType:
return self.instruction_set['makeVec'].format(self._print(arg))
if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
printed_args = [self._print(a) for a in arg]
instruction = 'makeVecBool' if is_boolean else 'makeVec'
return self.instruction_set[instruction].format(*printed_args)
else:
is_boolean = get_type_of_expression(arg) == create_type("bool")
instruction = 'makeVecConstBool' if is_boolean else 'makeVecConst'
return self.instruction_set[instruction].format(self._print(arg))
elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr)
if not result:
......@@ -542,12 +549,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if result:
return result
one = self.instruction_set['makeVec'].format(1.0)
one = self.instruction_set['makeVecConst'].format(1.0)
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
elif expr.exp == -1:
one = self.instruction_set['makeVec'].format(1.0)
one = self.instruction_set['makeVecConst'].format(1.0)
return self.instruction_set['/'].format(one, self._print(expr.base))
elif expr.exp == 0.5:
return self.instruction_set['sqrt'].format(self._print(expr.base))
......
......@@ -21,7 +21,10 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
'sqrt': 'sqrt[0]',
'makeVecConst': 'set[]',
'makeVec': 'set[]',
'makeVecBool': 'set[]',
'makeVecConstBool': 'set[]',
'makeZero': 'setzero[]',
'loadU': 'loadu[0]',
......@@ -68,8 +71,17 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
if intrinsic_id == 'makeVec':
if intrinsic_id == 'makeVecConst':
arg_string = "({})".format(",".join(["{0}"] * result['width']))
elif intrinsic_id == 'makeVec':
params = ["{" + str(i) + "}" for i in reversed(range(result['width']))]
arg_string = "({})".format(",".join(params))
elif intrinsic_id == 'makeVecBool':
params = ["(({{{i}}} ? -1.0 : 0.0)".format(i=i) for i in reversed(range(result['width']))]
arg_string = "({})".format(",".join(params))
elif intrinsic_id == 'makeVecConstBool':
params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])]
arg_string = "({})".format(",".join(params))
else:
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
......@@ -111,6 +123,11 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
result['bool'] = "__mmask%d" % (size,)
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
result['makeVecBool'] = "__mmask8(({}) )".format(params)
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = "__mmask8(({}) )".format(params)
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = "_mm256_rsqrt_ps({0})"
......
......@@ -18,12 +18,12 @@ from pystencils.transformations import (
# noinspection PyPep8Naming
class vec_any(sp.Function):
nargs = (1, )
nargs = (1,)
# noinspection PyPep8Naming
class vec_all(sp.Function):
nargs = (1, )
nargs = (1,)
def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
......@@ -53,7 +53,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
"""
if instruction_set is None:
return
all_fields = kernel_ast.fields_accessed
if nontemporal is None or nontemporal is False:
nontemporal = {}
......@@ -101,7 +101,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
if len(loop_nodes) == 0:
continue
loop_node = loop_nodes[0]
# Find all array accesses (indexed) that depend on the loop counter as offset
loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
substitutions = {}
......@@ -130,6 +130,11 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
loop_node.step = vector_width
loop_node.subs(substitutions)
vector_loop_counter = cast_func(tuple(loop_counter_symbol + i for i in range(vector_width)),
VectorType(loop_counter_symbol.dtype, vector_width))
fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
def insert_vector_casts(ast_node):
......
......@@ -160,7 +160,7 @@ def fast_subs(expression: T, substitutions: Dict,
if skip and skip(expr):
return expr
if hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions)
return expr.fast_subs(substitutions, skip)
if expr in substitutions:
return substitutions[expr]
if not hasattr(expr, 'args'):
......
......@@ -481,6 +481,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if isinstance(field.dtype, StructType):
assert field.index_dimensions == 1
accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
assert isinstance(accessed_field_name, str)
coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
else:
......@@ -504,7 +506,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
field_access.offsets, field_access.index)
if isinstance(get_base_type(field_access.field.dtype), StructType):
new_type = field_access.field.dtype.get_element_type(field_access.index[0])
accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
new_type = field_access.field.dtype.get_element_type(accessed_field_name)
result = reinterpret_cast_func(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment