From 21fb888aa743c1f8c3bf117f778566688bb9c1ae Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 15 Mar 2024 10:07:01 +0100 Subject: [PATCH] move integer_functions and fast_approximations into sympyextensions module --- .../backend/kernelcreation/freeze.py | 3 +- src/pystencils/boundaries/inkernel.py | 2 +- src/pystencils/sympyextensions/astnodes.py | 2 +- .../fast_approximation.py | 0 .../integer_functions.py | 66 +++++++------------ src/pystencils/sympyextensions/math.py | 2 +- src/pystencils/timeloop.py | 2 +- tests/test_fast_approximation.py | 2 +- tests/test_math_functions.py | 2 +- tests/test_sympyextensions.py | 2 +- 10 files changed, 33 insertions(+), 50 deletions(-) rename src/pystencils/{ => sympyextensions}/fast_approximation.py (100%) rename src/pystencils/{ => sympyextensions}/integer_functions.py (65%) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index ecdcf2f94..ebaf22812 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -53,9 +53,8 @@ class FreezeExpressions: - Augmented Assignments - AddressOf - - Conditionals (+ frontend class) - Relations (sp.Relational) - - pystencils.integer_functions + - pystencils.sympyextensions.integer_functions - pystencils.sympyextensions.bit_masks - GPU fast approximations (pystencils.fast_approximation) - ConditionalFieldAccess diff --git a/src/pystencils/boundaries/inkernel.py b/src/pystencils/boundaries/inkernel.py index 7cd9e628b..057704b13 100644 --- a/src/pystencils/boundaries/inkernel.py +++ b/src/pystencils/boundaries/inkernel.py @@ -4,7 +4,7 @@ from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE from pystencils.sympyextensions import TypedSymbol from pystencils.types import create_type from pystencils.field import Field -from pystencils.integer_functions import bitwise_and +from pystencils.sympyextensions.integer_functions import bitwise_and def add_neumann_boundary(eqs, fields, flag_field, boundary_flag="neumann_flag", inverse_flag=False): diff --git a/src/pystencils/sympyextensions/astnodes.py b/src/pystencils/sympyextensions/astnodes.py index 8483977d8..4fdc0f612 100644 --- a/src/pystencils/sympyextensions/astnodes.py +++ b/src/pystencils/sympyextensions/astnodes.py @@ -4,7 +4,7 @@ import uuid from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union import sympy as sp -from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment +from sympy.codegen.ast import Assignment, AugmentedAssignment from sympy.printing.latex import LatexPrinter import numpy as np diff --git a/src/pystencils/fast_approximation.py b/src/pystencils/sympyextensions/fast_approximation.py similarity index 100% rename from src/pystencils/fast_approximation.py rename to src/pystencils/sympyextensions/fast_approximation.py diff --git a/src/pystencils/integer_functions.py b/src/pystencils/sympyextensions/integer_functions.py similarity index 65% rename from src/pystencils/integer_functions.py rename to src/pystencils/sympyextensions/integer_functions.py index 90e79798d..f9c156971 100644 --- a/src/pystencils/integer_functions.py +++ b/src/pystencils/sympyextensions/integer_functions.py @@ -1,9 +1,4 @@ -# TODO #47 move to a module functions -import numpy as np import sympy as sp - -from pystencils.sympyextensions import CastFunc -from pystencils.types import create_type from pystencils.sympyextensions import is_integer_sequence @@ -11,22 +6,7 @@ class IntegerFunctionTwoArgsMixIn(sp.Function): is_integer = True def __new__(cls, arg1, arg2): - args = [] - for a in (arg1, arg2): - if isinstance(a, sp.Number) or isinstance(a, int): - args.append(CastFunc(a, create_type("int"))) - elif isinstance(a, np.generic): - args.append(CastFunc(a, a.dtype)) - else: - args.append(a) - - for a in args: - try: - dtype = get_type_of_expression(a) - if not dtype.is_int(): - raise ValueError("Argument to integer function is not an int but " + str(dtype)) - except NotImplementedError: - raise ValueError("Integer functions can only be constructed with typed expressions") + args = [arg1, arg2] return super().__new__(cls, *args) def _eval_evalf(self, *pargs, **kwargs): @@ -100,11 +80,12 @@ class modulo_floor(sp.Function): else: return super().__new__(cls, integer, divisor) - def to_c(self, print_func): - dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - assert dtype.is_int() - return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]), - print_func(self.args[1]), dtype=dtype) + # TODO: Implement this in FreezeExpressions + # def to_c(self, print_func): + # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) + # assert dtype.is_int() + # return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]), + # print_func(self.args[1]), dtype=dtype) # noinspection PyPep8Naming @@ -132,11 +113,12 @@ class modulo_ceil(sp.Function): else: return super().__new__(cls, integer, divisor) - def to_c(self, print_func): - dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - assert dtype.is_int() - code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))" - return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + # TODO: Implement this in FreezeExpressions + # def to_c(self, print_func): + # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) + # assert dtype.is_int() + # code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))" + # return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) # noinspection PyPep8Naming @@ -162,11 +144,12 @@ class div_ceil(sp.Function): else: return super().__new__(cls, integer, divisor) - def to_c(self, print_func): - dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - assert dtype.is_int() - code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )" - return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + # TODO: Implement this in FreezeExpressions + # def to_c(self, print_func): + # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) + # assert dtype.is_int() + # code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )" + # return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) # noinspection PyPep8Naming @@ -192,8 +175,9 @@ class div_floor(sp.Function): else: return super().__new__(cls, integer, divisor) - def to_c(self, print_func): - dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - assert dtype.is_int() - code = "(({dtype})({0}) / ({dtype})({1}))" - return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + # TODO: Implement this in FreezeExpressions + # def to_c(self, print_func): + # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) + # assert dtype.is_int() + # code = "(({dtype})({0}) / ({dtype})({1}))" + # return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py index 508132ce5..21a98ad78 100644 --- a/src/pystencils/sympyextensions/math.py +++ b/src/pystencils/sympyextensions/math.py @@ -549,7 +549,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], Returns: dict with 'adds', 'muls' and 'divs' keys """ - from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division + from pystencils.sympyextensions.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0, 'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0} diff --git a/src/pystencils/timeloop.py b/src/pystencils/timeloop.py index 55129afb6..5c3438680 100644 --- a/src/pystencils/timeloop.py +++ b/src/pystencils/timeloop.py @@ -1,6 +1,6 @@ import time -from pystencils.integer_functions import modulo_ceil +from pystencils.sympyextensions.integer_functions import modulo_ceil class TimeLoop: diff --git a/tests/test_fast_approximation.py b/tests/test_fast_approximation.py index e211d6897..375bcabac 100644 --- a/tests/test_fast_approximation.py +++ b/tests/test_fast_approximation.py @@ -2,7 +2,7 @@ import pytest import sympy as sp import pystencils as ps -from pystencils.fast_approximation import ( +from pystencils.sympyextensions.fast_approximation import ( fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts) diff --git a/tests/test_math_functions.py b/tests/test_math_functions.py index 1fd393788..6bd80644a 100644 --- a/tests/test_math_functions.py +++ b/tests/test_math_functions.py @@ -2,7 +2,7 @@ import pytest import sympy as sp import numpy as np import pystencils as ps -from pystencils.fast_approximation import fast_division +from pystencils.sympyextensions.fast_approximation import fast_division @pytest.mark.parametrize('dtype', ["float64", "float32"]) diff --git a/tests/test_sympyextensions.py b/tests/test_sympyextensions.py index 7afa99810..05c119968 100644 --- a/tests/test_sympyextensions.py +++ b/tests/test_sympyextensions.py @@ -15,7 +15,7 @@ from pystencils.sympyextensions import scalar_product from pystencils.sympyextensions import kronecker_delta from pystencils import Assignment -from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, +from pystencils.sympyextensions.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts) -- GitLab