Skip to content
Snippets Groups Projects
Commit 25d72d18 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add pskernelfunction and parameter collector

parent be1a46b6
No related merge requests found
......@@ -22,5 +22,5 @@ __all__ = [
"PsAssignment",
"PsDeclaration",
"PsLoop",
"ast_subs",
"ast_subs"
]
from functools import reduce
from typing import Any, cast
from pymbolic.primitives import Variable
from pymbolic.mapper.dependency import DependencyMapper
from .kernelfunction import PsKernelFunction
from .nodes import PsAstNode, PsExpression, PsAssignment, PsDeclaration, PsLoop, PsBlock
from ..typed_expressions import PsTypedVariable
from ..exceptions import PsMalformedAstException, PsInternalCompilerError
class UndefinedVariablesCollector:
"""Collector for undefined variables.
This class implements an AST visitor that collects all `PsTypedVariable`s that have been used
in the AST without being defined prior to their usage.
"""
def __init__(self) -> None:
self._pb_dep_mapper = DependencyMapper(
include_subscripts=False,
include_lookups=False,
include_calls=False,
include_cses=False,
)
def collect(self, node: PsAstNode) -> set[PsTypedVariable]:
"""Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage."""
match node:
case PsKernelFunction(block):
return self.collect(block)
case PsExpression(expr):
variables: set[Variable] = self._pb_dep_mapper(expr)
for var in variables:
if not isinstance(var, PsTypedVariable):
raise PsMalformedAstException(
f"Non-typed variable {var} encountered"
)
return cast(set[PsTypedVariable], variables)
case PsAssignment(lhs, rhs):
return self.collect(lhs) | self.collect(rhs)
case PsBlock(statements):
undefined_vars = set()
for stmt in statements[::-1]:
undefined_vars -= self.declared_variables(stmt)
undefined_vars |= self.collect(stmt)
return undefined_vars
case PsLoop(ctr, start, stop, step, body):
undefined_vars = (
self.collect(start)
| self.collect(stop)
| self.collect(step)
| self.collect(body)
)
undefined_vars.remove(ctr.symbol)
return undefined_vars
case unknown:
raise PsInternalCompilerError(
f"Don't know how to collect undefined variables from {unknown}"
)
def declared_variables(self, node: PsAstNode) -> set[PsTypedVariable]:
"""Returns the set of variables declared by the given node which are visible in the enclosing scope."""
match node:
case PsDeclaration(lhs, _):
return {lhs.symbol}
case PsAssignment() | PsExpression() | PsLoop() | PsBlock():
return set()
case unknown:
raise PsInternalCompilerError(
f"Don't know how to collect declared variables from {unknown}"
)
from typing import Sequence
from typing import Generator
from .nodes import PsAstNode, PsBlock, failing_cast
from ..typed_expressions import PsTypedVariable
from ...enums import Target
class PsKernelFunction(PsAstNode):
"""A complete pystencils kernel function."""
__match_args__ = ("block",)
def __init__(self, body: PsBlock, target: Target, name: str = "kernel"):
self._body = body
self._target = target
self._name = name
@property
def target(self) -> Target:
"""See pystencils.Target"""
return self._target
@property
def body(self) -> PsBlock:
return self._body
@body.setter
def body(self, body: PsBlock):
self._body = body
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, value: str):
self._name = value
def num_children(self) -> int:
return 1
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._body, )
def get_child(self, idx: int):
if idx not in (0, -1):
raise IndexError(f"Child index out of bounds: {idx}")
return self._body
def set_child(self, idx: int, c: PsAstNode):
if idx not in (0, -1):
raise IndexError(f"Child index out of bounds: {idx}")
self._body = failing_cast(PsBlock, c)
def get_parameters(self) -> Sequence[PsTypedVariable]:
"""Collect the list of parameters to this function.
This function performs a full traversal of the AST.
To improve performance, make sure to cache the result if necessary.
"""
from .analysis import UndefinedVariablesCollector
params = UndefinedVariablesCollector().collect(self)
return sorted(params, key=lambda p: p.name)
\ No newline at end of file
from __future__ import annotations
from typing import Sequence, Generator, TypeVar, Iterable, cast
from typing import Sequence, Generator, Iterable, cast
from abc import ABC, abstractmethod
import pymbolic.primitives as pb
from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue
T = TypeVar("T")
def failing_cast(target: type, obj: T):
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
return obj
from .util import failing_cast
class PsAstNode(ABC):
......@@ -49,20 +41,31 @@ class PsAstNode(ABC):
class PsBlock(PsAstNode):
__match_args__ = ("statements",)
def __init__(self, cs: Sequence[PsAstNode]):
self._children = list(cs)
self._statements = list(cs)
def num_children(self) -> int:
return len(self._children)
return len(self._statements)
def children(self) -> Generator[PsAstNode, None, None]:
yield from self._children
yield from self._statements
def get_child(self, idx: int):
return self._children[idx]
return self._statements[idx]
def set_child(self, idx: int, c: PsAstNode):
self._children[idx] = c
self._statements[idx] = c
@property
def statements(self) -> list[PsAstNode]:
return self._statements
@statements.setter
def statemetns(self, stm: Sequence[PsAstNode]):
self._statements = list(stm)
class PsLeafNode(PsAstNode):
......@@ -82,6 +85,8 @@ class PsLeafNode(PsAstNode):
class PsExpression(PsLeafNode):
"""Wrapper around pymbolics expressions."""
__match_args__ = ("expression",)
def __init__(self, expr: pb.Expression):
self._expr = expr
......@@ -107,6 +112,8 @@ class PsLvalueExpr(PsExpression):
class PsSymbolExpr(PsLvalueExpr):
"""Wrapper around PsTypedSymbols"""
__match_args__ = ("symbol",)
def __init__(self, symbol: PsTypedVariable):
super().__init__(symbol)
......@@ -120,6 +127,9 @@ class PsSymbolExpr(PsLvalueExpr):
class PsAssignment(PsAstNode):
__match_args__ = ("lhs", "rhs",)
def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
self._lhs = lhs
self._rhs = rhs
......@@ -160,6 +170,9 @@ class PsAssignment(PsAstNode):
class PsDeclaration(PsAssignment):
__match_args__ = ("declared_variable", "rhs",)
def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
super().__init__(lhs, rhs)
......@@ -172,11 +185,11 @@ class PsDeclaration(PsAssignment):
self._lhs = failing_cast(PsSymbolExpr, lvalue)
@property
def declared_symbol(self) -> PsSymbolExpr:
def declared_variable(self) -> PsSymbolExpr:
return cast(PsSymbolExpr, self._lhs)
@declared_symbol.setter
def declared_symbol(self, lvalue: PsSymbolExpr):
@declared_variable.setter
def declared_variable(self, lvalue: PsSymbolExpr):
self._lhs = lvalue
def set_child(self, idx: int, c: PsAstNode):
......@@ -190,6 +203,9 @@ class PsDeclaration(PsAssignment):
class PsLoop(PsAstNode):
__match_args__ = ("counter", "start", "stop", "step", "body")
def __init__(
self,
ctr: PsSymbolExpr,
......
from typing import TypeVar
T = TypeVar("T")
def failing_cast(target: type, obj: T):
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
return obj
......@@ -3,6 +3,7 @@ from __future__ import annotations
from pymbolic.mapper.c_code import CCodeMapper
from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, PsAssignment, PsLoop
from .ast.kernelfunction import PsKernelFunction
class CPrinter:
......@@ -23,6 +24,14 @@ class CPrinter:
@ast_visitor
def visit(self, node: PsAstNode):
raise ValueError("Cannot print this node.")
@visit.case(PsKernelFunction)
def function(self, func: PsKernelFunction) -> str:
params = func.get_parameters()
params_str = ", ".join(f"{p.dtype} {p.name}" for p in params)
decl = f"FUNC_PREFIX void {func.name} ( {params_str} )"
body = self.visit(func.body)
return f"{decl}\n{body}"
@visit.case(PsBlock)
def block(self, block: PsBlock):
......@@ -30,7 +39,7 @@ class CPrinter:
return self.indent("{ }")
self._current_indent_level += self._indent_width
interior = "".join(self.visit(c) for c in block.children())
interior = "\n".join(self.visit(c) for c in block.children())
self._current_indent_level -= self._indent_width
return self.indent("{\n") + interior + self.indent("}\n")
......@@ -40,11 +49,11 @@ class CPrinter:
@visit.case(PsDeclaration)
def declaration(self, decl: PsDeclaration):
lhs_symb = decl.declared_symbol.symbol
lhs_symb = decl.declared_variable.symbol
lhs_dtype = lhs_symb.dtype
rhs_code = self.visit(decl.rhs)
return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};\n")
return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};")
@visit.case(PsAssignment)
def assignment(self, asm: PsAssignment):
......@@ -54,7 +63,7 @@ class CPrinter:
@visit.case(PsLoop)
def loop(self, loop: PsLoop):
ctr_symbol = loop.counter.expression
ctr_symbol = loop.counter.symbol
ctr = ctr_symbol.name
start_code = self.visit(loop.start)
stop_code = self.visit(loop.stop)
......
class PsInternalCompilerError(Exception):
pass
class PsMalformedAstException(Exception):
pass
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment