Skip to content
Snippets Groups Projects
Commit c7b3958a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

much progress.

parent c60cce87
No related merge requests found
Showing
with 605 additions and 20 deletions
#include "TestSequencing.h"
#define FUNC_PREFIX inline
namespace pystencils {
/*************************************************************************************
* Kernels
*************************************************************************************/
namespace kernels{
FUNC_PREFIX void streamCollide_even( double * RESTRICT const _data_src, int64_t const _size_src_0, int64_t const _size_src_1, int64_t const _stride_src_0, int64_t const _stride_src_1, int64_t const _stride_src_2, double omega)
{
for (int64_t ctr_1 = 1; ctr_1 < _size_src_1 - 1; ctr_1 += 1)
{
for (int64_t ctr_0 = 1; ctr_0 < _size_src_0 - 1; ctr_0 += 1)
{
const double xi_1 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + _stride_src_1 + 7*_stride_src_2];
const double xi_2 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 5*_stride_src_2];
const double xi_3 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1];
const double xi_4 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 4*_stride_src_2];
const double xi_5 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 6*_stride_src_2];
const double xi_6 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 2*_stride_src_2];
const double xi_7 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 3*_stride_src_2];
const double xi_8 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 8*_stride_src_2];
const double xi_9 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_2];
const double vel0Term = xi_4 + xi_5 + xi_8;
const double vel1Term = xi_2 + xi_9;
const double delta_rho = vel0Term + vel1Term + xi_1 + xi_3 + xi_6 + xi_7;
const double u_0 = vel0Term - xi_1 - xi_2 - xi_7;
const double u_1 = vel1Term - xi_1 + xi_5 - xi_6 - xi_8;
const double u0Mu1 = u_0 - u_1;
const double u0Pu1 = u_0 + u_1;
const double f_eq_common = delta_rho - 1.5*(u_0*u_0) - 1.5*(u_1*u_1);
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1] = omega*(f_eq_common*0.44444444444444442 - xi_3) + xi_3;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 2*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_1*0.33333333333333331 - xi_9 + 0.5*(u_1*u_1)) + xi_9;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_1*-0.33333333333333331 - xi_6 + 0.5*(u_1*u_1)) + xi_6;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 4*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_0*-0.33333333333333331 - xi_7 + 0.5*(u_0*u_0)) + xi_7;
_data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 3*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_0*0.33333333333333331 - xi_4 + 0.5*(u_0*u_0)) + xi_4;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 8*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Mu1*-0.083333333333333329 - xi_2 + 0.125*(u0Mu1*u0Mu1)) + xi_2;
_data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + _stride_src_1 + 7*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Pu1*0.083333333333333329 - xi_5 + 0.125*(u0Pu1*u0Pu1)) + xi_5;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 6*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Pu1*-0.083333333333333329 - xi_1 + 0.125*(u0Pu1*u0Pu1)) + xi_1;
_data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 5*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Mu1*0.083333333333333329 - xi_8 + 0.125*(u0Mu1*u0Mu1)) + xi_8;
}
}
}
FUNC_PREFIX void streamCollide_odd( double * RESTRICT const _data_src, int64_t const _size_src_0, int64_t const _size_src_1, int64_t const _stride_src_0, int64_t const _stride_src_1, int64_t const _stride_src_2, double omega)
{
for (int64_t ctr_1 = 1; ctr_1 < _size_src_1 - 1; ctr_1 += 1)
{
for (int64_t ctr_0 = 1; ctr_0 < _size_src_0 - 1; ctr_0 += 1)
{
const double xi_1 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 5*_stride_src_2];
const double xi_2 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1];
const double xi_3 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + _stride_src_1 + 6*_stride_src_2];
const double xi_4 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 2*_stride_src_2];
const double xi_5 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 3*_stride_src_2];
const double xi_6 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 8*_stride_src_2];
const double xi_7 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 7*_stride_src_2];
const double xi_8 = _data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + _stride_src_2];
const double xi_9 = _data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 4*_stride_src_2];
const double vel0Term = xi_1 + xi_5 + xi_7;
const double vel1Term = xi_4 + xi_6;
const double delta_rho = vel0Term + vel1Term + xi_2 + xi_3 + xi_8 + xi_9;
const double u_0 = vel0Term - xi_3 - xi_6 - xi_9;
const double u_1 = vel1Term - xi_1 - xi_3 + xi_7 - xi_8;
const double u0Mu1 = u_0 - u_1;
const double u0Pu1 = u_0 + u_1;
const double f_eq_common = delta_rho - 1.5*(u_0*u_0) - 1.5*(u_1*u_1);
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1] = omega*(f_eq_common*0.44444444444444442 - xi_2) + xi_2;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + _stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_1*0.33333333333333331 - xi_4 + 0.5*(u_1*u_1)) + xi_4;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 2*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_1*-0.33333333333333331 - xi_8 + 0.5*(u_1*u_1)) + xi_8;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 3*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_0*-0.33333333333333331 - xi_9 + 0.5*(u_0*u_0)) + xi_9;
_data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 4*_stride_src_2] = omega*(f_eq_common*0.1111111111111111 + u_0*0.33333333333333331 - xi_5 + 0.5*(u_0*u_0)) + xi_5;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + _stride_src_1 + 5*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Mu1*-0.083333333333333329 - xi_6 + 0.125*(u0Mu1*u0Mu1)) + xi_6;
_data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + _stride_src_1 + 6*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Pu1*0.083333333333333329 - xi_7 + 0.125*(u0Pu1*u0Pu1)) + xi_7;
_data_src[_stride_src_0*ctr_0 + _stride_src_1*ctr_1 + 7*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Pu1*-0.083333333333333329 - xi_3 + 0.125*(u0Pu1*u0Pu1)) + xi_3;
_data_src[_stride_src_0*ctr_0 + _stride_src_0 + _stride_src_1*ctr_1 + 8*_stride_src_2] = omega*(f_eq_common*0.027777777777777776 + u0Mu1*0.083333333333333329 - xi_1 + 0.125*(u0Mu1*u0Mu1)) + xi_1;
}
}
}
} // namespace kernels
/*************************************************************************************
* Functions
*************************************************************************************/
void myFunction ( double * RESTRICT const _data_src, int64_t const _size_src_0, int64_t const _size_src_1, int64_t const _stride_src_0, int64_t const _stride_src_1, int64_t const _stride_src_2, double omega ) {
if((timestep & 1) ^ 1) {
pystencils::kernels::streamCollide_even(_data_src, _size_src_0, _size_src_1, _stride_src_0, _stride_src_1, _stride_src_2, omega);
}else {
pystencils::kernels::streamCollide_odd(_data_src, _size_src_0, _size_src_1, _stride_src_0, _stride_src_1, _stride_src_2, omega);
}
}
} // namespace pystencils
\ No newline at end of file
#pragma once
#define RESTRICT __restrict__
#include <cstdint>
namespace pystencils {
} // namespace pystencils
\ No newline at end of file
from typing import Callable, Sequence, Generator, Union, Optional
from dataclasses import dataclass
import os
from os import path
from .kernel_namespace import SfgKernelNamespace
from jinja2.filters import do_indent
from pystencils.astnodes import KernelFunction
from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgCondition, SfgBranch
from .tree.builders import SfgBranchBuilder, SfgSequencer
from .source_components import SfgFunction
@dataclass
class SfgCodeStyle:
indent_width: int = 2
def indent(self, s: str):
return do_indent(s, self.indent_width, first=True)
class SourceFileGenerator:
def __init__(self, namespace: str = "pystencils", basename: str = None):
def __init__(self,
namespace: str = "pystencils",
basename: str = None,
codestyle: SfgCodeStyle = SfgCodeStyle()):
if basename is None:
import __main__
......@@ -15,32 +38,50 @@ class SourceFileGenerator:
self.header_filename = basename + ".h"
self.cpp_filename = basename + ".cpp"
self._context = SfgContext(namespace)
self._context = SfgContext(namespace, codestyle)
def clean_files(self):
for file in (self.header_filename, self.cpp_filename):
if path.exists(file):
os.remove(file)
def __enter__(self):
self.clean_files()
return self._context
def __exit__(self, *args):
from .emitters.cpu.basic_cpu import BasicCpuEmitter
BasicCpuEmitter(self._context, self.basename).write_files()
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
from .emitters.cpu.basic_cpu import BasicCpuEmitter
BasicCpuEmitter(self._context, self.basename).write_files()
class SfgContext:
def __init__(self, root_namespace: str):
def __init__(self, root_namespace: str, codestyle: SfgCodeStyle):
self._root_namespace = root_namespace
self._codestyle = codestyle
self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")
# Source Components
self._includes = []
self._kernel_namespaces = { self._default_kernel_namespace.name : self._default_kernel_namespace }
self._functions = dict()
# Builder Components
self._sequencer = SfgSequencer(self)
@property
def root_namespace(self):
def root_namespace(self) -> str:
return self._root_namespace
@property
def codestyle(self) -> SfgCodeStyle:
return self._codestyle
@property
def kernels(self):
def kernels(self) -> SfgKernelNamespace:
return self._default_kernel_namespace
def kernel_namespace(self, name):
def kernel_namespace(self, name: str) -> SfgKernelNamespace:
if name in self._kernel_namespaces:
raise ValueError(f"Duplicate kernel namespace: {name}")
......@@ -48,6 +89,50 @@ class SfgContext:
self._kernel_namespaces[name] = kns
return kns
@property
def kernel_namespaces(self):
def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]:
yield from self._kernel_namespaces.values()
def functions(self) -> Generator[SfgFunction, None, None]:
yield from self._functions.values()
def include(self, header_file: str):
self._includes.append(header_file)
def function(self,
name: str,
ast_or_kernel_handle : Optional[Union[KernelFunction, SfgKernelHandle]] = None):
if name in self._functions:
raise ValueError(f"Duplicate function: {name}")
if ast_or_kernel_handle is not None:
if isinstance(ast_or_kernel_handle, KernelFunction):
khandle = self._default_kernel_namespace.add(ast_or_kernel_handle)
tree = SfgKernelCallNode(self, khandle)
elif isinstance(ast_or_kernel_handle, SfgKernelCallNode):
tree = ast_or_kernel_handle
else:
raise TypeError(f"Invalid type of argument `ast_or_kernel_handle`!")
else:
def sequencer(*args: SfgCallTreeNode):
tree = self.seq(*args)
func = SfgFunction(self, name, tree)
self._functions[name] = func
return sequencer
#----------------------------------------------------------------------------------------------
# Call Tree Node Factory
#----------------------------------------------------------------------------------------------
@property
def seq(self) -> SfgSequencer:
return self._sequencer
def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode:
return SfgKernelCallNode(kernel_handle)
@property
def branch(self) -> SfgBranchBuilder:
return SfgBranchBuilder(self)
\ No newline at end of file
......@@ -14,7 +14,8 @@ class BasicCpuEmitter:
'ctx': self._ctx,
'basename': self._basename,
'root_namespace': self._ctx.root_namespace,
'kernel_namespaces': list(self._ctx.kernel_namespaces)
'kernel_namespaces': list(self._ctx.kernel_namespaces()),
'functions': list(self._ctx.functions())
}
template_name = "BasicCpu"
......
from jinja2 import pass_context
from pystencils.astnodes import KernelFunction
from pystencils import Backend
from pystencils.backends import generate_c
from pystencils_sfg.tree import SfgCallTreeNode
from pystencils_sfg.source_components import SfgFunction
@pass_context
def generate_kernel_definition(ctx, ast: KernelFunction):
return generate_c(ast, dialect=Backend.C)
@pass_context
def generate_function_parameter_list(ctx, func: SfgFunction):
params = sorted(list(func.parameters), key=lambda p: p.name)
return ", ".join(f"{param.dtype} {param.name}" for param in params)
def generate_function_body(func: SfgFunction):
return func.get_code()
def add_filters_to_jinja(jinja_env):
jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition
\ No newline at end of file
jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition
jinja_env.filters['generate_function_parameter_list'] = generate_function_parameter_list
jinja_env.filters['generate_function_body'] = generate_function_body
......@@ -4,14 +4,30 @@
namespace {{root_namespace}} {
/*************************************************************************************
* Kernels
*************************************************************************************/
{% for kns in kernel_namespaces -%}
namespace {{ kns.name }}{
{% for ast in kns.asts -%}
{% for ast in kns.asts %}
{{ ast | generate_kernel_definition }}
{%- endfor %}
{% endfor %}
} // namespace {{ kns.name }}
{% endfor %}
/*************************************************************************************
* Functions
*************************************************************************************/
{% for function in functions %}
void {{ function.name }} ( {{ function | generate_function_parameter_list }} ) {
{{ function | generate_function_body | indent(2) }}
}
{% endfor %}
} // namespace {{root_namespace}}
......@@ -20,11 +20,11 @@ class SfgKernelNamespace:
"""Adds an existing pystencils AST to this namespace."""
astname = ast.function_name
if astname in self._asts:
raise ValueError(f"Duplicate ASTs: An AST with name {astname} was already registered in namespace {self._name}")
raise ValueError(f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}")
self._asts[astname] = ast
return SfgKernelHandle(self._ctx, astname, self)
return SfgKernelHandle(self._ctx, astname, self, ast.get_parameters())
def create(self, assignments, config: CreateKernelConfig):
ast = create_kernel(assignments, config)
......@@ -32,10 +32,11 @@ class SfgKernelNamespace:
class SfgKernelHandle:
def __init__(self, ctx, name: str, namespace: SfgKernelNamespace):
def __init__(self, ctx, name: str, namespace: SfgKernelNamespace, parameters):
self._ctx = ctx
self._name = name
self._namespace = namespace
self._parameters = parameters
@property
def kernel_name(self):
......@@ -47,5 +48,9 @@ class SfgKernelHandle:
@property
def fully_qualified_name(self):
return f"{self.ctx.root_namespace}::{self.kernel_namespace}::{self.kernel_name}"
return f"{self._ctx.root_namespace}::{self.kernel_namespace.name}::{self.kernel_name}"
@property
def parameters(self):
return self._parameters
\ No newline at end of file
from __future__ import annotations
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from .context import SfgContext
from .tree import SfgCallTreeNode, SfgSequence
from .tree.visitors import ParameterCollector
class SfgFunction:
def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode):
self._ctx = ctx
self._name = name
self._tree = tree
param_collector = ParameterCollector()
self._parameters = param_collector.visit(tree)
@property
def name(self):
return self._name
@property
def parameters(self):
return self._parameters
def get_code(self):
return self._tree.get_code(self._ctx)
from .basic_nodes import SfgCallTreeNode, SfgKernelCallNode, SfgBlock, SfgSequence
from .conditional import SfgBranch, SfgCondition
__all__ = [
SfgCallTreeNode, SfgKernelCallNode, SfgSequence, SfgBlock,
SfgCondition, SfgBranch
]
\ No newline at end of file
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence
if TYPE_CHECKING:
from ..context import SfgContext
from abc import ABC, abstractmethod
from functools import reduce
from jinja2.filters import do_indent
from ..kernel_namespace import SfgKernelHandle
from pystencils.typing import TypedSymbol
class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees. """
@property
@abstractmethod
def children(self) -> Sequence[SfgCallTreeNode]:
pass
@abstractmethod
def get_code(self, ctx: SfgContext) -> str:
"""Returns the code of this node.
By convention, the code block emitted by this function should not contain a trailing newline.
"""
pass
class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return ()
@property
@abstractmethod
def required_symbols(self) -> set(TypedSymbol):
pass
@property
@abstractmethod
def defined_symbols(self) -> set(TypedSymbol):
pass
class SfgCustomStatement(SfgCallTreeLeaf):
def __init__(self, statement: str):
self._statement = statement
def required_symbols(self) -> set(TypedSymbol):
return set()
def defined_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str:
return self._statement
class SfgSequence(SfgCallTreeNode):
def __init__(self, children: Sequence[SfgCallTreeNode]):
self._children = tuple(children)
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return self._children
def get_code(self, ctx: SfgContext) -> str:
return "\n".join(c.get_code(ctx) for c in self._children)
class SfgBlock(SfgCallTreeNode):
def __init__(self, subtree: SfgCallTreeNode):
super().__init__(ctx)
self._subtree = subtree
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return { self._subtree }
def get_code(self, ctx: SfgContext) -> str:
subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx))
return "{\n" + subtree_code + "\n}"
class SfgKernelCallNode(SfgCallTreeLeaf):
def __init__(self, kernel_handle: SfgKernelHandle):
self._kernel_handle = kernel_handle
@property
def required_symbols(self) -> set(TypedSymbol):
return set(p.symbol for p in self._kernel_handle.parameters)
@property
def defined_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str:
ast_params = self._kernel_handle.parameters
fnc_name = self._kernel_handle.fully_qualified_name
call_parameters = ", ".join([p.symbol.name for p in ast_params])
return f"{fnc_name}({call_parameters});"
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence
if TYPE_CHECKING:
from ..context import SfgContext
from abc import ABC, abstractmethod
from .basic_nodes import SfgCallTreeNode, SfgSequence, SfgBlock, SfgCustomStatement
from .conditional import SfgCondition, SfgCustomCondition, SfgBranch
class SfgNodeBuilder(ABC):
def __init__(self, ctx: SfgContext) -> None:
self._ctx = ctx
@abstractmethod
def resolve(self) -> SfgCallTreeNode:
pass
class SfgSequencer:
def __init__(self, ctx: SfgContext) -> None:
self._ctx = ctx
def __call__(self, *args) -> SfgSequence:
children = []
for i, arg in enumerate(args):
if isinstance(arg, SfgNodeBuilder):
children.append(arg.resolve())
elif isinstance(arg, SfgCallTreeNode):
children.append(arg)
elif isinstance(arg, str):
children.append(SfgCustomStatement(arg))
elif isinstance(arg, tuple):
# Tuples are treated as blocks
subseq = self(*arg)
children.append(SfgBlock(subseq))
else:
raise TypeError(f"Sequence argument {i} has invalid type.")
return SfgSequence(children)
class SfgBranchBuilder(SfgNodeBuilder):
def __init__(self, ctx: SfgContext) -> None:
super().__init__(ctx)
self._phase = 0
self._cond = None
self._branch_true = SfgSequence(())
self._branch_false = None
def __call__(self, *args) -> SfgBranchBuilder:
match self._phase:
case 0: # Condition
if len(args) != 1:
raise ValueError("Must specify exactly one argument as branch condition!")
cond = args[0]
if isinstance(cond, str):
cond = SfgCustomCondition(cond)
elif not isinstance(cond, SfgCondition):
raise ValueError("Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.")
self._cond = cond
case 1: # Then-branch
self._branch_true = self._ctx.seq(*args)
case 2: # Else-branch
self._branch_false = self._ctx.seq(*args)
case _: # There's not third branch!
raise TypeError("Branch construct already complete.")
self._phase += 1
return self
def resolve(self) -> SfgCallTreeNode:
return SfgBranch(self._cond, self._branch_true, self._branch_false)
\ No newline at end of file
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence, Optional
if TYPE_CHECKING:
from ..context import SfgContext
from jinja2.filters import do_indent
from pystencils.typing import TypedSymbol
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf
class SfgCondition(SfgCallTreeLeaf):
pass
class SfgCustomCondition(SfgCondition):
def __init__(self, cond_text: str):
self._cond_text = cond_text
def required_symbols(self) -> set(TypedSymbol):
return set()
def defined_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str:
return self._cond_text
# class IntEven(SfgCondition):
# def __init__(self, )
class SfgBranch(SfgCallTreeNode):
def __init__(self, cond: SfgCondition, branch_true: SfgCallTreeNode, branch_false: Optional[SfgCallTreeNode] = None):
self._cond = cond
self._branch_true = branch_true
self._branch_false = branch_false
@property
def children(self) -> Sequence[SfgCallTreeNode]:
if self._branch_false is not None:
return (self._branch_true, self._branch_false)
else:
return (self._branch_true,)
def get_code(self, ctx: SfgContext) -> str:
code = f"if({self._cond.get_code(ctx)}) {{\n"
code += ctx.codestyle.indent(self._branch_true.get_code(ctx))
code += "\n}"
if self._branch_false is not None:
code += "else {\n"
code += ctx.codestyle.indent(self._branch_false.get_code(ctx))
code += "\n}"
return code
from typing import Set
from functools import reduce
from pystencils.typing import TypedSymbol
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence
class ParameterCollector():
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node)
elif isinstance(node, SfgSequence):
return self._visit_SfgSequence(node)
else:
return self._visit_branchingNode(node)
def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
return leaf.required_symbols
def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
"""
Only in a sequence may parameters be defined and visible to subsequent nodes.
"""
params = set()
for c in sequence.children[::-1]:
if isinstance(c, SfgCallTreeLeaf):
# Only a leaf in a sequence may effectively define symbols
# Remove these from the required parameters
params -= c.defined_symbols
params |= self.visit(c)
return params
def _visit_branchingNode(self, node: SfgCallTreeNode):
"""
Each interior node that is not a sequence simply requires the union of all parameters
required by its children.
"""
return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())
from pystencils_sfg import SourceFileGenerator
from lbmpy.advanced_streaming import Timestep
from lbmpy import LBMConfig, create_lb_ast
with SourceFileGenerator() as sfg:
lb_config = LBMConfig(streaming_pattern='esotwist')
lb_ast_even = create_lb_ast(lbm_config=lb_config, timestep=Timestep.EVEN)
lb_ast_even.function_name = "streamCollide_even"
lb_ast_odd = create_lb_ast(lbm_config=lb_config, timestep=Timestep.ODD)
lb_ast_odd.function_name = "streamCollide_odd"
kernel_even = sfg.kernels.add(lb_ast_even)
kernel_odd = sfg.kernels.add(lb_ast_odd)
sfg.function("myFunction")(
sfg.branch("(timestep & 1) ^ 1")(
sfg.call(kernel_even)
)(
sfg.call(kernel_odd)
)
)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment