import sympy as sp import pytest from pystencils import Assignment, fields from pystencils.backend.ast.structural import ( PsAssignment, PsBlock, PsDeclaration, ) from pystencils.backend.ast.expressions import ( PsArrayAccess, PsBitwiseAnd, PsBitwiseOr, PsBitwiseXor, PsExpression, PsIntDiv, PsLeftShift, PsRightShift, PsAnd, PsOr, PsNot, PsEq, PsNe, PsLt, PsLe, PsGt, PsGe ) from pystencils.backend.constants import PsConstant from pystencils.backend.kernelcreation import ( KernelCreationContext, FreezeExpressions, FullIterationSpace, ) from pystencils.sympyextensions.integer_functions import ( bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, ) def test_freeze_simple(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) x, y, z = sp.symbols("x, y, z") asm = Assignment(z, 2 * x + y) fasm = freeze(asm) x2 = PsExpression.make(ctx.get_symbol("x")) y2 = PsExpression.make(ctx.get_symbol("y")) z2 = PsExpression.make(ctx.get_symbol("z")) two = PsExpression.make(PsConstant(2)) should = PsDeclaration(z2, y2 + two * x2) assert fasm.structurally_equal(should) assert not fasm.structurally_equal(PsAssignment(z2, two * x2 + y2)) def test_freeze_fields(): ctx = KernelCreationContext() zero = PsExpression.make(PsConstant(0, ctx.index_dtype)) forty_two = PsExpression.make(PsConstant(42, ctx.index_dtype)) one = PsExpression.make(PsConstant(1, ctx.index_dtype)) counter = ctx.get_symbol("ctr", ctx.index_dtype) ispace = FullIterationSpace( ctx, [FullIterationSpace.Dimension(zero, forty_two, one, counter)] ) ctx.set_iteration_space(ispace) freeze = FreezeExpressions(ctx) f, g = fields("f, g : [1D]") asm = Assignment(f.center(0), g.center(0)) f_arr = ctx.get_array(f) g_arr = ctx.get_array(g) fasm = freeze(asm) zero = PsExpression.make(PsConstant(0)) lhs = PsArrayAccess( f_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0]) + zero * one, ) rhs = PsArrayAccess( g_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0]) + zero * one, ) should = PsAssignment(lhs, rhs) assert fasm.structurally_equal(should) def test_freeze_integer_binops(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) x, y, z = sp.symbols("x, y, z") expr = bit_shift_left( bit_shift_right(bitwise_and(x, y), bitwise_or(y, z)), bitwise_xor(x, z) ) fexpr = freeze(expr) x2 = PsExpression.make(ctx.get_symbol("x")) y2 = PsExpression.make(ctx.get_symbol("y")) z2 = PsExpression.make(ctx.get_symbol("z")) should = PsLeftShift( PsRightShift(PsBitwiseAnd(x2, y2), PsBitwiseOr(y2, z2)), PsBitwiseXor(x2, z2) ) assert fexpr.structurally_equal(should) def test_freeze_integer_functions(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) x2 = PsExpression.make(ctx.get_symbol("x", ctx.index_dtype)) y2 = PsExpression.make(ctx.get_symbol("y", ctx.index_dtype)) z2 = PsExpression.make(ctx.get_symbol("z", ctx.index_dtype)) x, y, z = sp.symbols("x, y, z") asms = [ Assignment(z, int_div(x, y)), Assignment(z, int_power_of_2(x, y)), # Assignment(z, modulo_floor(x, y)), ] fasms = [freeze(asm) for asm in asms] should = [ PsDeclaration(z2, PsIntDiv(x2, y2)), PsDeclaration(z2, PsLeftShift(PsExpression.make(PsConstant(1)), x2)), # PsDeclaration(z2, PsMul(PsIntDiv(x2, y2), y2)), ] for fasm, correct in zip(fasms, should): assert fasm.structurally_equal(correct) def test_freeze_booleans(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) x2 = PsExpression.make(ctx.get_symbol("x")) y2 = PsExpression.make(ctx.get_symbol("y")) z2 = PsExpression.make(ctx.get_symbol("z")) x, y, z = sp.symbols("x, y, z") expr1 = freeze(sp.Not(sp.And(x, y))) assert expr1.structurally_equal(PsNot(PsAnd(x2, y2))) expr2 = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x)))) assert expr2.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2)))) @pytest.mark.parametrize("rel_pair", [ (sp.Eq, PsEq), (sp.Ne, PsNe), (sp.Lt, PsLt), (sp.Gt, PsGt), (sp.Le, PsLe), (sp.Ge, PsGe) ]) def test_freeze_relations(rel_pair): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) sp_op, ps_op = rel_pair x2 = PsExpression.make(ctx.get_symbol("x")) y2 = PsExpression.make(ctx.get_symbol("y")) z2 = PsExpression.make(ctx.get_symbol("z")) x, y, z = sp.symbols("x, y, z") expr1 = freeze(sp_op(x, y + z)) assert expr1.structurally_equal(ps_op(x2, y2 + z2))