diff --git a/field.py b/field.py index 17c3c7e10def98d9cf0e1caf3058894bc9690023..4b06592e1aafc298da5f066cbfece678c945c39f 100644 --- a/field.py +++ b/field.py @@ -548,7 +548,7 @@ class Field: offset_list[coord_id] += offset return Field.Access(self.field, tuple(offset_list), self.index) - def get_shifted(self, *shift)-> 'Field.Access': + def get_shifted(self, *shift) -> 'Field.Access': """Returns a new Access with changed spatial coordinates Example: diff --git a/kerncraft_coupling/generate_benchmark.py b/kerncraft_coupling/generate_benchmark.py index 2c27f3b0e7e5412fb7cdda248cf1072b8e835ef7..89565b5711569f02d1b29dd55c06d70c776a171e 100644 --- a/kerncraft_coupling/generate_benchmark.py +++ b/kerncraft_coupling/generate_benchmark.py @@ -1,5 +1,5 @@ from jinja2 import Template -from pystencils.backends.cbackend import generate_c +from pystencils.backends.cbackend import generate_c, get_headers from pystencils.sympyextensions import prod from pystencils.data_types import get_base_type @@ -9,6 +9,8 @@ benchmark_template = Template(""" #include <stdint.h> #include <stdbool.h> #include <math.h> +{{ includes }} + {%- if likwid %} #include <likwid.h> {%- endif %} @@ -98,6 +100,9 @@ def generate_benchmark(ast, likwid=False): fields.append((p.field_name, dtype, prod(field.shape))) call_parameters.append(p.field_name) + header_list = get_headers(ast) + includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list]) + args = { 'likwid': likwid, 'kernel_code': generate_c(ast, dialect='c'), @@ -105,5 +110,6 @@ def generate_benchmark(ast, likwid=False): 'fields': fields, 'constants': constants, 'call_argument_list': ",".join(call_parameters), + 'includes': includes, } return benchmark_template.render(**args) diff --git a/kerncraft_coupling/kerncraft_interface.py b/kerncraft_coupling/kerncraft_interface.py index b8d3e5e4ea2076ba8bf85650e1673befd8008ded..d40b42e9aab16f2b0ad003b02cb83a5d89d2b277 100644 --- a/kerncraft_coupling/kerncraft_interface.py +++ b/kerncraft_coupling/kerncraft_interface.py @@ -15,6 +15,7 @@ from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFie from pystencils.field import get_layout_from_strides from pystencils.sympyextensions import count_operations_in_ast from pystencils.utils import DotDict +import warnings class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): @@ -30,9 +31,9 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): Args: ast: pystencils ast machine: kerncraft machine model - specify this if kernel needs to be compiled - assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index coordinates is not - known. In this case either a structures of array (SoA) or array of structures (AoS) layout - is assumed + assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index + coordinates is not known. In this case either a structures of array (SoA) or + array of structures (AoS) layout is assumed """ super(PyStencilsKerncraftKernel, self).__init__(machine) @@ -43,9 +44,10 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): inner_loops = [l for l in ast.atoms(LoopOverCoordinate) if l.is_innermost_loop] if len(inner_loops) == 0: raise ValueError("No loop found in pystencils AST") - elif len(inner_loops) > 1: - raise ValueError("pystencils AST contains multiple inner loops - only one can be analyzed") else: + if len(inner_loops) > 1: + warnings.warn("pystencils AST contains multiple inner loops. " + "Only one can be analyzed - choosing first one") inner_loop = inner_loops[0] self._loop_stack = [] @@ -97,7 +99,6 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): self.datatype = list(self.variables.values())[0][0] # flops - # FIXME operation_count operation_count = count_operations_in_ast(inner_loop) self._flops = { '+': operation_count['adds'], diff --git a/runhelper/db.py b/runhelper/db.py index a19cc383aa50889b2fd94d6c0c0d712caa8657b9..e0655db3238b4ea9f4d70fb47695201bb0f00c38 100644 --- a/runhelper/db.py +++ b/runhelper/db.py @@ -74,7 +74,7 @@ class Database: document.save() self.backend.commit() - def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator[SimulationResult]: + def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator['SimulationResult']: """Query using simulation parameters. See blitzdb documentation for filter diff --git a/sympyextensions.py b/sympyextensions.py index 1be095058ff0b155b6285f78d368de69e3ef2c7b..8ecc5ece9f8a1c9f8be639ee830b15c6bd7172b3 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -439,10 +439,8 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], elif isinstance(term, Assignment): term = term.rhs - if not hasattr(term, 'evalf'): - return result - - term = term.evalf() + if hasattr(term, 'evalf'): + term = term.evalf() def check_type(e): if only_type is None: @@ -495,6 +493,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], else: warnings.warn("Counting operations: only integer exponents are supported in Pow, " "counting will be inaccurate") + elif t.func is sp.Piecewise: + for child_term, condition in t.args: + visit(child_term) + visit_children = False else: warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")