# -*- coding: utf-8 -*- # # Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> # # Distributed under terms of the GPLv3 license. """ """ import sympy as sp import pystencils from pystencils.data_types import create_type def test_wild_typed_symbol(): x = pystencils.fields('x: float32[3d]') typed_symbol = pystencils.data_types.TypedSymbol('a', create_type('float64')) assert x.center().match(sp.Wild('w1')) assert typed_symbol.match(sp.Wild('w1')) wild_ceiling = sp.ceiling(sp.Wild('w1')) assert sp.ceiling(x.center()).match(wild_ceiling) assert sp.ceiling(typed_symbol).match(wild_ceiling) def test_replace_and_subs_for_assignment_collection(): x, y = pystencils.fields('x, y: float32[3d]') a, b, c, d = sp.symbols('a, b, c, d') assignments = pystencils.AssignmentCollection({ a: sp.floor(1), b: 2, c: a + c, y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) }) expected_assignments = pystencils.AssignmentCollection({ a: sp.floor(3), b: 2, c: a + c, y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) }) assert expected_assignments == assignments.replace(1, 3) assert expected_assignments == assignments.subs({1: 3}) expected_assignments = pystencils.AssignmentCollection({ d: sp.floor(1), b: 2, c: d + c, y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) }) print(expected_assignments) print(assignments.subs(a, d)) assert expected_assignments == assignments.subs(a, d) def test_match_for_assignment_collection(): x, y = pystencils.fields('x, y: float32[3d]') a, b, c, d = sp.symbols('a, b, c, d') assignments = pystencils.AssignmentCollection({ a: sp.floor(1), b: 2, c: a + c, y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) }) w1 = sp.Wild('w1') w2 = sp.Wild('w2') w3 = sp.Wild('w3') wild_ceiling = sp.ceiling(w1) wild_addition = w1 + w2 assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2))[w1] == x.center() assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2)) == { w3: y.center(), w2: sp.floor(x.center()), w1: x.center() } assert assignments.find(wild_ceiling) == {sp.ceiling(x.center())} assert len([a for a in assignments.find(wild_addition) if isinstance(a, sp.Add)]) == 2 def main(): test_wild_typed_symbol() test_replace_and_subs_for_assignment_collection() test_match_for_assignment_collection() if __name__ == '__main__': main()