Skip to content
Snippets Groups Projects
Commit 3463ff54 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add global_declarations to cbackend

This enables astnodes.Nodes to have a member required_global_declarations
by which they can specify a global declaration required for their usage.
parent 1754ef27
Branches
1 merge request!5Add global_declarations to cbackend
......@@ -38,10 +38,39 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
Returns:
C-like code for the ast node and its descendants
"""
global_declarations = get_global_declarations(ast_node)
for d in global_declarations:
if hasattr(ast_node, "global_variables"):
ast_node.global_variables.update(d.symbols_defined)
else:
ast_node.global_variables = d.symbols_defined
printer = CBackend(signature_only=signature_only,
vector_instruction_set=ast_node.instruction_set,
dialect=dialect)
return printer(ast_node)
code = printer(ast_node)
if not signature_only and isinstance(ast_node, KernelFunction):
code = "\n" + code
for declaration in global_declarations:
code = printer(declaration) + "\n" + code
return code
def get_global_declarations(ast):
global_declarations = []
def visit_node(sub_ast):
if hasattr(sub_ast, "required_global_declarations"):
nonlocal global_declarations
global_declarations += sub_ast.required_global_declarations
if hasattr(sub_ast, "args"):
for node in sub_ast.args:
visit_node(node)
visit_node(ast)
return set(global_declarations)
def get_headers(ast_node: Node) -> Set[str]:
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy
import pystencils.astnodes
from pystencils.backends.cbackend import CBackend
from pystencils.data_types import TypedSymbol
class BogusDeclaration(pystencils.astnodes.Node):
"""Base class for all AST nodes."""
def __init__(self, parent=None):
self.parent = parent
@property
def args(self):
"""Returns all arguments/children of this node."""
raise set()
@property
def symbols_defined(self):
"""Set of symbols which are defined by this node."""
return {TypedSymbol('Foo', 'double')}
@property
def undefined_symbols(self):
"""Symbols which are used but are not defined inside this node."""
set()
def subs(self, subs_dict):
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args:
a.subs(subs_dict)
@property
def func(self):
return self.__class__
def atoms(self, arg_type):
"""Returns a set of all descendants recursively, which are an instance of the given type."""
result = set()
for arg in self.args:
if isinstance(arg, arg_type):
result.add(arg)
result.update(arg.atoms(arg_type))
return result
class BogusUsage(pystencils.astnodes.Node):
"""Base class for all AST nodes."""
def __init__(self, requires_global: bool, parent=None):
self.parent = parent
if requires_global:
self.required_global_declarations = [BogusDeclaration()]
@property
def args(self):
"""Returns all arguments/children of this node."""
return set()
@property
def symbols_defined(self):
"""Set of symbols which are defined by this node."""
return set()
@property
def undefined_symbols(self):
"""Symbols which are used but are not defined inside this node."""
return {TypedSymbol('Foo', 'double')}
def subs(self, subs_dict):
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args:
a.subs(subs_dict)
@property
def func(self):
return self.__class__
def atoms(self, arg_type):
"""Returns a set of all descendants recursively, which are an instance of the given type."""
result = set()
for arg in self.args:
if isinstance(arg, arg_type):
result.add(arg)
result.update(arg.atoms(arg_type))
return result
def test_global_definitions_with_global_symbol():
# Teach our printer to print new ast nodes
CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments)
print(pystencils.show_code(ast))
ast.body.append(BogusUsage(requires_global=True))
print(pystencils.show_code(ast))
kernel = ast.compile()
assert kernel is not None
assert TypedSymbol('Foo', 'double') not in [p.symbol for p in ast.get_parameters()]
def test_global_definitions_without_global_symbol():
# Teach our printer to print new ast nodes
CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments)
print(pystencils.show_code(ast))
ast.body.append(BogusUsage(requires_global=False))
print(pystencils.show_code(ast))
kernel = ast.compile()
assert kernel is not None
assert TypedSymbol('Foo', 'double') in [p.symbol for p in ast.get_parameters()]
def main():
test_global_definitions_with_global_symbol()
test_global_definitions_without_global_symbol()
if __name__ == '__main__':
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment