From 3a8cc5ae34aeefc2eced5aef335a67f3d48e7a53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de> Date: Wed, 8 Dec 2021 15:55:31 +0100 Subject: [PATCH] Minor fixes --- pystencils/backends/cbackend.py | 38 +++++---------------- pystencils/functions.py | 26 +++++++++++++++ pystencils/simp/assignment_collection.py | 19 ++++++++--- pystencils/typing/cast_functions.py | 8 +++-- pystencils/typing/leaf_typing.py | 12 +++---- pystencils/typing/types.py | 7 +--- pystencils_tests/test_types.py | 42 +++++------------------- 7 files changed, 70 insertions(+), 82 deletions(-) create mode 100644 pystencils/functions.py diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index b5dc66e10..4aa1d0964 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -16,6 +16,7 @@ from pystencils.typing import ( ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol) from pystencils.enums import Backend from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt +from pystencils.functions import DivFunc from pystencils.integer_functions import ( bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) @@ -436,19 +437,14 @@ class CustomSympyPrinter(CCodePrinter): def __init__(self): super(CustomSympyPrinter, self).__init__() - self._float_type = create_type("float32") def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" if not expr.free_symbols: return self._typed_number(expr.evalf(), get_type_of_expression(expr.base)) + return super(CustomSympyPrinter, self)._print_Pow(expr) - if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: - return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" - elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: - return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" - else: - return super(CustomSympyPrinter, self)._print_Pow(expr) + # TODO don't print ones in sp.Mul def _print_Rational(self, expr): """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" @@ -491,7 +487,10 @@ class CustomSympyPrinter(CCodePrinter): return f"&({self._print(expr.args[0])})" elif isinstance(expr, CastFunc): arg, data_type = expr.args - return f"(({data_type})({self._print(arg)}))" + if arg.is_Number: + return self._typed_number(arg, data_type) + else: + return f"(({data_type})({self._print(arg)}))" elif isinstance(expr, fast_division): return f"({self._print(expr.args[0] / expr.args[1])})" elif isinstance(expr, fast_sqrt): @@ -515,6 +514,8 @@ class CustomSympyPrinter(CCodePrinter): return f"(1 << ({self._print(expr.args[0])}))" elif expr.func == int_div: return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))" + elif expr.func == DivFunc: + return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))' else: name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__ arg_str = ', '.join(self._print(a) for a in expr.args) @@ -606,27 +607,6 @@ class CustomSympyPrinter(CCodePrinter): return f"(({a} < {b}) ? {a} : {b})" return inner_print_min(expr.args) - def _print_re(self, expr): - return f"real({self._print(expr.args[0])})" - - def _print_im(self, expr): - return f"imag({self._print(expr.args[0])})" - - def _print_ImaginaryUnit(self, expr): - return "complex<double>{0,1}" - - def _print_TypedImaginaryUnit(self, expr): - if expr.dtype.numpy_dtype == np.complex64: - return "complex<float>{0,1}" - elif expr.dtype.numpy_dtype == np.complex128: - return "complex<double>{0,1}" - else: - raise NotImplementedError( - "only complex64 and complex128 supported") - - def _print_Complex(self, expr): - return self._typed_number(expr, np.complex64) - # noinspection PyPep8Naming class VectorizedCustomSympyPrinter(CustomSympyPrinter): diff --git a/pystencils/functions.py b/pystencils/functions.py new file mode 100644 index 000000000..b1f349622 --- /dev/null +++ b/pystencils/functions.py @@ -0,0 +1,26 @@ +import sympy as sp + + +class DivFunc(sp.Function): + # TODO: documentation + is_Atom = True + is_real = True + + def __new__(cls, *args, **kwargs): + if len(args) != 2: + raise ValueError(f'{cls} takes only 2 arguments, instead {len(args)} received!') + divisor, dividend, *other_args = args + + return sp.Function.__new__(cls, divisor, dividend, *other_args, **kwargs) + + def _eval_evalf(self, *args, **kwargs): + return self.divisor.evalf() / self.dividend.evalf() + + @property + def divisor(self): + return self.args[0] + + @property + def dividend(self): + return self.args[1] + diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 7309e7d87..5a6f0d010 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -7,6 +7,7 @@ from sympy.codegen.rewriting import ReplaceOptim, optimize import pystencils from pystencils.assignment import Assignment +from pystencils.functions import DivFunc from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) from pystencils.sympyextensions import count_operations, fast_subs @@ -371,15 +372,23 @@ class AssignmentCollection: lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, lambda p: p.evalf()) - sympy_optimisations = [evaluate_constant_terms] + evaluate_pow = ReplaceOptim( + lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8, + lambda p: ( + sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else + DivFunc(sp.Integer(1), sp.Mul(*([p.base] * -p.exp), evaluate=False)) + )) + + sympy_optimisations = [evaluate_constant_terms, evaluate_pow] self.subexpressions = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) - if hasattr(a, 'lhs') - else a for a in self.subexpressions] + if hasattr(a, 'lhs') + else a for a in self.subexpressions] self.main_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) - if hasattr(a, 'lhs') - else a for a in self.main_assignments] + if hasattr(a, 'lhs') + else a for a in self.main_assignments] + # ----------------------------------------- Display and Printing ------------------------------------------------- def _repr_html_(self): diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py index e93a410a8..2e29b74bb 100644 --- a/pystencils/typing/cast_functions.py +++ b/pystencils/typing/cast_functions.py @@ -8,16 +8,18 @@ from pystencils.typing.typed_sympy import TypedSymbol class CastFunc(sp.Function): # TODO: documentation - # TODO: move function to `functions.py` is_Atom = True def __new__(cls, *args, **kwargs): if len(args) != 2: pass expr, dtype, *other_args = args + + # If we have two consecutive casts, throw the inner one away + if isinstance(expr, CastFunc): + expr = expr.args[0] if not isinstance(dtype, AbstractType): - raise NotImplementedError(f'{dtype} is not a subclass of AbstractType') - dtype = create_type(dtype) + dtype = BasicType(dtype) # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads # to problems when for example comparing cast_func's for equality diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index df36c0d91..ae66d3849 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -115,7 +115,7 @@ class TypeAdder: data_type = self.default_number_float.get() else: assert False, f'{sp.Number} is neither Float nor Integer' - return expr, data_type + return CastFunc(expr, data_type), data_type elif isinstance(expr, BooleanAtom): return expr, BasicType('bool') elif isinstance(expr, sp.Equality): @@ -130,16 +130,16 @@ class TypeAdder: elif isinstance(expr, flag_cond): # do not process the arguments to the bit shift - they must remain integers raise NotImplementedError('flag_cond') - elif isinstance(expr, sp.Mul): - raise NotImplementedError('sp.Mul') - # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? - # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)] + #elif isinstance(expr, sp.Mul): + # raise NotImplementedError('sp.Mul') + # # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? + # # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)] elif isinstance(expr, sp.Indexed): raise NotImplementedError('sp.Indexed') elif isinstance(expr, sp.Pow): args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) - return expr, collated_type + return expr.func(*[a for a, _ in args_types]), collated_type elif isinstance(expr, ExprCondPair): expr_expr, expr_type = self.figure_out_type(expr.expr) condition, condition_type = self.figure_out_type(expr.cond) diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py index 9ec46d5c1..9bab35c5f 100644 --- a/pystencils/typing/types.py +++ b/pystencils/typing/types.py @@ -1,9 +1,8 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Union import numpy as np import sympy as sp -import sympy.codegen.ast def is_supported_type(dtype: np.dtype): @@ -86,10 +85,6 @@ class BasicType(AbstractType): def base_type(self): return None - @property - def sympy_dtype(self): - return getattr(sympy.codegen.ast, str(self.numpy_dtype)) - @property def item_size(self): # TODO: what is this? Do we want self.numpy_type.itemsize???? return 1 diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index d1b0d33cb..631855d17 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -172,6 +172,7 @@ def test_sqrt_of_integer(dtype): assert constant not in code +# TODO this @pytest.mark.parametrize('dtype', ('float64', 'float32')) def test_integer_comparision(dtype): f = ps.fields(f"f: {dtype}[2D]") @@ -182,7 +183,6 @@ def test_integer_comparision(dtype): ast = ps.create_kernel(ur) code = ps.get_code_str(ast) - print(code) # There should be an explicit cast for the integer zero to the type of the field on the rhs if dtype == 'float64': t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (((double)(0))): (_data_f_00[_stride_f_1*ctr_1]));" @@ -192,44 +192,21 @@ def test_integer_comparision(dtype): assert t in code -# TODO this -def test_Basic_data_type(): +def test_typed_symbols_dtype(): assert typed_symbols(("s", "f"), np.uint) == typed_symbols("s, f", np.uint) t_symbols = typed_symbols(("s", "f"), np.uint) s = t_symbols[0] assert t_symbols[0] == TypedSymbol("s", np.uint) assert s.dtype.is_uint() - assert s.dtype.is_complex() == 0 - - assert typed_symbols("s", str).dtype.is_other() - assert typed_symbols("s", bool).dtype.is_other() - assert typed_symbols("s", np.void).dtype.is_other() assert typed_symbols("s", np.float64).dtype.c_name == 'double' - # removed for old sympy version - # assert typed_symbols(("s"), np.float64).dtype.sympy_dtype == typed_symbols(("s"), float).dtype.sympy_dtype - - f, g = ps.fields("f, g : double[2D]") - - expr = ps.Assignment(f.center(), 2 * g.center() + 5) - new_expr = type_all_numbers(expr, np.float64) - - assert "cast_func(2, double)" in str(new_expr) - assert "cast_func(5, double)" in str(new_expr) - - m = matrix_symbols("a, b", np.uint, 3, 3) - assert len(m) == 2 - m = m[0] - for i, elem in enumerate(m): - assert elem == TypedSymbol(f"a{i}", np.uint) - assert elem.dtype.is_uint() + assert typed_symbols("s", np.float32).dtype.c_name == 'float' assert TypedSymbol("s", np.uint).canonical == TypedSymbol("s", np.uint) assert TypedSymbol("s", np.uint).reversed == TypedSymbol("s", np.uint) -# TODO this def test_cast_func(): assert CastFunc(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical @@ -242,21 +219,19 @@ def test_pointer_arithmetic_func(): assert PointerArithmeticFunc(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical -# TODO this def test_division(): f = ps.fields('f(10): float32[2D]') m, tau = sp.symbols("m, tau") - up = [ps.Assignment(tau, 1.0 / (0.5 + (3.0 * m))), + up = [ps.Assignment(tau, 1 / (0.5 + (3.0 * m))), ps.Assignment(f.center, tau)] - - ast = ps.create_kernel(up, config=pystencils.config.CreateKernelConfig(data_type="float32")) + config = pystencils.config.CreateKernelConfig(data_type='float32', default_number_float='float32') + ast = ps.create_kernel(up, config=config) code = ps.get_code_str(ast) - assert "1.0f" in code + assert "((1.0f) / (m*3.0f + 0.5f))" in code -# TODO this def test_pow(): f = ps.fields('f(10): float32[2D]') m, tau = sp.symbols("m, tau") @@ -264,7 +239,8 @@ def test_pow(): up = [ps.Assignment(tau, m ** 1.5), ps.Assignment(f.center, tau)] - ast = ps.create_kernel(up, config=pystencils.config.CreateKernelConfig(data_type="float32")) + config = pystencils.config.CreateKernelConfig(data_type="float32", default_number_float='float32') + ast = ps.create_kernel(up, config=config) code = ps.get_code_str(ast) assert "1.5f" in code -- GitLab