Skip to content
Snippets Groups Projects
Commit 20d37fba authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code cleanup

parent 679bf618
No related merge requests found
Pipeline #60171 failed with stages
in 3 minutes and 53 seconds
from functools import reduce
from typing import Any, cast
from typing import cast
from pymbolic.primitives import Variable
from pymbolic.mapper.dependency import DependencyMapper
......@@ -47,7 +46,7 @@ class UndefinedVariablesCollector:
return self.collect(lhs) | self.collect(rhs)
case PsBlock(statements):
undefined_vars = set()
undefined_vars: set[PsTypedVariable] = set()
for stmt in statements[::-1]:
undefined_vars -= self.declared_variables(stmt)
undefined_vars |= self.collect(stmt)
......
......@@ -5,6 +5,7 @@ from .nodes import PsAstNode, PsBlock, failing_cast
from ..typed_expressions import PsTypedVariable
from ...enums import Target
class PsKernelFunction(PsAstNode):
"""A complete pystencils kernel function."""
......@@ -19,11 +20,11 @@ class PsKernelFunction(PsAstNode):
def target(self) -> Target:
"""See pystencils.Target"""
return self._target
@property
def body(self) -> PsBlock:
return self._body
@body.setter
def body(self, body: PsBlock):
self._body = body
......@@ -31,16 +32,16 @@ class PsKernelFunction(PsAstNode):
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, value: str):
self._name = value
def num_children(self) -> int:
return 1
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._body, )
yield from (self._body,)
def get_child(self, idx: int):
if idx not in (0, -1):
......@@ -51,14 +52,14 @@ class PsKernelFunction(PsAstNode):
if idx not in (0, -1):
raise IndexError(f"Child index out of bounds: {idx}")
self._body = failing_cast(PsBlock, c)
def get_parameters(self) -> Sequence[PsTypedVariable]:
"""Collect the list of parameters to this function.
This function performs a full traversal of the AST.
To improve performance, make sure to cache the result if necessary.
"""
from .analysis import UndefinedVariablesCollector
params = UndefinedVariablesCollector().collect(self)
return sorted(params, key=lambda p: p.name)
\ No newline at end of file
......@@ -3,8 +3,6 @@ from typing import Sequence, Generator, Iterable, cast
from abc import ABC, abstractmethod
import pymbolic.primitives as pb
from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant
from .util import failing_cast
......@@ -41,7 +39,6 @@ class PsAstNode(ABC):
class PsBlock(PsAstNode):
__match_args__ = ("statements",)
def __init__(self, cs: Sequence[PsAstNode]):
......@@ -62,9 +59,9 @@ class PsBlock(PsAstNode):
@property
def statements(self) -> list[PsAstNode]:
return self._statements
@statements.setter
def statemetns(self, stm: Sequence[PsAstNode]):
def statements(self, stm: Sequence[PsAstNode]):
self._statements = list(stm)
......@@ -127,8 +124,10 @@ class PsSymbolExpr(PsLvalueExpr):
class PsAssignment(PsAstNode):
__match_args__ = ("lhs", "rhs",)
__match_args__ = (
"lhs",
"rhs",
)
def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
self._lhs = lhs
......@@ -170,8 +169,10 @@ class PsAssignment(PsAstNode):
class PsDeclaration(PsAssignment):
__match_args__ = ("declared_variable", "rhs",)
__match_args__ = (
"declared_variable",
"rhs",
)
def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
super().__init__(lhs, rhs)
......@@ -203,7 +204,6 @@ class PsDeclaration(PsAssignment):
class PsLoop(PsAstNode):
__match_args__ = ("counter", "start", "stop", "step", "body")
def __init__(
......
class PsInternalCompilerError(Exception):
pass
class PsMalformedAstException(Exception):
pass
......@@ -226,13 +226,13 @@ class PsTypedConstant:
def __rsub__(self, other: Any):
return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype)
@staticmethod
def _divrem(dividend, divisor):
quotient = abs(dividend) // abs(divisor)
quotient = quotient if (dividend * divisor > 0) else (- quotient)
quotient = abs(dividend) // abs(divisor)
quotient = quotient if (dividend * divisor > 0) else (-quotient)
rem = abs(dividend) % abs(divisor)
rem = rem if dividend >= 0 else (- rem)
rem = rem if dividend >= 0 else (-rem)
return quotient, rem
def __truediv__(self, other: Any):
......@@ -274,7 +274,7 @@ class PsTypedConstant:
def __neg__(self):
minus_one = PsTypedConstant(-1, self._dtype)
return pb.Product((minus_one, self))
def __bool__(self):
return bool(self._value)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment