Commit 513acdfe authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow return types for WrapperFunction

parent 32e36a69
Pipeline #27136 failed with stage
in 1 minute and 37 seconds
......@@ -215,9 +215,11 @@ class FunctionCall(Node):
class WrapperFunction(pystencils.astnodes.KernelFunction):
def __init__(self, body, function_name='wrapper', target='cpu', backend='c'):
def __init__(self, body, function_name='wrapper', target='cpu', backend='c', return_type=None, return_value=None):
super().__init__(body, target, backend, compile_function=None, ghost_layers=0)
self.function_name = function_name
self.return_type = return_type
self.return_value = return_value
def generate_kernel_call(kernel_function):
......
......@@ -3,6 +3,7 @@ import functools
import sympy as sp
import pystencils.backends.cbackend
from pystencils.astnodes import KernelFunction
from pystencils.data_types import TypedSymbol
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils_autodiff.framework_integration.types import TemplateType
......@@ -38,7 +39,9 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
return "\n%s\n" % (''.join(block_contents.splitlines(True)))
def _print_WrapperFunction(self, node):
super_result = super()._print_KernelFunction(node)
if node.return_value:
node._body._nodes.append(node.return_value)
super_result = self._print_KernelFunction_extended(node)
if self._signatureOnly:
super_result += ';'
return super_result.replace('FUNC_PREFIX ', '')
......@@ -46,7 +49,23 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
def _print_TextureDeclaration(self, node):
return str(node)
def _print_KernelFunction(self, node):
def _print_KernelFunction_extended(self, node: KernelFunction):
return_type = node.return_type if hasattr(node, 'return_type') and node.return_type else 'void'
function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}"
for s in node.get_parameters() if hasattr(s.symbol, 'dtype')]
launch_bounds = ""
if self._dialect == 'cuda':
max_threads = node.indexing.max_threads_per_block()
if max_threads:
launch_bounds = f"__launch_bounds__({max_threads}) "
func_declaration = f"FUNC_PREFIX {launch_bounds} {self._print(return_type)} {node.function_name}({', '.join(function_arguments)})" # noqa
if self._signatureOnly:
return func_declaration
body = self._print(node.body)
return func_declaration + "\n" + body
def _print_KernelFunction(self, node, return_type='void'):
if node.backend == 'gpucuda':
prefix = '#define FUNC_PREFIX static __global__\n'
kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='cuda', with_globals=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment