Commit 57a3c27e authored by Martin Bauer's avatar Martin Bauer
Browse files

Generalized vectorization

- vectorization for loops with ranges that are not a multiple of vector width
- vectorization for variable sized loops if special transformation
  replace_inner_stride_with_one is run
parent d0a19b3d
......@@ -29,10 +29,10 @@ class Node:
"""Symbols which are used but are not defined inside this node."""
raise NotImplementedError()
def subs(self, *args, **kwargs) -> None:
def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args:
a.subs(*args, **kwargs)
def func(self):
......@@ -78,11 +78,11 @@ class Conditional(Node):
self.true_block = handle_child(true_block)
self.false_block = handle_child(false_block)
def subs(self, *args, **kwargs):
self.true_block.subs(*args, **kwargs)
def subs(self, subs_dict):
if self.false_block:
self.false_block.subs(*args, **kwargs)
self.condition_expr = self.condition_expr.subs(*args, **kwargs)
self.condition_expr = self.condition_expr.subs(subs_dict)
def args(self):
......@@ -238,6 +238,18 @@ class Block(Node):
def args(self):
return self._nodes
def subs(self, subs_dict) -> None:
new_args = []
for a in self.args:
if isinstance(a, SympyAssignment) and a.is_declaration and a.rhs in subs_dict.keys():
subs_dict[a.lhs] = subs_dict[a.rhs]
self._nodes = new_args
for a in self.args:
def insert_front(self, node):
node.parent = self
self._nodes.insert(0, node)
......@@ -334,14 +346,14 @@ class LoopOverCoordinate(Node):
result.prefix_lines = [l for l in self.prefix_lines]
return result
def subs(self, *args, **kwargs):
self.body.subs(*args, **kwargs)
def subs(self, subs_dict):
if hasattr(self.start, "subs"):
self.start = self.start.subs(*args, **kwargs)
self.start = self.start.subs(subs_dict)
if hasattr(self.stop, "subs"):
self.stop = self.stop.subs(*args, **kwargs)
self.stop = self.stop.subs(subs_dict)
if hasattr(self.step, "subs"):
self.step = self.step.subs(*args, **kwargs)
self.step = self.step.subs(subs_dict)
def args(self):
......@@ -443,9 +455,9 @@ class SympyAssignment(Node):
if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, sp.Indexed) or is_cast:
self._is_declaration = False
def subs(self, *args, **kwargs):
self.lhs = fast_subs(self.lhs, *args, **kwargs)
self.rhs = fast_subs(self.rhs, *args, **kwargs)
def subs(self, subs_dict):
self.lhs = fast_subs(self.lhs, subs_dict)
self.rhs = fast_subs(self.rhs, subs_dict)
def args(self):
......@@ -8,7 +8,8 @@ try:
except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from pystencils.bitoperations import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, bitwise_or
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_floor
from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func
from pystencils.backends.simd_instruction_sets import selected_instruction_set
......@@ -104,7 +105,6 @@ class CBackend:
def __init__(self, constants_as_floats=False, sympy_printer=None,
signature_only=False, vector_instruction_set=None):
if sympy_printer is None:
self.sympy_printer = CustomSympyPrinter(constants_as_floats)
if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats)
......@@ -239,9 +239,14 @@ class CustomSympyPrinter(CCodePrinter):
bitwise_or: '|',
bitwise_and: '&',
if hasattr(expr, 'to_c'):
return expr.to_c(self._print)
if expr.func == cast_func:
arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
elif expr.func == modulo_floor:
assert all(get_type_of_expression(e).is_int() for e in expr.args)
return "({dtype})({0} / {1}) * {1}".format(*expr.args, dtype=get_type_of_expression(expr.args[0]))
elif expr.func in function_map:
return "(%s %s %s)" % (self._print(expr.args[0]), function_map[expr.func], self._print(expr.args[1]))
......@@ -370,7 +375,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
a = a or [S.One]
# a = a or [cast_func(S.One, VectorType(create_type_from_string("double"), expr_type.width))]
a_str = [self._print(x) for x in a]
b_str = [self._print(x) for x in b]
import sympy as sp
bitwise_xor = sp.Function("bitwise_xor")
bit_shift_right = sp.Function("bit_shift_right")
bit_shift_left = sp.Function("bit_shift_left")
bitwise_and = sp.Function("bitwise_and")
bitwise_or = sp.Function("bitwise_or")
import sympy as sp
from pystencils import Field, TypedSymbol
from pystencils.bitoperations import bitwise_and
from pystencils.integer_functions import bitwise_and
from pystencils.boundaries.boundaryhandling import FlagInterface
from pystencils.data_types import create_type
......@@ -109,7 +109,7 @@ def make_python_function(kernel_function_node, argument_dict={}):
return lambda: func(*args)
def set_compiler_config(config):
def set_config(config):
Override the configuration provided in config file
......@@ -206,7 +206,7 @@ def read_config():
config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
if config['cache']['clear_cache_on_start']:
shutil.rmtree(config['cache']['object_cache'], ignore_errors=True)
create_folder(config['cache']['object_cache'], False)
create_folder(config['cache']['shared_library'], True)
......@@ -238,6 +238,12 @@ def hash_to_function_name(h):
return res.replace('-', 'm')
def clear_cache():
cache_config = get_cache_config()
shutil.rmtree(cache_config['object_cache'], ignore_errors=True)
create_folder(cache_config['object_cache'], False)
def compile_object_cache_to_shared_library():
compiler_config = get_compiler_config()
cache_config = get_cache_config()
import sympy as sp
import warnings
from pystencils.integer_functions import modulo_floor
from pystencils.sympyextensions import fast_subs
from pystencils.transformations import filtered_tree_iteration
from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, cast_func, collate_types, PointerType
import pystencils.astnodes as ast
from pystencils.transformations import cut_loop
def vectorize(ast_node, vector_width=4):
......@@ -12,24 +13,17 @@ def vectorize(ast_node, vector_width=4):
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4):
Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if
* loop bounds are constant
* loop range is a multiple of vector width
"""Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
inner_loops = [n for n in ast_node.atoms(ast.LoopOverCoordinate) if n.is_innermost_loop]
for loop_node in inner_loops:
loop_range = loop_node.stop - loop_node.start
# Check restrictions
if isinstance(loop_range, sp.Expr) and not loop_range.is_number:
warnings.warn("Currently only loops with fixed ranges can be vectorized - skipping loop")
if loop_range % vector_width != 0 or loop_node.step != 1:
warnings.warn("Currently only loops with loop bounds that are multiples "
"of vectorization width can be vectorized - skipping loop")
# cut off loop tail, that is not a multiple of four
cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
loop_nodes = cut_loop(loop_node, [cutting_point])
assert len(loop_nodes) in (1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width
loop_node = loop_nodes[0]
# Find all array accesses (indexed) that depend on the loop counter as offset
loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
......@@ -94,9 +88,13 @@ def insert_vector_casts(ast_node):
return expr
substitution_dict = {}
for assignment in filtered_tree_iteration(ast_node, ast.SympyAssignment):
subs_expr = fast_subs(assignment.rhs, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
def visit_node(node, substitution_dict):
substitution_dict = substitution_dict.copy()
for arg in node.args:
if isinstance(arg, ast.SympyAssignment):
assignment = arg
subs_expr = fast_subs(assignment.rhs, substitution_dict,
skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
assignment.rhs = visit_expr(subs_expr)
rhs_type = get_type_of_expression(assignment.rhs)
if isinstance(assignment.lhs, TypedSymbol):
......@@ -110,3 +108,7 @@ def insert_vector_casts(ast_node):
lhs_type = assignment.lhs.args[1]
if type(lhs_type) is VectorType and type(rhs_type) is not VectorType:
assignment.rhs = cast_func(assignment.rhs, lhs_type)
visit_node(arg, substitution_dict)
visit_node(ast_node, {})
......@@ -65,10 +65,13 @@ class TypedSymbol(sp.Symbol):
def create_type(specification):
Create a subclass of Type according to a string or an object of subclass Type
:param specification: Type object, or a string
:return: Type object, or a new Type object parsed from the string
"""Creates a subclass of Type according to a string or an object of subclass Type.
specification: Type object, or a string
Type object, or a new Type object parsed from the string
if isinstance(specification, Type):
return specification
......@@ -82,10 +85,13 @@ def create_type(specification):
def create_composite_type_from_string(specification):
Creates a new Type object from a c-like string specification
:param specification: Specification string
:return: Type object
"""Creates a new Type object from a c-like string specification.
specification: Specification string
Type object
specification = specification.lower().split()
parts = []
......@@ -432,6 +438,9 @@ class VectorType(Type):
def __hash__(self):
return hash((self.base_type, self.width))
def __getnewargs__(self):
return self._base_type, self.width
class PointerType(Type):
def __init__(self, base_type, const=False, restrict=True):
import sympy as sp
from pystencils.data_types import get_type_of_expression, collate_types
from pystencils.sympyextensions import is_integer_sequence
bitwise_xor = sp.Function("bitwise_xor")
bit_shift_right = sp.Function("bit_shift_right")
bit_shift_left = sp.Function("bit_shift_left")
bitwise_and = sp.Function("bitwise_and")
bitwise_or = sp.Function("bitwise_or")
# noinspection PyPep8Naming
class modulo_floor(sp.Function):
"""Returns the next smaller integer divisible by given divisor.
>>> modulo_floor(9, 4)
>>> modulo_floor(11, 4)
>>> modulo_floor(12, 4)
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> modulo_floor(a, b).to_c(str)
'(int64_t)((a) / (b)) * (b)'
nargs = 2
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return (int(integer) // int(divisor)) * divisor
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]),
print_func(self.args[1]), dtype=dtype)
......@@ -595,8 +595,16 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
def cut_loop(loop_node, cutting_points):
"""Cuts loop at given cutting points, that means one loop is transformed into len(cuttingPoints)+1 new loops
that range from old_begin to cutting_points[1], ..., cutting_points[-1] to old_end"""
"""Cuts loop at given cutting points.
One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place
list of new loop nodes
if loop_node.step != 1:
raise NotImplementedError("Can only split loops that have a step of 1")
new_loops = []
......@@ -607,12 +615,15 @@ def cut_loop(loop_node, cutting_points):
new_body = deepcopy(loop_node.body)
new_body.subs({loop_node.loop_counter_symbol: new_start})
elif new_end - new_start == 0:
new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
new_start, new_end, loop_node.step)
new_start = new_end
loop_node.parent.replace(loop_node, new_loops)
return new_loops
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool=False) -> None:
......@@ -973,3 +984,25 @@ def get_loop_hierarchy(ast_node):
if node:
return reversed(result)
def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
"""Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering).
Variable sized kernels can handle arbitrary field sizes and field shapes. However, the kernel is most efficient
if the innermost loop accesses the fields with stride 1. The inner loop can also only be vectorized if the inner
stride is 1. This transformation hard codes this inner stride to one to enable e.g. vectorization.
Warning: the assumption is not checked at runtime!
inner_loop_counters = {l.coordinate_to_loop_over
for l in ast_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop}
if len(inner_loop_counters) != 1:
raise ValueError("Inner loops iterate over different coordinates")
inner_loop_counter = inner_loop_counters.pop()
stride_params = [p for p in ast_node.parameters if p.is_field_stride_argument]
for stride_param in stride_params:
stride_symbol = stride_param.symbol
subs_dict = {IndexedBase(stride_symbol, shape=(1,))[inner_loop_counter]: 1}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment