# -*- coding: utf-8 -*- # # Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> # # Distributed under terms of the GPLv3 license. """ """ import sympy import pystencils.astnodes from pystencils.backends.cbackend import CBackend from pystencils.data_types import TypedSymbol class BogusDeclaration(pystencils.astnodes.Node): """Base class for all AST nodes.""" def __init__(self, parent=None): self.parent = parent @property def args(self): """Returns all arguments/children of this node.""" raise set() @property def symbols_defined(self): """Set of symbols which are defined by this node.""" return {TypedSymbol('Foo', 'double')} @property def undefined_symbols(self): """Symbols which are used but are not defined inside this node.""" set() def subs(self, subs_dict): """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" for a in self.args: a.subs(subs_dict) @property def func(self): return self.__class__ def atoms(self, arg_type): """Returns a set of all descendants recursively, which are an instance of the given type.""" result = set() for arg in self.args: if isinstance(arg, arg_type): result.add(arg) result.update(arg.atoms(arg_type)) return result class BogusUsage(pystencils.astnodes.Node): """Base class for all AST nodes.""" def __init__(self, requires_global: bool, parent=None): self.parent = parent if requires_global: self.required_global_declarations = [BogusDeclaration()] @property def args(self): """Returns all arguments/children of this node.""" return set() @property def symbols_defined(self): """Set of symbols which are defined by this node.""" return set() @property def undefined_symbols(self): """Symbols which are used but are not defined inside this node.""" return {TypedSymbol('Foo', 'double')} def subs(self, subs_dict): """Inplace! substitute, similar to sympy's but modifies the AST inplace.""" for a in self.args: a.subs(subs_dict) @property def func(self): return self.__class__ def atoms(self, arg_type): """Returns a set of all descendants recursively, which are an instance of the given type.""" result = set() for arg in self.args: if isinstance(arg, arg_type): result.add(arg) result.update(arg.atoms(arg_type)) return result def test_global_definitions_with_global_symbol(): # Teach our printer to print new ast nodes CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here" CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here" z, x, y = pystencils.fields("z, y, x: [2d]") normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) ast = pystencils.create_kernel(normal_assignments) print(pystencils.show_code(ast)) ast.body.append(BogusUsage(requires_global=True)) print(pystencils.show_code(ast)) kernel = ast.compile() assert kernel is not None assert TypedSymbol('Foo', 'double') not in [p.symbol for p in ast.get_parameters()] def test_global_definitions_without_global_symbol(): # Teach our printer to print new ast nodes CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here" CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here" z, x, y = pystencils.fields("z, y, x: [2d]") normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) ast = pystencils.create_kernel(normal_assignments) print(pystencils.show_code(ast)) ast.body.append(BogusUsage(requires_global=False)) print(pystencils.show_code(ast)) kernel = ast.compile() assert kernel is not None assert TypedSymbol('Foo', 'double') in [p.symbol for p in ast.get_parameters()] def main(): test_global_definitions_with_global_symbol() test_global_definitions_without_global_symbol() if __name__ == '__main__': main()