diff --git a/pystencils/assignment.py b/pystencils/assignment.py index 45a95d41ecd643618a3d3b6b4a8a9aa0455a2a02..0bf68799491be29886d215d5ff76010c034bd174 100644 --- a/pystencils/assignment.py +++ b/pystencils/assignment.py @@ -49,6 +49,17 @@ else: __str__ = assignment_str _print_Assignment = print_assignment_latex +# Apparently, in SymPy 1.4 Assignment.__hash__ is not implemented. This has been fixed in current master +try: + sympy_version = sp.__version__.split('.') + + if int(sympy_version[0]) <= 1 and int(sympy_version[1]) <= 4: + def hash_fun(self): + return hash((self.lhs, self.rhs)) + Assignment.__hash__ = hash_fun +except Exception: + pass + def assignment_from_stencil(stencil_array, input_field, output_field, normalization_factor=None, order='visual') -> Assignment: diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 8bf28e777e9bf9deec30f21cafe8cb3b7ac77b17..08c1da1c96f5ba63c1906cfddd6014f7d86f61fe 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -127,6 +127,31 @@ class TypedSymbol(sp.Symbol): def __getnewargs__(self): return self.name, self.dtype + # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html + @property + def is_integer(self): + if hasattr(self.dtype, 'numpy_dtype'): + return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer + else: + return super().is_integer + + @property + def is_negative(self): + if hasattr(self.dtype, 'numpy_dtype'): + if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger): + return False + + return super().is_positive + + @property + def is_real(self): + if hasattr(self.dtype, 'numpy_dtype'): + return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \ + np.issubdtype(self.dtype.numpy_dtype, np.floating) or \ + super().is_real + else: + return super().is_real + def create_type(specification): """Creates a subclass of Type according to a string or an object of subclass Type. diff --git a/pystencils/field.py b/pystencils/field.py index 5fa456c68f6b8dd4799ea7c73f346885da9814f4..b6437a971398c70c0bcf392129269b87c912f0b8 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -10,7 +10,7 @@ import sympy as sp from sympy.core.cache import cacheit from pystencils.alignedarray import aligned_empty -from pystencils.data_types import StructType, create_type +from pystencils.data_types import StructType, TypedSymbol, create_type from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol from pystencils.stencil import direction_string_to_offset, offset_to_direction_string from pystencils.sympyextensions import is_integer_sequence @@ -410,7 +410,7 @@ class Field(AbstractField): return self.hashable_contents() == other.hashable_contents() # noinspection PyAttributeOutsideInit,PyUnresolvedReferences - class Access(sp.Symbol, AbstractField.AbstractAccess): + class Access(TypedSymbol, AbstractField.AbstractAccess): """Class representing a relative access into a `Field`. This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up @@ -462,7 +462,7 @@ class Field(AbstractField): if superscript is not None: symbol_name += "^" + superscript - obj = super(Field.Access, self).__xnew__(self, symbol_name) + obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype) obj._field = field obj._offsets = [] for o in offsets: diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 8cddac623b7d4822f539753fffbc4231caf39d21..c5a1837915602db695f0bd46a7297f7f26cf4771 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -362,6 +362,34 @@ class AssignmentCollection: self.subexpressions = [Assignment(k, v) for k, v in sub_expressions_dict.items()] + def find(self, *args, **kwargs): + return set.union( + *[a.find(*args, **kwargs) for a in self.all_assignments] + ) + + def match(self, *args, **kwargs): + rtn = {} + for a in self.all_assignments: + partial_result = a.match(*args, **kwargs) + if partial_result: + rtn.update(partial_result) + return rtn + + def subs(self, *args, **kwargs): + return AssignmentCollection( + main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments], + subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions] + ) + + def replace(self, *args, **kwargs): + return AssignmentCollection( + main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments], + subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions] + ) + + def __eq__(self, other): + return set(self.all_assignments) == set(other.all_assignments) + class SymbolGen: """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" diff --git a/pystencils_tests/test_floor_ceil_int_optimization.py b/pystencils_tests/test_floor_ceil_int_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..2d6efba7b1431a8f47c0dfe317d48ff9bff7b9a2 --- /dev/null +++ b/pystencils_tests/test_floor_ceil_int_optimization.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" + +import sympy as sp + +import pystencils +from pystencils.data_types import create_type + + +def test_floor_ceil_int_optimization(): + x, y = pystencils.fields('x,y: int32[2d]') + a, b, c = sp.symbols('a, b, c') + int_symbol = sp.Symbol('int_symbol', integer=True) + typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('int64')) + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: sp.ceiling(typed_symbol), + c: sp.floor(int_symbol), + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + assert(typed_symbol.is_integer) + print(sp.simplify(sp.ceiling(typed_symbol))) + + print(assignments) + + wild_floor = sp.floor(sp.Wild('w1')) + + assert not sp.floor(int_symbol).match(wild_floor) + assert sp.floor(a).match(wild_floor) + + assert not assignments.find(wild_floor) + + +def test_floor_ceil_float_no_optimization(): + x, y = pystencils.fields('x,y: float32[2d]') + a, b, c = sp.symbols('a, b, c') + int_symbol = sp.Symbol('int_symbol', integer=True) + typed_symbol = pystencils.TypedSymbol('typed_symbol', create_type('float32')) + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: sp.ceiling(typed_symbol), + c: sp.floor(int_symbol), + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + assert not typed_symbol.is_integer + print(sp.simplify(sp.ceiling(typed_symbol))) + + print(assignments) + + wild_floor = sp.floor(sp.Wild('w1')) + + assert not sp.floor(int_symbol).match(wild_floor) + assert sp.floor(a).match(wild_floor) + + assert assignments.find(wild_floor) + + +def main(): + test_floor_ceil_int_optimization() + test_floor_ceil_float_no_optimization() + + +if __name__ == '__main__': + main() diff --git a/pystencils_tests/test_match_subs_for_assignment_collection.py b/pystencils_tests/test_match_subs_for_assignment_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..81dec872d115f3b04876845dcfff040894bd5353 --- /dev/null +++ b/pystencils_tests/test_match_subs_for_assignment_collection.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" + +import sympy as sp + +import pystencils +from pystencils.data_types import create_type + + +def test_wild_typed_symbol(): + x = pystencils.fields('x: float32[3d]') + typed_symbol = pystencils.data_types.TypedSymbol('a', create_type('float64')) + + assert x.center().match(sp.Wild('w1')) + assert typed_symbol.match(sp.Wild('w1')) + + wild_ceiling = sp.ceiling(sp.Wild('w1')) + assert sp.ceiling(x.center()).match(wild_ceiling) + assert sp.ceiling(typed_symbol).match(wild_ceiling) + + +def test_replace_and_subs_for_assignment_collection(): + + x, y = pystencils.fields('x, y: float32[3d]') + a, b, c, d = sp.symbols('a, b, c, d') + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: 2, + c: a + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + expected_assignments = pystencils.AssignmentCollection({ + a: sp.floor(3), + b: 2, + c: a + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + assert expected_assignments == assignments.replace(1, 3) + assert expected_assignments == assignments.subs({1: 3}) + + expected_assignments = pystencils.AssignmentCollection({ + d: sp.floor(1), + b: 2, + c: d + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + print(expected_assignments) + print(assignments.subs(a, d)) + assert expected_assignments == assignments.subs(a, d) + + +def test_match_for_assignment_collection(): + + x, y = pystencils.fields('x, y: float32[3d]') + a, b, c, d = sp.symbols('a, b, c, d') + + assignments = pystencils.AssignmentCollection({ + a: sp.floor(1), + b: 2, + c: a + c, + y.center(): sp.ceiling(x.center()) + sp.floor(x.center()) + }) + + w1 = sp.Wild('w1') + w2 = sp.Wild('w2') + w3 = sp.Wild('w3') + + wild_ceiling = sp.ceiling(w1) + wild_addition = w1 + w2 + + assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2))[w1] == x.center() + assert assignments.match(pystencils.Assignment(w3, wild_ceiling + w2)) == { + w3: y.center(), + w2: sp.floor(x.center()), + w1: x.center() + } + assert assignments.find(wild_ceiling) == {sp.ceiling(x.center())} + assert len([a for a in assignments.find(wild_addition) if isinstance(a, sp.Add)]) == 2 + + +def main(): + test_wild_typed_symbol() + test_replace_and_subs_for_assignment_collection() + test_match_for_assignment_collection() + + +if __name__ == '__main__': + main()