diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 8e99d5dde0876f607c0af4a71c23984c4c9f2229..ddb06f7148ac3a576299cf11e0dee8da5b7cfa6c 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -363,7 +363,9 @@ class AssignmentCollection: for k, v in sub_expressions_dict.items()] def find(self, *args, **kwargs): - return set.union(*[a.find(*args, **kwargs) for a in self.all_assignments]) + return set.union( + *[a.find(*args, **kwargs) for a in self.all_assignments] + ) def match(self, *args, **kwargs): rtn = {} diff --git a/pystencils_tests/test_floor_ceil_int_optimization.py b/pystencils_tests/test_floor_ceil_int_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..2d6efba7b1431a8f47c0dfe317d48ff9bff7b9a2 --- /dev/null +++ b/pystencils_tests/test_floor_ceil_int_optimization.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" + +import sympy as sp + +import pystencils +from pystencils.data_types import create_type + + +def test_floor_ceil_int_optimization(): + x, y = pystencils.fields('x,y: int32[2d]') + a, b, c = sp.symbols('a, b, c') + int_symbol = sp.Symbol('int_symbol', integer=True) + typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('int64')) + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: sp.ceiling(typed_symbol), + c: sp.floor(int_symbol), + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + assert(typed_symbol.is_integer) + print(sp.simplify(sp.ceiling(typed_symbol))) + + print(assignments) + + wild_floor = sp.floor(sp.Wild('w1')) + + assert not sp.floor(int_symbol).match(wild_floor) + assert sp.floor(a).match(wild_floor) + + assert not assignments.find(wild_floor) + + +def test_floor_ceil_float_no_optimization(): + x, y = pystencils.fields('x,y: float32[2d]') + a, b, c = sp.symbols('a, b, c') + int_symbol = sp.Symbol('int_symbol', integer=True) + typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('float32')) + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: sp.ceiling(typed_symbol), + c: sp.floor(int_symbol), + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + assert not typed_symbol.is_integer + print(sp.simplify(sp.ceiling(typed_symbol))) + + print(assignments) + + wild_floor = sp.floor(sp.Wild('w1')) + + assert not sp.floor(int_symbol).match(wild_floor) + assert sp.floor(a).match(wild_floor) + + assert assignments.find(wild_floor) + + +def main(): + test_floor_ceil_int_optimization() + test_floor_ceil_float_no_optimization() + + +if __name__ == '__main__': + main() diff --git a/pystencils_tests/test_match_subs_for_assignment_collection.py b/pystencils_tests/test_match_subs_for_assignment_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..81dec872d115f3b04876845dcfff040894bd5353 --- /dev/null +++ b/pystencils_tests/test_match_subs_for_assignment_collection.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" + +import sympy as sp + +import pystencils +from pystencils.data_types import create_type + + +def test_wild_typed_symbol(): + x = pystencils.fields('x: float32[3d]') + typed_symbol = pystencils.data_types.TypedSymbol('a', create_type('float64')) + + assert x.center().match(sp.Wild('w1')) + assert typed_symbol.match(sp.Wild('w1')) + + wild_ceiling = sp.ceiling(sp.Wild('w1')) + assert sp.ceiling(x.center()).match(wild_ceiling) + assert sp.ceiling(typed_symbol).match(wild_ceiling) + + +def test_replace_and_subs_for_assignment_collection(): + + x, y = pystencils.fields('x, y: float32[3d]') + a, b, c, d = sp.symbols('a, b, c, d') + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: 2, + c: a + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + expected_assignments = pystencils.AssignmentCollection({ + a: sp.floor(3), + b: 2, + c: a + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + assert expected_assignments == assignments.replace(1, 3) + assert expected_assignments == assignments.subs({1: 3}) + + expected_assignments = pystencils.AssignmentCollection({ + d: sp.floor(1), + b: 2, + c: d + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + print(expected_assignments) + print(assignments.subs(a, d)) + assert expected_assignments == assignments.subs(a, d) + + +def test_match_for_assignment_collection(): + + x, y = pystencils.fields('x, y: float32[3d]') + a, b, c, d = sp.symbols('a, b, c, d') + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: 2, + c: a + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + w1 = sp.Wild('w1') + w2 = sp.Wild('w2') + w3 = sp.Wild('w3') + + wild_ceiling = sp.ceiling(w1) + wild_addition = w1 + w2 + + assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2))[w1] == x.center() + assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2)) == { + w3: y.center(), + w2: sp.floor(x.center()), + w1: x.center() + } + assert assignments.find(wild_ceiling) == {sp.ceiling(x.center())} + assert len([a for a in assignments.find(wild_addition) if isinstance(a, sp.Add)]) == 2 + + +def main(): + test_wild_typed_symbol() + test_replace_and_subs_for_assignment_collection() + test_match_for_assignment_collection() + + +if __name__ == '__main__': + main()