Commit 6942ed0b authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'remove_floor_ceiling_for_integers' into 'master'

Remove floor, ceiling for integer symbols

See merge request pycodegen/pystencils!14
parents 3d00d87f 51d36ff4
......@@ -49,6 +49,17 @@ else:
__str__ = assignment_str
_print_Assignment = print_assignment_latex
# Apparently, in SymPy 1.4 Assignment.__hash__ is not implemented. This has been fixed in current master
try:
sympy_version = sp.__version__.split('.')
if int(sympy_version[0]) <= 1 and int(sympy_version[1]) <= 4:
def hash_fun(self):
return hash((self.lhs, self.rhs))
Assignment.__hash__ = hash_fun
except Exception:
pass
def assignment_from_stencil(stencil_array, input_field, output_field,
normalization_factor=None, order='visual') -> Assignment:
......
......@@ -127,6 +127,31 @@ class TypedSymbol(sp.Symbol):
def __getnewargs__(self):
return self.name, self.dtype
# For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
@property
def is_integer(self):
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False
return super().is_positive
@property
def is_real(self):
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
super().is_real
else:
return super().is_real
def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type.
......
......@@ -10,7 +10,7 @@ import sympy as sp
from sympy.core.cache import cacheit
from pystencils.alignedarray import aligned_empty
from pystencils.data_types import StructType, create_type
from pystencils.data_types import StructType, TypedSymbol, create_type
from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol
from pystencils.stencil import direction_string_to_offset, offset_to_direction_string
from pystencils.sympyextensions import is_integer_sequence
......@@ -410,7 +410,7 @@ class Field(AbstractField):
return self.hashable_contents() == other.hashable_contents()
# noinspection PyAttributeOutsideInit,PyUnresolvedReferences
class Access(sp.Symbol, AbstractField.AbstractAccess):
class Access(TypedSymbol, AbstractField.AbstractAccess):
"""Class representing a relative access into a `Field`.
This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up
......@@ -462,7 +462,7 @@ class Field(AbstractField):
if superscript is not None:
symbol_name += "^" + superscript
obj = super(Field.Access, self).__xnew__(self, symbol_name)
obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype)
obj._field = field
obj._offsets = []
for o in offsets:
......
......@@ -362,6 +362,34 @@ class AssignmentCollection:
self.subexpressions = [Assignment(k, v)
for k, v in sub_expressions_dict.items()]
def find(self, *args, **kwargs):
return set.union(
*[a.find(*args, **kwargs) for a in self.all_assignments]
)
def match(self, *args, **kwargs):
rtn = {}
for a in self.all_assignments:
partial_result = a.match(*args, **kwargs)
if partial_result:
rtn.update(partial_result)
return rtn
def subs(self, *args, **kwargs):
return AssignmentCollection(
main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],
subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]
)
def replace(self, *args, **kwargs):
return AssignmentCollection(
main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],
subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]
)
def __eq__(self, other):
return set(self.all_assignments) == set(other.all_assignments)
class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy as sp
import pystencils
from pystencils.data_types import create_type
def test_floor_ceil_int_optimization():
x, y = pystencils.fields('x,y: int32[2d]')
a, b, c = sp.symbols('a, b, c')
int_symbol = sp.Symbol('int_symbol', integer=True)
typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('int64'))
assignments = pystencils.AssignmentCollection({
a: sp.floor(1),
b: sp.ceiling(typed_symbol),
c: sp.floor(int_symbol),
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
assert(typed_symbol.is_integer)
print(sp.simplify(sp.ceiling(typed_symbol)))
print(assignments)
wild_floor = sp.floor(sp.Wild('w1'))
assert not sp.floor(int_symbol).match(wild_floor)
assert sp.floor(a).match(wild_floor)
assert not assignments.find(wild_floor)
def test_floor_ceil_float_no_optimization():
x, y = pystencils.fields('x,y: float32[2d]')
a, b, c = sp.symbols('a, b, c')
int_symbol = sp.Symbol('int_symbol', integer=True)
typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('float32'))
assignments = pystencils.AssignmentCollection({
a: sp.floor(1),
b: sp.ceiling(typed_symbol),
c: sp.floor(int_symbol),
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
assert not typed_symbol.is_integer
print(sp.simplify(sp.ceiling(typed_symbol)))
print(assignments)
wild_floor = sp.floor(sp.Wild('w1'))
assert not sp.floor(int_symbol).match(wild_floor)
assert sp.floor(a).match(wild_floor)
assert assignments.find(wild_floor)
def main():
test_floor_ceil_int_optimization()
test_floor_ceil_float_no_optimization()
if __name__ == '__main__':
main()
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy as sp
import pystencils
from pystencils.data_types import create_type
def test_wild_typed_symbol():
x = pystencils.fields('x: float32[3d]')
typed_symbol = pystencils.data_types.TypedSymbol('a', create_type('float64'))
assert x.center().match(sp.Wild('w1'))
assert typed_symbol.match(sp.Wild('w1'))
wild_ceiling = sp.ceiling(sp.Wild('w1'))
assert sp.ceiling(x.center()).match(wild_ceiling)
assert sp.ceiling(typed_symbol).match(wild_ceiling)
def test_replace_and_subs_for_assignment_collection():
x, y = pystencils.fields('x, y: float32[3d]')
a, b, c, d = sp.symbols('a, b, c, d')
assignments = pystencils.AssignmentCollection({
a: sp.floor(1),
b: 2,
c: a + c,
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
expected_assignments = pystencils.AssignmentCollection({
a: sp.floor(3),
b: 2,
c: a + c,
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
assert expected_assignments == assignments.replace(1, 3)
assert expected_assignments == assignments.subs({1: 3})
expected_assignments = pystencils.AssignmentCollection({
d: sp.floor(1),
b: 2,
c: d + c,
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
print(expected_assignments)
print(assignments.subs(a, d))
assert expected_assignments == assignments.subs(a, d)
def test_match_for_assignment_collection():
x, y = pystencils.fields('x, y: float32[3d]')
a, b, c, d = sp.symbols('a, b, c, d')
assignments = pystencils.AssignmentCollection({
a: sp.floor(1),
b: 2,
c: a + c,
y.center(): sp.ceiling(x.center()) + sp.floor(x.center())
})
w1 = sp.Wild('w1')
w2 = sp.Wild('w2')
w3 = sp.Wild('w3')
wild_ceiling = sp.ceiling(w1)
wild_addition = w1 + w2
assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2))[w1] == x.center()
assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2)) == {
w3: y.center(),
w2: sp.floor(x.center()),
w1: x.center()
}
assert assignments.find(wild_ceiling) == {sp.ceiling(x.center())}
assert len([a for a in assignments.find(wild_addition) if isinstance(a, sp.Add)]) == 2
def main():
test_wild_typed_symbol()
test_replace_and_subs_for_assignment_collection()
test_match_for_assignment_collection()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment