diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 8eb44766199e2aaad990bab167fc9267bc4017c9..1c28e324981d7bc4f5ea3ee310ca4ede0ee4e5f8 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -40,10 +40,39 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str Returns: C-like code for the ast node and its descendants """ + global_declarations = get_global_declarations(ast_node) + for d in global_declarations: + if hasattr(ast_node, "global_variables"): + ast_node.global_variables.update(d.symbols_defined) + else: + ast_node.global_variables = d.symbols_defined printer = CBackend(signature_only=signature_only, vector_instruction_set=ast_node.instruction_set, dialect=dialect) - return printer(ast_node) + code = printer(ast_node) + if not signature_only and isinstance(ast_node, KernelFunction): + code = "\n" + code + for declaration in global_declarations: + code = printer(declaration) + "\n" + code + + return code + + +def get_global_declarations(ast): + global_declarations = [] + + def visit_node(sub_ast): + if hasattr(sub_ast, "required_global_declarations"): + nonlocal global_declarations + global_declarations += sub_ast.required_global_declarations + + if hasattr(sub_ast, "args"): + for node in sub_ast.args: + visit_node(node) + + visit_node(ast) + + return set(global_declarations) def get_headers(ast_node: Node) -> Set[str]: diff --git a/pystencils_tests/test_global_definitions.py b/pystencils_tests/test_global_definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6609eb89649bee63d6db813d64d3660ed859f0 --- /dev/null +++ b/pystencils_tests/test_global_definitions.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" + +import sympy + +import pystencils.astnodes +from pystencils.backends.cbackend import CBackend +from pystencils.data_types import TypedSymbol + + +class BogusDeclaration(pystencils.astnodes.Node): + """Base class for all AST nodes.""" + + def __init__(self, parent=None): + self.parent = parent + + @property + def args(self): + """Returns all arguments/children of this node.""" + raise set() + + @property + def symbols_defined(self): + """Set of symbols which are defined by this node.""" + return {TypedSymbol('Foo', 'double')} + + @property + def undefined_symbols(self): + """Symbols which are used but are not defined inside this node.""" + set() + + def subs(self, subs_dict): + """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" + for a in self.args: + a.subs(subs_dict) + + @property + def func(self): + return self.__class__ + + def atoms(self, arg_type): + """Returns a set of all descendants recursively, which are an instance of the given type.""" + result = set() + for arg in self.args: + if isinstance(arg, arg_type): + result.add(arg) + result.update(arg.atoms(arg_type)) + return result + + +class BogusUsage(pystencils.astnodes.Node): + """Base class for all AST nodes.""" + + def __init__(self, requires_global: bool, parent=None): + self.parent = parent + if requires_global: + self.required_global_declarations = [BogusDeclaration()] + + @property + def args(self): + """Returns all arguments/children of this node.""" + return set() + + @property + def symbols_defined(self): + """Set of symbols which are defined by this node.""" + return set() + + @property + def undefined_symbols(self): + """Symbols which are used but are not defined inside this node.""" + return {TypedSymbol('Foo', 'double')} + + def subs(self, subs_dict): + """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" + for a in self.args: + a.subs(subs_dict) + + @property + def func(self): + return self.__class__ + + def atoms(self, arg_type): + """Returns a set of all descendants recursively, which are an instance of the given type.""" + result = set() + for arg in self.args: + if isinstance(arg, arg_type): + result.add(arg) + result.update(arg.atoms(arg_type)) + return result + + +def test_global_definitions_with_global_symbol(): + # Teach our printer to print new ast nodes + CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here" + CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here" + + z, x, y = pystencils.fields("z, y, x: [2d]") + + normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( + z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + + ast = pystencils.create_kernel(normal_assignments) + print(pystencils.show_code(ast)) + ast.body.append(BogusUsage(requires_global=True)) + print(pystencils.show_code(ast)) + kernel = ast.compile() + assert kernel is not None + + assert TypedSymbol('Foo', 'double') not in [p.symbol for p in ast.get_parameters()] + + +def test_global_definitions_without_global_symbol(): + # Teach our printer to print new ast nodes + CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here" + CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here" + + z, x, y = pystencils.fields("z, y, x: [2d]") + + normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( + z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + + ast = pystencils.create_kernel(normal_assignments) + print(pystencils.show_code(ast)) + ast.body.append(BogusUsage(requires_global=False)) + print(pystencils.show_code(ast)) + kernel = ast.compile() + assert kernel is not None + + assert TypedSymbol('Foo', 'double') in [p.symbol for p in ast.get_parameters()] + + +def main(): + test_global_definitions_with_global_symbol() + test_global_definitions_without_global_symbol() + + +if __name__ == '__main__': + main()