diff --git a/astnodes.py b/astnodes.py index 9ad56cdcd3f973b38897ceafe73be214abb2dd29..7f6eacf4149755990eb53be1e385492247bd81b5 100644 --- a/astnodes.py +++ b/astnodes.py @@ -29,10 +29,10 @@ class Node: """Symbols which are used but are not defined inside this node.""" raise NotImplementedError() - def subs(self, *args, **kwargs) -> None: + def subs(self, subs_dict) -> None: """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" for a in self.args: - a.subs(*args, **kwargs) + a.subs(subs_dict) @property def func(self): @@ -78,11 +78,11 @@ class Conditional(Node): self.true_block = handle_child(true_block) self.false_block = handle_child(false_block) - def subs(self, *args, **kwargs): - self.true_block.subs(*args, **kwargs) + def subs(self, subs_dict): + self.true_block.subs(subs_dict) if self.false_block: - self.false_block.subs(*args, **kwargs) - self.condition_expr = self.condition_expr.subs(*args, **kwargs) + self.false_block.subs(subs_dict) + self.condition_expr = self.condition_expr.subs(subs_dict) @property def args(self): @@ -238,6 +238,18 @@ class Block(Node): def args(self): return self._nodes + def subs(self, subs_dict) -> None: + new_args = [] + for a in self.args: + if isinstance(a, SympyAssignment) and a.is_declaration and a.rhs in subs_dict.keys(): + subs_dict[a.lhs] = subs_dict[a.rhs] + else: + new_args.append(a) + self._nodes = new_args + + for a in self.args: + a.subs(subs_dict) + def insert_front(self, node): node.parent = self self._nodes.insert(0, node) @@ -334,14 +346,14 @@ class LoopOverCoordinate(Node): result.prefix_lines = [l for l in self.prefix_lines] return result - def subs(self, *args, **kwargs): - self.body.subs(*args, **kwargs) + def subs(self, subs_dict): + self.body.subs(subs_dict) if hasattr(self.start, "subs"): - self.start = self.start.subs(*args, **kwargs) + self.start = self.start.subs(subs_dict) if hasattr(self.stop, "subs"): - self.stop = self.stop.subs(*args, **kwargs) + self.stop = self.stop.subs(subs_dict) if hasattr(self.step, "subs"): - self.step = self.step.subs(*args, **kwargs) + self.step = self.step.subs(subs_dict) @property def args(self): @@ -443,9 +455,9 @@ class SympyAssignment(Node): if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, sp.Indexed) or is_cast: self._is_declaration = False - def subs(self, *args, **kwargs): - self.lhs = fast_subs(self.lhs, *args, **kwargs) - self.rhs = fast_subs(self.rhs, *args, **kwargs) + def subs(self, subs_dict): + self.lhs = fast_subs(self.lhs, subs_dict) + self.rhs = fast_subs(self.rhs, subs_dict) @property def args(self): diff --git a/backends/cbackend.py b/backends/cbackend.py index 2c2c14a7a05abffdd397d1c3e45860b75cf30d8e..6b2b883d627f7b3446b278d36f47c666a1882747 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -8,7 +8,8 @@ try: except ImportError: from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 -from pystencils.bitoperations import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, bitwise_or +from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ + bitwise_or, modulo_floor from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func from pystencils.backends.simd_instruction_sets import selected_instruction_set @@ -104,7 +105,6 @@ class CBackend: def __init__(self, constants_as_floats=False, sympy_printer=None, signature_only=False, vector_instruction_set=None): if sympy_printer is None: - self.sympy_printer = CustomSympyPrinter(constants_as_floats) if vector_instruction_set is not None: self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats) else: @@ -239,9 +239,14 @@ class CustomSympyPrinter(CCodePrinter): bitwise_or: '|', bitwise_and: '&', } + if hasattr(expr, 'to_c'): + return expr.to_c(self._print) if expr.func == cast_func: arg, data_type = expr.args return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg)) + elif expr.func == modulo_floor: + assert all(get_type_of_expression(e).is_int() for e in expr.args) + return "({dtype})({0} / {1}) * {1}".format(*expr.args, dtype=get_type_of_expression(expr.args[0])) elif expr.func in function_map: return "(%s %s %s)" % (self._print(expr.args[0]), function_map[expr.func], self._print(expr.args[1])) else: @@ -370,7 +375,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): a.append(item) a = a or [S.One] - # a = a or [cast_func(S.One, VectorType(create_type_from_string("double"), expr_type.width))] a_str = [self._print(x) for x in a] b_str = [self._print(x) for x in b] diff --git a/bitoperations.py b/bitoperations.py deleted file mode 100644 index 865772efb41aa17f4f6972150f3fce86c0951ca1..0000000000000000000000000000000000000000 --- a/bitoperations.py +++ /dev/null @@ -1,6 +0,0 @@ -import sympy as sp -bitwise_xor = sp.Function("bitwise_xor") -bit_shift_right = sp.Function("bit_shift_right") -bit_shift_left = sp.Function("bit_shift_left") -bitwise_and = sp.Function("bitwise_and") -bitwise_or = sp.Function("bitwise_or") diff --git a/boundaries/inkernel.py b/boundaries/inkernel.py index 82cab6384f762ed220686ca2e1d92bc0865c2ad0..6d0ff9e66c032d7f601585435eaade3346a66894 100644 --- a/boundaries/inkernel.py +++ b/boundaries/inkernel.py @@ -1,6 +1,6 @@ import sympy as sp from pystencils import Field, TypedSymbol -from pystencils.bitoperations import bitwise_and +from pystencils.integer_functions import bitwise_and from pystencils.boundaries.boundaryhandling import FlagInterface from pystencils.data_types import create_type diff --git a/cpu/cpujit.py b/cpu/cpujit.py index d20febad9e383de1b6384031a0b53311db355b00..e073543bb9c48ed4111dac290346774f77c36aaf 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -109,7 +109,7 @@ def make_python_function(kernel_function_node, argument_dict={}): return lambda: func(*args) -def set_compiler_config(config): +def set_config(config): """ Override the configuration provided in config file @@ -206,7 +206,7 @@ def read_config(): config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid()) if config['cache']['clear_cache_on_start']: - shutil.rmtree(config['cache']['object_cache'], ignore_errors=True) + clear_cache() create_folder(config['cache']['object_cache'], False) create_folder(config['cache']['shared_library'], True) @@ -238,6 +238,12 @@ def hash_to_function_name(h): return res.replace('-', 'm') +def clear_cache(): + cache_config = get_cache_config() + shutil.rmtree(cache_config['object_cache'], ignore_errors=True) + create_folder(cache_config['object_cache'], False) + + def compile_object_cache_to_shared_library(): compiler_config = get_compiler_config() cache_config = get_cache_config() diff --git a/cpu/vectorization.py b/cpu/vectorization.py index d841c50b93dc2900041c8b5080eef60e9a2b29fb..ef0baf3ad788381a70aca7a002b4ab97d746578b 100644 --- a/cpu/vectorization.py +++ b/cpu/vectorization.py @@ -1,9 +1,10 @@ import sympy as sp import warnings +from pystencils.integer_functions import modulo_floor from pystencils.sympyextensions import fast_subs -from pystencils.transformations import filtered_tree_iteration from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, cast_func, collate_types, PointerType import pystencils.astnodes as ast +from pystencils.transformations import cut_loop def vectorize(ast_node, vector_width=4): @@ -12,25 +13,18 @@ def vectorize(ast_node, vector_width=4): def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4): - """ - Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if - * loop bounds are constant - * loop range is a multiple of vector width - """ + """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" inner_loops = [n for n in ast_node.atoms(ast.LoopOverCoordinate) if n.is_innermost_loop] for loop_node in inner_loops: loop_range = loop_node.stop - loop_node.start - # Check restrictions - if isinstance(loop_range, sp.Expr) and not loop_range.is_number: - warnings.warn("Currently only loops with fixed ranges can be vectorized - skipping loop") - continue - if loop_range % vector_width != 0 or loop_node.step != 1: - warnings.warn("Currently only loops with loop bounds that are multiples " - "of vectorization width can be vectorized - skipping loop") - continue - + # cut off loop tail, that is not a multiple of four + cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start + loop_nodes = cut_loop(loop_node, [cutting_point]) + assert len(loop_nodes) in (1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width + loop_node = loop_nodes[0] + # Find all array accesses (indexed) that depend on the loop counter as offset loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over) substitutions = {} @@ -94,19 +88,27 @@ def insert_vector_casts(ast_node): else: return expr - substitution_dict = {} - for assignment in filtered_tree_iteration(ast_node, ast.SympyAssignment): - subs_expr = fast_subs(assignment.rhs, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) - assignment.rhs = visit_expr(subs_expr) - rhs_type = get_type_of_expression(assignment.rhs) - if isinstance(assignment.lhs, TypedSymbol): - lhs_type = assignment.lhs.dtype - if type(rhs_type) is VectorType and type(lhs_type) is not VectorType: - new_lhs_type = VectorType(lhs_type, rhs_type.width) - new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) - substitution_dict[assignment.lhs] = new_lhs - assignment.lhs = new_lhs - elif assignment.lhs.func == cast_func: - lhs_type = assignment.lhs.args[1] - if type(lhs_type) is VectorType and type(rhs_type) is not VectorType: - assignment.rhs = cast_func(assignment.rhs, lhs_type) + def visit_node(node, substitution_dict): + substitution_dict = substitution_dict.copy() + for arg in node.args: + if isinstance(arg, ast.SympyAssignment): + assignment = arg + subs_expr = fast_subs(assignment.rhs, substitution_dict, + skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) + assignment.rhs = visit_expr(subs_expr) + rhs_type = get_type_of_expression(assignment.rhs) + if isinstance(assignment.lhs, TypedSymbol): + lhs_type = assignment.lhs.dtype + if type(rhs_type) is VectorType and type(lhs_type) is not VectorType: + new_lhs_type = VectorType(lhs_type, rhs_type.width) + new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) + substitution_dict[assignment.lhs] = new_lhs + assignment.lhs = new_lhs + elif assignment.lhs.func == cast_func: + lhs_type = assignment.lhs.args[1] + if type(lhs_type) is VectorType and type(rhs_type) is not VectorType: + assignment.rhs = cast_func(assignment.rhs, lhs_type) + else: + visit_node(arg, substitution_dict) + + visit_node(ast_node, {}) diff --git a/data_types.py b/data_types.py index 383d603b6fe789b6c00ccaec09f9c77549a1e79b..1e19748ea7c87c5a8e3566ff408e977074c642bf 100644 --- a/data_types.py +++ b/data_types.py @@ -65,10 +65,13 @@ class TypedSymbol(sp.Symbol): def create_type(specification): - """ - Create a subclass of Type according to a string or an object of subclass Type - :param specification: Type object, or a string - :return: Type object, or a new Type object parsed from the string + """Creates a subclass of Type according to a string or an object of subclass Type. + + Args: + specification: Type object, or a string + + Returns: + Type object, or a new Type object parsed from the string """ if isinstance(specification, Type): return specification @@ -82,10 +85,13 @@ def create_type(specification): @memorycache(maxsize=64) def create_composite_type_from_string(specification): - """ - Creates a new Type object from a c-like string specification - :param specification: Specification string - :return: Type object + """Creates a new Type object from a c-like string specification. + + Args: + specification: Specification string + + Returns: + Type object """ specification = specification.lower().split() parts = [] @@ -432,6 +438,9 @@ class VectorType(Type): def __hash__(self): return hash((self.base_type, self.width)) + def __getnewargs__(self): + return self._base_type, self.width + class PointerType(Type): def __init__(self, base_type, const=False, restrict=True): diff --git a/integer_functions.py b/integer_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf17bd18c1dd4deab24356aab93051f25fb5daf --- /dev/null +++ b/integer_functions.py @@ -0,0 +1,41 @@ +import sympy as sp + +from pystencils.data_types import get_type_of_expression, collate_types +from pystencils.sympyextensions import is_integer_sequence + +bitwise_xor = sp.Function("bitwise_xor") +bit_shift_right = sp.Function("bit_shift_right") +bit_shift_left = sp.Function("bit_shift_left") +bitwise_and = sp.Function("bitwise_and") +bitwise_or = sp.Function("bitwise_or") + + +# noinspection PyPep8Naming +class modulo_floor(sp.Function): + """Returns the next smaller integer divisible by given divisor. + + Examples: + >>> modulo_floor(9, 4) + 8 + >>> modulo_floor(11, 4) + 8 + >>> modulo_floor(12, 4) + 12 + >>> from pystencils import TypedSymbol + >>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32") + >>> modulo_floor(a, b).to_c(str) + '(int64_t)((a) / (b)) * (b)' + """ + nargs = 2 + + def __new__(cls, integer, divisor): + if is_integer_sequence((integer, divisor)): + return (int(integer) // int(divisor)) * divisor + else: + return super().__new__(cls, integer, divisor) + + def to_c(self, print_func): + dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) + assert dtype.is_int() + return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]), + print_func(self.args[1]), dtype=dtype) diff --git a/transformations.py b/transformations.py index 6bdfcb20e48a79b9b854b01dff352fb7f5415a8a..8893c93f811691a887175c8f456a4fed52480fa0 100644 --- a/transformations.py +++ b/transformations.py @@ -595,8 +595,16 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): def cut_loop(loop_node, cutting_points): - """Cuts loop at given cutting points, that means one loop is transformed into len(cuttingPoints)+1 new loops - that range from old_begin to cutting_points[1], ..., cutting_points[-1] to old_end""" + """Cuts loop at given cutting points. + + One loop is transformed into len(cuttingPoints)+1 new loops that range from + old_begin to cutting_points[1], ..., cutting_points[-1] to old_end + + Modifies the ast in place + + Returns: + list of new loop nodes + """ if loop_node.step != 1: raise NotImplementedError("Can only split loops that have a step of 1") new_loops = [] @@ -607,12 +615,15 @@ def cut_loop(loop_node, cutting_points): new_body = deepcopy(loop_node.body) new_body.subs({loop_node.loop_counter_symbol: new_start}) new_loops.append(new_body) + elif new_end - new_start == 0: + pass else: new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over, new_start, new_end, loop_node.step) new_loops.append(new_loop) new_start = new_end loop_node.parent.replace(loop_node, new_loops) + return new_loops def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool=False) -> None: @@ -973,3 +984,25 @@ def get_loop_hierarchy(ast_node): if node: result.append(node.coordinate_to_loop_over) return reversed(result) + + +def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: + """Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering). + + Variable sized kernels can handle arbitrary field sizes and field shapes. However, the kernel is most efficient + if the innermost loop accesses the fields with stride 1. The inner loop can also only be vectorized if the inner + stride is 1. This transformation hard codes this inner stride to one to enable e.g. vectorization. + + Warning: the assumption is not checked at runtime! + """ + inner_loop_counters = {l.coordinate_to_loop_over + for l in ast_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop} + if len(inner_loop_counters) != 1: + raise ValueError("Inner loops iterate over different coordinates") + inner_loop_counter = inner_loop_counters.pop() + + stride_params = [p for p in ast_node.parameters if p.is_field_stride_argument] + for stride_param in stride_params: + stride_symbol = stride_param.symbol + subs_dict = {IndexedBase(stride_symbol, shape=(1,))[inner_loop_counter]: 1} + ast_node.subs(subs_dict)