diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 73f34a1fd97642fdbad0e29bc41ab3d223b17fd4..a8384a01487056baa43fe8d86580ee58c92ce8b5 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -485,6 +485,44 @@ class PsDiv(PsBinOp): pass +class PsIntDiv(PsBinOp): + """C-like integer division (round to zero).""" + + # python_operator not implemented because both floordiv and truediv have + # different semantics. + pass + + +class PsLeftShift(PsBinOp): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.lshift + + +class PsRightShift(PsBinOp): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.rshift + + +class PsBitwiseAnd(PsBinOp): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.and_ + + +class PsBitwiseXor(PsBinOp): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.xor + + +class PsBitwiseOr(PsBinOp): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.or_ + + class PsArrayInitList(PsExpression): __match_args__ = ("items",) diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 190f4f001ae270123146aa9807ecd9fb5ebba6bd..b742c598db27e46d94655e309f277531bdcb75eb 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -13,22 +13,28 @@ from .ast.structural import ( ) from .ast.expressions import ( - PsSymbolExpr, + PsAdd, + PsAddressOf, + PsArrayInitList, + PsBinOp, + PsBitwiseAnd, + PsBitwiseOr, + PsBitwiseXor, + PsCall, + PsCast, PsConstantExpr, - PsSubscript, - PsVectorArrayAccess, + PsDeref, + PsDiv, + PsIntDiv, + PsLeftShift, PsLookup, - PsCall, - PsBinOp, - PsAdd, - PsSub, PsMul, - PsDiv, PsNeg, - PsDeref, - PsAddressOf, - PsCast, - PsArrayInitList, + PsRightShift, + PsSub, + PsSubscript, + PsSymbolExpr, + PsVectorArrayAccess, ) from .symbols import PsSymbol @@ -61,23 +67,32 @@ class Ops(Enum): See also https://en.cppreference.com/w/cpp/language/operator_precedence """ - Weakest = (0, LR.Middle) + Weakest = (17 - 17, LR.Middle) + + BitwiseOr = (17 - 13, LR.Left) + + BitwiseXor = (17 - 12, LR.Left) + + BitwiseAnd = (17 - 11, LR.Left) + + LeftShift = (17 - 7, LR.Left) + RightShift = (17 - 7, LR.Left) - Add = (1, LR.Left) - Sub = (1, LR.Left) + Add = (17 - 6, LR.Left) + Sub = (17 - 6, LR.Left) - Mul = (2, LR.Left) - Div = (2, LR.Left) - Rem = (2, LR.Left) + Mul = (17 - 5, LR.Left) + Div = (17 - 5, LR.Left) + Rem = (17 - 5, LR.Left) - Neg = (3, LR.Right) - AddressOf = (3, LR.Right) - Deref = (3, LR.Right) - Cast = (3, LR.Right) + Neg = (17 - 3, LR.Right) + AddressOf = (17 - 3, LR.Right) + Deref = (17 - 3, LR.Right) + Cast = (17 - 3, LR.Right) - Call = (4, LR.Left) - Subscript = (4, LR.Left) - Lookup = (4, LR.Left) + Call = (17 - 2, LR.Left) + Subscript = (17 - 2, LR.Left) + Lookup = (17 - 2, LR.Left) def __init__(self, pred: int, assoc: LR) -> None: self.precedence = pred @@ -312,7 +327,17 @@ class CAstPrinter: return ("-", Ops.Sub) case PsMul(): return ("*", Ops.Mul) - case PsDiv(): + case PsDiv() | PsIntDiv(): return ("/", Ops.Div) + case PsLeftShift(): + return ("<<", Ops.LeftShift) + case PsRightShift(): + return (">>", Ops.RightShift) + case PsBitwiseAnd(): + return ("&", Ops.BitwiseAnd) + case PsBitwiseXor(): + return ("^", Ops.BitwiseXor) + case PsBitwiseOr(): + return ("|", Ops.BitwiseOr) case _: assert False diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 2e571111f41c3ec1826df90fb88e96be735d6c4a..6ce0264e20c2ba4b0ecb84bb6c3485890a6fd3a2 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -4,7 +4,7 @@ from operator import add, mul, sub import sympy as sp -from ...sympyextensions import Assignment, AssignmentCollection +from ...sympyextensions import Assignment, AssignmentCollection, integer_functions from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc from ...field import Field, FieldType @@ -20,13 +20,19 @@ from ..ast.structural import ( ) from ..ast.expressions import ( PsArrayAccess, - PsVectorArrayAccess, - PsLookup, + PsArrayInitList, + PsBitwiseAnd, + PsBitwiseOr, + PsBitwiseXor, PsCall, + PsCast, PsConstantExpr, - PsArrayInitList, + PsIntDiv, + PsLeftShift, + PsLookup, + PsRightShift, PsSubscript, - PsCast, + PsVectorArrayAccess, ) from ..constants import PsConstant @@ -293,28 +299,48 @@ class FreezeExpressions: else: return PsArrayAccess(ptr, index) - def map_Function(self, func: sp.Function) -> PsCall: + def map_Function(self, func: sp.Function) -> PsExpression: """Map SymPy function calls by mapping sympy function classes to backend-supported function symbols. - SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`. + If applicable, functions are mapped to binary operators, e.g. `backend.ast.expressions.PsBitwiseXor`. + Other SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`. """ + args = tuple(self.visit_expr(arg) for arg in func.args) + match func: case sp.Abs(): - func_symbol = PsMathFunction(MathFunctions.Abs) + return PsCall(PsMathFunction(MathFunctions.Abs), args) case sp.exp(): - func_symbol = PsMathFunction(MathFunctions.Exp) + return PsCall(PsMathFunction(MathFunctions.Exp), args) case sp.sin(): - func_symbol = PsMathFunction(MathFunctions.Sin) + return PsCall(PsMathFunction(MathFunctions.Sin), args) case sp.cos(): - func_symbol = PsMathFunction(MathFunctions.Cos) + return PsCall(PsMathFunction(MathFunctions.Cos), args) case sp.tan(): - func_symbol = PsMathFunction(MathFunctions.Tan) + return PsCall(PsMathFunction(MathFunctions.Tan), args) + case integer_functions.int_div(): + return PsIntDiv(*args) + case integer_functions.bit_shift_left(): + return PsLeftShift(*args) + case integer_functions.bit_shift_right(): + return PsRightShift(*args) + case integer_functions.bitwise_and(): + return PsBitwiseAnd(*args) + case integer_functions.bitwise_xor(): + return PsBitwiseXor(*args) + case integer_functions.bitwise_or(): + return PsBitwiseOr(*args) + case integer_functions.int_power_of_2(): + return PsLeftShift(PsExpression.make(PsConstant(1)), args[0]) + # TODO: what exactly are the semantics? + # case integer_functions.modulo_floor(): + # case integer_functions.div_floor() + # TODO: requires if *expression* + # case integer_functions.modulo_ceil(): + # case integer_functions.div_ceil(): case _: raise FreezeError(f"Unsupported function: {func}") - args = tuple(self.visit_expr(arg) for arg in func.args) - return PsCall(func_symbol, args) - def map_Min(self, expr: sp.Min) -> PsCall: args = tuple(self.visit_expr(arg) for arg in expr.args) return PsCall(PsMathFunction(MathFunctions.Min), args) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index af9a81d0a5b4cabcebc7c79eefd564f1cd8e0053..3fbb9c1a8052599186ec1128473773b0220696ad 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -22,15 +22,21 @@ from ..ast.structural import ( PsAssignment, ) from ..ast.expressions import ( - PsSymbolExpr, - PsConstantExpr, - PsBinOp, PsArrayAccess, - PsSubscript, - PsLookup, - PsCall, PsArrayInitList, + PsBinOp, + PsBitwiseAnd, + PsBitwiseOr, + PsBitwiseXor, + PsCall, PsCast, + PsConstantExpr, + PsIntDiv, + PsLeftShift, + PsLookup, + PsRightShift, + PsSubscript, + PsSymbolExpr, ) from ..functions import PsMathFunction @@ -258,6 +264,36 @@ class Typifier: tc.apply_and_check(expr, member.dtype) + # integer operations + case ( + PsIntDiv(op1, op2) + | PsLeftShift(op1, op2) + | PsRightShift(op1, op2) + | PsBitwiseAnd(op1, op2) + | PsBitwiseXor(op1, op2) + | PsBitwiseOr(op1, op2) + ): + if tc.target_type is not None and not isinstance( + tc.target_type, PsIntegerType + ): + raise TypificationError( + f"Integer expression used in non-integer context.\n" + f" Integer expression: {expr}\n" + f" Context type: {tc.target_type}" + ) + + self.visit_expr(op1, tc) + self.visit_expr(op2, tc) + + if tc.target_type is None: + raise TypificationError( + f"Unable to infer type of integer expression {expr}." + ) + elif not isinstance(tc.target_type, PsIntegerType): + raise TypificationError( + f"Argument(s) to integer function are non-integer in expression {expr}." + ) + case PsBinOp(op1, op2): self.visit_expr(op1, tc) self.visit_expr(op2, tc) diff --git a/src/pystencils/sympyextensions/integer_functions.py b/src/pystencils/sympyextensions/integer_functions.py index f9c156971c967b461664af64f0538953d6df7652..c3dd181083ecbfecf5dc6d445f5fa1d0ac45f601 100644 --- a/src/pystencils/sympyextensions/integer_functions.py +++ b/src/pystencils/sympyextensions/integer_functions.py @@ -51,6 +51,8 @@ class int_div(IntegerFunctionTwoArgsMixIn): # noinspection PyPep8Naming +# TODO: What do the *two* arguments mean? +# Apparently, the second is required but ignored? class int_power_of_2(IntegerFunctionTwoArgsMixIn): pass diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index c8f39ffc85a28d40549c39c2bdac31d9faf92f87..269435257bcd8dbb486290fe1d3f35aee8e21319 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -4,11 +4,19 @@ from pystencils import Assignment, fields from pystencils.backend.ast.structural import ( PsAssignment, + PsBlock, PsDeclaration, ) from pystencils.backend.ast.expressions import ( + PsArrayAccess, + PsBitwiseAnd, + PsBitwiseOr, + PsBitwiseXor, PsExpression, - PsArrayAccess + PsIntDiv, + PsLeftShift, + PsMul, + PsRightShift, ) from pystencils.backend.constants import PsConstant from pystencils.backend.kernelcreation import ( @@ -17,6 +25,17 @@ from pystencils.backend.kernelcreation import ( FullIterationSpace, ) +from pystencils.sympyextensions.integer_functions import ( + bit_shift_left, + bit_shift_right, + bitwise_and, + bitwise_or, + bitwise_xor, + int_div, + int_power_of_2, + modulo_floor, +) + def test_freeze_simple(): ctx = KernelCreationContext() @@ -63,9 +82,66 @@ def test_freeze_fields(): zero = PsExpression.make(PsConstant(0)) - lhs = PsArrayAccess(f_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0]) + zero * one) - rhs = PsArrayAccess(g_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0]) + zero * one) + lhs = PsArrayAccess( + f_arr.base_pointer, + (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0]) + + zero * one, + ) + rhs = PsArrayAccess( + g_arr.base_pointer, + (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0]) + + zero * one, + ) should = PsAssignment(lhs, rhs) assert fasm.structurally_equal(should) + + +def test_freeze_integer_binops(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + expr = bit_shift_left( + bit_shift_right(bitwise_and(x, y), bitwise_or(y, z)), bitwise_xor(x, z) + ) + + fexpr = freeze(expr) + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + should = PsLeftShift( + PsRightShift(PsBitwiseAnd(x2, y2), PsBitwiseOr(y2, z2)), PsBitwiseXor(x2, z2) + ) + + assert fexpr.structurally_equal(should) + + +def test_freeze_integer_functions(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x2 = PsExpression.make(ctx.get_symbol("x", ctx.index_dtype)) + y2 = PsExpression.make(ctx.get_symbol("y", ctx.index_dtype)) + z2 = PsExpression.make(ctx.get_symbol("z", ctx.index_dtype)) + + x, y, z = sp.symbols("x, y, z") + asms = [ + Assignment(z, int_div(x, y)), + Assignment(z, int_power_of_2(x, y)), + # Assignment(z, modulo_floor(x, y)), + ] + + fasms = [freeze(asm) for asm in asms] + + should = [ + PsDeclaration(z2, PsIntDiv(x2, y2)), + PsDeclaration(z2, PsLeftShift(PsExpression.make(PsConstant(1)), x2)), + # PsDeclaration(z2, PsMul(PsIntDiv(x2, y2), y2)), + ] + + for fasm, correct in zip(fasms, should): + assert fasm.structurally_equal(correct) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 7625e22e3ce25b2467110d91dd951c775e9dbdb2..9ff18623ea70a1da87308d1b7cc04dec9885489b 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -12,6 +12,14 @@ from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError +from pystencils.sympyextensions.integer_functions import ( + bit_shift_left, + bit_shift_right, + bitwise_and, + bitwise_xor, + bitwise_or, +) + def test_typify_simple(): ctx = KernelCreationContext() @@ -114,3 +122,65 @@ def test_erronous_typing(): fasm = freeze(asm) with pytest.raises(TypificationError): typify(fasm) + + +def test_typify_integer_binops(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + ctx.get_symbol("x", ctx.index_dtype) + ctx.get_symbol("y", ctx.index_dtype) + ctx.get_symbol("z", ctx.index_dtype) + + x, y, z = sp.symbols("x, y, z") + expr = bit_shift_left( + bit_shift_right(bitwise_and(x, 2), bitwise_or(y, z)), bitwise_xor(2, 2) + ) # ^ + # TODO: x can not be a constant here, because then the typifier can not check that the arguments are integer. + expr = freeze(expr) + expr = typify(expr) + + def check(expr): + match expr: + case PsConstantExpr(cs): + assert cs.value == 2 + assert cs.dtype == constify(ctx.index_dtype) + case PsSymbolExpr(symb): + assert symb.name in "xyz" + assert symb.dtype == ctx.index_dtype + case PsBinOp(op1, op2): + check(op1) + check(op2) + case _: + pytest.fail(f"Unexpected expression: {expr}") + + check(expr) + + +def test_typify_integer_binops_floating_arg(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x = sp.Symbol("x") + expr = bit_shift_left(x, 2) + expr = freeze(expr) + + with pytest.raises(TypificationError): + expr = typify(expr) + + +def test_typify_integer_binops_in_floating_context(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + ctx.get_symbol("i", ctx.index_dtype) + + x, i = sp.symbols("x, i") + expr = x + bit_shift_left(i, 2) + expr = freeze(expr) + + with pytest.raises(TypificationError): + expr = typify(expr) diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index 5e80cae223c78ab00c911df9144d42768cfb52e1..c8294c6dd4ac1a8683a5f779cb52d3e628cbc708 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -75,3 +75,27 @@ def test_arithmetic_precedence(): expr = (a / b) + (c / (d + e) * f) code = cprint(expr) assert code == "a / b + c / (d + e) * f" + + +def test_printing_integer_functions(): + (i, j, k) = [PsExpression.make(PsSymbol(x, UInt(64))) for x in "ijk"] + cprint = CAstPrinter() + + from pystencils.backend.ast.expressions import ( + PsLeftShift, + PsRightShift, + PsBitwiseAnd, + PsBitwiseOr, + PsBitwiseXor, + PsIntDiv, + ) + + expr = PsBitwiseAnd( + PsBitwiseXor( + PsBitwiseXor(j, k), + PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)), + ), + i, + ) + code = cprint(expr) + assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i"