Commit f7cda45b authored by Martin Bauer's avatar Martin Bauer
Browse files

Operation counting: support for piecewise functions

- all branches are added up
parent 8a89d0bc
......@@ -548,7 +548,7 @@ class Field:
offset_list[coord_id] += offset
return Field.Access(self.field, tuple(offset_list), self.index)
def get_shifted(self, *shift)-> 'Field.Access':
def get_shifted(self, *shift) -> 'Field.Access':
"""Returns a new Access with changed spatial coordinates
Example:
......
from jinja2 import Template
from pystencils.backends.cbackend import generate_c
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.sympyextensions import prod
from pystencils.data_types import get_base_type
......@@ -9,6 +9,8 @@ benchmark_template = Template("""
#include <stdint.h>
#include <stdbool.h>
#include <math.h>
{{ includes }}
{%- if likwid %}
#include <likwid.h>
{%- endif %}
......@@ -98,6 +100,9 @@ def generate_benchmark(ast, likwid=False):
fields.append((p.field_name, dtype, prod(field.shape)))
call_parameters.append(p.field_name)
header_list = get_headers(ast)
includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
args = {
'likwid': likwid,
'kernel_code': generate_c(ast, dialect='c'),
......@@ -105,5 +110,6 @@ def generate_benchmark(ast, likwid=False):
'fields': fields,
'constants': constants,
'call_argument_list': ",".join(call_parameters),
'includes': includes,
}
return benchmark_template.render(**args)
......@@ -15,6 +15,7 @@ from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFie
from pystencils.field import get_layout_from_strides
from pystencils.sympyextensions import count_operations_in_ast
from pystencils.utils import DotDict
import warnings
class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
......@@ -30,9 +31,9 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
Args:
ast: pystencils ast
machine: kerncraft machine model - specify this if kernel needs to be compiled
assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index coordinates is not
known. In this case either a structures of array (SoA) or array of structures (AoS) layout
is assumed
assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index
coordinates is not known. In this case either a structures of array (SoA) or
array of structures (AoS) layout is assumed
"""
super(PyStencilsKerncraftKernel, self).__init__(machine)
......@@ -43,9 +44,10 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
inner_loops = [l for l in ast.atoms(LoopOverCoordinate) if l.is_innermost_loop]
if len(inner_loops) == 0:
raise ValueError("No loop found in pystencils AST")
elif len(inner_loops) > 1:
raise ValueError("pystencils AST contains multiple inner loops - only one can be analyzed")
else:
if len(inner_loops) > 1:
warnings.warn("pystencils AST contains multiple inner loops. "
"Only one can be analyzed - choosing first one")
inner_loop = inner_loops[0]
self._loop_stack = []
......@@ -97,7 +99,6 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
self.datatype = list(self.variables.values())[0][0]
# flops
# FIXME operation_count
operation_count = count_operations_in_ast(inner_loop)
self._flops = {
'+': operation_count['adds'],
......
......@@ -74,7 +74,7 @@ class Database:
document.save()
self.backend.commit()
def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator[SimulationResult]:
def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator['SimulationResult']:
"""Query using simulation parameters.
See blitzdb documentation for filter
......
......@@ -439,9 +439,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(term, Assignment):
term = term.rhs
if not hasattr(term, 'evalf'):
return result
if hasattr(term, 'evalf'):
term = term.evalf()
def check_type(e):
......@@ -495,6 +493,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate")
elif t.func is sp.Piecewise:
for child_term, condition in t.args:
visit(child_term)
visit_children = False
else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment