Skip to content
Snippets Groups Projects
Commit 89b6f68b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

baby steps toward field parameter mapping

parent c13f1095
No related merge requests found
......@@ -32,12 +32,21 @@ class SfgKernelNamespace:
class SfgKernelHandle:
def __init__(self, ctx, name: str, namespace: SfgKernelNamespace, parameters):
def __init__(self, ctx, name: str, namespace: SfgKernelNamespace, parameters: Sequence[KernelFunction.Parameter]):
self._ctx = ctx
self._name = name
self._namespace = namespace
self._parameters = parameters
self._scalar_params = set()
self._fields = set()
for param in self._parameters:
if param.is_field_parameter:
self._fields |= set(param.fields)
else:
self._scalar_params.add(param.symbol)
@property
def kernel_name(self):
return self._name
......@@ -53,4 +62,12 @@ class SfgKernelHandle:
@property
def parameters(self):
return self._parameters
@property
def scalar_parameters(self):
return self._scalar_params
@property
def fields(self):
return self.fields
\ No newline at end of file
......@@ -5,16 +5,19 @@ if TYPE_CHECKING:
from .context import SfgContext
from .tree import SfgCallTreeNode, SfgSequence
from .tree.visitors import ParameterCollector
from .tree.visitors import FlattenSequences, ParameterCollector
class SfgFunction:
def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode):
self._ctx = ctx
self._name = name
self._tree = tree
flattener = FlattenSequences()
flattener.visit(self._tree)
param_collector = ParameterCollector()
self._parameters = param_collector.visit(tree)
self._parameters = param_collector.visit(self._tree)
@property
def name(self):
......
from abc import ABC, abstractmethod
from .source_concepts import SrcObject, SrcMemberAccess
class SrcContiguousContainer(SrcObject):
def __init__(self, src_type, identifier: Optional[str]):
super().__init__(src_type, identifier)
@abstractmethod
def ptr(self) -> SrcMemberAccess:
pass
@abstractmethod
def size(self, dimension: int) -> SrcMemberAccess:
pass
@abstractmethod
def stride(self, dimension: int) -> SrcMemberAccess:
pass
from typing import Optional
from ..source_concepts import SrcMemberAccess
from ..containers import SrcContiguousContainer
class std_mdspan(SrcContiguousContainer):
def __init__(self, identifer: str):
super().__init__("std::mdspan", identifier)
def ptr(self):
return SrcMemberAccess(self, f"{self._identifier}.data_handle()")
def size(self, dimension: int):
return SrcMemberAccess(self, f"{self._identifier}.extents().extent({dimension})")
def stride(self, dimension: int):
return SrcMemberAccess(self, f"{self._identifier}.stride({dimension})")
from typing import Optional
from abc import ABC, abstractmethod
from pystencils import TypedSymbol
class SrcClass:
def __init__(self):
pass
class SrcObject(ABC):
def __init__(self, src_type, identifier: Optional[str]):
self._src_type = src_type
self._identifier = identifier
@property
def _sfg_symbol(self):
return TypedSymbol(self._identifier, self._src_type)
class SrcMemberAccess():
def __init__(self, obj: SrcObject, code_string: str):
self._obj = obj
self._code_string = code_string
def _sfg_code_string():
return self._code_string
......@@ -41,11 +41,24 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
def required_symbols(self) -> set(TypedSymbol):
pass
class SfgParameterDefinition(SfgCallTreeLeaf):
def __init__(self, defined_param: TypedSymbol, required_params: Set[TypedSymbol], code_string: str):
self._defined_param = defined_param
self._required_params = required_params
self._code_string = code_string
@property
@abstractmethod
def defined_symbols(self) -> set(TypedSymbol):
pass
def defined_symbol(self) -> TypedSymbol:
return self._defined_param
@property
def required_symbols(self) -> set(TypedSymbol):
return self._required_params
def get_code(self):
return self._code_string
class SfgCustomStatement(SfgCallTreeLeaf):
def __init__(self, statement: str):
......@@ -54,9 +67,6 @@ class SfgCustomStatement(SfgCallTreeLeaf):
def required_symbols(self) -> set(TypedSymbol):
return set()
def defined_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str:
return self._statement
......@@ -96,10 +106,6 @@ class SfgKernelCallNode(SfgCallTreeLeaf):
def required_symbols(self) -> set(TypedSymbol):
return set(p.symbol for p in self._kernel_handle.parameters)
@property
def defined_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str:
ast_params = self._kernel_handle.parameters
fnc_name = self._kernel_handle.fully_qualified_name
......
......@@ -5,13 +5,14 @@ if TYPE_CHECKING:
from ..context import SfgContext
from abc import ABC, abstractmethod
from pystencils import Field
from .basic_nodes import SfgCallTreeNode, SfgSequence, SfgBlock, SfgCustomStatement
from .conditional import SfgCondition, SfgCustomCondition, SfgBranch
from ..source_concepts.containers import SrcContiguousContainer
class SfgNodeBuilder(ABC):
def __init__(self, ctx: SfgContext) -> None:
self._ctx = ctx
@abstractmethod
def resolve(self) -> SfgCallTreeNode:
pass
......@@ -40,8 +41,8 @@ class SfgSequencer:
class SfgBranchBuilder(SfgNodeBuilder):
def __init__(self, ctx: SfgContext) -> None:
super().__init__(ctx)
def __init__(self, ctx: SfgContext):
self._ctx = ctx
self._phase = 0
self._cond = None
......@@ -67,7 +68,7 @@ class SfgBranchBuilder(SfgNodeBuilder):
self._branch_true = self._ctx.seq(*args)
case 2: # Else-branch
self._branch_false = self._ctx.seq(*args)
case _: # There's not third branch!
case _: # There's no third branch!
raise TypeError("Branch construct already complete.")
self._phase += 1
......@@ -77,4 +78,14 @@ class SfgBranchBuilder(SfgNodeBuilder):
def resolve(self) -> SfgCallTreeNode:
return SfgBranch(self._cond, self._branch_true, self._branch_false)
\ No newline at end of file
class SfgFieldMappingBuilder(SfgNodeBuilder):
def __init__(self, ctx: SfgContext):
super().__init__(ctx)
self._field = None
self._container = None
def __call__(self, field: Field, container: SrcContiguousContainer):
self._field = field
self._container = container
\ No newline at end of file
......@@ -18,9 +18,6 @@ class SfgCustomCondition(SfgCondition):
def required_symbols(self) -> set(TypedSymbol):
return set()
def defined_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str:
return self._cond_text
......
......@@ -3,10 +3,41 @@ from functools import reduce
from pystencils.typing import TypedSymbol
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgParameterDefinition
class FlattenSequences():
"""Flattens any nested sequences occuring in a kernel call tree."""
def visit(self, node: SfgCallTreeNode) -> None:
if isinstance(node, SfgSequence):
return self._visit_SfgSequence(node)
else:
for c in node.children:
self.visit(c)
def _visit_SfgSequence(self, sequence: SfgSequence) -> None:
children_flattened = []
def flatten(seq: SfgSequence):
for c in seq.children:
if isinstance(c, SfgSequence):
flatten(c)
else:
children_flattened.append(c)
flatten(sequence)
for c in children_flattened:
self.visit(c)
sequence._children = children_flattened
class ParameterCollector():
"""Collects all parameters required but not defined in a kernel call tree.
Requires that all sequences in the tree are flattened.
"""
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node)
......@@ -25,11 +56,10 @@ class ParameterCollector():
params = set()
for c in sequence.children[::-1]:
if isinstance(c, SfgCallTreeLeaf):
# Only a leaf in a sequence may effectively define symbols
# Remove these from the required parameters
if isinstance(c, SfgParameterDefinitionNode):
params -= c.defined_symbols
assert not isinstance(c, SfgSequence), "Sequence not flattened."
params |= self.visit(c)
return params
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment