From 501b2d7e0009ef7f9b7a1edb548fec589b511369 Mon Sep 17 00:00:00 2001
From: Martin Bauer <>
Date: Fri, 11 May 2018 16:30:05 +0200
Subject: [PATCH] Improved Vectorization

- support aligned load/stores
- nontemporal stores
- aligned memory allocation for arrays and temporary buffers
---                         |  2 +-                           | 48 +++++++++++++----
 backends/                  | 55 +++++++++++--------
 backends/     | 11 ++--
 cpu/                         |  2 +-
 cpu/                  | 78 +++++++++++++++++++++------                         | 21 ++++++--
 datahandling/ |  5 +-
 datahandling/   |  1 +                  | 31 +++++++++++                     | 20 +++----                    | 11 ++--
 12 files changed, 211 insertions(+), 74 deletions(-)

diff --git a/ b/
index 5eec6c027..dd9f09ba0 100644
--- a/
+++ b/
@@ -29,7 +29,7 @@ if Assignment:
     # back port for older sympy versions that don't have Assignment  yet
-    class Assignment(sp.Rel):
+    class Assignment(sp.Rel):  # pragma: no cover
         rel_op = ':='
         __slots__ = []
diff --git a/ b/
index 7f6eacf41..7a1c8198b 100644
--- a/
+++ b/
@@ -183,6 +183,7 @@ class KernelFunction(Node):
         # these variables are assumed to be global, so no automatic parameter is generated for them
         self.global_variables = set()
         self.backend = backend
+        self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
     def symbols_defined(self):
@@ -437,11 +438,15 @@ class SympyAssignment(Node):
         super(SympyAssignment, self).__init__(parent=None)
         self._lhs_symbol = lhs_symbol
         self.rhs = rhs_expr
-        self._is_declaration = True
-        is_cast = self._lhs_symbol.func == cast_func
-        if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, ResolvedFieldAccess) or is_cast:
-            self._is_declaration = False
         self._is_const = is_const
+        self._is_declaration = self.__is_declaration()
+    def __is_declaration(self):
+        if isinstance(self._lhs_symbol, cast_func):
+            return False
+        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
+            return False
+        return True
     def lhs(self):
@@ -450,10 +455,7 @@ class SympyAssignment(Node):
     def lhs(self, new_value):
         self._lhs_symbol = new_value
-        self._is_declaration = True
-        is_cast = self._lhs_symbol.func == cast_func
-        if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, sp.Indexed) or is_cast:
-            self._is_declaration = False
+        self._is_declaration = self.__is_declaration()
     def subs(self, subs_dict):
         self.lhs = fast_subs(self.lhs, subs_dict)
@@ -548,10 +550,21 @@ class ResolvedFieldAccess(sp.Indexed):
 class TemporaryMemoryAllocation(Node):
-    def __init__(self, typed_symbol, size):
+    """Node for temporary memory buffer allocation.
+    Always allocates aligned memory.
+    Args:
+        typed_symbol: symbol used as pointer (has to be typed)
+        size: number of elements to allocate
+        align_offset: the align_offset's element is aligned
+    """
+    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
         super(TemporaryMemoryAllocation, self).__init__(parent=None)
         self.symbol = typed_symbol
         self.size = size
+        self.headers = ['<stdlib.h>']
+        self._align_offset = align_offset
     def symbols_defined(self):
@@ -568,11 +581,24 @@ class TemporaryMemoryAllocation(Node):
     def args(self):
         return [self.symbol]
+    def offset(self, byte_alignment):
+        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
+        np_dtype = self.symbol.dtype.base_type.numpy_dtype
+        assert byte_alignment % np_dtype.itemsize == 0
+        return -self._align_offset % (byte_alignment / np_dtype.itemsize)
 class TemporaryMemoryFree(Node):
-    def __init__(self, typed_symbol):
+    def __init__(self, alloc_node):
         super(TemporaryMemoryFree, self).__init__(parent=None)
-        self.symbol = typed_symbol
+        self.alloc_node = alloc_node
+    @property
+    def symbol(self):
+        return self.alloc_node.symbol
+    def offset(self, byte_alignment):
+        return self.alloc_node.offset(byte_alignment)
     def symbols_defined(self):
diff --git a/backends/ b/backends/
index 6b2b883d6..188679d47 100644
--- a/backends/
+++ b/backends/
@@ -9,10 +9,10 @@ except ImportError:
     from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
 from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
-    bitwise_or, modulo_floor
-from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
-from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func
-from pystencils.backends.simd_instruction_sets import selected_instruction_set
+    bitwise_or, modulo_floor, modulo_ceil
+from pystencils.astnodes import Node, ResolvedFieldAccess, KernelFunction
+from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
+    vector_memory_access
 __all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
@@ -37,9 +37,8 @@ def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants
         double = create_type('double')
         use_float_constants = double not in field_types
-    vector_is = selected_instruction_set['double']
     printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only,
-                       vector_instruction_set=vector_is)
+                       vector_instruction_set=ast_node.instruction_set)
     return printer(ast_node)
@@ -47,12 +46,11 @@ def get_headers(ast_node: Node) -> Set[str]:
     """Return a set of header files, necessary to compile the printed C-like code."""
     headers = set()
+    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
+        headers.update(ast_node.instruction_set['headers'])
     if hasattr(ast_node, 'headers'):
-    elif isinstance(ast_node, SympyAssignment):
-        if type(get_type_of_expression(ast_node.rhs)) is VectorType:
-            headers.update(selected_instruction_set['double']['headers'])
     for a in ast_node.args:
         if isinstance(a, Node):
@@ -165,18 +163,32 @@ class CBackend:
             lhs_type = get_type_of_expression(node.lhs)
-            if type(lhs_type) is VectorType and node.lhs.func == cast_func:
-                return self._vectorInstructionSet['storeU'].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
-                                                                   self.sympy_printer.doprint(node.rhs)) + ';'
+            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
+                arg, data_type, aligned, nontemporal = node.lhs.args
+                instr = 'storeU'
+                if aligned:
+                    instr = 'stream' if nontemporal else 'storeA'
+                return self._vectorInstructionSet[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
+                                                                self.sympy_printer.doprint(node.rhs)) + ';'
                 return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
     def _print_TemporaryMemoryAllocation(self, node):
-        return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympy_printer.doprint(,
-                                        node.symbol.dtype.base_type, self.sympy_printer.doprint(node.size))
+        align = 128
+        np_dtype = node.symbol.dtype.base_type.numpy_dtype
+        required_size = np_dtype.itemsize * node.size + align
+        size = modulo_ceil(required_size, align)
+        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
+        return code.format(dtype=node.symbol.dtype,
+                           name=self.sympy_printer.doprint(,
+                           size=int(size),
+                           offset=int(node.offset(align)),
+                           align=align)
     def _print_TemporaryMemoryFree(self, node):
-        return "delete [] %s;" % (self.sympy_printer.doprint(,)
+        align = 128
+        return "free(%s - %d);" % (self.sympy_printer.doprint(, node.offset(align))
     def _print_CustomCppCode(node):
@@ -270,13 +282,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
             return None
     def _print_Function(self, expr):
-        if expr.func == cast_func:
+        if expr.func == vector_memory_access:
+            arg, data_type, aligned, _ = expr.args
+            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
+            return instruction.format("& " + self._print(arg))
+        elif expr.func == cast_func:
             arg, data_type = expr.args
             if type(data_type) is VectorType:
-                if type(arg) is ResolvedFieldAccess:
-                    return self.instruction_set['loadU'].format("& " + self._print(arg))
-                else:
-                    return self.instruction_set['makeVec'].format(self._print(arg))
+                return self.instruction_set['makeVec'].format(self._print(arg))
         return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
diff --git a/backends/ b/backends/
index d69b21dea..518e6a59a 100644
--- a/backends/
+++ b/backends/
@@ -1,7 +1,7 @@
 # noinspection SpellCheckingInspection
-def x86_vector_instruction_set(data_type='double', instruction_set='avx'):
+def get_vector_instruction_set(data_type='double', instruction_set='avx'):
     base_names = {
         '+': 'add[0, 1]',
         '-': 'sub[0, 1]',
@@ -26,7 +26,8 @@ def x86_vector_instruction_set(data_type='double', instruction_set='avx'):
         'loadU': 'loadu[0]',
         'loadA': 'load[0]',
         'storeU': 'storeu[0,1]',
-        'storeA': 'store [0,1]',
+        'storeA': 'store[0,1]',
+        'stream': 'stream[0,1]',
     headers = {
@@ -86,9 +87,3 @@ def x86_vector_instruction_set(data_type='double', instruction_set='avx'):
     result['headers'] = headers[instruction_set]
     return result
-selected_instruction_set = {
-    'float': x86_vector_instruction_set('float', 'avx'),
-    'double': x86_vector_instruction_set('double', 'avx'),
diff --git a/cpu/ b/cpu/
index e073543bb..cb840d3db 100644
--- a/cpu/
+++ b/cpu/
@@ -277,7 +277,7 @@ atexit.register(compile_object_cache_to_shared_library)
 def generate_code(ast, restrict_qualifier, function_prefix, source_file):
     headers = get_headers(ast)
-    headers.update(['<cmath>', '<cstdint>'])
+    headers.update(['<math.h>', '<stdint.h>'])
     code = generate_c(ast)
     includes = "\n".join(["#include %s" % (include_file,) for include_file in headers])
diff --git a/cpu/ b/cpu/
index ef0baf3ad..3745bb31c 100644
--- a/cpu/
+++ b/cpu/
@@ -1,20 +1,60 @@
 import sympy as sp
 import warnings
+from typing import Union, Container
+from pystencils.backends.simd_instruction_sets import get_vector_instruction_set
 from pystencils.integer_functions import modulo_floor
 from pystencils.sympyextensions import fast_subs
-from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, cast_func, collate_types, PointerType
+from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, vector_memory_access, cast_func, \
+    collate_types, PointerType
 import pystencils.astnodes as ast
-from pystencils.transformations import cut_loop
-def vectorize(ast_node, vector_width=4):
-    vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width)
-    insert_vector_casts(ast_node)
-def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4):
+from pystencils.transformations import cut_loop, filtered_tree_iteration
+from pystencils.field import Field
+def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx',
+              assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False):
+    """Explicit vectorization using SIMD vectorization via intrinsics.
+    Args:
+        kernel_ast: abstract syntax tree (KernelFunction node)
+        vector_instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
+        assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are
+                        used. If true, some of the loads are assumed to be from aligned memory addresses.
+                        For example if x is the fastest coordinate, the access to center can be fetched via an
+                        aligned-load instruction, for the west or east accesses potentially slower unaligend-load
+                        instructions have to be used.
+        nontemporal: a container of fields or field names for which nontemporal (streaming) stores are used.
+                     If true, nontemporal access instructions are used for all fields.
+    """
+    all_fields = kernel_ast.fields_accessed
+    if nontemporal is None or nontemporal is False:
+        nontemporal = {}
+    elif nontemporal is True:
+        nontemporal = all_fields
+    field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float)
+    if len(field_float_dtypes) != 1:
+        raise NotImplementedError("Cannot vectorize kernels that contain accesses "
+                                  "to differently typed floating point fields")
+    float_size = field_float_dtypes.pop().numpy_dtype.itemsize
+    assert float_size in (8, 4)
+    vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
+                                           instruction_set=vector_instruction_set)
+    vector_width = vector_is['width']
+    kernel_ast.instruction_set = vector_is
+    vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal)
+    insert_vector_casts(kernel_ast)
+def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields):
     """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
-    inner_loops = [n for n in ast_node.atoms(ast.LoopOverCoordinate) if n.is_innermost_loop]
+    all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
+    inner_loops = [n for n in all_loops if n.is_innermost_loop]
+    zero_loop_counters = {l.loop_counter_symbol: 0 for l in all_loops}
     for loop_node in inner_loops:
         loop_range = loop_node.stop - loop_node.start
@@ -33,13 +73,20 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width=4):
             base, index = indexed.args
             if loop_counter_symbol in index.atoms(sp.Symbol):
                 loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
+                aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
                 if not loop_counter_is_offset:
                     successful = False
                 typed_symbol = base.label
                 assert type(typed_symbol.dtype) is PointerType, \
                     "Type of access is {}, {}".format(typed_symbol.dtype, indexed)
-                substitutions[indexed] = cast_func(indexed, VectorType(typed_symbol.dtype.base_type, vector_width))
+                vec_type = VectorType(typed_symbol.dtype.base_type, vector_width)
+                use_aligned_access = aligned_access and assume_aligned
+                nontemporal = False
+                if hasattr(indexed, 'field'):
+                    nontemporal = (indexed.field in nontemporal_fields) or ( in nontemporal_fields)
+                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal)
         if not successful:
             warnings.warn("Could not vectorize loop because of non-consecutive memory access")
@@ -52,8 +99,9 @@ def insert_vector_casts(ast_node):
     """Inserts necessary casts from scalar values to vector values."""
     def visit_expr(expr):
-        if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == cast_func) or \
-                isinstance(expr, sp.boolalg.BooleanFunction):
+        if expr.func in (cast_func, vector_memory_access):
+            return expr
+        elif expr.func in (sp.Add, sp.Mul) or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
             new_args = [visit_expr(a) for a in expr.args]
             arg_types = [get_type_of_expression(a) for a in new_args]
             if not any(type(t) is VectorType for t in arg_types):
@@ -104,7 +152,7 @@ def insert_vector_casts(ast_node):
                         new_lhs = TypedSymbol(, new_lhs_type)
                         substitution_dict[assignment.lhs] = new_lhs
                         assignment.lhs = new_lhs
-                elif assignment.lhs.func == cast_func:
+                elif isinstance(assignment.lhs.func, cast_func):
                     lhs_type = assignment.lhs.args[1]
                     if type(lhs_type) is VectorType and type(rhs_type) is not VectorType:
                         assignment.rhs = cast_func(assignment.rhs, lhs_type)
diff --git a/ b/
index 1e19748ea..318a24752 100644
--- a/
+++ b/
@@ -10,10 +10,13 @@ from sympy.core.cache import cacheit
 from pystencils.cache import memorycache
 from pystencils.utils import all_equal
+from sympy.logic.boolalg import Boolean
-# to work in conditions of sp.Piecewise cast_func has to be of type Relational as well
-class cast_func(sp.Function, sp.Rel):
+# noinspection PyPep8Naming
+class cast_func(sp.Function, Boolean):
+    # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
     def canonical(self):
         if hasattr(self.args[0], 'canonical'):
@@ -25,8 +28,18 @@ class cast_func(sp.Function, sp.Rel):
     def is_commutative(self):
         return self.args[0].is_commutative
+    @property
+    def dtype(self):
+        return self.args[1]
+# noinspection PyPep8Naming
+class vector_memory_access(cast_func):
+    nargs = (4,)
-class pointer_arithmetic_func(sp.Function, sp.Rel):
+# noinspection PyPep8Naming
+class pointer_arithmetic_func(sp.Function, Boolean):
     def canonical(self):
         if hasattr(self.args[0], 'canonical'):
@@ -285,7 +298,7 @@ def get_type_of_expression(expr):
         return expr.dtype
     elif isinstance(expr, sp.Symbol):
         raise ValueError("All symbols inside this expression have to be typed!")
-    elif hasattr(expr, 'func') and expr.func == cast_func:
+    elif isinstance(expr, cast_func):
         return expr.args[1]
     elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
         collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
diff --git a/datahandling/ b/datahandling/
index 6ac9e9241..f0a287898 100644
--- a/datahandling/
+++ b/datahandling/
@@ -85,7 +85,7 @@ class ParallelDataHandling(DataHandling):
     def add_array(self, name, values_per_cell=1, dtype=np.float64, latex_name=None, ghost_layers=None,
-                  layout=None, cpu=True, gpu=None):
+                  layout=None, cpu=True, gpu=None, alignment=False):
         if ghost_layers is None:
             ghost_layers = self.default_ghost_layers
         if gpu is None:
@@ -99,6 +99,9 @@ class ParallelDataHandling(DataHandling):
         if name in self.blocks[0] or self.GPU_DATA_PREFIX + name in self.blocks[0]:
             raise ValueError("Data with this name has already been added")
+        if alignment:
+            raise NotImplementedError("Aligned field allocated not yet supported in parallel data handling")
         self._fieldInformation[name] = {'ghost_layers': ghost_layers,
                                         'values_per_cell': values_per_cell,
                                         'layout': layout,
diff --git a/datahandling/ b/datahandling/
index b46845278..e91b37f1c 100644
--- a/datahandling/
+++ b/datahandling/
@@ -89,6 +89,7 @@ class SerialDataHandling(DataHandling):
             'values_per_cell': values_per_cell,
             'layout': layout,
             'dtype': dtype,
+            'alignment': alignment,
         if values_per_cell > 1:
diff --git a/ b/
index 5cf17bd18..db8358ff2 100644
--- a/
+++ b/
@@ -39,3 +39,34 @@ class modulo_floor(sp.Function):
         assert dtype.is_int()
         return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]),
                                                          print_func(self.args[1]), dtype=dtype)
+# noinspection PyPep8Naming
+class modulo_ceil(sp.Function):
+    """Returns the next smaller integer divisible by given divisor.
+    Examples:
+        >>> modulo_ceil(9, 4)
+        12
+        >>> modulo_ceil(11, 4)
+        12
+        >>> modulo_ceil(12, 4)
+        12
+        >>> from pystencils import TypedSymbol
+        >>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
+        >>> modulo_ceil(a, b).to_c(str)
+        '(a) % (b) == 0 ? a : ((int64_t)((a) / (b))+1) * (b)'
+    """
+    nargs = 2
+    def __new__(cls, integer, divisor):
+        if is_integer_sequence((integer, divisor)):
+            return integer if integer % divisor == 0 else ((integer // divisor) + 1) * divisor
+        else:
+            return super().__new__(cls, integer, divisor)
+    def to_c(self, print_func):
+        dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
+        assert dtype.is_int()
+        code = "({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1})"
+        return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
diff --git a/ b/
index 10c12be27..f79865f2e 100644
--- a/
+++ b/
@@ -2,6 +2,7 @@ from types import MappingProxyType
 import sympy as sp
 from pystencils.assignment import Assignment
 from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAssignment
+from pystencils.cpu.vectorization import vectorize
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.transformations import remove_conditionals_in_staggered_kernel
@@ -25,9 +26,10 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
                      pairs ``[(x_lower_gl, x_upper_gl), .... ]``
         cpu_openmp: True or number of threads for OpenMP parallelization, False for no OpenMP
-        cpu_vectorize_info: pair of instruction set name, i.e. one of 'sse, 'avx' or 'avx512'
-                            and data type 'float' or 'double'. For example ``('avx', 'double')``
-        gpu_indexing: either 'block' or 'line' , or custom indexing class, see `pystencils.gpucuda.AbstractIndexing`
+        cpu_vectorize_info: a dictionary with keys, 'vector_instruction_set', 'assume_aligned' and 'nontemporal'
+                            for documentation of these parameters see vectorize function. Example:
+                            '{'vector_instruction_set': 'avx512', 'assume_aligned': True, 'nontemporal':True}'
+        gpu_indexing: either 'block' or 'line' , or custom indexing class, see `AbstractIndexing`
         gpu_indexing_params: dict with indexing parameters (constructor parameters of indexing class)
                              e.g. for 'block' one can specify '{'block_size': (20, 20, 10) }'
@@ -70,12 +72,12 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
         if cpu_openmp:
             add_openmp(ast, num_threads=cpu_openmp)
         if cpu_vectorize_info:
-            import pystencils.backends.simd_instruction_sets as vec
-            from pystencils.cpu.vectorization import vectorize
-            vec_params = cpu_vectorize_info
-            vec.selected_instruction_set = vec.x86_vector_instruction_set(instruction_set=vec_params[0],
-                                                                          data_type=vec_params[1])
-            vectorize(ast)
+            if cpu_vectorize_info is True:
+                vectorize(ast, vector_instruction_set='avx', assume_aligned=False, nontemporal=None)
+            elif isinstance(cpu_vectorize_info, dict):
+                vectorize(ast, **cpu_vectorize_info)
+            else:
+                raise ValueError("Invalid value for cpu_vectorize_info")
         return ast
     elif target == 'llvm':
         from pystencils.llvm import create_kernel
diff --git a/ b/
index 8893c93f8..b2c6f5e27 100644
--- a/
+++ b/
@@ -13,10 +13,13 @@ from pystencils.slicing import normalize_slice
 import pystencils.astnodes as ast
-def filtered_tree_iteration(node, node_type):
+def filtered_tree_iteration(node, node_type, stop_type=None):
     for arg in node.args:
         if isinstance(arg, node_type):
             yield arg
+        elif stop_type and isinstance(node, stop_type):
+            continue
         yield from filtered_tree_iteration(arg, node_type)
@@ -590,8 +593,10 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
     for tmp_array in symbols_with_temporary_array:
         tmp_array_pointer = TypedSymbol(, PointerType(tmp_array.dtype))
-        outer_loop.parent.insert_front(ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop))
-        outer_loop.parent.append(ast.TemporaryMemoryFree(tmp_array_pointer))
+        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
+        free_node = ast.TemporaryMemoryFree(alloc_node)
+        outer_loop.parent.insert_front(alloc_node)
+        outer_loop.parent.append(free_node)
 def cut_loop(loop_node, cutting_points):