Commit 57a3c27e authored by Martin Bauer's avatar Martin Bauer
Browse files

Generalized vectorization

- vectorization for loops with ranges that are not a multiple of vector width
- vectorization for variable sized loops if special transformation
  replace_inner_stride_with_one is run
parent d0a19b3d
...@@ -29,10 +29,10 @@ class Node: ...@@ -29,10 +29,10 @@ class Node:
"""Symbols which are used but are not defined inside this node.""" """Symbols which are used but are not defined inside this node."""
raise NotImplementedError() raise NotImplementedError()
def subs(self, *args, **kwargs) -> None: def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace.""" """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args: for a in self.args:
a.subs(*args, **kwargs) a.subs(subs_dict)
@property @property
def func(self): def func(self):
...@@ -78,11 +78,11 @@ class Conditional(Node): ...@@ -78,11 +78,11 @@ class Conditional(Node):
self.true_block = handle_child(true_block) self.true_block = handle_child(true_block)
self.false_block = handle_child(false_block) self.false_block = handle_child(false_block)
def subs(self, *args, **kwargs): def subs(self, subs_dict):
self.true_block.subs(*args, **kwargs) self.true_block.subs(subs_dict)
if self.false_block: if self.false_block:
self.false_block.subs(*args, **kwargs) self.false_block.subs(subs_dict)
self.condition_expr = self.condition_expr.subs(*args, **kwargs) self.condition_expr = self.condition_expr.subs(subs_dict)
@property @property
def args(self): def args(self):
...@@ -238,6 +238,18 @@ class Block(Node): ...@@ -238,6 +238,18 @@ class Block(Node):
def args(self): def args(self):
return self._nodes return self._nodes
def subs(self, subs_dict) -> None:
new_args = []
for a in self.args:
if isinstance(a, SympyAssignment) and a.is_declaration and a.rhs in subs_dict.keys():
subs_dict[a.lhs] = subs_dict[a.rhs]
else:
new_args.append(a)
self._nodes = new_args
for a in self.args:
a.subs(subs_dict)
def insert_front(self, node): def insert_front(self, node):
node.parent = self node.parent = self
self._nodes.insert(0, node) self._nodes.insert(0, node)
...@@ -334,14 +346,14 @@ class LoopOverCoordinate(Node): ...@@ -334,14 +346,14 @@ class LoopOverCoordinate(Node):
result.prefix_lines = [l for l in self.prefix_lines] result.prefix_lines = [l for l in self.prefix_lines]
return result return result
def subs(self, *args, **kwargs): def subs(self, subs_dict):
self.body.subs(*args, **kwargs) self.body.subs(subs_dict)
if hasattr(self.start, "subs"): if hasattr(self.start, "subs"):
self.start = self.start.subs(*args, **kwargs) self.start = self.start.subs(subs_dict)
if hasattr(self.stop, "subs"): if hasattr(self.stop, "subs"):
self.stop = self.stop.subs(*args, **kwargs) self.stop = self.stop.subs(subs_dict)
if hasattr(self.step, "subs"): if hasattr(self.step, "subs"):
self.step = self.step.subs(*args, **kwargs) self.step = self.step.subs(subs_dict)
@property @property
def args(self): def args(self):
...@@ -443,9 +455,9 @@ class SympyAssignment(Node): ...@@ -443,9 +455,9 @@ class SympyAssignment(Node):
if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, sp.Indexed) or is_cast: if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, sp.Indexed) or is_cast:
self._is_declaration = False self._is_declaration = False
def subs(self, *args, **kwargs): def subs(self, subs_dict):
self.lhs = fast_subs(self.lhs, *args, **kwargs) self.lhs = fast_subs(self.lhs, subs_dict)
self.rhs = fast_subs(self.rhs, *args, **kwargs) self.rhs = fast_subs(self.rhs, subs_dict)
@property @property
def args(self): def args(self):
......
...@@ -8,7 +8,8 @@ try: ...@@ -8,7 +8,8 @@ try:
except ImportError: except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from pystencils.bitoperations import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, bitwise_or from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_floor
from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func
from pystencils.backends.simd_instruction_sets import selected_instruction_set from pystencils.backends.simd_instruction_sets import selected_instruction_set
...@@ -104,7 +105,6 @@ class CBackend: ...@@ -104,7 +105,6 @@ class CBackend:
def __init__(self, constants_as_floats=False, sympy_printer=None, def __init__(self, constants_as_floats=False, sympy_printer=None,
signature_only=False, vector_instruction_set=None): signature_only=False, vector_instruction_set=None):
if sympy_printer is None: if sympy_printer is None:
self.sympy_printer = CustomSympyPrinter(constants_as_floats)
if vector_instruction_set is not None: if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats) self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats)
else: else:
...@@ -239,9 +239,14 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -239,9 +239,14 @@ class CustomSympyPrinter(CCodePrinter):
bitwise_or: '|', bitwise_or: '|',
bitwise_and: '&', bitwise_and: '&',
} }
if hasattr(expr, 'to_c'):
return expr.to_c(self._print)
if expr.func == cast_func: if expr.func == cast_func:
arg, data_type = expr.args arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg)) return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
elif expr.func == modulo_floor:
assert all(get_type_of_expression(e).is_int() for e in expr.args)
return "({dtype})({0} / {1}) * {1}".format(*expr.args, dtype=get_type_of_expression(expr.args[0]))
elif expr.func in function_map: elif expr.func in function_map:
return "(%s %s %s)" % (self._print(expr.args[0]), function_map[expr.func], self._print(expr.args[1])) return "(%s %s %s)" % (self._print(expr.args[0]), function_map[expr.func], self._print(expr.args[1]))
else: else:
...@@ -370,7 +375,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -370,7 +375,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
a.append(item) a.append(item)
a = a or [S.One] a = a or [S.One]
# a = a or [cast_func(S.One, VectorType(create_type_from_string("double"), expr_type.width))]
a_str = [self._print(x) for x in a] a_str = [self._print(x) for x in a]
b_str = [self._print(x) for x in b] b_str = [self._print(x) for x in b]
......
import sympy as sp
bitwise_xor = sp.Function("bitwise_xor")
bit_shift_right = sp.Function("bit_shift_right")
bit_shift_left = sp.Function("bit_shift_left")
bitwise_and = sp.Function("bitwise_and")
bitwise_or = sp.Function("bitwise_or")
import sympy as sp import sympy as sp
from pystencils import Field, TypedSymbol from pystencils import Field, TypedSymbol
from pystencils.bitoperations import bitwise_and from pystencils.integer_functions import bitwise_and
from pystencils.boundaries.boundaryhandling import FlagInterface from pystencils.boundaries.boundaryhandling import FlagInterface
from pystencils.data_types import create_type from pystencils.data_types import create_type
......
...@@ -109,7 +109,7 @@ def make_python_function(kernel_function_node, argument_dict={}): ...@@ -109,7 +109,7 @@ def make_python_function(kernel_function_node, argument_dict={}):
return lambda: func(*args) return lambda: func(*args)
def set_compiler_config(config): def set_config(config):
""" """
Override the configuration provided in config file Override the configuration provided in config file
...@@ -206,7 +206,7 @@ def read_config(): ...@@ -206,7 +206,7 @@ def read_config():
config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid()) config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
if config['cache']['clear_cache_on_start']: if config['cache']['clear_cache_on_start']:
shutil.rmtree(config['cache']['object_cache'], ignore_errors=True) clear_cache()
create_folder(config['cache']['object_cache'], False) create_folder(config['cache']['object_cache'], False)
create_folder(config['cache']['shared_library'], True) create_folder(config['cache']['shared_library'], True)
...@@ -238,6 +238,12 @@ def hash_to_function_name(h): ...@@ -238,6 +238,12 @@ def hash_to_function_name(h):
return res.replace('-', 'm') return res.replace('-', 'm')
def clear_cache():
cache_config = get_cache_config()
shutil.rmtree(cache_config['object_cache'], ignore_errors=True)
create_folder(cache_config['object_cache'], False)
def compile_object_cache_to_shared_library(): def compile_object_cache_to_shared_library():
compiler_config = get_compiler_config() compiler_config = get_compiler_config()
cache_config = get_cache_config() cache_config = get_cache_config()
......
import sympy as sp import sympy as sp
import warnings import warnings
from pystencils.integer_functions import modulo_floor
from pystencils.sympyextensions import fast_subs from pystencils.sympyextensions import fast_subs
from pystencils.transformations import filtered_tree_iteration
from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, cast_func, collate_types, PointerType from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, cast_func, collate_types, PointerType
import pystencils.astnodes as ast import pystencils.astnodes as ast
from pystencils.transformations import cut_loop
def vectorize(ast_node, vector_width=4): def vectorize(ast_node, vector_width=4):
...@@ -12,25 +13,18 @@ def vectorize(ast_node, vector_width=4): ...@@ -12,25 +13,18 @@ def vectorize(ast_node, vector_width=4):
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4): def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4):
""" """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if
* loop bounds are constant
* loop range is a multiple of vector width
"""
inner_loops = [n for n in ast_node.atoms(ast.LoopOverCoordinate) if n.is_innermost_loop] inner_loops = [n for n in ast_node.atoms(ast.LoopOverCoordinate) if n.is_innermost_loop]
for loop_node in inner_loops: for loop_node in inner_loops:
loop_range = loop_node.stop - loop_node.start loop_range = loop_node.stop - loop_node.start
# Check restrictions # cut off loop tail, that is not a multiple of four
if isinstance(loop_range, sp.Expr) and not loop_range.is_number: cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
warnings.warn("Currently only loops with fixed ranges can be vectorized - skipping loop") loop_nodes = cut_loop(loop_node, [cutting_point])
continue assert len(loop_nodes) in (1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width
if loop_range % vector_width != 0 or loop_node.step != 1: loop_node = loop_nodes[0]
warnings.warn("Currently only loops with loop bounds that are multiples "
"of vectorization width can be vectorized - skipping loop")
continue
# Find all array accesses (indexed) that depend on the loop counter as offset # Find all array accesses (indexed) that depend on the loop counter as offset
loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over) loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
substitutions = {} substitutions = {}
...@@ -94,19 +88,27 @@ def insert_vector_casts(ast_node): ...@@ -94,19 +88,27 @@ def insert_vector_casts(ast_node):
else: else:
return expr return expr
substitution_dict = {} def visit_node(node, substitution_dict):
for assignment in filtered_tree_iteration(ast_node, ast.SympyAssignment): substitution_dict = substitution_dict.copy()
subs_expr = fast_subs(assignment.rhs, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) for arg in node.args:
assignment.rhs = visit_expr(subs_expr) if isinstance(arg, ast.SympyAssignment):
rhs_type = get_type_of_expression(assignment.rhs) assignment = arg
if isinstance(assignment.lhs, TypedSymbol): subs_expr = fast_subs(assignment.rhs, substitution_dict,
lhs_type = assignment.lhs.dtype skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
if type(rhs_type) is VectorType and type(lhs_type) is not VectorType: assignment.rhs = visit_expr(subs_expr)
new_lhs_type = VectorType(lhs_type, rhs_type.width) rhs_type = get_type_of_expression(assignment.rhs)
new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) if isinstance(assignment.lhs, TypedSymbol):
substitution_dict[assignment.lhs] = new_lhs lhs_type = assignment.lhs.dtype
assignment.lhs = new_lhs if type(rhs_type) is VectorType and type(lhs_type) is not VectorType:
elif assignment.lhs.func == cast_func: new_lhs_type = VectorType(lhs_type, rhs_type.width)
lhs_type = assignment.lhs.args[1] new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
if type(lhs_type) is VectorType and type(rhs_type) is not VectorType: substitution_dict[assignment.lhs] = new_lhs
assignment.rhs = cast_func(assignment.rhs, lhs_type) assignment.lhs = new_lhs
elif assignment.lhs.func == cast_func:
lhs_type = assignment.lhs.args[1]
if type(lhs_type) is VectorType and type(rhs_type) is not VectorType:
assignment.rhs = cast_func(assignment.rhs, lhs_type)
else:
visit_node(arg, substitution_dict)
visit_node(ast_node, {})
...@@ -65,10 +65,13 @@ class TypedSymbol(sp.Symbol): ...@@ -65,10 +65,13 @@ class TypedSymbol(sp.Symbol):
def create_type(specification): def create_type(specification):
""" """Creates a subclass of Type according to a string or an object of subclass Type.
Create a subclass of Type according to a string or an object of subclass Type
:param specification: Type object, or a string Args:
:return: Type object, or a new Type object parsed from the string specification: Type object, or a string
Returns:
Type object, or a new Type object parsed from the string
""" """
if isinstance(specification, Type): if isinstance(specification, Type):
return specification return specification
...@@ -82,10 +85,13 @@ def create_type(specification): ...@@ -82,10 +85,13 @@ def create_type(specification):
@memorycache(maxsize=64) @memorycache(maxsize=64)
def create_composite_type_from_string(specification): def create_composite_type_from_string(specification):
""" """Creates a new Type object from a c-like string specification.
Creates a new Type object from a c-like string specification
:param specification: Specification string Args:
:return: Type object specification: Specification string
Returns:
Type object
""" """
specification = specification.lower().split() specification = specification.lower().split()
parts = [] parts = []
...@@ -432,6 +438,9 @@ class VectorType(Type): ...@@ -432,6 +438,9 @@ class VectorType(Type):
def __hash__(self): def __hash__(self):
return hash((self.base_type, self.width)) return hash((self.base_type, self.width))
def __getnewargs__(self):
return self._base_type, self.width
class PointerType(Type): class PointerType(Type):
def __init__(self, base_type, const=False, restrict=True): def __init__(self, base_type, const=False, restrict=True):
......
import sympy as sp
from pystencils.data_types import get_type_of_expression, collate_types
from pystencils.sympyextensions import is_integer_sequence
bitwise_xor = sp.Function("bitwise_xor")
bit_shift_right = sp.Function("bit_shift_right")
bit_shift_left = sp.Function("bit_shift_left")
bitwise_and = sp.Function("bitwise_and")
bitwise_or = sp.Function("bitwise_or")
# noinspection PyPep8Naming
class modulo_floor(sp.Function):
"""Returns the next smaller integer divisible by given divisor.
Examples:
>>> modulo_floor(9, 4)
8
>>> modulo_floor(11, 4)
8
>>> modulo_floor(12, 4)
12
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> modulo_floor(a, b).to_c(str)
'(int64_t)((a) / (b)) * (b)'
"""
nargs = 2
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return (int(integer) // int(divisor)) * divisor
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]),
print_func(self.args[1]), dtype=dtype)
...@@ -595,8 +595,16 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -595,8 +595,16 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
def cut_loop(loop_node, cutting_points): def cut_loop(loop_node, cutting_points):
"""Cuts loop at given cutting points, that means one loop is transformed into len(cuttingPoints)+1 new loops """Cuts loop at given cutting points.
that range from old_begin to cutting_points[1], ..., cutting_points[-1] to old_end"""
One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place
Returns:
list of new loop nodes
"""
if loop_node.step != 1: if loop_node.step != 1:
raise NotImplementedError("Can only split loops that have a step of 1") raise NotImplementedError("Can only split loops that have a step of 1")
new_loops = [] new_loops = []
...@@ -607,12 +615,15 @@ def cut_loop(loop_node, cutting_points): ...@@ -607,12 +615,15 @@ def cut_loop(loop_node, cutting_points):
new_body = deepcopy(loop_node.body) new_body = deepcopy(loop_node.body)
new_body.subs({loop_node.loop_counter_symbol: new_start}) new_body.subs({loop_node.loop_counter_symbol: new_start})
new_loops.append(new_body) new_loops.append(new_body)
elif new_end - new_start == 0:
pass
else: else:
new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over, new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
new_start, new_end, loop_node.step) new_start, new_end, loop_node.step)
new_loops.append(new_loop) new_loops.append(new_loop)
new_start = new_end new_start = new_end
loop_node.parent.replace(loop_node, new_loops) loop_node.parent.replace(loop_node, new_loops)
return new_loops
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool=False) -> None: def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool=False) -> None:
...@@ -973,3 +984,25 @@ def get_loop_hierarchy(ast_node): ...@@ -973,3 +984,25 @@ def get_loop_hierarchy(ast_node):
if node: if node:
result.append(node.coordinate_to_loop_over) result.append(node.coordinate_to_loop_over)
return reversed(result) return reversed(result)
def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
"""Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering).
Variable sized kernels can handle arbitrary field sizes and field shapes. However, the kernel is most efficient
if the innermost loop accesses the fields with stride 1. The inner loop can also only be vectorized if the inner
stride is 1. This transformation hard codes this inner stride to one to enable e.g. vectorization.
Warning: the assumption is not checked at runtime!
"""
inner_loop_counters = {l.coordinate_to_loop_over
for l in ast_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop}
if len(inner_loop_counters) != 1:
raise ValueError("Inner loops iterate over different coordinates")
inner_loop_counter = inner_loop_counters.pop()
stride_params = [p for p in ast_node.parameters if p.is_field_stride_argument]
for stride_param in stride_params:
stride_symbol = stride_param.symbol
subs_dict = {IndexedBase(stride_symbol, shape=(1,))[inner_loop_counter]: 1}
ast_node.subs(subs_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment