Skip to content
Snippets Groups Projects
Commit 5cf6febc authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

don't zero cachelines beyond the end of a field

parent 2d5fcf53
No related merge requests found
...@@ -8,7 +8,7 @@ import sympy as sp ...@@ -8,7 +8,7 @@ import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.logic.boolalg import BooleanFalse, BooleanTrue from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from pystencils.astnodes import KernelFunction, Node from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
...@@ -293,7 +293,14 @@ class CBackend: ...@@ -293,7 +293,14 @@ class CBackend:
pre_code = '' pre_code = ''
if nontemporal and 'cachelineZero' in self._vector_instruction_set: if nontemporal and 'cachelineZero' in self._vector_instruction_set:
pre_code = f"if (((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0) " + "{\n\t" + \ first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0"
offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i))
* node.lhs.args[0].field.spatial_strides[i] for i in
range(len(node.lhs.args[0].field.spatial_strides))])
size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
element_size = 8 if data_type.base_type.base_name == 'double' else 4
size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
self._vector_instruction_set['cachelineZero'].format(ptr) + ';\n}\n' self._vector_instruction_set['cachelineZero'].format(ptr) + ';\n}\n'
code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
......
...@@ -43,7 +43,7 @@ exclude_lines = ...@@ -43,7 +43,7 @@ exclude_lines =
if __name__ == .__main__.: if __name__ == .__main__.:
skip_covered = True skip_covered = True
fail_under = 88 fail_under = 87
[html] [html]
directory = coverage_report directory = coverage_report
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment