From a9d6eb0a04f2df080fe5ddc4b6b4721f1c85debe Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 2 Aug 2019 18:30:50 +0200 Subject: [PATCH] Add test_floor_ceil_int_optimization --- pystencils/simp/assignment_collection.py | 4 +- .../test_floor_ceil_int_optimization.py | 75 ++++++++++++++ ...st_match_subs_for_assignment_collection.py | 99 +++++++++++++++++++ 3 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 pystencils_tests/test_floor_ceil_int_optimization.py create mode 100644 pystencils_tests/test_match_subs_for_assignment_collection.py diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 8e99d5dde..ddb06f714 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -363,7 +363,9 @@ class AssignmentCollection: for k, v in sub_expressions_dict.items()] def find(self, *args, **kwargs): - return set.union(*[a.find(*args, **kwargs) for a in self.all_assignments]) + return set.union( + *[a.find(*args, **kwargs) for a in self.all_assignments] + ) def match(self, *args, **kwargs): rtn = {} diff --git a/pystencils_tests/test_floor_ceil_int_optimization.py b/pystencils_tests/test_floor_ceil_int_optimization.py new file mode 100644 index 000000000..2d6efba7b --- /dev/null +++ b/pystencils_tests/test_floor_ceil_int_optimization.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" + +import sympy as sp + +import pystencils +from pystencils.data_types import create_type + + +def test_floor_ceil_int_optimization(): + x, y = pystencils.fields('x,y: int32[2d]') + a, b, c = sp.symbols('a, b, c') + int_symbol = sp.Symbol('int_symbol', integer=True) + typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('int64')) + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: sp.ceiling(typed_symbol), + c: sp.floor(int_symbol), + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + assert(typed_symbol.is_integer) + print(sp.simplify(sp.ceiling(typed_symbol))) + + print(assignments) + + wild_floor = sp.floor(sp.Wild('w1')) + + assert not sp.floor(int_symbol).match(wild_floor) + assert sp.floor(a).match(wild_floor) + + assert not assignments.find(wild_floor) + + +def test_floor_ceil_float_no_optimization(): + x, y = pystencils.fields('x,y: float32[2d]') + a, b, c = sp.symbols('a, b, c') + int_symbol = sp.Symbol('int_symbol', integer=True) + typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('float32')) + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: sp.ceiling(typed_symbol), + c: sp.floor(int_symbol), + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + assert not typed_symbol.is_integer + print(sp.simplify(sp.ceiling(typed_symbol))) + + print(assignments) + + wild_floor = sp.floor(sp.Wild('w1')) + + assert not sp.floor(int_symbol).match(wild_floor) + assert sp.floor(a).match(wild_floor) + + assert assignments.find(wild_floor) + + +def main(): + test_floor_ceil_int_optimization() + test_floor_ceil_float_no_optimization() + + +if __name__ == '__main__': + main() diff --git a/pystencils_tests/test_match_subs_for_assignment_collection.py b/pystencils_tests/test_match_subs_for_assignment_collection.py new file mode 100644 index 000000000..81dec872d --- /dev/null +++ b/pystencils_tests/test_match_subs_for_assignment_collection.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" + +import sympy as sp + +import pystencils +from pystencils.data_types import create_type + + +def test_wild_typed_symbol(): + x = pystencils.fields('x: float32[3d]') + typed_symbol = pystencils.data_types.TypedSymbol('a', create_type('float64')) + + assert x.center().match(sp.Wild('w1')) + assert typed_symbol.match(sp.Wild('w1')) + + wild_ceiling = sp.ceiling(sp.Wild('w1')) + assert sp.ceiling(x.center()).match(wild_ceiling) + assert sp.ceiling(typed_symbol).match(wild_ceiling) + + +def test_replace_and_subs_for_assignment_collection(): + + x, y = pystencils.fields('x, y: float32[3d]') + a, b, c, d = sp.symbols('a, b, c, d') + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: 2, + c: a + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + expected_assignments = pystencils.AssignmentCollection({ + a: sp.floor(3), + b: 2, + c: a + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + assert expected_assignments == assignments.replace(1, 3) + assert expected_assignments == assignments.subs({1: 3}) + + expected_assignments = pystencils.AssignmentCollection({ + d: sp.floor(1), + b: 2, + c: d + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + print(expected_assignments) + print(assignments.subs(a, d)) + assert expected_assignments == assignments.subs(a, d) + + +def test_match_for_assignment_collection(): + + x, y = pystencils.fields('x, y: float32[3d]') + a, b, c, d = sp.symbols('a, b, c, d') + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: 2, + c: a + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + w1 = sp.Wild('w1') + w2 = sp.Wild('w2') + w3 = sp.Wild('w3') + + wild_ceiling = sp.ceiling(w1) + wild_addition = w1 + w2 + + assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2))[w1] == x.center() + assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2)) == { + w3: y.center(), + w2: sp.floor(x.center()), + w1: x.center() + } + assert assignments.find(wild_ceiling) == {sp.ceiling(x.center())} + assert len([a for a in assignments.find(wild_addition) if isinstance(a, sp.Add)]) == 2 + + +def main(): + test_wild_typed_symbol() + test_replace_and_subs_for_assignment_collection() + test_match_for_assignment_collection() + + +if __name__ == '__main__': + main() -- GitLab