diff --git a/pystencils/nbackend/ast/analysis.py b/pystencils/nbackend/ast/analysis.py index d52c65cf3f279206e21fe6a6fb66901c572860b2..6a3162c1be47d8fd2d8e58f239ed579c719809fe 100644 --- a/pystencils/nbackend/ast/analysis.py +++ b/pystencils/nbackend/ast/analysis.py @@ -1,5 +1,4 @@ -from functools import reduce -from typing import Any, cast +from typing import cast from pymbolic.primitives import Variable from pymbolic.mapper.dependency import DependencyMapper @@ -47,7 +46,7 @@ class UndefinedVariablesCollector: return self.collect(lhs) | self.collect(rhs) case PsBlock(statements): - undefined_vars = set() + undefined_vars: set[PsTypedVariable] = set() for stmt in statements[::-1]: undefined_vars -= self.declared_variables(stmt) undefined_vars |= self.collect(stmt) diff --git a/pystencils/nbackend/ast/kernelfunction.py b/pystencils/nbackend/ast/kernelfunction.py index a12abb45a700f596e77d626f4378091b12fd19e3..aaf1ac5e51c30277d1b0f64ca7ccfcde28145339 100644 --- a/pystencils/nbackend/ast/kernelfunction.py +++ b/pystencils/nbackend/ast/kernelfunction.py @@ -5,6 +5,7 @@ from .nodes import PsAstNode, PsBlock, failing_cast from ..typed_expressions import PsTypedVariable from ...enums import Target + class PsKernelFunction(PsAstNode): """A complete pystencils kernel function.""" @@ -19,11 +20,11 @@ class PsKernelFunction(PsAstNode): def target(self) -> Target: """See pystencils.Target""" return self._target - + @property def body(self) -> PsBlock: return self._body - + @body.setter def body(self, body: PsBlock): self._body = body @@ -31,16 +32,16 @@ class PsKernelFunction(PsAstNode): @property def name(self) -> str: return self._name - + @name.setter def name(self, value: str): self._name = value def num_children(self) -> int: return 1 - + def children(self) -> Generator[PsAstNode, None, None]: - yield from (self._body, ) + yield from (self._body,) def get_child(self, idx: int): if idx not in (0, -1): @@ -51,14 +52,14 @@ class PsKernelFunction(PsAstNode): if idx not in (0, -1): raise IndexError(f"Child index out of bounds: {idx}") self._body = failing_cast(PsBlock, c) - + def get_parameters(self) -> Sequence[PsTypedVariable]: """Collect the list of parameters to this function. - + This function performs a full traversal of the AST. To improve performance, make sure to cache the result if necessary. """ from .analysis import UndefinedVariablesCollector + params = UndefinedVariablesCollector().collect(self) return sorted(params, key=lambda p: p.name) - \ No newline at end of file diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py index 5418a1a37dc6f43f08084d0e10e74cde5d830ae8..da3ead273ce14a80c0874630aefebf5cf2835d80 100644 --- a/pystencils/nbackend/ast/nodes.py +++ b/pystencils/nbackend/ast/nodes.py @@ -3,8 +3,6 @@ from typing import Sequence, Generator, Iterable, cast from abc import ABC, abstractmethod -import pymbolic.primitives as pb - from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant from .util import failing_cast @@ -41,7 +39,6 @@ class PsAstNode(ABC): class PsBlock(PsAstNode): - __match_args__ = ("statements",) def __init__(self, cs: Sequence[PsAstNode]): @@ -62,9 +59,9 @@ class PsBlock(PsAstNode): @property def statements(self) -> list[PsAstNode]: return self._statements - + @statements.setter - def statemetns(self, stm: Sequence[PsAstNode]): + def statements(self, stm: Sequence[PsAstNode]): self._statements = list(stm) @@ -127,8 +124,10 @@ class PsSymbolExpr(PsLvalueExpr): class PsAssignment(PsAstNode): - - __match_args__ = ("lhs", "rhs",) + __match_args__ = ( + "lhs", + "rhs", + ) def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): self._lhs = lhs @@ -170,8 +169,10 @@ class PsAssignment(PsAstNode): class PsDeclaration(PsAssignment): - - __match_args__ = ("declared_variable", "rhs",) + __match_args__ = ( + "declared_variable", + "rhs", + ) def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression): super().__init__(lhs, rhs) @@ -203,7 +204,6 @@ class PsDeclaration(PsAssignment): class PsLoop(PsAstNode): - __match_args__ = ("counter", "start", "stop", "step", "body") def __init__( diff --git a/pystencils/nbackend/exceptions.py b/pystencils/nbackend/exceptions.py index a388017a56de7f9bcea784f0ea0a4cb1de3e875f..8e48653047d70ccac3fea1d490871ff861510e9b 100644 --- a/pystencils/nbackend/exceptions.py +++ b/pystencils/nbackend/exceptions.py @@ -1,8 +1,6 @@ - class PsInternalCompilerError(Exception): pass class PsMalformedAstException(Exception): pass - diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index 62c8b76951510a64be1d36b1753d0c69bdc3c20b..ee3536a52e3b2ffc3bf2908c3f3254348020e94d 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -226,13 +226,13 @@ class PsTypedConstant: def __rsub__(self, other: Any): return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype) - + @staticmethod def _divrem(dividend, divisor): - quotient = abs(dividend) // abs(divisor) - quotient = quotient if (dividend * divisor > 0) else (- quotient) + quotient = abs(dividend) // abs(divisor) + quotient = quotient if (dividend * divisor > 0) else (-quotient) rem = abs(dividend) % abs(divisor) - rem = rem if dividend >= 0 else (- rem) + rem = rem if dividend >= 0 else (-rem) return quotient, rem def __truediv__(self, other: Any): @@ -274,7 +274,7 @@ class PsTypedConstant: def __neg__(self): minus_one = PsTypedConstant(-1, self._dtype) return pb.Product((minus_one, self)) - + def __bool__(self): return bool(self._value)