From f7cda45b2ce5d1befaec075ac897a1707190d186 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 15 Mar 2019 16:53:31 +0100
Subject: [PATCH] Operation counting: support for piecewise functions

- all branches are added up
---
 field.py                                  |  2 +-
 kerncraft_coupling/generate_benchmark.py  |  8 +++++++-
 kerncraft_coupling/kerncraft_interface.py | 13 +++++++------
 runhelper/db.py                           |  2 +-
 sympyextensions.py                        | 10 ++++++----
 5 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/field.py b/field.py
index 17c3c7e10..4b06592e1 100644
--- a/field.py
+++ b/field.py
@@ -548,7 +548,7 @@ class Field:
             offset_list[coord_id] += offset
             return Field.Access(self.field, tuple(offset_list), self.index)
 
-        def get_shifted(self, *shift)-> 'Field.Access':
+        def get_shifted(self, *shift) -> 'Field.Access':
             """Returns a new Access with changed spatial coordinates
 
             Example:
diff --git a/kerncraft_coupling/generate_benchmark.py b/kerncraft_coupling/generate_benchmark.py
index 2c27f3b0e..89565b571 100644
--- a/kerncraft_coupling/generate_benchmark.py
+++ b/kerncraft_coupling/generate_benchmark.py
@@ -1,5 +1,5 @@
 from jinja2 import Template
-from pystencils.backends.cbackend import generate_c
+from pystencils.backends.cbackend import generate_c, get_headers
 from pystencils.sympyextensions import prod
 from pystencils.data_types import get_base_type
 
@@ -9,6 +9,8 @@ benchmark_template = Template("""
 #include <stdint.h>
 #include <stdbool.h>
 #include <math.h>
+{{ includes }}
+
 {%- if likwid %}
 #include <likwid.h>
 {%- endif %}
@@ -98,6 +100,9 @@ def generate_benchmark(ast, likwid=False):
             fields.append((p.field_name, dtype, prod(field.shape)))
             call_parameters.append(p.field_name)
 
+    header_list = get_headers(ast)
+    includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
+
     args = {
         'likwid': likwid,
         'kernel_code': generate_c(ast, dialect='c'),
@@ -105,5 +110,6 @@ def generate_benchmark(ast, likwid=False):
         'fields': fields,
         'constants': constants,
         'call_argument_list': ",".join(call_parameters),
+        'includes': includes,
     }
     return benchmark_template.render(**args)
diff --git a/kerncraft_coupling/kerncraft_interface.py b/kerncraft_coupling/kerncraft_interface.py
index b8d3e5e4e..d40b42e9a 100644
--- a/kerncraft_coupling/kerncraft_interface.py
+++ b/kerncraft_coupling/kerncraft_interface.py
@@ -15,6 +15,7 @@ from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFie
 from pystencils.field import get_layout_from_strides
 from pystencils.sympyextensions import count_operations_in_ast
 from pystencils.utils import DotDict
+import warnings
 
 
 class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
@@ -30,9 +31,9 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
         Args:
             ast: pystencils ast
             machine: kerncraft machine model - specify this if kernel needs to be compiled
-            assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index coordinates is not
-                    known. In this case either a structures of array (SoA) or array of structures (AoS) layout
-                    is assumed
+            assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index
+                    coordinates is not known. In this case either a structures of array (SoA) or
+                    array of structures (AoS) layout is assumed
         """
         super(PyStencilsKerncraftKernel, self).__init__(machine)
 
@@ -43,9 +44,10 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
         inner_loops = [l for l in ast.atoms(LoopOverCoordinate) if l.is_innermost_loop]
         if len(inner_loops) == 0:
             raise ValueError("No loop found in pystencils AST")
-        elif len(inner_loops) > 1:
-            raise ValueError("pystencils AST contains multiple inner loops - only one can be analyzed")
         else:
+            if len(inner_loops) > 1:
+                warnings.warn("pystencils AST contains multiple inner loops. "
+                              "Only one can be analyzed - choosing first one")
             inner_loop = inner_loops[0]
 
         self._loop_stack = []
@@ -97,7 +99,6 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
         self.datatype = list(self.variables.values())[0][0]
 
         # flops
-        # FIXME operation_count
         operation_count = count_operations_in_ast(inner_loop)
         self._flops = {
             '+': operation_count['adds'],
diff --git a/runhelper/db.py b/runhelper/db.py
index a19cc383a..e0655db32 100644
--- a/runhelper/db.py
+++ b/runhelper/db.py
@@ -74,7 +74,7 @@ class Database:
         document.save()
         self.backend.commit()
 
-    def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator[SimulationResult]:
+    def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator['SimulationResult']:
         """Query using simulation parameters.
 
         See blitzdb documentation for filter
diff --git a/sympyextensions.py b/sympyextensions.py
index 1be095058..8ecc5ece9 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -439,10 +439,8 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
     elif isinstance(term, Assignment):
         term = term.rhs
 
-    if not hasattr(term, 'evalf'):
-        return result
-
-    term = term.evalf()
+    if hasattr(term, 'evalf'):
+        term = term.evalf()
 
     def check_type(e):
         if only_type is None:
@@ -495,6 +493,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
             else:
                 warnings.warn("Counting operations: only integer exponents are supported in Pow, "
                               "counting will be inaccurate")
+        elif t.func is sp.Piecewise:
+            for child_term, condition in t.args:
+                visit(child_term)
+            visit_children = False
         else:
             warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")
 
-- 
GitLab