Commit f7cda45b authored by Martin Bauer's avatar Martin Bauer
Browse files

Operation counting: support for piecewise functions

- all branches are added up
parent 8a89d0bc
...@@ -548,7 +548,7 @@ class Field: ...@@ -548,7 +548,7 @@ class Field:
offset_list[coord_id] += offset offset_list[coord_id] += offset
return Field.Access(self.field, tuple(offset_list), self.index) return Field.Access(self.field, tuple(offset_list), self.index)
def get_shifted(self, *shift)-> 'Field.Access': def get_shifted(self, *shift) -> 'Field.Access':
"""Returns a new Access with changed spatial coordinates """Returns a new Access with changed spatial coordinates
Example: Example:
......
from jinja2 import Template from jinja2 import Template
from pystencils.backends.cbackend import generate_c from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.sympyextensions import prod from pystencils.sympyextensions import prod
from pystencils.data_types import get_base_type from pystencils.data_types import get_base_type
...@@ -9,6 +9,8 @@ benchmark_template = Template(""" ...@@ -9,6 +9,8 @@ benchmark_template = Template("""
#include <stdint.h> #include <stdint.h>
#include <stdbool.h> #include <stdbool.h>
#include <math.h> #include <math.h>
{{ includes }}
{%- if likwid %} {%- if likwid %}
#include <likwid.h> #include <likwid.h>
{%- endif %} {%- endif %}
...@@ -98,6 +100,9 @@ def generate_benchmark(ast, likwid=False): ...@@ -98,6 +100,9 @@ def generate_benchmark(ast, likwid=False):
fields.append((p.field_name, dtype, prod(field.shape))) fields.append((p.field_name, dtype, prod(field.shape)))
call_parameters.append(p.field_name) call_parameters.append(p.field_name)
header_list = get_headers(ast)
includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
args = { args = {
'likwid': likwid, 'likwid': likwid,
'kernel_code': generate_c(ast, dialect='c'), 'kernel_code': generate_c(ast, dialect='c'),
...@@ -105,5 +110,6 @@ def generate_benchmark(ast, likwid=False): ...@@ -105,5 +110,6 @@ def generate_benchmark(ast, likwid=False):
'fields': fields, 'fields': fields,
'constants': constants, 'constants': constants,
'call_argument_list': ",".join(call_parameters), 'call_argument_list': ",".join(call_parameters),
'includes': includes,
} }
return benchmark_template.render(**args) return benchmark_template.render(**args)
...@@ -15,6 +15,7 @@ from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFie ...@@ -15,6 +15,7 @@ from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFie
from pystencils.field import get_layout_from_strides from pystencils.field import get_layout_from_strides
from pystencils.sympyextensions import count_operations_in_ast from pystencils.sympyextensions import count_operations_in_ast
from pystencils.utils import DotDict from pystencils.utils import DotDict
import warnings
class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
...@@ -30,9 +31,9 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): ...@@ -30,9 +31,9 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
Args: Args:
ast: pystencils ast ast: pystencils ast
machine: kerncraft machine model - specify this if kernel needs to be compiled machine: kerncraft machine model - specify this if kernel needs to be compiled
assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index coordinates is not assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index
known. In this case either a structures of array (SoA) or array of structures (AoS) layout coordinates is not known. In this case either a structures of array (SoA) or
is assumed array of structures (AoS) layout is assumed
""" """
super(PyStencilsKerncraftKernel, self).__init__(machine) super(PyStencilsKerncraftKernel, self).__init__(machine)
...@@ -43,9 +44,10 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): ...@@ -43,9 +44,10 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
inner_loops = [l for l in ast.atoms(LoopOverCoordinate) if l.is_innermost_loop] inner_loops = [l for l in ast.atoms(LoopOverCoordinate) if l.is_innermost_loop]
if len(inner_loops) == 0: if len(inner_loops) == 0:
raise ValueError("No loop found in pystencils AST") raise ValueError("No loop found in pystencils AST")
elif len(inner_loops) > 1:
raise ValueError("pystencils AST contains multiple inner loops - only one can be analyzed")
else: else:
if len(inner_loops) > 1:
warnings.warn("pystencils AST contains multiple inner loops. "
"Only one can be analyzed - choosing first one")
inner_loop = inner_loops[0] inner_loop = inner_loops[0]
self._loop_stack = [] self._loop_stack = []
...@@ -97,7 +99,6 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): ...@@ -97,7 +99,6 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
self.datatype = list(self.variables.values())[0][0] self.datatype = list(self.variables.values())[0][0]
# flops # flops
# FIXME operation_count
operation_count = count_operations_in_ast(inner_loop) operation_count = count_operations_in_ast(inner_loop)
self._flops = { self._flops = {
'+': operation_count['adds'], '+': operation_count['adds'],
......
...@@ -74,7 +74,7 @@ class Database: ...@@ -74,7 +74,7 @@ class Database:
document.save() document.save()
self.backend.commit() self.backend.commit()
def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator[SimulationResult]: def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator['SimulationResult']:
"""Query using simulation parameters. """Query using simulation parameters.
See blitzdb documentation for filter See blitzdb documentation for filter
......
...@@ -439,10 +439,8 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -439,10 +439,8 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(term, Assignment): elif isinstance(term, Assignment):
term = term.rhs term = term.rhs
if not hasattr(term, 'evalf'): if hasattr(term, 'evalf'):
return result term = term.evalf()
term = term.evalf()
def check_type(e): def check_type(e):
if only_type is None: if only_type is None:
...@@ -495,6 +493,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -495,6 +493,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
else: else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, " warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate") "counting will be inaccurate")
elif t.func is sp.Piecewise:
for child_term, condition in t.args:
visit(child_term)
visit_children = False
else: else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate") warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment