diff --git a/pystencils/kerncraft_coupling/kerncraft_interface.py b/pystencils/kerncraft_coupling/kerncraft_interface.py index d40b42e9aab16f2b0ad003b02cb83a5d89d2b277..ebdcccab2e63d7ab1a472dd603f7653b7b8625de 100644 --- a/pystencils/kerncraft_coupling/kerncraft_interface.py +++ b/pystencils/kerncraft_coupling/kerncraft_interface.py @@ -14,6 +14,7 @@ from pystencils.kerncraft_coupling.generate_benchmark import generate_benchmark from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFieldAccess, KernelFunction from pystencils.field import get_layout_from_strides from pystencils.sympyextensions import count_operations_in_ast +from pystencils.transformations import filtered_tree_iteration from pystencils.utils import DotDict import warnings @@ -25,7 +26,8 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): """ LIKWID_BASE = '/usr/local/likwid' - def __init__(self, ast: KernelFunction, machine: Optional[MachineModel] = None, assumed_layout='SoA'): + def __init__(self, ast: KernelFunction, machine: Optional[MachineModel] = None, assumed_layout='SoA', + debug_print=False): """Create a kerncraft kernel using a pystencils AST Args: @@ -41,7 +43,8 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): self.temporary_dir = TemporaryDirectory() # Loops - inner_loops = [l for l in ast.atoms(LoopOverCoordinate) if l.is_innermost_loop] + inner_loops = [l for l in filtered_tree_iteration(ast, LoopOverCoordinate, stop_type=SympyAssignment) + if l.is_innermost_loop] if len(inner_loops) == 0: raise ValueError("No loop found in pystencils AST") else: @@ -109,6 +112,17 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): del self._flops[k] self.check() + if debug_print: + from pprint import pprint + print("----------------------------- Loop Stack --------------------------") + pprint(self._loop_stack) + print("----------------------------- Sources -----------------------------") + pprint(self.sources) + print("----------------------------- Destinations ------------------------") + pprint(self.destinations) + print("----------------------------- FLOPS -------------------------------") + pprint(self._flops) + def iaca_analysis(self, micro_architecture, asm_block='auto', pointer_increment='auto_with_manual_fallback', verbose=False): compiler, compiler_args = self._machine.get_compiler() diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 8ecc5ece9f8a1c9f8be639ee830b15c6bd7172b3..22a2035e9ae53b2e176a1c5355d663103b9d8d25 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -497,6 +497,8 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], for child_term, condition in t.args: visit(child_term) visit_children = False + elif isinstance(t, sp.Rel): + pass else: warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")