Skip to content
Snippets Groups Projects
Commit 5f364431 authored by Jan Hönig's avatar Jan Hönig
Browse files

Merge branch 'Revisions' into 'master'

Revisions

See merge request pycodegen/pystencils!293
parents 46ff7269 a624abc2
Branches
Tags
No related merge requests found
......@@ -443,9 +443,8 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols:
if isinstance(expr.exp, sp.Integer) and (-8 < expr.exp < 8):
raise NotImplementedError("This pow should be simplified already?")
# return self._typed_number(expr.evalf(), get_type_of_expression(expr.base))
return super(CustomSympyPrinter, self)._print_Pow(expr)
# TODO don't print ones in sp.Mul
......@@ -508,13 +507,13 @@ class CustomSympyPrinter(CCodePrinter):
else:
return f"(({data_type})({self._print(arg)}))"
elif isinstance(expr, fast_division):
return f"({self._print(expr.args[0] / expr.args[1])})"
raise ValueError("fast_division is only supported for Taget.GPU")
elif isinstance(expr, fast_sqrt):
return f"({self._print(sp.sqrt(expr.args[0]))})"
raise ValueError("fast_sqrt is only supported for Taget.GPU")
elif isinstance(expr, fast_inv_sqrt):
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt):
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, sp.Abs):
return f"abs({self._print(expr.args[0])})"
elif isinstance(expr, sp.Mod):
......@@ -681,21 +680,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend),
**self._kwargs)
return result
elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr)
if not result:
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]),
**self._kwargs)
return result
elif expr.func == fast_sqrt:
return f"({self._print(sp.sqrt(expr.args[0]))})"
elif expr.func == fast_inv_sqrt:
result = self._scalarFallback('_print_Function', expr)
if not result:
if 'rsqrt' in self.instruction_set:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs)
else:
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, fast_division):
raise ValueError("fast_division is only supported for Taget.GPU")
elif isinstance(expr, fast_sqrt):
raise ValueError("fast_sqrt is only supported for Taget.GPU")
elif isinstance(expr, fast_inv_sqrt):
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
instr = 'any' if isinstance(expr, vec_any) else 'all'
expr_type = get_type_of_expression(expr.args[0])
......
......@@ -10,6 +10,7 @@ from sympy.functions import Abs
from sympy.core.numbers import Zero
from pystencils.assignment import Assignment
from pystencils.functions import DivFunc
from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
from pystencils.typing.typed_sympy import FieldPointerSymbol
......@@ -158,17 +159,23 @@ def fast_subs(expression: T, substitutions: Dict,
if type(expression) is sp.Matrix:
return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))
def visit(expr):
def visit(expr, evaluate=True):
if skip and skip(expr):
return expr
if hasattr(expr, "fast_subs"):
elif hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions, skip)
if expr in substitutions:
elif expr in substitutions:
return substitutions[expr]
if not hasattr(expr, 'args'):
elif not hasattr(expr, 'args'):
return expr
param_list = [visit(a) for a in expr.args]
return expr if not param_list else expr.func(*param_list)
elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)):
args = [visit(a, False) for a in expr.args]
return expr.func(*args)
else:
param_list = [visit(a, evaluate) for a in expr.args]
if isinstance(expr, (sp.Mul, sp.Add)):
return expr if not param_list else expr.func(*param_list, evaluate=evaluate)
return expr if not param_list else expr.func(*param_list)
if len(substitutions) == 0:
return expression
......
......@@ -9,8 +9,8 @@ import sympy as sp
import pystencils.astnodes as ast
from pystencils.assignment import Assignment
from pystencils.typing import (
PointerType, StructType, TypedSymbol, get_base_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
from pystencils.field import Field, FieldType
from pystencils.typing import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
......@@ -607,13 +607,7 @@ def move_constants_before_loop(ast_node):
get_blocks(ast_node, all_blocks)
for block in all_blocks:
children = block.take_child_nodes()
# Every time a symbol can be replaced in the current block because the assignment
# was found in a parent block, but with a different lhs symbol (same rhs)
# the outer symbol is inserted here as key.
substitute_variables = {}
for child in children:
# Before traversing the next child, all symbols are substituted first.
child.subs(substitute_variables)
if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments
block.append(child)
......@@ -629,14 +623,7 @@ def move_constants_before_loop(ast_node):
exists_already = False
if not exists_already:
rhs_identical = check_if_assignment_already_in_block(child, target, True)
if rhs_identical:
# there is already an assignment out there with the same rhs
# -> replace all lhs symbols in this block with the lhs of the outer assignment
# -> remove the local assignment (do not re-append child to the former block)
substitute_variables[child.lhs] = rhs_identical.lhs
else:
target.insert_before(child, child_to_insert_before)
target.insert_before(child, child_to_insert_before)
elif exists_already and exists_already.rhs == child.rhs:
if target.args.index(exists_already) > target.args.index(child_to_insert_before):
assert target.args.count(exists_already) == 1
......@@ -650,7 +637,6 @@ def move_constants_before_loop(ast_node):
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
child_to_insert_before)
substitute_variables[child.lhs] = new_symbol
def split_inner_loop(ast_node: ast.Node, symbol_groups):
......@@ -771,12 +757,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa
This analysis needs the integer set library (ISL) islpy, so it is not done by
default.
"""
from sympy.codegen.rewriting import ReplaceOptim, optimize
remove_casts = ReplaceOptim(lambda e: isinstance(e, CastFunc), lambda p: p.expr)
for conditional in node.atoms(ast.Conditional):
# TODO simplify conditional before the type system! Casts make it very hard here
# conditional.condition_expr = sp.simplify(conditional.condition_expr)
if conditional.condition_expr == sp.true:
condition_expression = optimize(conditional.condition_expr, [remove_casts])
condition_expression = sp.simplify(condition_expression)
if condition_expression == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false:
elif condition_expression == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification:
try:
......
......@@ -236,6 +236,10 @@ class TypeAdder:
else:
raise NotImplementedError(f'Pointer Arithmetic is implemented only for Add, not {expr}')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
return expr.func(*new_args) if new_args else expr, collated_type
if isinstance(expr, (sp.Add, sp.Mul)):
return expr.func(*new_args, evaluate=False) if new_args else expr, collated_type
else:
return expr.func(*new_args) if new_args else expr, collated_type
else:
raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing')
import numpy as np
import sympy as sp
import pystencils as ps
import pystencils.config
......@@ -25,3 +26,21 @@ def test_kernel_decorator_config():
a[0] @= b[0] + c[0]
ps.create_kernel(**test)
def test_kernel_decorator2():
h = sp.symbols("h")
dtype = "float64"
src, dst = ps.fields(f"src, src_tmp: {dtype}[3D]")
@ps.kernel
def kernel_func():
dst[0, 0, 0] @= (src[1, 0, 0] + src[-1, 0, 0]
+ src[0, 1, 0] + src[0, -1, 0]
+ src[0, 0, 1] + src[0, 0, -1]) / (6 * h ** 2)
# assignments = ps.assignment_from_stencil(stencil, src, dst, normalization_factor=2)
ast = ps.create_kernel(kernel_func)
code = ps.get_code_str(ast)
......@@ -622,3 +622,20 @@ def test_source_stencil(stencil):
assert len(diff.atoms(ps.field.Field.Access)) == 1
else:
assert len(diff.atoms(ps.field.Field.Access)) == 2
def test_fvm_staggered_simplification():
D = sp.Symbol("D")
data_type = "float64"
c = ps.fields(f"c: {data_type}[2D]", layout='fzyx')
j = ps.fields(f"j(2): {data_type}[2D]", layout='fzyx', field_type=ps.FieldType.STAGGERED_FLUX)
grad_c = sp.Matrix([ps.fd.diff(c, i) for i in range(c.spatial_dimensions)])
ek = ps.fd.FVM1stOrder(c, flux=-D * grad_c)
ast = ps.create_staggered_kernel(ek.discrete_flux(j))
code = ps.get_code_str(ast)
assert '_size_c_0 - 1 < _size_c_0 - 1' not in code
......@@ -2,6 +2,7 @@ import pytest
import sympy as sp
import numpy as np
import pystencils as ps
from pystencils.fast_approximation import fast_division
@pytest.mark.parametrize('dtype', ["float64", "float32"])
......@@ -66,3 +67,47 @@ def test_single_arguments(dtype, func, target):
np.testing.assert_allclose(dh.gather_array("x")[0, 0], float(func(1.0).evalf()),
rtol=10**-3 if dtype == 'float32' else 10**-5)
@pytest.mark.parametrize('a', [sp.Symbol('a'), ps.fields('a: float64[2d]').center])
def test_avoid_pow(a):
x = ps.fields('x: float64[2d]')
up = ps.Assignment(x.center_vector[0], 2 * a ** 2 / 3)
ast = ps.create_kernel(up)
code = ps.get_code_str(ast)
assert "pow" not in code
def test_avoid_pow_fast_div():
x = ps.fields('x: float64[2d]')
a = ps.fields('a: float64[2d]').center
up = ps.Assignment(x.center_vector[0], fast_division(1, (a**2)))
ast = ps.create_kernel(up, config=ps.CreateKernelConfig(target=ps.Target.GPU))
# ps.show_code(ast)
code = ps.get_code_str(ast)
assert "pow" not in code
def test_avoid_pow_move_constants():
# At the end of the kernel creation the function move_constants_before_loop will be called
# This function additionally contains substitutions for symbols with the same value
# Thus it simplifies the equations again
x = ps.fields('x: float64[2d]')
a, b, c = sp.symbols("a, b, c")
up = [ps.Assignment(a, 0.0),
ps.Assignment(b, 0.0),
ps.Assignment(c, 0.0),
ps.Assignment(x.center_vector[0], a**2/18 - a*b/6 - a/18 + b**2/18 + b/18 - c**2/36)]
ast = ps.create_kernel(up)
code = ps.get_code_str(ast)
ps.show_code(ast)
assert "pow" not in code
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment