diff --git a/pystencils/data_types.py b/pystencils/data_types.py index efee3d9c2baad351ee82eab63ec3bf481c469c75..a4752b148afa43bff9ac39a999b5dc4ca33fd430 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -40,6 +40,35 @@ def matrix_symbols(names, dtype, rows, cols): return tuple(matrices) +def assumptions_from_dtype(dtype): + """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype + + Args: + dtype (BasicType, np.dtype): a Numpy data type + Returns: + A dict of SymPy assumptions + """ + if hasattr(dtype, 'numpy_dtype'): + dtype = dtype.numpy_dtype + + assumptions = dict() + + try: + if np.issubdtype(dtype, np.integer): + assumptions.update({'integer': True}) + + if np.issubdtype(dtype, np.unsignedinteger): + assumptions.update({'negative': False}) + + if np.issubdtype(dtype, np.integer) or \ + np.issubdtype(dtype, np.floating): + assumptions.update({'real': True}) + except Exception: + pass + + return assumptions + + # noinspection PyPep8Naming class address_of(sp.Function): is_Atom = True @@ -87,6 +116,7 @@ class cast_func(sp.Function): # -> thus a separate class boolean_cast_func is introduced if isinstance(expr, Boolean): cls = boolean_cast_func + return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) @property @@ -184,7 +214,8 @@ class TypedSymbol(sp.Symbol): return obj def __new_stage2__(cls, name, dtype, *args, **kwargs): - obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs) + assumptions = assumptions_from_dtype(dtype) + obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs) try: obj._dtype = create_type(dtype) except (TypeError, ValueError): @@ -205,52 +236,6 @@ class TypedSymbol(sp.Symbol): def __getnewargs__(self): return self.name, self.dtype - # For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html - @property - def is_integer(self): - """ - Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate - - For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html - """ - if hasattr(self.dtype, 'numpy_dtype'): - return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer - else: - return super().is_integer - - @property - def is_negative(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if hasattr(self.dtype, 'numpy_dtype'): - if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger): - return False - - return super().is_negative - - @property - def is_nonnegative(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if self.is_negative is False: - return True - else: - return super().is_nonnegative - - @property - def is_real(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if hasattr(self.dtype, 'numpy_dtype'): - return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \ - np.issubdtype(self.dtype.numpy_dtype, np.floating) or \ - super().is_real - else: - return super().is_real - def create_type(specification): """Creates a subclass of Type according to a string or an object of subclass Type. diff --git a/pystencils/integer_functions.py b/pystencils/integer_functions.py index fa5a4d739433658e1dff26389396eb296c6e4099..1cd9b197a61e6667212a20ae53e1b14020b7d0e8 100644 --- a/pystencils/integer_functions.py +++ b/pystencils/integer_functions.py @@ -6,6 +6,8 @@ from pystencils.sympyextensions import is_integer_sequence class IntegerFunctionTwoArgsMixIn(sp.Function): + is_Integer = True + def __new__(cls, arg1, arg2): args = [] for a in (arg1, arg2): @@ -25,10 +27,6 @@ class IntegerFunctionTwoArgsMixIn(sp.Function): raise ValueError("Integer functions can only be constructed with typed expressions") return super().__new__(cls, *args) - @property - def is_integer(self): - return True - # noinspection PyPep8Naming class bitwise_xor(IntegerFunctionTwoArgsMixIn): @@ -82,6 +80,7 @@ class modulo_floor(sp.Function): '(int64_t)((a) / (b)) * (b)' """ nargs = 2 + is_Integer = True def __new__(cls, integer, divisor): if is_integer_sequence((integer, divisor)): @@ -113,6 +112,7 @@ class modulo_ceil(sp.Function): '((a) % (b) == 0 ? a : ((int64_t)((a) / (b))+1) * (b))' """ nargs = 2 + is_Integer = True def __new__(cls, integer, divisor): if is_integer_sequence((integer, divisor)): @@ -142,6 +142,7 @@ class div_ceil(sp.Function): '( (a) % (b) == 0 ? (int64_t)(a) / (int64_t)(b) : ( (int64_t)(a) / (int64_t)(b) ) +1 )' """ nargs = 2 + is_Integer = True def __new__(cls, integer, divisor): if is_integer_sequence((integer, divisor)): @@ -171,6 +172,7 @@ class div_floor(sp.Function): '((int64_t)(a) / (int64_t)(b))' """ nargs = 2 + is_Integer = True def __new__(cls, integer, divisor): if is_integer_sequence((integer, divisor)): diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 887f802c91eb3a82ebd8ea43f6fbc17d18d18cef..322e04db1d47b93efa89d9d024f4dfda325b5dc4 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -1,7 +1,8 @@ -from pystencils import data_types -from pystencils.data_types import * import sympy as sp +from pystencils import data_types +from pystencils.data_types import * +from pystencils.kernelparameters import FieldShapeSymbol def test_parsing(): @@ -22,6 +23,7 @@ def test_collation(): assert collate_types([double4_type, float_type]) == double4_type assert collate_types([double4_type, float4_type]) == double4_type + def test_dtype_of_constants(): # Some come constants are neither of type Integer,Float,Rational and don't have args @@ -34,3 +36,16 @@ def test_dtype_of_constants(): # >>> pi.args # () get_type_of_expression(sp.pi) + + +def test_assumptions(): + + x = pystencils.fields('x: float32[3d]') + + assert x.shape[0].is_nonnegative + assert (2 * x.shape[0]).is_nonnegative + assert (2 * x.shape[0]).is_integer + assert(TypedSymbol('a', create_type('uint64'))).is_nonnegative + assert (TypedSymbol('a', create_type('uint64'))).is_positive is None + assert (TypedSymbol('a', create_type('uint64')) + 1).is_positive + assert (x.shape[0] + 1).is_real