diff --git a/src/pystencils/backend/ast/astnode.py b/src/pystencils/backend/ast/astnode.py index 6680c58e5971409a585e0d9787691f646cabe701..3487d4200700ebd1de03e41e374e0a550553a57f 100644 --- a/src/pystencils/backend/ast/astnode.py +++ b/src/pystencils/backend/ast/astnode.py @@ -29,6 +29,10 @@ class PsAstNode(ABC): def set_child(self, idx: int, c: PsAstNode): pass + @abstractmethod + def clone(self) -> PsAstNode: + pass + def structurally_equal(self, other: PsAstNode) -> bool: """Check two ASTs for structural equality.""" return ( diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 4d4cfb457b82b4dffca341e1f961f1d556d99df1..eab309a1d380eba184a4a0393db7d34b2af53468 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -1,5 +1,5 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from typing import Sequence, overload from ..symbols import PsSymbol @@ -54,10 +54,18 @@ class PsExpression(PsAstNode, ABC): else: raise ValueError(f"Cannot make expression out of {obj}") + @abstractmethod + def clone(self) -> PsExpression: + pass + class PsLvalueExpr(PsExpression, ABC): """Base class for all expressions that may occur as an lvalue""" + @abstractmethod + def clone(self) -> PsLvalueExpr: + pass + class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr): """A single symbol as an expression.""" @@ -75,6 +83,9 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr): def symbol(self, symbol: PsSymbol): self._symbol = symbol + def clone(self) -> PsSymbolExpr: + return PsSymbolExpr(self._symbol) + def structurally_equal(self, other: PsAstNode) -> bool: if not isinstance(other, PsSymbolExpr): return False @@ -99,6 +110,9 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def constant(self, c: PsConstant): self._constant = c + def clone(self) -> PsConstantExpr: + return PsConstantExpr(self._constant) + def structurally_equal(self, other: PsAstNode) -> bool: if not isinstance(other, PsConstantExpr): return False @@ -132,6 +146,9 @@ class PsSubscript(PsLvalueExpr): def index(self, expr: PsExpression): self._index = expr + def clone(self) -> PsSubscript: + return PsSubscript(self._base.clone(), self._index.clone()) + def get_children(self) -> tuple[PsAstNode, ...]: return (self._base, self._index) @@ -180,6 +197,9 @@ class PsArrayAccess(PsSubscript): """Data type of this expression, i.e. the element type of the underlying array""" return self._base_ptr.array.element_type + def clone(self) -> PsArrayAccess: + return PsArrayAccess(self._base_ptr, self._index.clone()) + def __repr__(self) -> str: return f"ArrayAccess({repr(self._base_ptr)}, {repr(self._index)})" @@ -226,6 +246,15 @@ class PsVectorArrayAccess(PsArrayAccess): def alignment(self) -> int: return self._alignment + def clone(self) -> PsVectorArrayAccess: + return PsVectorArrayAccess( + self._base_ptr, + self._index.clone(), + self.vector_entries, + self._stride, + self._alignment, + ) + def structurally_equal(self, other: PsAstNode) -> bool: if not isinstance(other, PsVectorArrayAccess): return False @@ -243,7 +272,7 @@ class PsLookup(PsExpression): def __init__(self, aggregate: PsExpression, member_name: str) -> None: self._aggregate = aggregate - self._member = member_name + self._member_name = member_name @property def aggregate(self) -> PsExpression: @@ -255,12 +284,15 @@ class PsLookup(PsExpression): @property def member_name(self) -> str: - return self._member + return self._member_name @member_name.setter def member_name(self, name: str): self._name = name + def clone(self) -> PsLookup: + return PsLookup(self._aggregate.clone(), self._member_name) + def get_children(self) -> tuple[PsAstNode, ...]: return (self._aggregate,) @@ -298,6 +330,9 @@ class PsCall(PsExpression): self._args = list(exprs) + def clone(self) -> PsCall: + return PsCall(self._function, [arg.clone() for arg in self._args]) + def get_children(self) -> tuple[PsAstNode, ...]: return self.args @@ -324,6 +359,9 @@ class PsUnOp(PsExpression): def operand(self, expr: PsExpression): self._operand = expr + def clone(self) -> PsUnOp: + return type(self)(self._operand.clone()) + def get_children(self) -> tuple[PsAstNode, ...]: return (self._operand,) @@ -359,6 +397,9 @@ class PsCast(PsUnOp): def target_type(self, dtype: PsType): self._target_type = dtype + def clone(self) -> PsUnOp: + return PsCast(self._target_type, self._operand.clone()) + def structurally_equal(self, other: PsAstNode) -> bool: if not isinstance(other, PsCast): return False @@ -391,6 +432,9 @@ class PsBinOp(PsExpression): def operand2(self, expr: PsExpression): self._op2 = expr + def clone(self) -> PsBinOp: + return type(self)(self._op1.clone(), self._op2.clone()) + def get_children(self) -> tuple[PsAstNode, ...]: return (self._op1, self._op2) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index c9da546a829d85436f1efd84a223a5e35b704f0a..2338ee8c4064042827b5017a496755ce985913a8 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -20,6 +20,9 @@ class PsBlock(PsAstNode): def set_child(self, idx: int, c: PsAstNode): self._statements[idx] = c + def clone(self) -> PsBlock: + return PsBlock([stmt.clone() for stmt in self._statements]) + @property def statements(self) -> list[PsAstNode]: return self._statements @@ -47,6 +50,9 @@ class PsStatement(PsAstNode): def expression(self, expr: PsExpression): self._expression = expr + def clone(self) -> PsStatement: + return PsStatement(self._expression.clone()) + def get_children(self) -> tuple[PsAstNode, ...]: return (self._expression,) @@ -82,6 +88,9 @@ class PsAssignment(PsAstNode): def rhs(self, expr: PsExpression): self._rhs = expr + def clone(self) -> PsAssignment: + return PsAssignment(self._lhs.clone(), self._rhs.clone()) + def get_children(self) -> tuple[PsAstNode, ...]: return (self._lhs, self._rhs) @@ -123,6 +132,9 @@ class PsDeclaration(PsAssignment): def declared_variable(self, lvalue: PsSymbolExpr): self._lhs = lvalue + def clone(self) -> PsDeclaration: + return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone()) + def set_child(self, idx: int, c: PsAstNode): idx = [0, 1][idx] # trick to normalize index if idx == 0: @@ -193,6 +205,15 @@ class PsLoop(PsAstNode): def body(self, block: PsBlock): self._body = block + def clone(self) -> PsLoop: + return PsLoop( + self._ctr.clone(), + self._start.clone(), + self._stop.clone(), + self._step.clone(), + self._body.clone(), + ) + def get_children(self) -> tuple[PsAstNode, ...]: return (self._ctr, self._start, self._stop, self._step, self._body) @@ -252,6 +273,13 @@ class PsConditional(PsAstNode): def branch_false(self, block: PsBlock | None): self._branch_false = block + def clone(self) -> PsConditional: + return PsConditional( + self._condition.clone(), + self._branch_true.clone(), + self._branch_false.clone() if self._branch_false is not None else None, + ) + def get_children(self) -> tuple[PsAstNode, ...]: return (self._condition, self._branch_true) + ( (self._branch_false,) if self._branch_false is not None else () @@ -285,6 +313,9 @@ class PsComment(PsLeafMixIn, PsAstNode): def lines(self) -> tuple[str, ...]: return self._lines + def clone(self) -> PsComment: + return PsComment(self._text) + def structurally_equal(self, other: PsAstNode) -> bool: if not isinstance(other, PsComment): return False diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 8b54da651dd81a325612be6abda57e2d56fa3ed5..63dc9170a0efc1e5bb0b626353d290774c584ac8 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -117,7 +117,9 @@ class FreezeExpressions: for summand in expr.args: if summand.is_negative: signs.append(-1) - elif isinstance(summand, sp.Mul) and any(factor.is_negative for factor in summand.args): + elif isinstance(summand, sp.Mul) and any( + factor.is_negative for factor in summand.args + ): signs.append(-1) else: signs.append(1) @@ -126,18 +128,18 @@ class FreezeExpressions: for sign, arg in zip(signs[1:], expr.args[1:]): if sign == -1: - arg = - arg + arg = -arg op = sub else: op = add frozen_expr = op(frozen_expr, self.visit_expr(arg)) - + return frozen_expr def map_Mul(self, expr: sp.Mul) -> PsExpression: return reduce(mul, (self.visit_expr(arg) for arg in expr.args)) - + def map_Pow(self, expr: sp.Pow) -> PsExpression: base = expr.args[0] exponent = expr.args[1] @@ -147,18 +149,29 @@ class FreezeExpressions: expand_product = False if exponent.is_Integer: + if exponent == 0: + return PsExpression.make(PsConstant(1)) + if exponent.is_negative: reciprocal = True - exponent = - exponent + exponent = -exponent - if exponent <= sp.Integer(5): + if exponent <= sp.Integer( + 5 + ): # TODO: is this a sensible limit? maybe make this configurable. expand_product = True if expand_product: - frozen_expr = reduce(mul, [base_frozen] * int(exponent)) + frozen_expr = reduce( + mul, + [base_frozen] + + [base_frozen.clone() for _ in range(0, int(exponent) - 1)], + ) else: exponent_frozen = self.visit_expr(exponent) - frozen_expr = PsMathFunction(MathFunctions.Pow)(base_frozen, exponent_frozen) + frozen_expr = PsMathFunction(MathFunctions.Pow)( + base_frozen, exponent_frozen + ) if reciprocal: one = PsExpression.make(PsConstant(1)) diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index 20d63e85ef802bbe5d5d5a25441fa326b533db77..c38ac60b44dc76007b183ac4e4ce5cfb61554c8b 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -42,10 +42,10 @@ class KernelParameter: type(self) is type(other) and self._hashable_contents() == other._hashable_contents() ) - + def __str__(self) -> str: return self._name - + def __repr__(self) -> str: return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})" @@ -60,7 +60,7 @@ class FieldParameter(KernelParameter, ABC): @property def field(self): return self._field - + def _hashable_contents(self): return super()._hashable_contents() + (self._field,) @@ -75,7 +75,7 @@ class FieldShapeParam(FieldParameter): @property def coordinate(self): return self._coordinate - + def _hashable_contents(self): return super()._hashable_contents() + (self._coordinate,) @@ -90,7 +90,7 @@ class FieldStrideParam(FieldParameter): @property def coordinate(self): return self._coordinate - + def _hashable_contents(self): return super()._hashable_contents() + (self._coordinate,) diff --git a/src/pystencils/sympyextensions/astnodes.py b/src/pystencils/sympyextensions/astnodes.py index 4fdc0f612019cede0096a1e9c7afca5256c6ca5d..8483977d8a83067757a7677dc88c1fb2db3e003d 100644 --- a/src/pystencils/sympyextensions/astnodes.py +++ b/src/pystencils/sympyextensions/astnodes.py @@ -4,7 +4,7 @@ import uuid from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union import sympy as sp -from sympy.codegen.ast import Assignment, AugmentedAssignment +from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment from sympy.printing.latex import LatexPrinter import numpy as np diff --git a/tests/nbackend/test_ast_nodes.py b/tests/nbackend/test_ast_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2dd0e04081050138be1c8efc0210f81c999707 --- /dev/null +++ b/tests/nbackend/test_ast_nodes.py @@ -0,0 +1,66 @@ +from pystencils.backend.symbols import PsSymbol +from pystencils.backend.constants import PsConstant +from pystencils.backend.ast.expressions import ( + PsExpression, + PsCast, + PsDeref, + PsSubscript, +) +from pystencils.backend.ast.structural import ( + PsStatement, + PsAssignment, + PsBlock, + PsConditional, + PsComment, + PsLoop, +) +from pystencils.types.quick import Fp, Ptr + + +def test_cloning(): + x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"] + c1 = PsExpression.make(PsConstant(3.0)) + c2 = PsExpression.make(PsConstant(-1.0)) + one = PsExpression.make(PsConstant(1)) + + def check(orig, clone): + assert not (orig is clone) + assert type(orig) is type(clone) + assert orig.structurally_equal(clone) + + for c1, c2 in zip(orig.children, clone.children, strict=True): + check(c1, c2) + + for ast in [ + x, + y, + c1, + x + y, + x / y + c1, + c1 + c2, + PsStatement(x * y * z + c1), + PsAssignment(y, x / c1), + PsBlock([PsAssignment(x, c1 * y), PsAssignment(z, c2 + c1 * z)]), + PsConditional( + y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) + ), + PsLoop( + x, + y, + z, + one, + PsBlock( + [ + PsComment("Loop body"), + PsAssignment(x, y), + PsAssignment(x, y), + PsStatement( + PsDeref(PsCast(Ptr(Fp(32)), z)) + + PsSubscript(z, one + one + one) + ), + ] + ), + ), + ]: + ast_clone = ast.clone() + check(ast, ast_clone)