Commit 8828b59d authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Set assumptions for TypedSymbol/cast_func/IntegerFunctionTwoArgsMixIn the SymPy way

parent 0ff425f3
...@@ -40,6 +40,35 @@ def matrix_symbols(names, dtype, rows, cols): ...@@ -40,6 +40,35 @@ def matrix_symbols(names, dtype, rows, cols):
return tuple(matrices) return tuple(matrices)
def assumptions_from_dtype(dtype):
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
dtype (BasicType, np.dtype): a Numpy data type
Returns:
A dict of SymPy assumptions
"""
if hasattr(dtype, 'numpy_dtype'):
dtype = dtype.numpy_dtype
assumptions = dict()
try:
if np.issubdtype(dtype, np.integer):
assumptions.update({'integer': True})
if np.issubdtype(dtype, np.unsignedinteger):
assumptions.update({'negative': False})
if np.issubdtype(dtype, np.integer) or \
np.issubdtype(dtype, np.floating):
assumptions.update({'real': True})
except Exception:
pass
return assumptions
# noinspection PyPep8Naming # noinspection PyPep8Naming
class address_of(sp.Function): class address_of(sp.Function):
is_Atom = True is_Atom = True
...@@ -87,6 +116,7 @@ class cast_func(sp.Function): ...@@ -87,6 +116,7 @@ class cast_func(sp.Function):
# -> thus a separate class boolean_cast_func is introduced # -> thus a separate class boolean_cast_func is introduced
if isinstance(expr, Boolean): if isinstance(expr, Boolean):
cls = boolean_cast_func cls = boolean_cast_func
return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property @property
...@@ -184,7 +214,8 @@ class TypedSymbol(sp.Symbol): ...@@ -184,7 +214,8 @@ class TypedSymbol(sp.Symbol):
return obj return obj
def __new_stage2__(cls, name, dtype, *args, **kwargs): def __new_stage2__(cls, name, dtype, *args, **kwargs):
obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs) assumptions = assumptions_from_dtype(dtype)
obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
try: try:
obj._dtype = create_type(dtype) obj._dtype = create_type(dtype)
except (TypeError, ValueError): except (TypeError, ValueError):
...@@ -205,52 +236,6 @@ class TypedSymbol(sp.Symbol): ...@@ -205,52 +236,6 @@ class TypedSymbol(sp.Symbol):
def __getnewargs__(self): def __getnewargs__(self):
return self.name, self.dtype return self.name, self.dtype
# For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
@property
def is_integer(self):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False
return super().is_negative
@property
def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False:
return True
else:
return super().is_nonnegative
@property
def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
super().is_real
else:
return super().is_real
def create_type(specification): def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type. """Creates a subclass of Type according to a string or an object of subclass Type.
......
...@@ -6,6 +6,8 @@ from pystencils.sympyextensions import is_integer_sequence ...@@ -6,6 +6,8 @@ from pystencils.sympyextensions import is_integer_sequence
class IntegerFunctionTwoArgsMixIn(sp.Function): class IntegerFunctionTwoArgsMixIn(sp.Function):
is_Integer = True
def __new__(cls, arg1, arg2): def __new__(cls, arg1, arg2):
args = [] args = []
for a in (arg1, arg2): for a in (arg1, arg2):
...@@ -25,10 +27,6 @@ class IntegerFunctionTwoArgsMixIn(sp.Function): ...@@ -25,10 +27,6 @@ class IntegerFunctionTwoArgsMixIn(sp.Function):
raise ValueError("Integer functions can only be constructed with typed expressions") raise ValueError("Integer functions can only be constructed with typed expressions")
return super().__new__(cls, *args) return super().__new__(cls, *args)
@property
def is_integer(self):
return True
# noinspection PyPep8Naming # noinspection PyPep8Naming
class bitwise_xor(IntegerFunctionTwoArgsMixIn): class bitwise_xor(IntegerFunctionTwoArgsMixIn):
...@@ -82,6 +80,7 @@ class modulo_floor(sp.Function): ...@@ -82,6 +80,7 @@ class modulo_floor(sp.Function):
'(int64_t)((a) / (b)) * (b)' '(int64_t)((a) / (b)) * (b)'
""" """
nargs = 2 nargs = 2
is_Integer = True
def __new__(cls, integer, divisor): def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)): if is_integer_sequence((integer, divisor)):
...@@ -113,6 +112,7 @@ class modulo_ceil(sp.Function): ...@@ -113,6 +112,7 @@ class modulo_ceil(sp.Function):
'((a) % (b) == 0 ? a : ((int64_t)((a) / (b))+1) * (b))' '((a) % (b) == 0 ? a : ((int64_t)((a) / (b))+1) * (b))'
""" """
nargs = 2 nargs = 2
is_Integer = True
def __new__(cls, integer, divisor): def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)): if is_integer_sequence((integer, divisor)):
...@@ -142,6 +142,7 @@ class div_ceil(sp.Function): ...@@ -142,6 +142,7 @@ class div_ceil(sp.Function):
'( (a) % (b) == 0 ? (int64_t)(a) / (int64_t)(b) : ( (int64_t)(a) / (int64_t)(b) ) +1 )' '( (a) % (b) == 0 ? (int64_t)(a) / (int64_t)(b) : ( (int64_t)(a) / (int64_t)(b) ) +1 )'
""" """
nargs = 2 nargs = 2
is_Integer = True
def __new__(cls, integer, divisor): def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)): if is_integer_sequence((integer, divisor)):
...@@ -171,6 +172,7 @@ class div_floor(sp.Function): ...@@ -171,6 +172,7 @@ class div_floor(sp.Function):
'((int64_t)(a) / (int64_t)(b))' '((int64_t)(a) / (int64_t)(b))'
""" """
nargs = 2 nargs = 2
is_Integer = True
def __new__(cls, integer, divisor): def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)): if is_integer_sequence((integer, divisor)):
......
from pystencils import data_types
from pystencils.data_types import *
import sympy as sp import sympy as sp
from pystencils import data_types
from pystencils.data_types import *
from pystencils.kernelparameters import FieldShapeSymbol
def test_parsing(): def test_parsing():
...@@ -22,6 +23,7 @@ def test_collation(): ...@@ -22,6 +23,7 @@ def test_collation():
assert collate_types([double4_type, float_type]) == double4_type assert collate_types([double4_type, float_type]) == double4_type
assert collate_types([double4_type, float4_type]) == double4_type assert collate_types([double4_type, float4_type]) == double4_type
def test_dtype_of_constants(): def test_dtype_of_constants():
# Some come constants are neither of type Integer,Float,Rational and don't have args # Some come constants are neither of type Integer,Float,Rational and don't have args
...@@ -34,3 +36,16 @@ def test_dtype_of_constants(): ...@@ -34,3 +36,16 @@ def test_dtype_of_constants():
# >>> pi.args # >>> pi.args
# () # ()
get_type_of_expression(sp.pi) get_type_of_expression(sp.pi)
def test_assumptions():
x = pystencils.fields('x: float32[3d]')
assert x.shape[0].is_nonnegative
assert (2 * x.shape[0]).is_nonnegative
assert (2 * x.shape[0]).is_integer
assert(TypedSymbol('a', create_type('uint64'))).is_nonnegative
assert (TypedSymbol('a', create_type('uint64'))).is_positive is None
assert (TypedSymbol('a', create_type('uint64')) + 1).is_positive
assert (x.shape[0] + 1).is_real
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment