Commit 85053df1 authored by Martin Bauer's avatar Martin Bauer
Browse files

More general vectorization

parent a822ffc9
......@@ -233,11 +233,17 @@ class CBackend:
self.sympy_printer.doprint(node.rhs))
else:
lhs_type = get_type_of_expression(node.lhs)
printed_mask = ""
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
arg, data_type, aligned, nontemporal = node.lhs.args
arg, data_type, aligned, nontemporal, mask = node.lhs.args
instr = 'storeU'
if aligned:
instr = 'stream' if nontemporal else 'storeA'
if mask != True:
instr = 'maskStore' if aligned else 'maskStoreU'
printed_mask = self.sympy_printer.doprint(mask)
if self._vector_instruction_set['dataTypePrefix']['double'] == '__mm256d':
printed_mask = "_mm256_castpd_si256({})".format(printed_mask)
rhs_type = get_type_of_expression(node.rhs)
if type(rhs_type) is not VectorType:
......@@ -246,7 +252,8 @@ class CBackend:
rhs = node.rhs
return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
self.sympy_printer.doprint(rhs)) + ';'
self.sympy_printer.doprint(rhs),
printed_mask) + ';'
else:
return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
......@@ -450,7 +457,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def _print_Function(self, expr):
if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _ = expr.args
arg, data_type, aligned, _, mask = expr.args
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format("& " + self._print(arg))
elif isinstance(expr, cast_func):
......
......@@ -32,7 +32,24 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
'storeU': 'storeu[0,1]',
'storeA': 'store[0,1]',
'stream': 'stream[0,1]',
'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]',
'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]'
}
if instruction_set == 'avx512':
base_names.update({
'maskStore': 'mask_store[0, 2, 1]',
'maskStoreU': 'mask_storeu[0, 2, 1]',
'maskLoad': 'mask_load[2, 1, 0]',
'maskLoadU': 'mask_loadu[2, 1, 0]'
})
if instruction_set == 'avx':
base_names.update({
'maskStore': 'maskstore[0, 2, 1]',
'maskStoreU': 'maskstore[0, 2, 1]',
'maskLoad': 'maskload[0, 1]',
'maskLoadU': 'maskloadu[0, 1]'
})
for comparison_op, constant in comparisons.items():
base_names[comparison_op] = 'cmp[0, 1, %s]' % (constant,)
......
......@@ -123,7 +123,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
nontemporal = False
if hasattr(indexed, 'field'):
nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal)
substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True)
if not successful:
warnings.warn("Could not vectorize loop because of non-consecutive memory access")
continue
......@@ -136,6 +136,30 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
mask_conditionals(loop_node)
def mask_conditionals(loop_body):
def visit_node(node, mask):
if isinstance(node, ast.Conditional):
true_mask = sp.And(node.condition_expr, mask)
visit_node(node.true_block, true_mask)
if node.false_block:
false_mask = sp.And(sp.Not(node.condition_expr), mask)
visit_node(node, false_mask)
node.condition_expr = vec_any(node.condition_expr)
elif isinstance(node, ast.SympyAssignment):
if mask is not True:
s = {ma: vector_memory_access(ma.args[0], ma.args[1], ma.args[2], ma.args[3], sp.And(mask, ma.args[4]))
for ma in node.atoms(vector_memory_access)}
node.subs(s)
else:
for arg in node.args:
visit_node(arg, mask)
visit_node(loop_body, mask=True)
def insert_vector_casts(ast_node):
"""Inserts necessary casts from scalar values to vector values."""
......@@ -143,8 +167,10 @@ def insert_vector_casts(ast_node):
handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all)
def visit_expr(expr):
if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access):
if isinstance(expr, vector_memory_access):
return vector_memory_access(expr.args[0], expr.args[1], expr.args[2], expr.args[3],
visit_expr(expr.args[4]))
elif isinstance(expr, cast_func):
return expr
elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
new_args = [visit_expr(a) for a in expr.args]
......@@ -199,10 +225,12 @@ def insert_vector_casts(ast_node):
new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
substitution_dict[assignment.lhs] = new_lhs
assignment.lhs = new_lhs
elif isinstance(assignment.lhs.func, cast_func):
lhs_type = assignment.lhs.args[1]
if type(lhs_type) is VectorType and type(rhs_type) is not VectorType:
assignment.rhs = cast_func(assignment.rhs, lhs_type)
elif isinstance(assignment.lhs, vector_memory_access):
assignment.lhs = visit_expr(assignment.lhs)
#elif isinstance(assignment.lhs, cast_func): # TODO check if necessary
# lhs_type = assignment.lhs.args[1]
# if type(lhs_type) is VectorType and type(rhs_type) is not VectorType:
# assignment.rhs = cast_func(assignment.rhs, lhs_type)
elif isinstance(arg, ast.Conditional):
arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict,
skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
......
......@@ -190,7 +190,8 @@ class boolean_cast_func(cast_func, Boolean):
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
nargs = (4,)
# Arguments are: read/write expression, type, aligned, nontemporal, mask (or none)
nargs = (5,)
# noinspection PyPep8Naming
......
......@@ -891,10 +891,10 @@ class KernelConstraintsCheck:
if isinstance(lhs, AbstractField.AbstractAccess):
fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets)
if len(self._field_writes[fai]) > 1:
raise ValueError(
"Field {} is written at two different locations".format(
lhs.field.name))
#if len(self._field_writes[fai]) > 1:
# raise ValueError(
# "Field {} is written at two different locations".format(
# lhs.field.name))
elif isinstance(lhs, sp.Symbol):
if self.scopes.is_defined_locally(lhs):
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment