Skip to content
Snippets Groups Projects
Commit 075ae357 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'bauerd/integer-fun' into 'backend-rework'

Integer functions

See merge request !368
parents 579299cb 7568c463
1 merge request!368Integer functions
Pipeline #64764 failed with stages
in 3 minutes and 49 seconds
......@@ -485,6 +485,44 @@ class PsDiv(PsBinOp):
pass
class PsIntDiv(PsBinOp):
"""C-like integer division (round to zero)."""
# python_operator not implemented because both floordiv and truediv have
# different semantics.
pass
class PsLeftShift(PsBinOp):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.lshift
class PsRightShift(PsBinOp):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.rshift
class PsBitwiseAnd(PsBinOp):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.and_
class PsBitwiseXor(PsBinOp):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.xor
class PsBitwiseOr(PsBinOp):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.or_
class PsArrayInitList(PsExpression):
__match_args__ = ("items",)
......
......@@ -13,22 +13,28 @@ from .ast.structural import (
)
from .ast.expressions import (
PsSymbolExpr,
PsAdd,
PsAddressOf,
PsArrayInitList,
PsBinOp,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsCall,
PsCast,
PsConstantExpr,
PsSubscript,
PsVectorArrayAccess,
PsDeref,
PsDiv,
PsIntDiv,
PsLeftShift,
PsLookup,
PsCall,
PsBinOp,
PsAdd,
PsSub,
PsMul,
PsDiv,
PsNeg,
PsDeref,
PsAddressOf,
PsCast,
PsArrayInitList,
PsRightShift,
PsSub,
PsSubscript,
PsSymbolExpr,
PsVectorArrayAccess,
)
from .symbols import PsSymbol
......@@ -61,23 +67,32 @@ class Ops(Enum):
See also https://en.cppreference.com/w/cpp/language/operator_precedence
"""
Weakest = (0, LR.Middle)
Weakest = (17 - 17, LR.Middle)
BitwiseOr = (17 - 13, LR.Left)
BitwiseXor = (17 - 12, LR.Left)
BitwiseAnd = (17 - 11, LR.Left)
LeftShift = (17 - 7, LR.Left)
RightShift = (17 - 7, LR.Left)
Add = (1, LR.Left)
Sub = (1, LR.Left)
Add = (17 - 6, LR.Left)
Sub = (17 - 6, LR.Left)
Mul = (2, LR.Left)
Div = (2, LR.Left)
Rem = (2, LR.Left)
Mul = (17 - 5, LR.Left)
Div = (17 - 5, LR.Left)
Rem = (17 - 5, LR.Left)
Neg = (3, LR.Right)
AddressOf = (3, LR.Right)
Deref = (3, LR.Right)
Cast = (3, LR.Right)
Neg = (17 - 3, LR.Right)
AddressOf = (17 - 3, LR.Right)
Deref = (17 - 3, LR.Right)
Cast = (17 - 3, LR.Right)
Call = (4, LR.Left)
Subscript = (4, LR.Left)
Lookup = (4, LR.Left)
Call = (17 - 2, LR.Left)
Subscript = (17 - 2, LR.Left)
Lookup = (17 - 2, LR.Left)
def __init__(self, pred: int, assoc: LR) -> None:
self.precedence = pred
......@@ -312,7 +327,17 @@ class CAstPrinter:
return ("-", Ops.Sub)
case PsMul():
return ("*", Ops.Mul)
case PsDiv():
case PsDiv() | PsIntDiv():
return ("/", Ops.Div)
case PsLeftShift():
return ("<<", Ops.LeftShift)
case PsRightShift():
return (">>", Ops.RightShift)
case PsBitwiseAnd():
return ("&", Ops.BitwiseAnd)
case PsBitwiseXor():
return ("^", Ops.BitwiseXor)
case PsBitwiseOr():
return ("|", Ops.BitwiseOr)
case _:
assert False
......@@ -4,7 +4,7 @@ from operator import add, mul, sub
import sympy as sp
from ...sympyextensions import Assignment, AssignmentCollection
from ...sympyextensions import Assignment, AssignmentCollection, integer_functions
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc
from ...field import Field, FieldType
......@@ -20,13 +20,19 @@ from ..ast.structural import (
)
from ..ast.expressions import (
PsArrayAccess,
PsVectorArrayAccess,
PsLookup,
PsArrayInitList,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsCall,
PsCast,
PsConstantExpr,
PsArrayInitList,
PsIntDiv,
PsLeftShift,
PsLookup,
PsRightShift,
PsSubscript,
PsCast,
PsVectorArrayAccess,
)
from ..constants import PsConstant
......@@ -293,28 +299,48 @@ class FreezeExpressions:
else:
return PsArrayAccess(ptr, index)
def map_Function(self, func: sp.Function) -> PsCall:
def map_Function(self, func: sp.Function) -> PsExpression:
"""Map SymPy function calls by mapping sympy function classes to backend-supported function symbols.
SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`.
If applicable, functions are mapped to binary operators, e.g. `backend.ast.expressions.PsBitwiseXor`.
Other SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`.
"""
args = tuple(self.visit_expr(arg) for arg in func.args)
match func:
case sp.Abs():
func_symbol = PsMathFunction(MathFunctions.Abs)
return PsCall(PsMathFunction(MathFunctions.Abs), args)
case sp.exp():
func_symbol = PsMathFunction(MathFunctions.Exp)
return PsCall(PsMathFunction(MathFunctions.Exp), args)
case sp.sin():
func_symbol = PsMathFunction(MathFunctions.Sin)
return PsCall(PsMathFunction(MathFunctions.Sin), args)
case sp.cos():
func_symbol = PsMathFunction(MathFunctions.Cos)
return PsCall(PsMathFunction(MathFunctions.Cos), args)
case sp.tan():
func_symbol = PsMathFunction(MathFunctions.Tan)
return PsCall(PsMathFunction(MathFunctions.Tan), args)
case integer_functions.int_div():
return PsIntDiv(*args)
case integer_functions.bit_shift_left():
return PsLeftShift(*args)
case integer_functions.bit_shift_right():
return PsRightShift(*args)
case integer_functions.bitwise_and():
return PsBitwiseAnd(*args)
case integer_functions.bitwise_xor():
return PsBitwiseXor(*args)
case integer_functions.bitwise_or():
return PsBitwiseOr(*args)
case integer_functions.int_power_of_2():
return PsLeftShift(PsExpression.make(PsConstant(1)), args[0])
# TODO: what exactly are the semantics?
# case integer_functions.modulo_floor():
# case integer_functions.div_floor()
# TODO: requires if *expression*
# case integer_functions.modulo_ceil():
# case integer_functions.div_ceil():
case _:
raise FreezeError(f"Unsupported function: {func}")
args = tuple(self.visit_expr(arg) for arg in func.args)
return PsCall(func_symbol, args)
def map_Min(self, expr: sp.Min) -> PsCall:
args = tuple(self.visit_expr(arg) for arg in expr.args)
return PsCall(PsMathFunction(MathFunctions.Min), args)
......
......@@ -22,15 +22,21 @@ from ..ast.structural import (
PsAssignment,
)
from ..ast.expressions import (
PsSymbolExpr,
PsConstantExpr,
PsBinOp,
PsArrayAccess,
PsSubscript,
PsLookup,
PsCall,
PsArrayInitList,
PsBinOp,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsCall,
PsCast,
PsConstantExpr,
PsIntDiv,
PsLeftShift,
PsLookup,
PsRightShift,
PsSubscript,
PsSymbolExpr,
)
from ..functions import PsMathFunction
......@@ -258,6 +264,36 @@ class Typifier:
tc.apply_and_check(expr, member.dtype)
# integer operations
case (
PsIntDiv(op1, op2)
| PsLeftShift(op1, op2)
| PsRightShift(op1, op2)
| PsBitwiseAnd(op1, op2)
| PsBitwiseXor(op1, op2)
| PsBitwiseOr(op1, op2)
):
if tc.target_type is not None and not isinstance(
tc.target_type, PsIntegerType
):
raise TypificationError(
f"Integer expression used in non-integer context.\n"
f" Integer expression: {expr}\n"
f" Context type: {tc.target_type}"
)
self.visit_expr(op1, tc)
self.visit_expr(op2, tc)
if tc.target_type is None:
raise TypificationError(
f"Unable to infer type of integer expression {expr}."
)
elif not isinstance(tc.target_type, PsIntegerType):
raise TypificationError(
f"Argument(s) to integer function are non-integer in expression {expr}."
)
case PsBinOp(op1, op2):
self.visit_expr(op1, tc)
self.visit_expr(op2, tc)
......
......@@ -51,6 +51,8 @@ class int_div(IntegerFunctionTwoArgsMixIn):
# noinspection PyPep8Naming
# TODO: What do the *two* arguments mean?
# Apparently, the second is required but ignored?
class int_power_of_2(IntegerFunctionTwoArgsMixIn):
pass
......
......@@ -4,11 +4,19 @@ from pystencils import Assignment, fields
from pystencils.backend.ast.structural import (
PsAssignment,
PsBlock,
PsDeclaration,
)
from pystencils.backend.ast.expressions import (
PsArrayAccess,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsExpression,
PsArrayAccess
PsIntDiv,
PsLeftShift,
PsMul,
PsRightShift,
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.kernelcreation import (
......@@ -17,6 +25,17 @@ from pystencils.backend.kernelcreation import (
FullIterationSpace,
)
from pystencils.sympyextensions.integer_functions import (
bit_shift_left,
bit_shift_right,
bitwise_and,
bitwise_or,
bitwise_xor,
int_div,
int_power_of_2,
modulo_floor,
)
def test_freeze_simple():
ctx = KernelCreationContext()
......@@ -63,9 +82,66 @@ def test_freeze_fields():
zero = PsExpression.make(PsConstant(0))
lhs = PsArrayAccess(f_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0]) + zero * one)
rhs = PsArrayAccess(g_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0]) + zero * one)
lhs = PsArrayAccess(
f_arr.base_pointer,
(PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0])
+ zero * one,
)
rhs = PsArrayAccess(
g_arr.base_pointer,
(PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0])
+ zero * one,
)
should = PsAssignment(lhs, rhs)
assert fasm.structurally_equal(should)
def test_freeze_integer_binops():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
expr = bit_shift_left(
bit_shift_right(bitwise_and(x, y), bitwise_or(y, z)), bitwise_xor(x, z)
)
fexpr = freeze(expr)
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
should = PsLeftShift(
PsRightShift(PsBitwiseAnd(x2, y2), PsBitwiseOr(y2, z2)), PsBitwiseXor(x2, z2)
)
assert fexpr.structurally_equal(should)
def test_freeze_integer_functions():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x2 = PsExpression.make(ctx.get_symbol("x", ctx.index_dtype))
y2 = PsExpression.make(ctx.get_symbol("y", ctx.index_dtype))
z2 = PsExpression.make(ctx.get_symbol("z", ctx.index_dtype))
x, y, z = sp.symbols("x, y, z")
asms = [
Assignment(z, int_div(x, y)),
Assignment(z, int_power_of_2(x, y)),
# Assignment(z, modulo_floor(x, y)),
]
fasms = [freeze(asm) for asm in asms]
should = [
PsDeclaration(z2, PsIntDiv(x2, y2)),
PsDeclaration(z2, PsLeftShift(PsExpression.make(PsConstant(1)), x2)),
# PsDeclaration(z2, PsMul(PsIntDiv(x2, y2), y2)),
]
for fasm, correct in zip(fasms, should):
assert fasm.structurally_equal(correct)
......@@ -12,6 +12,14 @@ from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
from pystencils.sympyextensions.integer_functions import (
bit_shift_left,
bit_shift_right,
bitwise_and,
bitwise_xor,
bitwise_or,
)
def test_typify_simple():
ctx = KernelCreationContext()
......@@ -114,3 +122,65 @@ def test_erronous_typing():
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
def test_typify_integer_binops():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ctx.get_symbol("x", ctx.index_dtype)
ctx.get_symbol("y", ctx.index_dtype)
ctx.get_symbol("z", ctx.index_dtype)
x, y, z = sp.symbols("x, y, z")
expr = bit_shift_left(
bit_shift_right(bitwise_and(x, 2), bitwise_or(y, z)), bitwise_xor(2, 2)
) # ^
# TODO: x can not be a constant here, because then the typifier can not check that the arguments are integer.
expr = freeze(expr)
expr = typify(expr)
def check(expr):
match expr:
case PsConstantExpr(cs):
assert cs.value == 2
assert cs.dtype == constify(ctx.index_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.index_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(expr)
def test_typify_integer_binops_floating_arg():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x = sp.Symbol("x")
expr = bit_shift_left(x, 2)
expr = freeze(expr)
with pytest.raises(TypificationError):
expr = typify(expr)
def test_typify_integer_binops_in_floating_context():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ctx.get_symbol("i", ctx.index_dtype)
x, i = sp.symbols("x, i")
expr = x + bit_shift_left(i, 2)
expr = freeze(expr)
with pytest.raises(TypificationError):
expr = typify(expr)
......@@ -75,3 +75,27 @@ def test_arithmetic_precedence():
expr = (a / b) + (c / (d + e) * f)
code = cprint(expr)
assert code == "a / b + c / (d + e) * f"
def test_printing_integer_functions():
(i, j, k) = [PsExpression.make(PsSymbol(x, UInt(64))) for x in "ijk"]
cprint = CAstPrinter()
from pystencils.backend.ast.expressions import (
PsLeftShift,
PsRightShift,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsIntDiv,
)
expr = PsBitwiseAnd(
PsBitwiseXor(
PsBitwiseXor(j, k),
PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)),
),
i,
)
code = cprint(expr)
assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i"
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment