diff --git a/pystencils/nbackend/ast/__init__.py b/pystencils/nbackend/ast/__init__.py index 567840b2d61724b48dab7ca0158f0355cfd797cf..95cb7831b7d6d66dba3c65c4ed766f97b45bd121 100644 --- a/pystencils/nbackend/ast/__init__.py +++ b/pystencils/nbackend/ast/__init__.py @@ -22,5 +22,5 @@ __all__ = [ "PsAssignment", "PsDeclaration", "PsLoop", - "ast_subs", + "ast_subs" ] diff --git a/pystencils/nbackend/ast/analysis.py b/pystencils/nbackend/ast/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..d52c65cf3f279206e21fe6a6fb66901c572860b2 --- /dev/null +++ b/pystencils/nbackend/ast/analysis.py @@ -0,0 +1,85 @@ +from functools import reduce +from typing import Any, cast + +from pymbolic.primitives import Variable +from pymbolic.mapper.dependency import DependencyMapper + +from .kernelfunction import PsKernelFunction +from .nodes import PsAstNode, PsExpression, PsAssignment, PsDeclaration, PsLoop, PsBlock +from ..typed_expressions import PsTypedVariable +from ..exceptions import PsMalformedAstException, PsInternalCompilerError + + +class UndefinedVariablesCollector: + """Collector for undefined variables. + + This class implements an AST visitor that collects all `PsTypedVariable`s that have been used + in the AST without being defined prior to their usage. + """ + + def __init__(self) -> None: + self._pb_dep_mapper = DependencyMapper( + include_subscripts=False, + include_lookups=False, + include_calls=False, + include_cses=False, + ) + + def collect(self, node: PsAstNode) -> set[PsTypedVariable]: + """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage.""" + + match node: + case PsKernelFunction(block): + return self.collect(block) + + case PsExpression(expr): + variables: set[Variable] = self._pb_dep_mapper(expr) + + for var in variables: + if not isinstance(var, PsTypedVariable): + raise PsMalformedAstException( + f"Non-typed variable {var} encountered" + ) + + return cast(set[PsTypedVariable], variables) + + case PsAssignment(lhs, rhs): + return self.collect(lhs) | self.collect(rhs) + + case PsBlock(statements): + undefined_vars = set() + for stmt in statements[::-1]: + undefined_vars -= self.declared_variables(stmt) + undefined_vars |= self.collect(stmt) + + return undefined_vars + + case PsLoop(ctr, start, stop, step, body): + undefined_vars = ( + self.collect(start) + | self.collect(stop) + | self.collect(step) + | self.collect(body) + ) + undefined_vars.remove(ctr.symbol) + return undefined_vars + + case unknown: + raise PsInternalCompilerError( + f"Don't know how to collect undefined variables from {unknown}" + ) + + def declared_variables(self, node: PsAstNode) -> set[PsTypedVariable]: + """Returns the set of variables declared by the given node which are visible in the enclosing scope.""" + + match node: + case PsDeclaration(lhs, _): + return {lhs.symbol} + + case PsAssignment() | PsExpression() | PsLoop() | PsBlock(): + return set() + + case unknown: + raise PsInternalCompilerError( + f"Don't know how to collect declared variables from {unknown}" + ) diff --git a/pystencils/nbackend/ast/kernelfunction.py b/pystencils/nbackend/ast/kernelfunction.py new file mode 100644 index 0000000000000000000000000000000000000000..6c9aad854fb86d0824309f70f11135e441a337fe --- /dev/null +++ b/pystencils/nbackend/ast/kernelfunction.py @@ -0,0 +1,64 @@ +from typing import Sequence + +from typing import Generator +from .nodes import PsAstNode, PsBlock, failing_cast +from ..typed_expressions import PsTypedVariable +from ...enums import Target + +class PsKernelFunction(PsAstNode): + """A complete pystencils kernel function.""" + + __match_args__ = ("block",) + + def __init__(self, body: PsBlock, target: Target, name: str = "kernel"): + self._body = body + self._target = target + self._name = name + + @property + def target(self) -> Target: + """See pystencils.Target""" + return self._target + + @property + def body(self) -> PsBlock: + return self._body + + @body.setter + def body(self, body: PsBlock): + self._body = body + + @property + def name(self) -> str: + return self._name + + @name.setter + def name(self, value: str): + self._name = value + + def num_children(self) -> int: + return 1 + + def children(self) -> Generator[PsAstNode, None, None]: + yield from (self._body, ) + + def get_child(self, idx: int): + if idx not in (0, -1): + raise IndexError(f"Child index out of bounds: {idx}") + return self._body + + def set_child(self, idx: int, c: PsAstNode): + if idx not in (0, -1): + raise IndexError(f"Child index out of bounds: {idx}") + self._body = failing_cast(PsBlock, c) + + def get_parameters(self) -> Sequence[PsTypedVariable]: + """Collect the list of parameters to this function. + + This function performs a full traversal of the AST. + To improve performance, make sure to cache the result if necessary. + """ + from .analysis import UndefinedVariablesCollector + params = UndefinedVariablesCollector().collect(self) + return sorted(params, key=lambda p: p.name) + \ No newline at end of file diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py index 40200ff38132f6ce920b79402f646165b4c15d1f..30b1a4dd60450002e88597559a728ec8a29a03df 100644 --- a/pystencils/nbackend/ast/nodes.py +++ b/pystencils/nbackend/ast/nodes.py @@ -1,20 +1,12 @@ from __future__ import annotations -from typing import Sequence, Generator, TypeVar, Iterable, cast +from typing import Sequence, Generator, Iterable, cast from abc import ABC, abstractmethod import pymbolic.primitives as pb from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue - - -T = TypeVar("T") - - -def failing_cast(target: type, obj: T): - if not isinstance(obj, target): - raise TypeError(f"Casting {obj} to {target} failed.") - return obj +from .util import failing_cast class PsAstNode(ABC): @@ -49,20 +41,31 @@ class PsAstNode(ABC): class PsBlock(PsAstNode): + + __match_args__ = ("statements",) + def __init__(self, cs: Sequence[PsAstNode]): - self._children = list(cs) + self._statements = list(cs) def num_children(self) -> int: - return len(self._children) + return len(self._statements) def children(self) -> Generator[PsAstNode, None, None]: - yield from self._children + yield from self._statements def get_child(self, idx: int): - return self._children[idx] + return self._statements[idx] def set_child(self, idx: int, c: PsAstNode): - self._children[idx] = c + self._statements[idx] = c + + @property + def statements(self) -> list[PsAstNode]: + return self._statements + + @statements.setter + def statemetns(self, stm: Sequence[PsAstNode]): + self._statements = list(stm) class PsLeafNode(PsAstNode): @@ -82,6 +85,8 @@ class PsLeafNode(PsAstNode): class PsExpression(PsLeafNode): """Wrapper around pymbolics expressions.""" + __match_args__ = ("expression",) + def __init__(self, expr: pb.Expression): self._expr = expr @@ -107,6 +112,8 @@ class PsLvalueExpr(PsExpression): class PsSymbolExpr(PsLvalueExpr): """Wrapper around PsTypedSymbols""" + __match_args__ = ("symbol",) + def __init__(self, symbol: PsTypedVariable): super().__init__(symbol) @@ -120,6 +127,9 @@ class PsSymbolExpr(PsLvalueExpr): class PsAssignment(PsAstNode): + + __match_args__ = ("lhs", "rhs",) + def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): self._lhs = lhs self._rhs = rhs @@ -160,6 +170,9 @@ class PsAssignment(PsAstNode): class PsDeclaration(PsAssignment): + + __match_args__ = ("declared_variable", "rhs",) + def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression): super().__init__(lhs, rhs) @@ -172,11 +185,11 @@ class PsDeclaration(PsAssignment): self._lhs = failing_cast(PsSymbolExpr, lvalue) @property - def declared_symbol(self) -> PsSymbolExpr: + def declared_variable(self) -> PsSymbolExpr: return cast(PsSymbolExpr, self._lhs) - @declared_symbol.setter - def declared_symbol(self, lvalue: PsSymbolExpr): + @declared_variable.setter + def declared_variable(self, lvalue: PsSymbolExpr): self._lhs = lvalue def set_child(self, idx: int, c: PsAstNode): @@ -190,6 +203,9 @@ class PsDeclaration(PsAssignment): class PsLoop(PsAstNode): + + __match_args__ = ("counter", "start", "stop", "step", "body") + def __init__( self, ctr: PsSymbolExpr, diff --git a/pystencils/nbackend/ast/util.py b/pystencils/nbackend/ast/util.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1866baf6d89ad7dd79b857f5d45115d3597a81 --- /dev/null +++ b/pystencils/nbackend/ast/util.py @@ -0,0 +1,9 @@ +from typing import TypeVar + +T = TypeVar("T") + + +def failing_cast(target: type, obj: T): + if not isinstance(obj, target): + raise TypeError(f"Casting {obj} to {target} failed.") + return obj diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py index a1f176f8375922578c7d52aa3d501b57e7cb44cd..7c4bf4b7c015d8f45b0377ce38cbf395a6f874cb 100644 --- a/pystencils/nbackend/c_printer.py +++ b/pystencils/nbackend/c_printer.py @@ -3,6 +3,7 @@ from __future__ import annotations from pymbolic.mapper.c_code import CCodeMapper from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, PsAssignment, PsLoop +from .ast.kernelfunction import PsKernelFunction class CPrinter: @@ -23,6 +24,14 @@ class CPrinter: @ast_visitor def visit(self, node: PsAstNode): raise ValueError("Cannot print this node.") + + @visit.case(PsKernelFunction) + def function(self, func: PsKernelFunction) -> str: + params = func.get_parameters() + params_str = ", ".join(f"{p.dtype} {p.name}" for p in params) + decl = f"FUNC_PREFIX void {func.name} ( {params_str} )" + body = self.visit(func.body) + return f"{decl}\n{body}" @visit.case(PsBlock) def block(self, block: PsBlock): @@ -30,7 +39,7 @@ class CPrinter: return self.indent("{ }") self._current_indent_level += self._indent_width - interior = "".join(self.visit(c) for c in block.children()) + interior = "\n".join(self.visit(c) for c in block.children()) self._current_indent_level -= self._indent_width return self.indent("{\n") + interior + self.indent("}\n") @@ -40,11 +49,11 @@ class CPrinter: @visit.case(PsDeclaration) def declaration(self, decl: PsDeclaration): - lhs_symb = decl.declared_symbol.symbol + lhs_symb = decl.declared_variable.symbol lhs_dtype = lhs_symb.dtype rhs_code = self.visit(decl.rhs) - return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};\n") + return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};") @visit.case(PsAssignment) def assignment(self, asm: PsAssignment): @@ -54,7 +63,7 @@ class CPrinter: @visit.case(PsLoop) def loop(self, loop: PsLoop): - ctr_symbol = loop.counter.expression + ctr_symbol = loop.counter.symbol ctr = ctr_symbol.name start_code = self.visit(loop.start) stop_code = self.visit(loop.stop) diff --git a/pystencils/nbackend/exceptions.py b/pystencils/nbackend/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..a388017a56de7f9bcea784f0ea0a4cb1de3e875f --- /dev/null +++ b/pystencils/nbackend/exceptions.py @@ -0,0 +1,8 @@ + +class PsInternalCompilerError(Exception): + pass + + +class PsMalformedAstException(Exception): + pass +