diff --git a/assignment.py b/assignment.py index 5eec6c027b366afd8eb9cb29d2e54a433212730a..dd9f09ba03f825dc433016dcfc2e500e6bcc5c2a 100644 --- a/assignment.py +++ b/assignment.py @@ -29,7 +29,7 @@ if Assignment: else: # back port for older sympy versions that don't have Assignment yet - class Assignment(sp.Rel): + class Assignment(sp.Rel): # pragma: no cover rel_op = ':=' __slots__ = [] diff --git a/astnodes.py b/astnodes.py index 7f6eacf4149755990eb53be1e385492247bd81b5..7a1c8198bce798a7e406961f62c964006c5a49a2 100644 --- a/astnodes.py +++ b/astnodes.py @@ -183,6 +183,7 @@ class KernelFunction(Node): # these variables are assumed to be global, so no automatic parameter is generated for them self.global_variables = set() self.backend = backend + self.instruction_set = None # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use @property def symbols_defined(self): @@ -437,11 +438,15 @@ class SympyAssignment(Node): super(SympyAssignment, self).__init__(parent=None) self._lhs_symbol = lhs_symbol self.rhs = rhs_expr - self._is_declaration = True - is_cast = self._lhs_symbol.func == cast_func - if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, ResolvedFieldAccess) or is_cast: - self._is_declaration = False self._is_const = is_const + self._is_declaration = self.__is_declaration() + + def __is_declaration(self): + if isinstance(self._lhs_symbol, cast_func): + return False + if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)): + return False + return True @property def lhs(self): @@ -450,10 +455,7 @@ class SympyAssignment(Node): @lhs.setter def lhs(self, new_value): self._lhs_symbol = new_value - self._is_declaration = True - is_cast = self._lhs_symbol.func == cast_func - if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, sp.Indexed) or is_cast: - self._is_declaration = False + self._is_declaration = self.__is_declaration() def subs(self, subs_dict): self.lhs = fast_subs(self.lhs, subs_dict) @@ -548,10 +550,21 @@ class ResolvedFieldAccess(sp.Indexed): class TemporaryMemoryAllocation(Node): - def __init__(self, typed_symbol, size): + """Node for temporary memory buffer allocation. + + Always allocates aligned memory. + + Args: + typed_symbol: symbol used as pointer (has to be typed) + size: number of elements to allocate + align_offset: the align_offset's element is aligned + """ + def __init__(self, typed_symbol: TypedSymbol, size, align_offset): super(TemporaryMemoryAllocation, self).__init__(parent=None) self.symbol = typed_symbol self.size = size + self.headers = ['<stdlib.h>'] + self._align_offset = align_offset @property def symbols_defined(self): @@ -568,11 +581,24 @@ class TemporaryMemoryAllocation(Node): def args(self): return [self.symbol] + def offset(self, byte_alignment): + """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment.""" + np_dtype = self.symbol.dtype.base_type.numpy_dtype + assert byte_alignment % np_dtype.itemsize == 0 + return -self._align_offset % (byte_alignment / np_dtype.itemsize) + class TemporaryMemoryFree(Node): - def __init__(self, typed_symbol): + def __init__(self, alloc_node): super(TemporaryMemoryFree, self).__init__(parent=None) - self.symbol = typed_symbol + self.alloc_node = alloc_node + + @property + def symbol(self): + return self.alloc_node.symbol + + def offset(self, byte_alignment): + return self.alloc_node.offset(byte_alignment) @property def symbols_defined(self): diff --git a/backends/cbackend.py b/backends/cbackend.py index 6b2b883d627f7b3446b278d36f47c666a1882747..188679d47256222c6a147007434c0d26ce50a017 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -9,10 +9,10 @@ except ImportError: from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \ - bitwise_or, modulo_floor -from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment -from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func -from pystencils.backends.simd_instruction_sets import selected_instruction_set + bitwise_or, modulo_floor, modulo_ceil +from pystencils.astnodes import Node, ResolvedFieldAccess, KernelFunction +from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ + vector_memory_access __all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] @@ -37,9 +37,8 @@ def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants double = create_type('double') use_float_constants = double not in field_types - vector_is = selected_instruction_set['double'] printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only, - vector_instruction_set=vector_is) + vector_instruction_set=ast_node.instruction_set) return printer(ast_node) @@ -47,12 +46,11 @@ def get_headers(ast_node: Node) -> Set[str]: """Return a set of header files, necessary to compile the printed C-like code.""" headers = set() + if isinstance(ast_node, KernelFunction) and ast_node.instruction_set: + headers.update(ast_node.instruction_set['headers']) + if hasattr(ast_node, 'headers'): headers.update(ast_node.headers) - elif isinstance(ast_node, SympyAssignment): - if type(get_type_of_expression(ast_node.rhs)) is VectorType: - headers.update(selected_instruction_set['double']['headers']) - for a in ast_node.args: if isinstance(a, Node): headers.update(get_headers(a)) @@ -165,18 +163,32 @@ class CBackend: self.sympy_printer.doprint(node.rhs)) else: lhs_type = get_type_of_expression(node.lhs) - if type(lhs_type) is VectorType and node.lhs.func == cast_func: - return self._vectorInstructionSet['storeU'].format("&" + self.sympy_printer.doprint(node.lhs.args[0]), - self.sympy_printer.doprint(node.rhs)) + ';' + if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): + arg, data_type, aligned, nontemporal = node.lhs.args + instr = 'storeU' + if aligned: + instr = 'stream' if nontemporal else 'storeA' + + return self._vectorInstructionSet[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]), + self.sympy_printer.doprint(node.rhs)) + ';' else: return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) def _print_TemporaryMemoryAllocation(self, node): - return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympy_printer.doprint(node.symbol.name), - node.symbol.dtype.base_type, self.sympy_printer.doprint(node.size)) + align = 128 + np_dtype = node.symbol.dtype.base_type.numpy_dtype + required_size = np_dtype.itemsize * node.size + align + size = modulo_ceil(required_size, align) + code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};" + return code.format(dtype=node.symbol.dtype, + name=self.sympy_printer.doprint(node.symbol.name), + size=int(size), + offset=int(node.offset(align)), + align=align) def _print_TemporaryMemoryFree(self, node): - return "delete [] %s;" % (self.sympy_printer.doprint(node.symbol.name),) + align = 128 + return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) @staticmethod def _print_CustomCppCode(node): @@ -270,13 +282,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return None def _print_Function(self, expr): - if expr.func == cast_func: + if expr.func == vector_memory_access: + arg, data_type, aligned, _ = expr.args + instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] + return instruction.format("& " + self._print(arg)) + elif expr.func == cast_func: arg, data_type = expr.args if type(data_type) is VectorType: - if type(arg) is ResolvedFieldAccess: - return self.instruction_set['loadU'].format("& " + self._print(arg)) - else: - return self.instruction_set['makeVec'].format(self._print(arg)) + return self.instruction_set['makeVec'].format(self._print(arg)) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py index d69b21dea97ca3fcaba21ca79b7df2a484c88f4a..518e6a59a774ae6554ea2f9c25f65a444236ceed 100644 --- a/backends/simd_instruction_sets.py +++ b/backends/simd_instruction_sets.py @@ -1,7 +1,7 @@ # noinspection SpellCheckingInspection -def x86_vector_instruction_set(data_type='double', instruction_set='avx'): +def get_vector_instruction_set(data_type='double', instruction_set='avx'): base_names = { '+': 'add[0, 1]', '-': 'sub[0, 1]', @@ -26,7 +26,8 @@ def x86_vector_instruction_set(data_type='double', instruction_set='avx'): 'loadU': 'loadu[0]', 'loadA': 'load[0]', 'storeU': 'storeu[0,1]', - 'storeA': 'store [0,1]', + 'storeA': 'store[0,1]', + 'stream': 'stream[0,1]', } headers = { @@ -86,9 +87,3 @@ def x86_vector_instruction_set(data_type='double', instruction_set='avx'): result['headers'] = headers[instruction_set] return result - - -selected_instruction_set = { - 'float': x86_vector_instruction_set('float', 'avx'), - 'double': x86_vector_instruction_set('double', 'avx'), -} diff --git a/cpu/cpujit.py b/cpu/cpujit.py index e073543bb9c48ed4111dac290346774f77c36aaf..cb840d3db2cd78dc90bc0613cf389de7371fb814 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -277,7 +277,7 @@ atexit.register(compile_object_cache_to_shared_library) def generate_code(ast, restrict_qualifier, function_prefix, source_file): headers = get_headers(ast) - headers.update(['<cmath>', '<cstdint>']) + headers.update(['<math.h>', '<stdint.h>']) code = generate_c(ast) includes = "\n".join(["#include %s" % (include_file,) for include_file in headers]) diff --git a/cpu/vectorization.py b/cpu/vectorization.py index ef0baf3ad788381a70aca7a002b4ab97d746578b..3745bb31c232800bf18aa96ac9dae32f00a74ef5 100644 --- a/cpu/vectorization.py +++ b/cpu/vectorization.py @@ -1,20 +1,60 @@ import sympy as sp import warnings + +from typing import Union, Container + +from pystencils.backends.simd_instruction_sets import get_vector_instruction_set from pystencils.integer_functions import modulo_floor from pystencils.sympyextensions import fast_subs -from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, cast_func, collate_types, PointerType +from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, vector_memory_access, cast_func, \ + collate_types, PointerType import pystencils.astnodes as ast -from pystencils.transformations import cut_loop - - -def vectorize(ast_node, vector_width=4): - vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width) - insert_vector_casts(ast_node) - - -def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4): +from pystencils.transformations import cut_loop, filtered_tree_iteration +from pystencils.field import Field + + +def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx', + assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False): + """Explicit vectorization using SIMD vectorization via intrinsics. + + Args: + kernel_ast: abstract syntax tree (KernelFunction node) + vector_instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512') + assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are + used. If true, some of the loads are assumed to be from aligned memory addresses. + For example if x is the fastest coordinate, the access to center can be fetched via an + aligned-load instruction, for the west or east accesses potentially slower unaligend-load + instructions have to be used. + nontemporal: a container of fields or field names for which nontemporal (streaming) stores are used. + If true, nontemporal access instructions are used for all fields. + + """ + all_fields = kernel_ast.fields_accessed + if nontemporal is None or nontemporal is False: + nontemporal = {} + elif nontemporal is True: + nontemporal = all_fields + + field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float) + if len(field_float_dtypes) != 1: + raise NotImplementedError("Cannot vectorize kernels that contain accesses " + "to differently typed floating point fields") + float_size = field_float_dtypes.pop().numpy_dtype.itemsize + assert float_size in (8, 4) + vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float', + instruction_set=vector_instruction_set) + vector_width = vector_is['width'] + kernel_ast.instruction_set = vector_is + + vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal) + insert_vector_casts(kernel_ast) + + +def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields): """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" - inner_loops = [n for n in ast_node.atoms(ast.LoopOverCoordinate) if n.is_innermost_loop] + all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) + inner_loops = [n for n in all_loops if n.is_innermost_loop] + zero_loop_counters = {l.loop_counter_symbol: 0 for l in all_loops} for loop_node in inner_loops: loop_range = loop_node.stop - loop_node.start @@ -33,13 +73,20 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4): base, index = indexed.args if loop_counter_symbol in index.atoms(sp.Symbol): loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms() + aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0 if not loop_counter_is_offset: successful = False break typed_symbol = base.label assert type(typed_symbol.dtype) is PointerType, \ "Type of access is {}, {}".format(typed_symbol.dtype, indexed) - substitutions[indexed] = cast_func(indexed, VectorType(typed_symbol.dtype.base_type, vector_width)) + + vec_type = VectorType(typed_symbol.dtype.base_type, vector_width) + use_aligned_access = aligned_access and assume_aligned + nontemporal = False + if hasattr(indexed, 'field'): + nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields) + substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal) if not successful: warnings.warn("Could not vectorize loop because of non-consecutive memory access") continue @@ -52,8 +99,9 @@ def insert_vector_casts(ast_node): """Inserts necessary casts from scalar values to vector values.""" def visit_expr(expr): - if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == cast_func) or \ - isinstance(expr, sp.boolalg.BooleanFunction): + if expr.func in (cast_func, vector_memory_access): + return expr + elif expr.func in (sp.Add, sp.Mul) or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): new_args = [visit_expr(a) for a in expr.args] arg_types = [get_type_of_expression(a) for a in new_args] if not any(type(t) is VectorType for t in arg_types): @@ -104,7 +152,7 @@ def insert_vector_casts(ast_node): new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) substitution_dict[assignment.lhs] = new_lhs assignment.lhs = new_lhs - elif assignment.lhs.func == cast_func: + elif isinstance(assignment.lhs.func, cast_func): lhs_type = assignment.lhs.args[1] if type(lhs_type) is VectorType and type(rhs_type) is not VectorType: assignment.rhs = cast_func(assignment.rhs, lhs_type) diff --git a/data_types.py b/data_types.py index 1e19748ea7c87c5a8e3566ff408e977074c642bf..318a2475296872c8e1e6d755179ef86f4762cc34 100644 --- a/data_types.py +++ b/data_types.py @@ -10,10 +10,13 @@ from sympy.core.cache import cacheit from pystencils.cache import memorycache from pystencils.utils import all_equal +from sympy.logic.boolalg import Boolean -# to work in conditions of sp.Piecewise cast_func has to be of type Relational as well -class cast_func(sp.Function, sp.Rel): +# noinspection PyPep8Naming +class cast_func(sp.Function, Boolean): + # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well + @property def canonical(self): if hasattr(self.args[0], 'canonical'): @@ -25,8 +28,18 @@ class cast_func(sp.Function, sp.Rel): def is_commutative(self): return self.args[0].is_commutative + @property + def dtype(self): + return self.args[1] + + +# noinspection PyPep8Naming +class vector_memory_access(cast_func): + nargs = (4,) + -class pointer_arithmetic_func(sp.Function, sp.Rel): +# noinspection PyPep8Naming +class pointer_arithmetic_func(sp.Function, Boolean): @property def canonical(self): if hasattr(self.args[0], 'canonical'): @@ -285,7 +298,7 @@ def get_type_of_expression(expr): return expr.dtype elif isinstance(expr, sp.Symbol): raise ValueError("All symbols inside this expression have to be typed!") - elif hasattr(expr, 'func') and expr.func == cast_func: + elif isinstance(expr, cast_func): return expr.args[1] elif hasattr(expr, 'func') and expr.func == sp.Piecewise: collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args)) diff --git a/datahandling/parallel_datahandling.py b/datahandling/parallel_datahandling.py index 6ac9e9241693a49245b2a78bdd530a81f7d5bd78..f0a287898c69f4785cad311c28ab79fe925160cd 100644 --- a/datahandling/parallel_datahandling.py +++ b/datahandling/parallel_datahandling.py @@ -85,7 +85,7 @@ class ParallelDataHandling(DataHandling): self._custom_data_names.append(name) def add_array(self, name, values_per_cell=1, dtype=np.float64, latex_name=None, ghost_layers=None, - layout=None, cpu=True, gpu=None): + layout=None, cpu=True, gpu=None, alignment=False): if ghost_layers is None: ghost_layers = self.default_ghost_layers if gpu is None: @@ -99,6 +99,9 @@ class ParallelDataHandling(DataHandling): if name in self.blocks[0] or self.GPU_DATA_PREFIX + name in self.blocks[0]: raise ValueError("Data with this name has already been added") + if alignment: + raise NotImplementedError("Aligned field allocated not yet supported in parallel data handling") + self._fieldInformation[name] = {'ghost_layers': ghost_layers, 'values_per_cell': values_per_cell, 'layout': layout, diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py index b46845278d8dfa5b7edc361d867f68e6e345bc67..e91b37f1c5b209a3d6257e0c357aa96bea284bf4 100644 --- a/datahandling/serial_datahandling.py +++ b/datahandling/serial_datahandling.py @@ -89,6 +89,7 @@ class SerialDataHandling(DataHandling): 'values_per_cell': values_per_cell, 'layout': layout, 'dtype': dtype, + 'alignment': alignment, } if values_per_cell > 1: diff --git a/integer_functions.py b/integer_functions.py index 5cf17bd18c1dd4deab24356aab93051f25fb5daf..db8358ff2b23a88e94287ec313eecd9cce41f427 100644 --- a/integer_functions.py +++ b/integer_functions.py @@ -39,3 +39,34 @@ class modulo_floor(sp.Function): assert dtype.is_int() return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + + +# noinspection PyPep8Naming +class modulo_ceil(sp.Function): + """Returns the next smaller integer divisible by given divisor. + + Examples: + >>> modulo_ceil(9, 4) + 12 + >>> modulo_ceil(11, 4) + 12 + >>> modulo_ceil(12, 4) + 12 + >>> from pystencils import TypedSymbol + >>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32") + >>> modulo_ceil(a, b).to_c(str) + '(a) % (b) == 0 ? a : ((int64_t)((a) / (b))+1) * (b)' + """ + nargs = 2 + + def __new__(cls, integer, divisor): + if is_integer_sequence((integer, divisor)): + return integer if integer % divisor == 0 else ((integer // divisor) + 1) * divisor + else: + return super().__new__(cls, integer, divisor) + + def to_c(self, print_func): + dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) + assert dtype.is_int() + code = "({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1})" + return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) diff --git a/kernelcreation.py b/kernelcreation.py index 10c12be27d972e553fd707f8af563c6c51df54a1..f79865f2ebc2fea30349082fd32310b2e600dbe5 100644 --- a/kernelcreation.py +++ b/kernelcreation.py @@ -2,6 +2,7 @@ from types import MappingProxyType import sympy as sp from pystencils.assignment import Assignment from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAssignment +from pystencils.cpu.vectorization import vectorize from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.gpucuda.indexing import indexing_creator_from_params from pystencils.transformations import remove_conditionals_in_staggered_kernel @@ -25,9 +26,10 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice pairs ``[(x_lower_gl, x_upper_gl), .... ]`` cpu_openmp: True or number of threads for OpenMP parallelization, False for no OpenMP - cpu_vectorize_info: pair of instruction set name, i.e. one of 'sse, 'avx' or 'avx512' - and data type 'float' or 'double'. For example ``('avx', 'double')`` - gpu_indexing: either 'block' or 'line' , or custom indexing class, see `pystencils.gpucuda.AbstractIndexing` + cpu_vectorize_info: a dictionary with keys, 'vector_instruction_set', 'assume_aligned' and 'nontemporal' + for documentation of these parameters see vectorize function. Example: + '{'vector_instruction_set': 'avx512', 'assume_aligned': True, 'nontemporal':True}' + gpu_indexing: either 'block' or 'line' , or custom indexing class, see `AbstractIndexing` gpu_indexing_params: dict with indexing parameters (constructor parameters of indexing class) e.g. for 'block' one can specify '{'block_size': (20, 20, 10) }' @@ -70,12 +72,12 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice if cpu_openmp: add_openmp(ast, num_threads=cpu_openmp) if cpu_vectorize_info: - import pystencils.backends.simd_instruction_sets as vec - from pystencils.cpu.vectorization import vectorize - vec_params = cpu_vectorize_info - vec.selected_instruction_set = vec.x86_vector_instruction_set(instruction_set=vec_params[0], - data_type=vec_params[1]) - vectorize(ast) + if cpu_vectorize_info is True: + vectorize(ast, vector_instruction_set='avx', assume_aligned=False, nontemporal=None) + elif isinstance(cpu_vectorize_info, dict): + vectorize(ast, **cpu_vectorize_info) + else: + raise ValueError("Invalid value for cpu_vectorize_info") return ast elif target == 'llvm': from pystencils.llvm import create_kernel diff --git a/transformations.py b/transformations.py index 8893c93f811691a887175c8f456a4fed52480fa0..b2c6f5e2727e90cb9c0fa5019a674620c5e6b587 100644 --- a/transformations.py +++ b/transformations.py @@ -13,10 +13,13 @@ from pystencils.slicing import normalize_slice import pystencils.astnodes as ast -def filtered_tree_iteration(node, node_type): +def filtered_tree_iteration(node, node_type, stop_type=None): for arg in node.args: if isinstance(arg, node_type): yield arg + elif stop_type and isinstance(node, stop_type): + continue + yield from filtered_tree_iteration(arg, node_type) @@ -590,8 +593,10 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): for tmp_array in symbols_with_temporary_array: tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype)) - outer_loop.parent.insert_front(ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop)) - outer_loop.parent.append(ast.TemporaryMemoryFree(tmp_array_pointer)) + alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start) + free_node = ast.TemporaryMemoryFree(alloc_node) + outer_loop.parent.insert_front(alloc_node) + outer_loop.parent.append(free_node) def cut_loop(loop_node, cutting_points):