From 3463ff547dead9c0eb63ec0e48f9547e9b35662f Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 8 Jul 2019 13:48:26 +0200
Subject: [PATCH] Add global_declarations to cbackend

This enables astnodes.Nodes to have a member required_global_declarations
by which they can specify a global declaration required for their usage.
---
 pystencils/backends/cbackend.py             |  31 ++++-
 pystencils_tests/test_global_definitions.py | 146 ++++++++++++++++++++
 2 files changed, 176 insertions(+), 1 deletion(-)
 create mode 100644 pystencils_tests/test_global_definitions.py

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 92a6080c7..9f4aadadc 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -38,10 +38,39 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
     Returns:
         C-like code for the ast node and its descendants
     """
+    global_declarations = get_global_declarations(ast_node)
+    for d in global_declarations:
+        if hasattr(ast_node, "global_variables"):
+            ast_node.global_variables.update(d.symbols_defined)
+        else:
+            ast_node.global_variables = d.symbols_defined
     printer = CBackend(signature_only=signature_only,
                        vector_instruction_set=ast_node.instruction_set,
                        dialect=dialect)
-    return printer(ast_node)
+    code = printer(ast_node)
+    if not signature_only and isinstance(ast_node, KernelFunction):
+        code = "\n" + code
+        for declaration in global_declarations:
+            code = printer(declaration) + "\n" + code
+
+    return code
+
+
+def get_global_declarations(ast):
+    global_declarations = []
+
+    def visit_node(sub_ast):
+        if hasattr(sub_ast, "required_global_declarations"):
+            nonlocal global_declarations
+            global_declarations += sub_ast.required_global_declarations
+
+        if hasattr(sub_ast, "args"):
+            for node in sub_ast.args:
+                visit_node(node)
+
+    visit_node(ast)
+
+    return set(global_declarations)
 
 
 def get_headers(ast_node: Node) -> Set[str]:
diff --git a/pystencils_tests/test_global_definitions.py b/pystencils_tests/test_global_definitions.py
new file mode 100644
index 000000000..9b6609eb8
--- /dev/null
+++ b/pystencils_tests/test_global_definitions.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+
+import sympy
+
+import pystencils.astnodes
+from pystencils.backends.cbackend import CBackend
+from pystencils.data_types import TypedSymbol
+
+
+class BogusDeclaration(pystencils.astnodes.Node):
+    """Base class for all AST nodes."""
+
+    def __init__(self, parent=None):
+        self.parent = parent
+
+    @property
+    def args(self):
+        """Returns all arguments/children of this node."""
+        raise set()
+
+    @property
+    def symbols_defined(self):
+        """Set of symbols which are defined by this node."""
+        return {TypedSymbol('Foo', 'double')}
+
+    @property
+    def undefined_symbols(self):
+        """Symbols which are used but are not defined inside this node."""
+        set()
+
+    def subs(self, subs_dict):
+        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
+        for a in self.args:
+            a.subs(subs_dict)
+
+    @property
+    def func(self):
+        return self.__class__
+
+    def atoms(self, arg_type):
+        """Returns a set of all descendants recursively, which are an instance of the given type."""
+        result = set()
+        for arg in self.args:
+            if isinstance(arg, arg_type):
+                result.add(arg)
+            result.update(arg.atoms(arg_type))
+        return result
+
+
+class BogusUsage(pystencils.astnodes.Node):
+    """Base class for all AST nodes."""
+
+    def __init__(self, requires_global: bool, parent=None):
+        self.parent = parent
+        if requires_global:
+            self.required_global_declarations = [BogusDeclaration()]
+
+    @property
+    def args(self):
+        """Returns all arguments/children of this node."""
+        return set()
+
+    @property
+    def symbols_defined(self):
+        """Set of symbols which are defined by this node."""
+        return set()
+
+    @property
+    def undefined_symbols(self):
+        """Symbols which are used but are not defined inside this node."""
+        return {TypedSymbol('Foo', 'double')}
+
+    def subs(self, subs_dict):
+        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
+        for a in self.args:
+            a.subs(subs_dict)
+
+    @property
+    def func(self):
+        return self.__class__
+
+    def atoms(self, arg_type):
+        """Returns a set of all descendants recursively, which are an instance of the given type."""
+        result = set()
+        for arg in self.args:
+            if isinstance(arg, arg_type):
+                result.add(arg)
+            result.update(arg.atoms(arg_type))
+        return result
+
+
+def test_global_definitions_with_global_symbol():
+    # Teach our printer to print new ast nodes
+    CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
+    CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"
+
+    z, x, y = pystencils.fields("z, y, x: [2d]")
+
+    normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
+        z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
+
+    ast = pystencils.create_kernel(normal_assignments)
+    print(pystencils.show_code(ast))
+    ast.body.append(BogusUsage(requires_global=True))
+    print(pystencils.show_code(ast))
+    kernel = ast.compile()
+    assert kernel is not None
+
+    assert TypedSymbol('Foo', 'double') not in [p.symbol for p in ast.get_parameters()]
+
+
+def test_global_definitions_without_global_symbol():
+    # Teach our printer to print new ast nodes
+    CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
+    CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"
+
+    z, x, y = pystencils.fields("z, y, x: [2d]")
+
+    normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
+        z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
+
+    ast = pystencils.create_kernel(normal_assignments)
+    print(pystencils.show_code(ast))
+    ast.body.append(BogusUsage(requires_global=False))
+    print(pystencils.show_code(ast))
+    kernel = ast.compile()
+    assert kernel is not None
+
+    assert TypedSymbol('Foo', 'double') in [p.symbol for p in ast.get_parameters()]
+
+
+def main():
+    test_global_definitions_with_global_symbol()
+    test_global_definitions_without_global_symbol()
+
+
+if __name__ == '__main__':
+    main()
-- 
GitLab