diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index f425eef885203c6305e69fbdd56226dbf0b36aa9..486bd126025612914804a1011aaecc10e2912f59 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -443,9 +443,8 @@ class CustomSympyPrinter(CCodePrinter): def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" - if not expr.free_symbols: + if isinstance(expr.exp, sp.Integer) and (-8 < expr.exp < 8): raise NotImplementedError("This pow should be simplified already?") - # return self._typed_number(expr.evalf(), get_type_of_expression(expr.base)) return super(CustomSympyPrinter, self)._print_Pow(expr) # TODO don't print ones in sp.Mul @@ -508,13 +507,13 @@ class CustomSympyPrinter(CCodePrinter): else: return f"(({data_type})({self._print(arg)}))" elif isinstance(expr, fast_division): - return f"({self._print(expr.args[0] / expr.args[1])})" + raise ValueError("fast_division is only supported for Taget.GPU") elif isinstance(expr, fast_sqrt): - return f"({self._print(sp.sqrt(expr.args[0]))})" + raise ValueError("fast_sqrt is only supported for Taget.GPU") + elif isinstance(expr, fast_inv_sqrt): + raise ValueError("fast_inv_sqrt is only supported for Taget.GPU") elif isinstance(expr, vec_any) or isinstance(expr, vec_all): return self._print(expr.args[0]) - elif isinstance(expr, fast_inv_sqrt): - return f"({self._print(1 / sp.sqrt(expr.args[0]))})" elif isinstance(expr, sp.Abs): return f"abs({self._print(expr.args[0])})" elif isinstance(expr, sp.Mod): @@ -681,21 +680,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend), **self._kwargs) return result - elif expr.func == fast_division: - result = self._scalarFallback('_print_Function', expr) - if not result: - result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]), - **self._kwargs) - return result - elif expr.func == fast_sqrt: - return f"({self._print(sp.sqrt(expr.args[0]))})" - elif expr.func == fast_inv_sqrt: - result = self._scalarFallback('_print_Function', expr) - if not result: - if 'rsqrt' in self.instruction_set: - return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs) - else: - return f"({self._print(1 / sp.sqrt(expr.args[0]))})" + elif isinstance(expr, fast_division): + raise ValueError("fast_division is only supported for Taget.GPU") + elif isinstance(expr, fast_sqrt): + raise ValueError("fast_sqrt is only supported for Taget.GPU") + elif isinstance(expr, fast_inv_sqrt): + raise ValueError("fast_inv_sqrt is only supported for Taget.GPU") elif isinstance(expr, vec_any) or isinstance(expr, vec_all): instr = 'any' if isinstance(expr, vec_any) else 'all' expr_type = get_type_of_expression(expr.args[0]) diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index b07707edc8dafc7da5313bb6acef65f2e4381ad5..1a99fa8eb7f67fa89e650795f79d9f8ad2c3cb01 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -10,6 +10,7 @@ from sympy.functions import Abs from sympy.core.numbers import Zero from pystencils.assignment import Assignment +from pystencils.functions import DivFunc from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType from pystencils.typing.typed_sympy import FieldPointerSymbol @@ -158,17 +159,23 @@ def fast_subs(expression: T, substitutions: Dict, if type(expression) is sp.Matrix: return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions)) - def visit(expr): + def visit(expr, evaluate=True): if skip and skip(expr): return expr - if hasattr(expr, "fast_subs"): + elif hasattr(expr, "fast_subs"): return expr.fast_subs(substitutions, skip) - if expr in substitutions: + elif expr in substitutions: return substitutions[expr] - if not hasattr(expr, 'args'): + elif not hasattr(expr, 'args'): return expr - param_list = [visit(a) for a in expr.args] - return expr if not param_list else expr.func(*param_list) + elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)): + args = [visit(a, False) for a in expr.args] + return expr.func(*args) + else: + param_list = [visit(a, evaluate) for a in expr.args] + if isinstance(expr, (sp.Mul, sp.Add)): + return expr if not param_list else expr.func(*param_list, evaluate=evaluate) + return expr if not param_list else expr.func(*param_list) if len(substitutions) == 0: return expression diff --git a/pystencils/transformations.py b/pystencils/transformations.py index c022e728db0c7df47368be26941842e9664b2c76..43beefd25f79eebe8b78c40d46186e69ae195d64 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -9,8 +9,8 @@ import sympy as sp import pystencils.astnodes as ast from pystencils.assignment import Assignment -from pystencils.typing import ( - PointerType, StructType, TypedSymbol, get_base_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type) +from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type, + ReinterpretCastFunc, get_next_parent_of_type, parents_of_type) from pystencils.field import Field, FieldType from pystencils.typing import FieldPointerSymbol from pystencils.simp.assignment_collection import AssignmentCollection @@ -607,13 +607,7 @@ def move_constants_before_loop(ast_node): get_blocks(ast_node, all_blocks) for block in all_blocks: children = block.take_child_nodes() - # Every time a symbol can be replaced in the current block because the assignment - # was found in a parent block, but with a different lhs symbol (same rhs) - # the outer symbol is inserted here as key. - substitute_variables = {} for child in children: - # Before traversing the next child, all symbols are substituted first. - child.subs(substitute_variables) if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments block.append(child) @@ -629,14 +623,7 @@ def move_constants_before_loop(ast_node): exists_already = False if not exists_already: - rhs_identical = check_if_assignment_already_in_block(child, target, True) - if rhs_identical: - # there is already an assignment out there with the same rhs - # -> replace all lhs symbols in this block with the lhs of the outer assignment - # -> remove the local assignment (do not re-append child to the former block) - substitute_variables[child.lhs] = rhs_identical.lhs - else: - target.insert_before(child, child_to_insert_before) + target.insert_before(child, child_to_insert_before) elif exists_already and exists_already.rhs == child.rhs: if target.args.index(exists_already) > target.args.index(child_to_insert_before): assert target.args.count(exists_already) == 1 @@ -650,7 +637,6 @@ def move_constants_before_loop(ast_node): new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype) target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const), child_to_insert_before) - substitute_variables[child.lhs] = new_symbol def split_inner_loop(ast_node: ast.Node, symbol_groups): @@ -771,12 +757,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa This analysis needs the integer set library (ISL) islpy, so it is not done by default. """ + from sympy.codegen.rewriting import ReplaceOptim, optimize + remove_casts = ReplaceOptim(lambda e: isinstance(e, CastFunc), lambda p: p.expr) + for conditional in node.atoms(ast.Conditional): # TODO simplify conditional before the type system! Casts make it very hard here - # conditional.condition_expr = sp.simplify(conditional.condition_expr) - if conditional.condition_expr == sp.true: + condition_expression = optimize(conditional.condition_expr, [remove_casts]) + condition_expression = sp.simplify(condition_expression) + if condition_expression == sp.true: conditional.parent.replace(conditional, [conditional.true_block]) - elif conditional.condition_expr == sp.false: + elif condition_expression == sp.false: conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else []) elif loop_counter_simplification: try: diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index ddffd61ced02b3603e7a21a784860d49127e1b5f..b0928d0b79ef657f7a3882cbbd39433c5a2b9fe1 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -236,6 +236,10 @@ class TypeAdder: else: raise NotImplementedError(f'Pointer Arithmetic is implemented only for Add, not {expr}') new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] - return expr.func(*new_args) if new_args else expr, collated_type + + if isinstance(expr, (sp.Add, sp.Mul)): + return expr.func(*new_args, evaluate=False) if new_args else expr, collated_type + else: + return expr.func(*new_args) if new_args else expr, collated_type else: raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing') diff --git a/pystencils_tests/test_create_kernel_config.py b/pystencils_tests/test_create_kernel_config.py index e8ad310c778e2ace3e49681acc3aef552b8f22bc..c3a211b80e90515c99082edd87025f4cbe478766 100644 --- a/pystencils_tests/test_create_kernel_config.py +++ b/pystencils_tests/test_create_kernel_config.py @@ -1,4 +1,5 @@ import numpy as np +import sympy as sp import pystencils as ps import pystencils.config @@ -25,3 +26,21 @@ def test_kernel_decorator_config(): a[0] @= b[0] + c[0] ps.create_kernel(**test) + + +def test_kernel_decorator2(): + h = sp.symbols("h") + dtype = "float64" + + src, dst = ps.fields(f"src, src_tmp: {dtype}[3D]") + + @ps.kernel + def kernel_func(): + dst[0, 0, 0] @= (src[1, 0, 0] + src[-1, 0, 0] + + src[0, 1, 0] + src[0, -1, 0] + + src[0, 0, 1] + src[0, 0, -1]) / (6 * h ** 2) + + # assignments = ps.assignment_from_stencil(stencil, src, dst, normalization_factor=2) + ast = ps.create_kernel(kernel_func) + + code = ps.get_code_str(ast) diff --git a/pystencils_tests/test_fvm.py b/pystencils_tests/test_fvm.py index 9c7c1323311eeda5243b72fa42b87bf4734ff5eb..e4e3cacd53d34a6ea2079e1acbc82f93b8e23585 100644 --- a/pystencils_tests/test_fvm.py +++ b/pystencils_tests/test_fvm.py @@ -622,3 +622,20 @@ def test_source_stencil(stencil): assert len(diff.atoms(ps.field.Field.Access)) == 1 else: assert len(diff.atoms(ps.field.Field.Access)) == 2 + + +def test_fvm_staggered_simplification(): + D = sp.Symbol("D") + data_type = "float64" + + c = ps.fields(f"c: {data_type}[2D]", layout='fzyx') + j = ps.fields(f"j(2): {data_type}[2D]", layout='fzyx', field_type=ps.FieldType.STAGGERED_FLUX) + + grad_c = sp.Matrix([ps.fd.diff(c, i) for i in range(c.spatial_dimensions)]) + + ek = ps.fd.FVM1stOrder(c, flux=-D * grad_c) + + ast = ps.create_staggered_kernel(ek.discrete_flux(j)) + code = ps.get_code_str(ast) + + assert '_size_c_0 - 1 < _size_c_0 - 1' not in code diff --git a/pystencils_tests/test_math_functions.py b/pystencils_tests/test_math_functions.py index 5655fbda60012d612ff7686f2d7288784b3a004f..eacb490e9d84c8a47be7b16750951b7924adc80f 100644 --- a/pystencils_tests/test_math_functions.py +++ b/pystencils_tests/test_math_functions.py @@ -2,6 +2,7 @@ import pytest import sympy as sp import numpy as np import pystencils as ps +from pystencils.fast_approximation import fast_division @pytest.mark.parametrize('dtype', ["float64", "float32"]) @@ -66,3 +67,47 @@ def test_single_arguments(dtype, func, target): np.testing.assert_allclose(dh.gather_array("x")[0, 0], float(func(1.0).evalf()), rtol=10**-3 if dtype == 'float32' else 10**-5) + + +@pytest.mark.parametrize('a', [sp.Symbol('a'), ps.fields('a: float64[2d]').center]) +def test_avoid_pow(a): + x = ps.fields('x: float64[2d]') + + up = ps.Assignment(x.center_vector[0], 2 * a ** 2 / 3) + ast = ps.create_kernel(up) + + code = ps.get_code_str(ast) + + assert "pow" not in code + + +def test_avoid_pow_fast_div(): + x = ps.fields('x: float64[2d]') + a = ps.fields('a: float64[2d]').center + + up = ps.Assignment(x.center_vector[0], fast_division(1, (a**2))) + ast = ps.create_kernel(up, config=ps.CreateKernelConfig(target=ps.Target.GPU)) + # ps.show_code(ast) + + code = ps.get_code_str(ast) + + assert "pow" not in code + + +def test_avoid_pow_move_constants(): + # At the end of the kernel creation the function move_constants_before_loop will be called + # This function additionally contains substitutions for symbols with the same value + # Thus it simplifies the equations again + x = ps.fields('x: float64[2d]') + a, b, c = sp.symbols("a, b, c") + + up = [ps.Assignment(a, 0.0), + ps.Assignment(b, 0.0), + ps.Assignment(c, 0.0), + ps.Assignment(x.center_vector[0], a**2/18 - a*b/6 - a/18 + b**2/18 + b/18 - c**2/36)] + ast = ps.create_kernel(up) + + code = ps.get_code_str(ast) + ps.show_code(ast) + + assert "pow" not in code