diff --git a/src/pystencils/functions.py b/src/pystencils/functions.py index df8d0ef6f611285d80a94ed08de8b72d41f2322f..76fbb70742352cf9d5627fe012652bb064acdfed 100644 --- a/src/pystencils/functions.py +++ b/src/pystencils/functions.py @@ -2,32 +2,6 @@ import sympy as sp from .types import PsPointerType -class DivFunc(sp.Function): - """ - DivFunc represents a division operation, since sympy represents divisions with ^-1 - """ - is_Atom = True - is_real = True - - def __new__(cls, *args, **kwargs): - if len(args) != 2: - raise ValueError(f'{cls} takes only 2 arguments, instead {len(args)} received!') - divisor, dividend, *other_args = args - - return sp.Function.__new__(cls, divisor, dividend, *other_args, **kwargs) - - def _eval_evalf(self, *args, **kwargs): - return self.divisor.evalf() / self.dividend.evalf() - - @property - def divisor(self): - return self.args[0] - - @property - def dividend(self): - return self.args[1] - - class AddressOf(sp.Function): """ AddressOf is the '&' operation in C. It gets the address of a lvalue. diff --git a/src/pystencils/integer_set_analysis.py b/src/pystencils/old/integer_set_analysis.py similarity index 100% rename from src/pystencils/integer_set_analysis.py rename to src/pystencils/old/integer_set_analysis.py diff --git a/src/pystencils/kernel_contrains_check.py b/src/pystencils/old/kernel_contrains_check.py similarity index 100% rename from src/pystencils/kernel_contrains_check.py rename to src/pystencils/old/kernel_contrains_check.py diff --git a/src/pystencils/node_collection.py b/src/pystencils/old/node_collection.py similarity index 100% rename from src/pystencils/node_collection.py rename to src/pystencils/old/node_collection.py diff --git a/src/pystencils/old/transformations.py b/src/pystencils/old/transformations.py index 76d356f9fb85caed903c0247de337670bf106430..4399de6aa3a14f51cb0c48247181f4eea7e5bdee 100644 --- a/src/pystencils/old/transformations.py +++ b/src/pystencils/old/transformations.py @@ -34,21 +34,6 @@ def filtered_tree_iteration(node, node_type, stop_type=None): yield from filtered_tree_iteration(arg, node_type) -def generic_visit(term, visitor): - if isinstance(term, AssignmentCollection): - new_main_assignments = generic_visit(term.main_assignments, visitor) - new_subexpressions = generic_visit(term.subexpressions, visitor) - return term.copy(new_main_assignments, new_subexpressions) - elif isinstance(term, list): - return [generic_visit(e, visitor) for e in term] - elif isinstance(term, Assignment): - return Assignment(term.lhs, generic_visit(term.rhs, visitor)) - elif isinstance(term, sp.Matrix): - return term.applyfunc(lambda e: generic_visit(e, visitor)) - else: - return visitor(term) - - def iterate_loops_by_depth(node, nesting_depth): """Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth. diff --git a/src/pystencils/placeholder_function.py b/src/pystencils/placeholder_function.py index 8d675f0338de2607438069495dffb9bd5a2a725a..00acb17bd71cdd7cfb628d89e5e1c85034c449ce 100644 --- a/src/pystencils/placeholder_function.py +++ b/src/pystencils/placeholder_function.py @@ -2,10 +2,9 @@ from typing import List import sympy as sp -from pystencils.sympyextensions.assignmentcollection.assignment import Assignment -from pystencils.sympyextensions.astnodes import Node +from pystencils.sympyextensions import Assignment from pystencils.sympyextensions import is_constant -from pystencils.transformations import generic_visit +from pystencils.sympyextensions.astnodes import generic_visit class PlaceholderFunction: @@ -56,11 +55,10 @@ def to_placeholder_function(expr, name): def remove_placeholder_functions(expr): subexpressions = [] + # TODO: Seems broken to me. Is this ever used? def visit(e): - if isinstance(e, Node): - return e - elif isinstance(e, PlaceholderFunction): + if isinstance(e, PlaceholderFunction): for se in e.subexpressions: if se.lhs not in {a.lhs for a in subexpressions}: subexpressions.append(se) diff --git a/src/pystencils/simplificationfactory.py b/src/pystencils/simplificationfactory.py index 50ee2d7f8175ff7445b9ec0f13071164c50099b1..869454ecf0f4e1ab05f79b7370a2d22f30c1dcdb 100644 --- a/src/pystencils/simplificationfactory.py +++ b/src/pystencils/simplificationfactory.py @@ -1,5 +1,12 @@ -from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one, - insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros) +from pystencils.sympyextensions import ( + SimplificationStrategy, + insert_constants, + insert_symbol_times_minus_one, + insert_constant_multiples, + insert_constant_additions, + insert_squares, + insert_zeros, +) def create_simplification_strategy(): diff --git a/src/pystencils/slicing.py b/src/pystencils/slicing.py index 64b9d308f21dfbea63d4096323fa014b725dc572..2aed3a899c76003d67bab3ed9d80a249f2908f99 100644 --- a/src/pystencils/slicing.py +++ b/src/pystencils/slicing.py @@ -26,7 +26,7 @@ class SlicedGetterDataHandling: def __getitem__(self, slice_obj): if slice_obj is None: - slice_obj = make_slice[:, :] if self.data_handling.dim == 2 else make_slice[:, :, 0.5] + slice_obj = make_slice[:, :] if self.dh.dim == 2 else make_slice[:, :, 0.5] return self.dh.gather_array(self.name, slice_obj).squeeze() diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py index 3e7eeaabcc070eb28bce99a998051ad7963dad9d..508132ce56412b35b27d0e2e5e8b57e294ca51f5 100644 --- a/src/pystencils/sympyextensions/math.py +++ b/src/pystencils/sympyextensions/math.py @@ -11,7 +11,6 @@ from sympy.functions import Abs from sympy.core.numbers import Zero from .astnodes import Assignment -from pystencils.functions import DivFunc from .typed_sympy import CastFunc, FieldPointerSymbol from ..types import PsPointerType, PsVectorType @@ -169,7 +168,7 @@ def fast_subs(expression: T, substitutions: Dict, return substitutions[expr] elif not hasattr(expr, 'args'): return expr - elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)): + elif isinstance(expr, sp.UnevaluatedExpr): args = [visit(a, False) for a in expr.args] return expr.func(*args) else: @@ -641,8 +640,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], visit_children = False elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)): pass - elif isinstance(t, DivFunc): - result["divs"] += 1 else: warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate") diff --git a/tests/test_sympyextensions.py b/tests/test_sympyextensions.py index 1929cc0666a302de3e10423c40794f219bfa38ae..7afa998108b4ef64937767e3a706d4dcbf67eaf5 100644 --- a/tests/test_sympyextensions.py +++ b/tests/test_sympyextensions.py @@ -15,7 +15,6 @@ from pystencils.sympyextensions import scalar_product from pystencils.sympyextensions import kronecker_delta from pystencils import Assignment -from pystencils.functions import DivFunc from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts) @@ -164,11 +163,11 @@ def test_count_operations(): assert ops['divs'] == 1 assert ops['sqrts'] == 1 - expr = DivFunc(x, y) + expr = x / y ops = count_operations(expr, only_type=None) assert ops['divs'] == 1 - expr = DivFunc(x + z, y + z) + expr = x + z / y + z ops = count_operations(expr, only_type=None) assert ops['adds'] == 2 assert ops['divs'] == 1 @@ -177,12 +176,12 @@ def test_count_operations(): ops = count_operations(expr, only_type=None) assert ops['muls'] == 99 - expr = DivFunc(1, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))) + expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) ops = count_operations(expr, only_type=None) assert ops['divs'] == 1 assert ops['muls'] == 99 - expr = DivFunc(y + z, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))) + expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) ops = count_operations(expr, only_type=None) assert ops['adds'] == 1 assert ops['divs'] == 1