From d0a0696392116eed8be18c52c852f3ada6ab4723 Mon Sep 17 00:00:00 2001 From: markus holzer <markus.holzer@fau.de> Date: Sun, 9 Aug 2020 08:32:26 +0200 Subject: [PATCH] Fixed operation count --- pystencils/sympyextensions.py | 5 +++-- pystencils_tests/test_datahandling.py | 6 +++--- pystencils_tests/test_sympyextensions.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 55fe74967..31e42224a 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -481,7 +481,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], pass elif t.func is sp.Mul: if check_type(t): - result['muls'] += len(t.args) + result['muls'] += len(t.args) - 1 for a in t.args: if a == 1 or a == -1: result['muls'] -= 1 @@ -509,7 +509,8 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], if t.exp >= 0: result['muls'] += int(t.exp) - 1 else: - result['muls'] -= 1 + if result['muls'] > 0: + result['muls'] -= 1 result['divs'] += 1 result['muls'] += (-int(t.exp)) - 1 elif sp.nsimplify(t.exp) == sp.Rational(1, 2): diff --git a/pystencils_tests/test_datahandling.py b/pystencils_tests/test_datahandling.py index 4d6dd72a4..6e53d1e8b 100644 --- a/pystencils_tests/test_datahandling.py +++ b/pystencils_tests/test_datahandling.py @@ -116,9 +116,6 @@ def kernel_execution_jacobi(dh, target): assert dh.is_on_gpu('f') assert dh.is_on_gpu('tmp') - with pytest.raises(ValueError): - dh.add_array('f', gpu=test_gpu) - stencil_2d = [(1, 0), (-1, 0), (0, 1), (0, -1)] stencil_3d = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)] stencil = stencil_2d if dh.dim == 2 else stencil_3d @@ -263,6 +260,9 @@ def test_get_kwarg(): dh.fill("src", 1.0, ghost_layers=True) dh.fill("dst", 0.0, ghost_layers=True) + with pytest.raises(ValueError): + dh.add_array('src') + ur = ps.Assignment(src.center, dst.center) kernel = ps.create_kernel(ur).compile() diff --git a/pystencils_tests/test_sympyextensions.py b/pystencils_tests/test_sympyextensions.py index 1636df632..2135ee88e 100644 --- a/pystencils_tests/test_sympyextensions.py +++ b/pystencils_tests/test_sympyextensions.py @@ -119,7 +119,7 @@ def test_count_operations(): expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100) ops = count_operations(expr, only_type=None) assert ops['adds'] == 1 - assert ops['muls'] == 100 + assert ops['muls'] == 99 assert ops['divs'] == 1 assert ops['sqrts'] == 1 -- GitLab