Commit 7cf5b5bf authored by Martin Bauer's avatar Martin Bauer
Browse files

Improved handling of integer functions and typing of constants

- numpy constants get directly their numpy type
- integer functions check for integer types at construction
parent f875fbc0
...@@ -47,6 +47,11 @@ class cast_func(sp.Function): ...@@ -47,6 +47,11 @@ class cast_func(sp.Function):
is_Atom = True is_Atom = True
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if len(args) != 2:
pass
expr, dtype, *other_args = args
if not isinstance(dtype, Type):
dtype = create_type(dtype)
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality # to problems when for example comparing cast_func's for equality
...@@ -55,9 +60,9 @@ class cast_func(sp.Function): ...@@ -55,9 +60,9 @@ class cast_func(sp.Function):
# rhs = cast_func(0, 'int') # rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class boolean_cast_func is introduced # -> thus a separate class boolean_cast_func is introduced
if isinstance(args[0], Boolean): if isinstance(expr, Boolean):
cls = boolean_cast_func cls = boolean_cast_func
return sp.Function.__new__(cls, *args, **kwargs) return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property @property
def canonical(self): def canonical(self):
...@@ -471,7 +476,7 @@ class BasicType(Type): ...@@ -471,7 +476,7 @@ class BasicType(Type):
return 1 return 1
def is_int(self): def is_int(self):
return self.numpy_dtype in np.sctypes['int'] return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
def is_float(self): def is_float(self):
return self.numpy_dtype in np.sctypes['float'] return self.numpy_dtype in np.sctypes['float']
......
import numpy as np
import sympy as sp import sympy as sp
from pystencils.data_types import collate_types, get_type_of_expression from pystencils.data_types import cast_func, collate_types, create_type, get_type_of_expression
from pystencils.sympyextensions import is_integer_sequence from pystencils.sympyextensions import is_integer_sequence
bitwise_xor = sp.Function("bitwise_xor")
bit_shift_right = sp.Function("bit_shift_right") class IntegerFunctionTwoArgsMixIn(sp.Function):
bit_shift_left = sp.Function("bit_shift_left") def __new__(cls, arg1, arg2):
bitwise_and = sp.Function("bitwise_and") args = []
bitwise_or = sp.Function("bitwise_or") for a in (arg1, arg2):
int_div = sp.Function("int_div") if isinstance(a, sp.Number) or isinstance(a, int):
int_power_of_2 = sp.Function("int_power_of_2") args.append(cast_func(a, create_type("int")))
elif isinstance(a, np.generic):
args.append(cast_func(a, a.dtype))
else:
args.append(a)
for a in args:
try:
type = get_type_of_expression(a)
if not type.is_int():
raise ValueError("Argument to integer function is not an int but " + str(type))
except NotImplementedError:
raise ValueError("Integer functions can only be constructed with typed expressions")
return super().__new__(cls, *args)
# noinspection PyPep8Naming
class bitwise_xor(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bit_shift_right(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bit_shift_left(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bitwise_and(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bitwise_or(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class int_div(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class int_power_of_2(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming # noinspection PyPep8Naming
......
...@@ -5,6 +5,7 @@ from collections import OrderedDict, defaultdict, namedtuple ...@@ -5,6 +5,7 @@ from collections import OrderedDict, defaultdict, namedtuple
from copy import deepcopy from copy import deepcopy
from types import MappingProxyType from types import MappingProxyType
import numpy as np
import sympy as sp import sympy as sp
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
...@@ -802,6 +803,8 @@ class KernelConstraintsCheck: ...@@ -802,6 +803,8 @@ class KernelConstraintsCheck:
return rhs return rhs
elif isinstance(rhs, sp.Symbol): elif isinstance(rhs, sp.Symbol):
return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
elif type_constants and isinstance(rhs, np.generic):
return cast_func(rhs, create_type(rhs.dtype))
elif type_constants and isinstance(rhs, sp.Number): elif type_constants and isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant'])) return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
elif isinstance(rhs, sp.Mul): elif isinstance(rhs, sp.Mul):
...@@ -809,6 +812,8 @@ class KernelConstraintsCheck: ...@@ -809,6 +812,8 @@ class KernelConstraintsCheck:
return rhs.func(*new_args) if new_args else rhs return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed): elif isinstance(rhs, sp.Indexed):
return rhs return rhs
elif isinstance(rhs, cast_func):
return cast_func(self.process_expression(rhs.args[0], type_constants=False), rhs.dtype)
else: else:
if isinstance(rhs, sp.Pow): if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers # don't process exponents -> they should remain integers
......
""" """
Test of pystencils.data_types.address_of Test of pystencils.data_types.address_of
""" """
import sympy
import pystencils import pystencils
from pystencils.data_types import PointerType, address_of, cast_func from pystencils.data_types import PointerType, address_of, cast_func
from pystencils.simp.simplifications import sympy_cse from pystencils.simp.simplifications import sympy_cse
...@@ -48,12 +45,3 @@ def test_address_of_with_cse(): ...@@ -48,12 +45,3 @@ def test_address_of_with_cse():
ast = pystencils.create_kernel(assignments_cse) ast = pystencils.create_kernel(assignments_cse)
code = pystencils.show_code(ast) code = pystencils.show_code(ast)
print(code) print(code)
def main():
test_address_of()
test_address_of_with_cse()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment