diff --git a/TestSequencing.cpp b/TestSequencing.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c95fb5cd720106cb667991140fb2e4a300b3182 --- /dev/null +++ b/TestSequencing.cpp @@ -0,0 +1,106 @@ +#include "TestSequencing.h" + +#define FUNC_PREFIX inline + +namespace pystencils { + +/************************************************************************************* + * Kernels +*************************************************************************************/ + +namespace kernels{ + + +FUNC_PREFIX void streamCollide_even( double * RESTRICT const _data_src, int64_t const _size_src_0, int64_t const _size_src_1, int64_t const _stride_src_0, int64_t const _stride_src_1, int64_t const _stride_src_2, double omega) +{ + for (int64_t ctr_1 = 1; ctr_1 < _size_src_1 - 1; ctr_1 += 1) + { + for (int64_t ctr_0 = 1; ctr_0 < _size_src_0 - 1; ctr_0 += 1) + { + const double xi_1 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + _stride_src_1 + 7*_stride_src_2]; + const double xi_2 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 5*_stride_src_2]; + const double xi_3 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1]; + const double xi_4 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 4*_stride_src_2]; + const double xi_5 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 6*_stride_src_2]; + const double xi_6 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 2*_stride_src_2]; + const double xi_7 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 3*_stride_src_2]; + const double xi_8 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 8*_stride_src_2]; + const double xi_9 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_2]; + const double vel0Term = xi_4 + xi_5 + xi_8; + const double vel1Term = xi_2 + xi_9; + const double delta_rho = vel0Term + vel1Term + xi_1 + xi_3 + xi_6 + xi_7; + const double u_0 = vel0Term - xi_1 - xi_2 - xi_7; + const double u_1 = vel1Term - xi_1 + xi_5 - xi_6 - xi_8; + const double u0Mu1 = u_0 - u_1; + const double u0Pu1 = u_0 + u_1; + const double f_eq_common = delta_rho - 1.5*(u_0*u_0) - 1.5*(u_1*u_1); + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1] = omega*(f_eq_common*0.44444444444444442 - xi_3) + xi_3; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 2*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_1*0.33333333333333331 - xi_9 + 0.5*(u_1*u_1)) + xi_9; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_1*-0.33333333333333331 - xi_6 + 0.5*(u_1*u_1)) + xi_6; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 4*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_0*-0.33333333333333331 - xi_7 + 0.5*(u_0*u_0)) + xi_7; + _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 3*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_0*0.33333333333333331 - xi_4 + 0.5*(u_0*u_0)) + xi_4; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 8*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Mu1*-0.083333333333333329 - xi_2 + 0.125*(u0Mu1*u0Mu1)) + xi_2; + _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + _stride_src_1 + 7*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Pu1*0.083333333333333329 - xi_5 + 0.125*(u0Pu1*u0Pu1)) + xi_5; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 6*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Pu1*-0.083333333333333329 - xi_1 + 0.125*(u0Pu1*u0Pu1)) + xi_1; + _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 5*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Mu1*0.083333333333333329 - xi_8 + 0.125*(u0Mu1*u0Mu1)) + xi_8; + } + } +} + +FUNC_PREFIX void streamCollide_odd( double * RESTRICT const _data_src, int64_t const _size_src_0, int64_t const _size_src_1, int64_t const _stride_src_0, int64_t const _stride_src_1, int64_t const _stride_src_2, double omega) +{ + for (int64_t ctr_1 = 1; ctr_1 < _size_src_1 - 1; ctr_1 += 1) + { + for (int64_t ctr_0 = 1; ctr_0 < _size_src_0 - 1; ctr_0 += 1) + { + const double xi_1 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 5*_stride_src_2]; + const double xi_2 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1]; + const double xi_3 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + _stride_src_1 + 6*_stride_src_2]; + const double xi_4 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 2*_stride_src_2]; + const double xi_5 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 3*_stride_src_2]; + const double xi_6 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 8*_stride_src_2]; + const double xi_7 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 7*_stride_src_2]; + const double xi_8 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + _stride_src_2]; + const double xi_9 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 4*_stride_src_2]; + const double vel0Term = xi_1 + xi_5 + xi_7; + const double vel1Term = xi_4 + xi_6; + const double delta_rho = vel0Term + vel1Term + xi_2 + xi_3 + xi_8 + xi_9; + const double u_0 = vel0Term - xi_3 - xi_6 - xi_9; + const double u_1 = vel1Term - xi_1 - xi_3 + xi_7 - xi_8; + const double u0Mu1 = u_0 - u_1; + const double u0Pu1 = u_0 + u_1; + const double f_eq_common = delta_rho - 1.5*(u_0*u_0) - 1.5*(u_1*u_1); + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1] = omega*(f_eq_common*0.44444444444444442 - xi_2) + xi_2; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + _stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_1*0.33333333333333331 - xi_4 + 0.5*(u_1*u_1)) + xi_4; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 2*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_1*-0.33333333333333331 - xi_8 + 0.5*(u_1*u_1)) + xi_8; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 3*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_0*-0.33333333333333331 - xi_9 + 0.5*(u_0*u_0)) + xi_9; + _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 4*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_0*0.33333333333333331 - xi_5 + 0.5*(u_0*u_0)) + xi_5; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 5*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Mu1*-0.083333333333333329 - xi_6 + 0.125*(u0Mu1*u0Mu1)) + xi_6; + _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + _stride_src_1 + 6*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Pu1*0.083333333333333329 - xi_7 + 0.125*(u0Pu1*u0Pu1)) + xi_7; + _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 7*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Pu1*-0.083333333333333329 - xi_3 + 0.125*(u0Pu1*u0Pu1)) + xi_3; + _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 8*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Mu1*0.083333333333333329 - xi_1 + 0.125*(u0Mu1*u0Mu1)) + xi_1; + } + } +} + + +} // namespace kernels + + +/************************************************************************************* + * Functions +*************************************************************************************/ + + + +void myFunction ( double * RESTRICT const _data_src, int64_t const _size_src_0, int64_t const _size_src_1, int64_t const _stride_src_0, int64_t const _stride_src_1, int64_t const _stride_src_2, double omega ) { + if((timestep & 1) ^ 1) { + pystencils::kernels::streamCollide_even(_data_src, _size_src_0, _size_src_1, _stride_src_0, _stride_src_1, _stride_src_2, omega); + }else { + pystencils::kernels::streamCollide_odd(_data_src, _size_src_0, _size_src_1, _stride_src_0, _stride_src_1, _stride_src_2, omega); + } +} + + + +} // namespace pystencils \ No newline at end of file diff --git a/TestSequencing.h b/TestSequencing.h new file mode 100644 index 0000000000000000000000000000000000000000..af6f9689e7d9f48afadba784b499fb7d294881b2 --- /dev/null +++ b/TestSequencing.h @@ -0,0 +1,11 @@ +#pragma once + +#define RESTRICT __restrict__ + +#include <cstdint> + + +namespace pystencils { + + +} // namespace pystencils \ No newline at end of file diff --git a/pystencils_sfg/call_tree/abstract_node.py b/pystencils_sfg/call_tree/abstract_node.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/pystencils_sfg/context.py b/pystencils_sfg/context.py index f2e6493716a1c1b3d98c05a1fd835e4a3ed80adb..7fc01763e5058089933993b1998a2d789cdeea1e 100644 --- a/pystencils_sfg/context.py +++ b/pystencils_sfg/context.py @@ -1,9 +1,32 @@ +from typing import Callable, Sequence, Generator, Union, Optional +from dataclasses import dataclass + +import os from os import path -from .kernel_namespace import SfgKernelNamespace +from jinja2.filters import do_indent + +from pystencils.astnodes import KernelFunction + +from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle +from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgCondition, SfgBranch +from .tree.builders import SfgBranchBuilder, SfgSequencer +from .source_components import SfgFunction + + +@dataclass +class SfgCodeStyle: + indent_width: int = 2 + + def indent(self, s: str): + return do_indent(s, self.indent_width, first=True) + class SourceFileGenerator: - def __init__(self, namespace: str = "pystencils", basename: str = None): + def __init__(self, + namespace: str = "pystencils", + basename: str = None, + codestyle: SfgCodeStyle = SfgCodeStyle()): if basename is None: import __main__ @@ -15,32 +38,50 @@ class SourceFileGenerator: self.header_filename = basename + ".h" self.cpp_filename = basename + ".cpp" - self._context = SfgContext(namespace) + self._context = SfgContext(namespace, codestyle) + + def clean_files(self): + for file in (self.header_filename, self.cpp_filename): + if path.exists(file): + os.remove(file) def __enter__(self): + self.clean_files() return self._context - def __exit__(self, *args): - from .emitters.cpu.basic_cpu import BasicCpuEmitter - BasicCpuEmitter(self._context, self.basename).write_files() + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + from .emitters.cpu.basic_cpu import BasicCpuEmitter + BasicCpuEmitter(self._context, self.basename).write_files() class SfgContext: - def __init__(self, root_namespace: str): + def __init__(self, root_namespace: str, codestyle: SfgCodeStyle): self._root_namespace = root_namespace + self._codestyle = codestyle self._default_kernel_namespace = SfgKernelNamespace(self, "kernels") + # Source Components + self._includes = [] self._kernel_namespaces = { self._default_kernel_namespace.name : self._default_kernel_namespace } + self._functions = dict() + + # Builder Components + self._sequencer = SfgSequencer(self) @property - def root_namespace(self): + def root_namespace(self) -> str: return self._root_namespace + + @property + def codestyle(self) -> SfgCodeStyle: + return self._codestyle @property - def kernels(self): + def kernels(self) -> SfgKernelNamespace: return self._default_kernel_namespace - def kernel_namespace(self, name): + def kernel_namespace(self, name: str) -> SfgKernelNamespace: if name in self._kernel_namespaces: raise ValueError(f"Duplicate kernel namespace: {name}") @@ -48,6 +89,50 @@ class SfgContext: self._kernel_namespaces[name] = kns return kns - @property - def kernel_namespaces(self): + def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: yield from self._kernel_namespaces.values() + + def functions(self) -> Generator[SfgFunction, None, None]: + yield from self._functions.values() + + def include(self, header_file: str): + self._includes.append(header_file) + + def function(self, + name: str, + ast_or_kernel_handle : Optional[Union[KernelFunction, SfgKernelHandle]] = None): + if name in self._functions: + raise ValueError(f"Duplicate function: {name}") + + if ast_or_kernel_handle is not None: + if isinstance(ast_or_kernel_handle, KernelFunction): + khandle = self._default_kernel_namespace.add(ast_or_kernel_handle) + tree = SfgKernelCallNode(self, khandle) + elif isinstance(ast_or_kernel_handle, SfgKernelCallNode): + tree = ast_or_kernel_handle + else: + raise TypeError(f"Invalid type of argument `ast_or_kernel_handle`!") + else: + def sequencer(*args: SfgCallTreeNode): + tree = self.seq(*args) + func = SfgFunction(self, name, tree) + self._functions[name] = func + + return sequencer + + + #---------------------------------------------------------------------------------------------- + # Call Tree Node Factory + #---------------------------------------------------------------------------------------------- + + @property + def seq(self) -> SfgSequencer: + return self._sequencer + + def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode: + return SfgKernelCallNode(kernel_handle) + + @property + def branch(self) -> SfgBranchBuilder: + return SfgBranchBuilder(self) + \ No newline at end of file diff --git a/pystencils_sfg/emitters/cpu/basic_cpu.py b/pystencils_sfg/emitters/cpu/basic_cpu.py index 7dd25a280a3bea1827edc08b6bbafb7e856416b2..2181681e8bbf952ed2ef16e7ba7aab927684598f 100644 --- a/pystencils_sfg/emitters/cpu/basic_cpu.py +++ b/pystencils_sfg/emitters/cpu/basic_cpu.py @@ -14,7 +14,8 @@ class BasicCpuEmitter: 'ctx': self._ctx, 'basename': self._basename, 'root_namespace': self._ctx.root_namespace, - 'kernel_namespaces': list(self._ctx.kernel_namespaces) + 'kernel_namespaces': list(self._ctx.kernel_namespaces()), + 'functions': list(self._ctx.functions()) } template_name = "BasicCpu" diff --git a/pystencils_sfg/emitters/cpu/jinja_filters.py b/pystencils_sfg/emitters/cpu/jinja_filters.py index 0bf4b76c07a9d1e917a577d3ab95d094a59620cb..0abbdf3385dfd954d79b6da6ebd23b197b6b1e16 100644 --- a/pystencils_sfg/emitters/cpu/jinja_filters.py +++ b/pystencils_sfg/emitters/cpu/jinja_filters.py @@ -1,11 +1,25 @@ from jinja2 import pass_context + from pystencils.astnodes import KernelFunction from pystencils import Backend from pystencils.backends import generate_c +from pystencils_sfg.tree import SfgCallTreeNode +from pystencils_sfg.source_components import SfgFunction + @pass_context def generate_kernel_definition(ctx, ast: KernelFunction): return generate_c(ast, dialect=Backend.C) +@pass_context +def generate_function_parameter_list(ctx, func: SfgFunction): + params = sorted(list(func.parameters), key=lambda p: p.name) + return ", ".join(f"{param.dtype} {param.name}" for param in params) + +def generate_function_body(func: SfgFunction): + return func.get_code() + def add_filters_to_jinja(jinja_env): - jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition \ No newline at end of file + jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition + jinja_env.filters['generate_function_parameter_list'] = generate_function_parameter_list + jinja_env.filters['generate_function_body'] = generate_function_body diff --git a/pystencils_sfg/emitters/cpu/templates/BasicCpu.tmpl.cpp b/pystencils_sfg/emitters/cpu/templates/BasicCpu.tmpl.cpp index 82afa145c89b19d1a179e5c9bf3e92e4ea61a9f2..b9782a3899637076eead9d17af0da85c434cc0b8 100644 --- a/pystencils_sfg/emitters/cpu/templates/BasicCpu.tmpl.cpp +++ b/pystencils_sfg/emitters/cpu/templates/BasicCpu.tmpl.cpp @@ -4,14 +4,30 @@ namespace {{root_namespace}} { +/************************************************************************************* + * Kernels +*************************************************************************************/ + {% for kns in kernel_namespaces -%} namespace {{ kns.name }}{ -{% for ast in kns.asts -%} +{% for ast in kns.asts %} {{ ast | generate_kernel_definition }} -{%- endfor %} +{% endfor %} } // namespace {{ kns.name }} {% endfor %} +/************************************************************************************* + * Functions +*************************************************************************************/ + +{% for function in functions %} + +void {{ function.name }} ( {{ function | generate_function_parameter_list }} ) { + {{ function | generate_function_body | indent(2) }} +} + +{% endfor %} + } // namespace {{root_namespace}} diff --git a/pystencils_sfg/kernel_namespace.py b/pystencils_sfg/kernel_namespace.py index f71a5d9a2be41158cfaa95fba6b76317f5fa1867..218780fc6aa6598cbf30e4c1755d9776628b83bf 100644 --- a/pystencils_sfg/kernel_namespace.py +++ b/pystencils_sfg/kernel_namespace.py @@ -20,11 +20,11 @@ class SfgKernelNamespace: """Adds an existing pystencils AST to this namespace.""" astname = ast.function_name if astname in self._asts: - raise ValueError(f"Duplicate ASTs: An AST with name {astname} was already registered in namespace {self._name}") + raise ValueError(f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}") self._asts[astname] = ast - return SfgKernelHandle(self._ctx, astname, self) + return SfgKernelHandle(self._ctx, astname, self, ast.get_parameters()) def create(self, assignments, config: CreateKernelConfig): ast = create_kernel(assignments, config) @@ -32,10 +32,11 @@ class SfgKernelNamespace: class SfgKernelHandle: - def __init__(self, ctx, name: str, namespace: SfgKernelNamespace): + def __init__(self, ctx, name: str, namespace: SfgKernelNamespace, parameters): self._ctx = ctx self._name = name self._namespace = namespace + self._parameters = parameters @property def kernel_name(self): @@ -47,5 +48,9 @@ class SfgKernelHandle: @property def fully_qualified_name(self): - return f"{self.ctx.root_namespace}::{self.kernel_namespace}::{self.kernel_name}" + return f"{self._ctx.root_namespace}::{self.kernel_namespace.name}::{self.kernel_name}" + + @property + def parameters(self): + return self._parameters \ No newline at end of file diff --git a/pystencils_sfg/source_components.py b/pystencils_sfg/source_components.py new file mode 100644 index 0000000000000000000000000000000000000000..38e93307e5e1433a740ffe88c0df98150aaa36cd --- /dev/null +++ b/pystencils_sfg/source_components.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .context import SfgContext + +from .tree import SfgCallTreeNode, SfgSequence +from .tree.visitors import ParameterCollector + +class SfgFunction: + def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode): + self._ctx = ctx + self._name = name + self._tree = tree + + param_collector = ParameterCollector() + self._parameters = param_collector.visit(tree) + + @property + def name(self): + return self._name + + @property + def parameters(self): + return self._parameters + + def get_code(self): + return self._tree.get_code(self._ctx) + diff --git a/pystencils_sfg/tree/__init__.py b/pystencils_sfg/tree/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9615188f0c7892b78bdf00e277cbe7f134b3b75f --- /dev/null +++ b/pystencils_sfg/tree/__init__.py @@ -0,0 +1,7 @@ +from .basic_nodes import SfgCallTreeNode, SfgKernelCallNode, SfgBlock, SfgSequence +from .conditional import SfgBranch, SfgCondition + +__all__ = [ + SfgCallTreeNode, SfgKernelCallNode, SfgSequence, SfgBlock, + SfgCondition, SfgBranch +] \ No newline at end of file diff --git a/pystencils_sfg/tree/basic_nodes.py b/pystencils_sfg/tree/basic_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..194c404e45f49e9d53fa5f84592ebd42de5d6daf --- /dev/null +++ b/pystencils_sfg/tree/basic_nodes.py @@ -0,0 +1,108 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Sequence + +if TYPE_CHECKING: + from ..context import SfgContext + +from abc import ABC, abstractmethod +from functools import reduce + +from jinja2.filters import do_indent + +from ..kernel_namespace import SfgKernelHandle + +from pystencils.typing import TypedSymbol + +class SfgCallTreeNode(ABC): + """Base class for all nodes comprising SFG call trees. """ + + @property + @abstractmethod + def children(self) -> Sequence[SfgCallTreeNode]: + pass + + @abstractmethod + def get_code(self, ctx: SfgContext) -> str: + """Returns the code of this node. + + By convention, the code block emitted by this function should not contain a trailing newline. + """ + pass + + +class SfgCallTreeLeaf(SfgCallTreeNode, ABC): + + @property + def children(self) -> Sequence[SfgCallTreeNode]: + return () + + @property + @abstractmethod + def required_symbols(self) -> set(TypedSymbol): + pass + + @property + @abstractmethod + def defined_symbols(self) -> set(TypedSymbol): + pass + + +class SfgCustomStatement(SfgCallTreeLeaf): + def __init__(self, statement: str): + self._statement = statement + + def required_symbols(self) -> set(TypedSymbol): + return set() + + def defined_symbols(self) -> set(TypedSymbol): + return set() + + def get_code(self, ctx: SfgContext) -> str: + return self._statement + + +class SfgSequence(SfgCallTreeNode): + def __init__(self, children: Sequence[SfgCallTreeNode]): + self._children = tuple(children) + + @property + def children(self) -> Sequence[SfgCallTreeNode]: + return self._children + + def get_code(self, ctx: SfgContext) -> str: + return "\n".join(c.get_code(ctx) for c in self._children) + + +class SfgBlock(SfgCallTreeNode): + def __init__(self, subtree: SfgCallTreeNode): + super().__init__(ctx) + self._subtree = subtree + + @property + def children(self) -> Sequence[SfgCallTreeNode]: + return { self._subtree } + + def get_code(self, ctx: SfgContext) -> str: + subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx)) + + return "{\n" + subtree_code + "\n}" + + +class SfgKernelCallNode(SfgCallTreeLeaf): + def __init__(self, kernel_handle: SfgKernelHandle): + self._kernel_handle = kernel_handle + + @property + def required_symbols(self) -> set(TypedSymbol): + return set(p.symbol for p in self._kernel_handle.parameters) + + @property + def defined_symbols(self) -> set(TypedSymbol): + return set() + + def get_code(self, ctx: SfgContext) -> str: + ast_params = self._kernel_handle.parameters + fnc_name = self._kernel_handle.fully_qualified_name + call_parameters = ", ".join([p.symbol.name for p in ast_params]) + + return f"{fnc_name}({call_parameters});" diff --git a/pystencils_sfg/tree/builders.py b/pystencils_sfg/tree/builders.py new file mode 100644 index 0000000000000000000000000000000000000000..4a74424e39fcbec81dbe1351b1d8e7e8d5c1cd61 --- /dev/null +++ b/pystencils_sfg/tree/builders.py @@ -0,0 +1,80 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Sequence + +if TYPE_CHECKING: + from ..context import SfgContext + +from abc import ABC, abstractmethod +from .basic_nodes import SfgCallTreeNode, SfgSequence, SfgBlock, SfgCustomStatement +from .conditional import SfgCondition, SfgCustomCondition, SfgBranch + +class SfgNodeBuilder(ABC): + def __init__(self, ctx: SfgContext) -> None: + self._ctx = ctx + + @abstractmethod + def resolve(self) -> SfgCallTreeNode: + pass + +class SfgSequencer: + def __init__(self, ctx: SfgContext) -> None: + self._ctx = ctx + + def __call__(self, *args) -> SfgSequence: + children = [] + for i, arg in enumerate(args): + if isinstance(arg, SfgNodeBuilder): + children.append(arg.resolve()) + elif isinstance(arg, SfgCallTreeNode): + children.append(arg) + elif isinstance(arg, str): + children.append(SfgCustomStatement(arg)) + elif isinstance(arg, tuple): + # Tuples are treated as blocks + subseq = self(*arg) + children.append(SfgBlock(subseq)) + else: + raise TypeError(f"Sequence argument {i} has invalid type.") + + return SfgSequence(children) + + +class SfgBranchBuilder(SfgNodeBuilder): + def __init__(self, ctx: SfgContext) -> None: + super().__init__(ctx) + self._phase = 0 + + self._cond = None + self._branch_true = SfgSequence(()) + self._branch_false = None + + def __call__(self, *args) -> SfgBranchBuilder: + match self._phase: + case 0: # Condition + if len(args) != 1: + raise ValueError("Must specify exactly one argument as branch condition!") + + cond = args[0] + + if isinstance(cond, str): + cond = SfgCustomCondition(cond) + elif not isinstance(cond, SfgCondition): + raise ValueError("Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.") + + self._cond = cond + + case 1: # Then-branch + self._branch_true = self._ctx.seq(*args) + case 2: # Else-branch + self._branch_false = self._ctx.seq(*args) + case _: # There's not third branch! + raise TypeError("Branch construct already complete.") + + self._phase += 1 + + return self + + def resolve(self) -> SfgCallTreeNode: + return SfgBranch(self._cond, self._branch_true, self._branch_false) + + \ No newline at end of file diff --git a/pystencils_sfg/tree/conditional.py b/pystencils_sfg/tree/conditional.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a967bf6c980ef4387b32a4861c9ce0ecc8328d --- /dev/null +++ b/pystencils_sfg/tree/conditional.py @@ -0,0 +1,57 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Sequence, Optional + +if TYPE_CHECKING: + from ..context import SfgContext + +from jinja2.filters import do_indent +from pystencils.typing import TypedSymbol + +from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf + +class SfgCondition(SfgCallTreeLeaf): + pass + +class SfgCustomCondition(SfgCondition): + def __init__(self, cond_text: str): + self._cond_text = cond_text + + def required_symbols(self) -> set(TypedSymbol): + return set() + + def defined_symbols(self) -> set(TypedSymbol): + return set() + + def get_code(self, ctx: SfgContext) -> str: + return self._cond_text + + +# class IntEven(SfgCondition): +# def __init__(self, ) + + +class SfgBranch(SfgCallTreeNode): + def __init__(self, cond: SfgCondition, branch_true: SfgCallTreeNode, branch_false: Optional[SfgCallTreeNode] = None): + self._cond = cond + self._branch_true = branch_true + self._branch_false = branch_false + + @property + def children(self) -> Sequence[SfgCallTreeNode]: + if self._branch_false is not None: + return (self._branch_true, self._branch_false) + else: + return (self._branch_true,) + + def get_code(self, ctx: SfgContext) -> str: + code = f"if({self._cond.get_code(ctx)}) {{\n" + code += ctx.codestyle.indent(self._branch_true.get_code(ctx)) + code += "\n}" + + if self._branch_false is not None: + code += "else {\n" + code += ctx.codestyle.indent(self._branch_false.get_code(ctx)) + code += "\n}" + + return code + diff --git a/pystencils_sfg/call_tree/__init__.py b/pystencils_sfg/tree/sequencing.py similarity index 100% rename from pystencils_sfg/call_tree/__init__.py rename to pystencils_sfg/tree/sequencing.py diff --git a/pystencils_sfg/tree/visitors.py b/pystencils_sfg/tree/visitors.py new file mode 100644 index 0000000000000000000000000000000000000000..4d830d8baf6d48b0fe5f608b04e57dc730765e70 --- /dev/null +++ b/pystencils_sfg/tree/visitors.py @@ -0,0 +1,41 @@ +from typing import Set +from functools import reduce + +from pystencils.typing import TypedSymbol + +from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence + + +class ParameterCollector(): + def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: + if isinstance(node, SfgCallTreeLeaf): + return self._visit_SfgCallTreeLeaf(node) + elif isinstance(node, SfgSequence): + return self._visit_SfgSequence(node) + else: + return self._visit_branchingNode(node) + + def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]: + return leaf.required_symbols + + def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]: + """ + Only in a sequence may parameters be defined and visible to subsequent nodes. + """ + + params = set() + for c in sequence.children[::-1]: + if isinstance(c, SfgCallTreeLeaf): + # Only a leaf in a sequence may effectively define symbols + # Remove these from the required parameters + params -= c.defined_symbols + + params |= self.visit(c) + return params + + def _visit_branchingNode(self, node: SfgCallTreeNode): + """ + Each interior node that is not a sequence simply requires the union of all parameters + required by its children. + """ + return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set()) diff --git a/tests/TestSequencing.py b/tests/TestSequencing.py new file mode 100644 index 0000000000000000000000000000000000000000..1303bd31f988783f378236922202387ee4b32c7b --- /dev/null +++ b/tests/TestSequencing.py @@ -0,0 +1,25 @@ +from pystencils_sfg import SourceFileGenerator + +from lbmpy.advanced_streaming import Timestep +from lbmpy import LBMConfig, create_lb_ast + +with SourceFileGenerator() as sfg: + + lb_config = LBMConfig(streaming_pattern='esotwist') + + lb_ast_even = create_lb_ast(lbm_config=lb_config, timestep=Timestep.EVEN) + lb_ast_even.function_name = "streamCollide_even" + + lb_ast_odd = create_lb_ast(lbm_config=lb_config, timestep=Timestep.ODD) + lb_ast_odd.function_name = "streamCollide_odd" + + kernel_even = sfg.kernels.add(lb_ast_even) + kernel_odd = sfg.kernels.add(lb_ast_odd) + + sfg.function("myFunction")( + sfg.branch("(timestep & 1) ^ 1")( + sfg.call(kernel_even) + )( + sfg.call(kernel_odd) + ) + )