Commit a822ffc9 authored by Martin Bauer's avatar Martin Bauer
Browse files

Better boolean support in sympy printer

- bugfix: loop counter of vectorized loop now correctly stored as SIMD vector with entries i, i+1, i+2, ...
- basis for in-kernel boundary handling
parent 609c4b08
...@@ -293,6 +293,10 @@ class Block(Node): ...@@ -293,6 +293,10 @@ class Block(Node):
for a in self.args: for a in self.args:
a.subs(subs_dict) a.subs(subs_dict)
def fast_subs(self, subs_dict, skip=None):
self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
return self
def insert_front(self, node): def insert_front(self, node):
if isinstance(node, collections.abc.Iterable): if isinstance(node, collections.abc.Iterable):
node = list(node) node = list(node)
...@@ -408,6 +412,16 @@ class LoopOverCoordinate(Node): ...@@ -408,6 +412,16 @@ class LoopOverCoordinate(Node):
if hasattr(self.step, "subs"): if hasattr(self.step, "subs"):
self.step = self.step.subs(subs_dict) self.step = self.step.subs(subs_dict)
def fast_subs(self, subs_dict, skip=None):
self.body = fast_subs(self.body, subs_dict, skip)
if isinstance(self.start, sp.Basic):
self.start = fast_subs(self.start, subs_dict, skip)
if isinstance(self.stop, sp.Basic):
self.stop = fast_subs(self.stop, subs_dict, skip)
if isinstance(self.step, sp.Basic):
self.step = fast_subs(self.step, subs_dict, skip)
return self
@property @property
def args(self): def args(self):
result = [self.body] result = [self.body]
...@@ -538,7 +552,7 @@ class SympyAssignment(Node): ...@@ -538,7 +552,7 @@ class SympyAssignment(Node):
@property @property
def args(self): def args(self):
return [self._lhs_symbol, self.rhs] return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
@property @property
def symbols_defined(self): def symbols_defined(self):
...@@ -603,7 +617,7 @@ class ResolvedFieldAccess(sp.Indexed): ...@@ -603,7 +617,7 @@ class ResolvedFieldAccess(sp.Indexed):
self.args[1].subs(old, new), self.args[1].subs(old, new),
self.field, self.offsets, self.idx_coordinate_values) self.field, self.offsets, self.idx_coordinate_values)
def fast_subs(self, substitutions): def fast_subs(self, substitutions, skip=None):
if self in substitutions: if self in substitutions:
return substitutions[self] return substitutions[self]
return ResolvedFieldAccess(self.args[0].subs(substitutions), return ResolvedFieldAccess(self.args[0].subs(substitutions),
......
...@@ -5,7 +5,6 @@ import numpy as np ...@@ -5,7 +5,6 @@ import numpy as np
import sympy as sp import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import ( from pystencils.data_types import (
...@@ -457,7 +456,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -457,7 +456,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
arg, data_type = expr.args arg, data_type = expr.args
if type(data_type) is VectorType: if type(data_type) is VectorType:
return self.instruction_set['makeVec'].format(self._print(arg)) if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
printed_args = [self._print(a) for a in arg]
instruction = 'makeVecBool' if is_boolean else 'makeVec'
return self.instruction_set[instruction].format(*printed_args)
else:
is_boolean = get_type_of_expression(arg) == create_type("bool")
instruction = 'makeVecConstBool' if is_boolean else 'makeVecConst'
return self.instruction_set[instruction].format(self._print(arg))
elif expr.func == fast_division: elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr) result = self._scalarFallback('_print_Function', expr)
if not result: if not result:
...@@ -542,12 +549,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -542,12 +549,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if result: if result:
return result return result
one = self.instruction_set['makeVec'].format(1.0) one = self.instruction_set['makeVecConst'].format(1.0)
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
elif expr.exp == -1: elif expr.exp == -1:
one = self.instruction_set['makeVec'].format(1.0) one = self.instruction_set['makeVecConst'].format(1.0)
return self.instruction_set['/'].format(one, self._print(expr.base)) return self.instruction_set['/'].format(one, self._print(expr.base))
elif expr.exp == 0.5: elif expr.exp == 0.5:
return self.instruction_set['sqrt'].format(self._print(expr.base)) return self.instruction_set['sqrt'].format(self._print(expr.base))
......
...@@ -21,7 +21,10 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -21,7 +21,10 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
'sqrt': 'sqrt[0]', 'sqrt': 'sqrt[0]',
'makeVecConst': 'set[]',
'makeVec': 'set[]', 'makeVec': 'set[]',
'makeVecBool': 'set[]',
'makeVecConstBool': 'set[]',
'makeZero': 'setzero[]', 'makeZero': 'setzero[]',
'loadU': 'loadu[0]', 'loadU': 'loadu[0]',
...@@ -68,8 +71,17 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -68,8 +71,17 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
function_shortcut = function_shortcut.strip() function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')] name = function_shortcut[:function_shortcut.index('[')]
if intrinsic_id == 'makeVec': if intrinsic_id == 'makeVecConst':
arg_string = "({})".format(",".join(["{0}"] * result['width'])) arg_string = "({})".format(",".join(["{0}"] * result['width']))
elif intrinsic_id == 'makeVec':
params = ["{" + str(i) + "}" for i in reversed(range(result['width']))]
arg_string = "({})".format(",".join(params))
elif intrinsic_id == 'makeVecBool':
params = ["(({{{i}}} ? -1.0 : 0.0)".format(i=i) for i in reversed(range(result['width']))]
arg_string = "({})".format(",".join(params))
elif intrinsic_id == 'makeVecConstBool':
params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])]
arg_string = "({})".format(",".join(params))
else: else:
args = function_shortcut[function_shortcut.index('[') + 1: -1] args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "(" arg_string = "("
...@@ -111,6 +123,11 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -111,6 +123,11 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,) result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
result['bool'] = "__mmask%d" % (size,) result['bool'] = "__mmask%d" % (size,)
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
result['makeVecBool'] = "__mmask8(({}) )".format(params)
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = "__mmask8(({}) )".format(params)
if instruction_set == 'avx' and data_type == 'float': if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = "_mm256_rsqrt_ps({0})" result['rsqrt'] = "_mm256_rsqrt_ps({0})"
......
...@@ -18,12 +18,12 @@ from pystencils.transformations import ( ...@@ -18,12 +18,12 @@ from pystencils.transformations import (
# noinspection PyPep8Naming # noinspection PyPep8Naming
class vec_any(sp.Function): class vec_any(sp.Function):
nargs = (1, ) nargs = (1,)
# noinspection PyPep8Naming # noinspection PyPep8Naming
class vec_all(sp.Function): class vec_all(sp.Function):
nargs = (1, ) nargs = (1,)
def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
...@@ -53,7 +53,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', ...@@ -53,7 +53,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
""" """
if instruction_set is None: if instruction_set is None:
return return
all_fields = kernel_ast.fields_accessed all_fields = kernel_ast.fields_accessed
if nontemporal is None or nontemporal is False: if nontemporal is None or nontemporal is False:
nontemporal = {} nontemporal = {}
...@@ -101,7 +101,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -101,7 +101,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
if len(loop_nodes) == 0: if len(loop_nodes) == 0:
continue continue
loop_node = loop_nodes[0] loop_node = loop_nodes[0]
# Find all array accesses (indexed) that depend on the loop counter as offset # Find all array accesses (indexed) that depend on the loop counter as offset
loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over) loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
substitutions = {} substitutions = {}
...@@ -130,6 +130,11 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -130,6 +130,11 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
loop_node.step = vector_width loop_node.step = vector_width
loop_node.subs(substitutions) loop_node.subs(substitutions)
vector_loop_counter = cast_func(tuple(loop_counter_symbol + i for i in range(vector_width)),
VectorType(loop_counter_symbol.dtype, vector_width))
fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
def insert_vector_casts(ast_node): def insert_vector_casts(ast_node):
......
...@@ -160,7 +160,7 @@ def fast_subs(expression: T, substitutions: Dict, ...@@ -160,7 +160,7 @@ def fast_subs(expression: T, substitutions: Dict,
if skip and skip(expr): if skip and skip(expr):
return expr return expr
if hasattr(expr, "fast_subs"): if hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions) return expr.fast_subs(substitutions, skip)
if expr in substitutions: if expr in substitutions:
return substitutions[expr] return substitutions[expr]
if not hasattr(expr, 'args'): if not hasattr(expr, 'args'):
......
...@@ -481,6 +481,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -481,6 +481,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if isinstance(field.dtype, StructType): if isinstance(field.dtype, StructType):
assert field.index_dimensions == 1 assert field.index_dimensions == 1
accessed_field_name = field_access.index[0] accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
assert isinstance(accessed_field_name, str) assert isinstance(accessed_field_name, str)
coordinates[e] = field.dtype.get_element_offset(accessed_field_name) coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
else: else:
...@@ -504,7 +506,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -504,7 +506,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
field_access.offsets, field_access.index) field_access.offsets, field_access.index)
if isinstance(get_base_type(field_access.field.dtype), StructType): if isinstance(get_base_type(field_access.field.dtype), StructType):
new_type = field_access.field.dtype.get_element_type(field_access.index[0]) accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
new_type = field_access.field.dtype.get_element_type(accessed_field_name)
result = reinterpret_cast_func(result, new_type) result = reinterpret_cast_func(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment) return visit_sympy_expr(result, enclosing_block, sympy_assignment)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment