Skip to content
Snippets Groups Projects
Commit 0cc3b825 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'sympy-assumptions-from-dtype' into 'master'

Set assumptions for TypedSymbol/cast_func/IntegerFunctionTwoArgsMixIn the SymPy way

See merge request !66
parents 67002956 8828b59d
Branches
Tags
No related merge requests found
......@@ -40,6 +40,35 @@ def matrix_symbols(names, dtype, rows, cols):
return tuple(matrices)
def assumptions_from_dtype(dtype):
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
dtype (BasicType, np.dtype): a Numpy data type
Returns:
A dict of SymPy assumptions
"""
if hasattr(dtype, 'numpy_dtype'):
dtype = dtype.numpy_dtype
assumptions = dict()
try:
if np.issubdtype(dtype, np.integer):
assumptions.update({'integer': True})
if np.issubdtype(dtype, np.unsignedinteger):
assumptions.update({'negative': False})
if np.issubdtype(dtype, np.integer) or \
np.issubdtype(dtype, np.floating):
assumptions.update({'real': True})
except Exception:
pass
return assumptions
# noinspection PyPep8Naming
class address_of(sp.Function):
is_Atom = True
......@@ -87,6 +116,7 @@ class cast_func(sp.Function):
# -> thus a separate class boolean_cast_func is introduced
if isinstance(expr, Boolean):
cls = boolean_cast_func
return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property
......@@ -184,7 +214,8 @@ class TypedSymbol(sp.Symbol):
return obj
def __new_stage2__(cls, name, dtype, *args, **kwargs):
obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
assumptions = assumptions_from_dtype(dtype)
obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
try:
obj._dtype = create_type(dtype)
except (TypeError, ValueError):
......@@ -205,52 +236,6 @@ class TypedSymbol(sp.Symbol):
def __getnewargs__(self):
return self.name, self.dtype
# For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
@property
def is_integer(self):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False
return super().is_negative
@property
def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False:
return True
else:
return super().is_nonnegative
@property
def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
super().is_real
else:
return super().is_real
def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type.
......
......@@ -6,6 +6,8 @@ from pystencils.sympyextensions import is_integer_sequence
class IntegerFunctionTwoArgsMixIn(sp.Function):
is_Integer = True
def __new__(cls, arg1, arg2):
args = []
for a in (arg1, arg2):
......@@ -25,10 +27,6 @@ class IntegerFunctionTwoArgsMixIn(sp.Function):
raise ValueError("Integer functions can only be constructed with typed expressions")
return super().__new__(cls, *args)
@property
def is_integer(self):
return True
# noinspection PyPep8Naming
class bitwise_xor(IntegerFunctionTwoArgsMixIn):
......@@ -82,6 +80,7 @@ class modulo_floor(sp.Function):
'(int64_t)((a) / (b)) * (b)'
"""
nargs = 2
is_Integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
......@@ -113,6 +112,7 @@ class modulo_ceil(sp.Function):
'((a) % (b) == 0 ? a : ((int64_t)((a) / (b))+1) * (b))'
"""
nargs = 2
is_Integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
......@@ -142,6 +142,7 @@ class div_ceil(sp.Function):
'( (a) % (b) == 0 ? (int64_t)(a) / (int64_t)(b) : ( (int64_t)(a) / (int64_t)(b) ) +1 )'
"""
nargs = 2
is_Integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
......@@ -171,6 +172,7 @@ class div_floor(sp.Function):
'((int64_t)(a) / (int64_t)(b))'
"""
nargs = 2
is_Integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
......
from pystencils import data_types
from pystencils.data_types import *
import sympy as sp
from pystencils import data_types
from pystencils.data_types import *
from pystencils.kernelparameters import FieldShapeSymbol
def test_parsing():
......@@ -22,6 +23,7 @@ def test_collation():
assert collate_types([double4_type, float_type]) == double4_type
assert collate_types([double4_type, float4_type]) == double4_type
def test_dtype_of_constants():
# Some come constants are neither of type Integer,Float,Rational and don't have args
......@@ -34,3 +36,16 @@ def test_dtype_of_constants():
# >>> pi.args
# ()
get_type_of_expression(sp.pi)
def test_assumptions():
x = pystencils.fields('x: float32[3d]')
assert x.shape[0].is_nonnegative
assert (2 * x.shape[0]).is_nonnegative
assert (2 * x.shape[0]).is_integer
assert(TypedSymbol('a', create_type('uint64'))).is_nonnegative
assert (TypedSymbol('a', create_type('uint64'))).is_positive is None
assert (TypedSymbol('a', create_type('uint64')) + 1).is_positive
assert (x.shape[0] + 1).is_real
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment