From 6bc218dfbdb3f9d72ddb1c863dd3d109e7d827db Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 12 Mar 2024 13:07:26 +0100 Subject: [PATCH] More cleanup: - remove DivFunc - Fix various imports - Move node_collection, isl, kernel_constraints_check to old --- src/pystencils/functions.py | 26 ------------------- .../{ => old}/integer_set_analysis.py | 0 .../{ => old}/kernel_contrains_check.py | 0 src/pystencils/{ => old}/node_collection.py | 0 src/pystencils/old/transformations.py | 15 ----------- src/pystencils/placeholder_function.py | 10 +++---- src/pystencils/simplificationfactory.py | 11 ++++++-- src/pystencils/slicing.py | 2 +- src/pystencils/sympyextensions/math.py | 5 +--- tests/test_sympyextensions.py | 9 +++---- 10 files changed, 19 insertions(+), 59 deletions(-) rename src/pystencils/{ => old}/integer_set_analysis.py (100%) rename src/pystencils/{ => old}/kernel_contrains_check.py (100%) rename src/pystencils/{ => old}/node_collection.py (100%) diff --git a/src/pystencils/functions.py b/src/pystencils/functions.py index df8d0ef6f..76fbb7074 100644 --- a/src/pystencils/functions.py +++ b/src/pystencils/functions.py @@ -2,32 +2,6 @@ import sympy as sp from .types import PsPointerType -class DivFunc(sp.Function): - """ - DivFunc represents a division operation, since sympy represents divisions with ^-1 - """ - is_Atom = True - is_real = True - - def __new__(cls, *args, **kwargs): - if len(args) != 2: - raise ValueError(f'{cls} takes only 2 arguments, instead {len(args)} received!') - divisor, dividend, *other_args = args - - return sp.Function.__new__(cls, divisor, dividend, *other_args, **kwargs) - - def _eval_evalf(self, *args, **kwargs): - return self.divisor.evalf() / self.dividend.evalf() - - @property - def divisor(self): - return self.args[0] - - @property - def dividend(self): - return self.args[1] - - class AddressOf(sp.Function): """ AddressOf is the '&' operation in C. It gets the address of a lvalue. diff --git a/src/pystencils/integer_set_analysis.py b/src/pystencils/old/integer_set_analysis.py similarity index 100% rename from src/pystencils/integer_set_analysis.py rename to src/pystencils/old/integer_set_analysis.py diff --git a/src/pystencils/kernel_contrains_check.py b/src/pystencils/old/kernel_contrains_check.py similarity index 100% rename from src/pystencils/kernel_contrains_check.py rename to src/pystencils/old/kernel_contrains_check.py diff --git a/src/pystencils/node_collection.py b/src/pystencils/old/node_collection.py similarity index 100% rename from src/pystencils/node_collection.py rename to src/pystencils/old/node_collection.py diff --git a/src/pystencils/old/transformations.py b/src/pystencils/old/transformations.py index 76d356f9f..4399de6aa 100644 --- a/src/pystencils/old/transformations.py +++ b/src/pystencils/old/transformations.py @@ -34,21 +34,6 @@ def filtered_tree_iteration(node, node_type, stop_type=None): yield from filtered_tree_iteration(arg, node_type) -def generic_visit(term, visitor): - if isinstance(term, AssignmentCollection): - new_main_assignments = generic_visit(term.main_assignments, visitor) - new_subexpressions = generic_visit(term.subexpressions, visitor) - return term.copy(new_main_assignments, new_subexpressions) - elif isinstance(term, list): - return [generic_visit(e, visitor) for e in term] - elif isinstance(term, Assignment): - return Assignment(term.lhs, generic_visit(term.rhs, visitor)) - elif isinstance(term, sp.Matrix): - return term.applyfunc(lambda e: generic_visit(e, visitor)) - else: - return visitor(term) - - def iterate_loops_by_depth(node, nesting_depth): """Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth. diff --git a/src/pystencils/placeholder_function.py b/src/pystencils/placeholder_function.py index 8d675f033..00acb17bd 100644 --- a/src/pystencils/placeholder_function.py +++ b/src/pystencils/placeholder_function.py @@ -2,10 +2,9 @@ from typing import List import sympy as sp -from pystencils.sympyextensions.assignmentcollection.assignment import Assignment -from pystencils.sympyextensions.astnodes import Node +from pystencils.sympyextensions import Assignment from pystencils.sympyextensions import is_constant -from pystencils.transformations import generic_visit +from pystencils.sympyextensions.astnodes import generic_visit class PlaceholderFunction: @@ -56,11 +55,10 @@ def to_placeholder_function(expr, name): def remove_placeholder_functions(expr): subexpressions = [] + # TODO: Seems broken to me. Is this ever used? def visit(e): - if isinstance(e, Node): - return e - elif isinstance(e, PlaceholderFunction): + if isinstance(e, PlaceholderFunction): for se in e.subexpressions: if se.lhs not in {a.lhs for a in subexpressions}: subexpressions.append(se) diff --git a/src/pystencils/simplificationfactory.py b/src/pystencils/simplificationfactory.py index 50ee2d7f8..869454ecf 100644 --- a/src/pystencils/simplificationfactory.py +++ b/src/pystencils/simplificationfactory.py @@ -1,5 +1,12 @@ -from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one, - insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros) +from pystencils.sympyextensions import ( + SimplificationStrategy, + insert_constants, + insert_symbol_times_minus_one, + insert_constant_multiples, + insert_constant_additions, + insert_squares, + insert_zeros, +) def create_simplification_strategy(): diff --git a/src/pystencils/slicing.py b/src/pystencils/slicing.py index 64b9d308f..2aed3a899 100644 --- a/src/pystencils/slicing.py +++ b/src/pystencils/slicing.py @@ -26,7 +26,7 @@ class SlicedGetterDataHandling: def __getitem__(self, slice_obj): if slice_obj is None: - slice_obj = make_slice[:, :] if self.data_handling.dim == 2 else make_slice[:, :, 0.5] + slice_obj = make_slice[:, :] if self.dh.dim == 2 else make_slice[:, :, 0.5] return self.dh.gather_array(self.name, slice_obj).squeeze() diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py index 3e7eeaabc..508132ce5 100644 --- a/src/pystencils/sympyextensions/math.py +++ b/src/pystencils/sympyextensions/math.py @@ -11,7 +11,6 @@ from sympy.functions import Abs from sympy.core.numbers import Zero from .astnodes import Assignment -from pystencils.functions import DivFunc from .typed_sympy import CastFunc, FieldPointerSymbol from ..types import PsPointerType, PsVectorType @@ -169,7 +168,7 @@ def fast_subs(expression: T, substitutions: Dict, return substitutions[expr] elif not hasattr(expr, 'args'): return expr - elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)): + elif isinstance(expr, sp.UnevaluatedExpr): args = [visit(a, False) for a in expr.args] return expr.func(*args) else: @@ -641,8 +640,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], visit_children = False elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)): pass - elif isinstance(t, DivFunc): - result["divs"] += 1 else: warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate") diff --git a/tests/test_sympyextensions.py b/tests/test_sympyextensions.py index 1929cc066..7afa99810 100644 --- a/tests/test_sympyextensions.py +++ b/tests/test_sympyextensions.py @@ -15,7 +15,6 @@ from pystencils.sympyextensions import scalar_product from pystencils.sympyextensions import kronecker_delta from pystencils import Assignment -from pystencils.functions import DivFunc from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts) @@ -164,11 +163,11 @@ def test_count_operations(): assert ops['divs'] == 1 assert ops['sqrts'] == 1 - expr = DivFunc(x, y) + expr = x / y ops = count_operations(expr, only_type=None) assert ops['divs'] == 1 - expr = DivFunc(x + z, y + z) + expr = x + z / y + z ops = count_operations(expr, only_type=None) assert ops['adds'] == 2 assert ops['divs'] == 1 @@ -177,12 +176,12 @@ def test_count_operations(): ops = count_operations(expr, only_type=None) assert ops['muls'] == 99 - expr = DivFunc(1, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))) + expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) ops = count_operations(expr, only_type=None) assert ops['divs'] == 1 assert ops['muls'] == 99 - expr = DivFunc(y + z, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))) + expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) ops = count_operations(expr, only_type=None) assert ops['adds'] == 1 assert ops['divs'] == 1 -- GitLab