integer_functions.py 2.5 KB
 Martin Bauer committed May 11, 2018 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 ``````import sympy as sp from pystencils.data_types import get_type_of_expression, collate_types from pystencils.sympyextensions import is_integer_sequence bitwise_xor = sp.Function("bitwise_xor") bit_shift_right = sp.Function("bit_shift_right") bit_shift_left = sp.Function("bit_shift_left") bitwise_and = sp.Function("bitwise_and") bitwise_or = sp.Function("bitwise_or") # noinspection PyPep8Naming class modulo_floor(sp.Function): """Returns the next smaller integer divisible by given divisor. Examples: >>> modulo_floor(9, 4) 8 >>> modulo_floor(11, 4) 8 >>> modulo_floor(12, 4) 12 >>> from pystencils import TypedSymbol >>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32") >>> modulo_floor(a, b).to_c(str) '(int64_t)((a) / (b)) * (b)' """ nargs = 2 def __new__(cls, integer, divisor): if is_integer_sequence((integer, divisor)): return (int(integer) // int(divisor)) * divisor else: return super().__new__(cls, integer, divisor) def to_c(self, print_func): dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) assert dtype.is_int() return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) `````` Martin Bauer committed May 13, 2018 42 43 44 45 `````` # noinspection PyPep8Naming class modulo_ceil(sp.Function): `````` Martin Bauer committed Jun 05, 2018 46 `````` """Returns the next bigger integer divisible by given divisor. `````` Martin Bauer committed May 13, 2018 47 48 49 50 51 52 53 54 55 56 57 `````` Examples: >>> modulo_ceil(9, 4) 12 >>> modulo_ceil(11, 4) 12 >>> modulo_ceil(12, 4) 12 >>> from pystencils import TypedSymbol >>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32") >>> modulo_ceil(a, b).to_c(str) `````` Martin Bauer committed Jun 07, 2018 58 `````` '((a) % (b) == 0 ? a : ((int64_t)((a) / (b))+1) * (b))' `````` Martin Bauer committed May 13, 2018 59 60 61 62 63 64 65 66 67 68 69 70 `````` """ nargs = 2 def __new__(cls, integer, divisor): if is_integer_sequence((integer, divisor)): return integer if integer % divisor == 0 else ((integer // divisor) + 1) * divisor else: return super().__new__(cls, integer, divisor) def to_c(self, print_func): dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) assert dtype.is_int() `````` Martin Bauer committed Jun 05, 2018 71 `````` code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))" `````` Martin Bauer committed May 13, 2018 72 `` return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)``