Commit a9d6eb0a authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add test_floor_ceil_int_optimization

parent 18ab6246
Pipeline #16897 passed with stage
in 5 minutes and 21 seconds
......@@ -363,7 +363,9 @@ class AssignmentCollection:
for k, v in sub_expressions_dict.items()]
def find(self, *args, **kwargs):
return set.union(*[a.find(*args, **kwargs) for a in self.all_assignments])
return set.union(
*[a.find(*args, **kwargs) for a in self.all_assignments]
)
def match(self, *args, **kwargs):
rtn = {}
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy as sp
import pystencils
from pystencils.data_types import create_type
def test_floor_ceil_int_optimization():
x, y = pystencils.fields('x,y: int32[2d]')
a, b, c = sp.symbols('a, b, c')
int_symbol = sp.Symbol('int_symbol', integer=True)
typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('int64'))
assignments = pystencils.AssignmentCollection({
a: sp.floor(1),
b: sp.ceiling(typed_symbol),
c: sp.floor(int_symbol),
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
assert(typed_symbol.is_integer)
print(sp.simplify(sp.ceiling(typed_symbol)))
print(assignments)
wild_floor = sp.floor(sp.Wild('w1'))
assert not sp.floor(int_symbol).match(wild_floor)
assert sp.floor(a).match(wild_floor)
assert not assignments.find(wild_floor)
def test_floor_ceil_float_no_optimization():
x, y = pystencils.fields('x,y: float32[2d]')
a, b, c = sp.symbols('a, b, c')
int_symbol = sp.Symbol('int_symbol', integer=True)
typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('float32'))
assignments = pystencils.AssignmentCollection({
a: sp.floor(1),
b: sp.ceiling(typed_symbol),
c: sp.floor(int_symbol),
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
assert not typed_symbol.is_integer
print(sp.simplify(sp.ceiling(typed_symbol)))
print(assignments)
wild_floor = sp.floor(sp.Wild('w1'))
assert not sp.floor(int_symbol).match(wild_floor)
assert sp.floor(a).match(wild_floor)
assert assignments.find(wild_floor)
def main():
test_floor_ceil_int_optimization()
test_floor_ceil_float_no_optimization()
if __name__ == '__main__':
main()
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy as sp
import pystencils
from pystencils.data_types import create_type
def test_wild_typed_symbol():
x = pystencils.fields('x: float32[3d]')
typed_symbol = pystencils.data_types.TypedSymbol('a', create_type('float64'))
assert x.center().match(sp.Wild('w1'))
assert typed_symbol.match(sp.Wild('w1'))
wild_ceiling = sp.ceiling(sp.Wild('w1'))
assert sp.ceiling(x.center()).match(wild_ceiling)
assert sp.ceiling(typed_symbol).match(wild_ceiling)
def test_replace_and_subs_for_assignment_collection():
x, y = pystencils.fields('x, y: float32[3d]')
a, b, c, d = sp.symbols('a, b, c, d')
assignments = pystencils.AssignmentCollection({
a: sp.floor(1),
b: 2,
c: a + c,
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
expected_assignments = pystencils.AssignmentCollection({
a: sp.floor(3),
b: 2,
c: a + c,
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
assert expected_assignments == assignments.replace(1, 3)
assert expected_assignments == assignments.subs({1: 3})
expected_assignments = pystencils.AssignmentCollection({
d: sp.floor(1),
b: 2,
c: d + c,
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
print(expected_assignments)
print(assignments.subs(a, d))
assert expected_assignments == assignments.subs(a, d)
def test_match_for_assignment_collection():
x, y = pystencils.fields('x, y: float32[3d]')
a, b, c, d = sp.symbols('a, b, c, d')
assignments = pystencils.AssignmentCollection({
a: sp.floor(1),
b: 2,
c: a + c,
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
w1 = sp.Wild('w1')
w2 = sp.Wild('w2')
w3 = sp.Wild('w3')
wild_ceiling = sp.ceiling(w1)
wild_addition = w1 + w2
assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2))[w1] == x.center()
assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2)) == {
w3: y.center(),
w2: sp.floor(x.center()),
w1: x.center()
}
assert assignments.find(wild_ceiling) == {sp.ceiling(x.center())}
assert len([a for a in assignments.find(wild_addition) if isinstance(a, sp.Add)]) == 2
def main():
test_wild_typed_symbol()
test_replace_and_subs_for_assignment_collection()
test_match_for_assignment_collection()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment