display_utils.py 3.84 KB
Newer Older
1
from typing import Any, Dict, Optional, Union
2
3
4

import sympy as sp

Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.astnodes import KernelFunction
6
from pystencils.kernel_wrapper import KernelWrapper
Martin Bauer's avatar
Martin Bauer committed
7

Martin Bauer's avatar
Martin Bauer committed
8

9
def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True):
Martin Bauer's avatar
Martin Bauer committed
10
    """Show a sympy or pystencils AST as dot graph"""
Martin Bauer's avatar
Martin Bauer committed
11
    from pystencils.astnodes import Node
Martin Bauer's avatar
Martin Bauer committed
12
    import graphviz
Martin Bauer's avatar
Martin Bauer committed
13
14
    graph_style = {} if graph_style is None else graph_style

Martin Bauer's avatar
Martin Bauer committed
15
    if isinstance(expr, Node):
Martin Bauer's avatar
Martin Bauer committed
16
        from pystencils.backends.dot import print_dot
17
        return graphviz.Source(print_dot(expr, short=short, graph_attr=graph_style))
Martin Bauer's avatar
Martin Bauer committed
18
19
    else:
        from sympy.printing.dot import dotprint
Martin Bauer's avatar
Martin Bauer committed
20
        return graphviz.Source(dotprint(expr, graph_attr=graph_style))
Martin Bauer's avatar
Martin Bauer committed
21
22


Martin Bauer's avatar
Martin Bauer committed
23
24
def highlight_cpp(code: str):
    """Highlight the given C/C++ source code with pygments."""
Martin Bauer's avatar
Martin Bauer committed
25
26
    from IPython.display import HTML, display
    from pygments import highlight
Martin Bauer's avatar
Martin Bauer committed
27
    # noinspection PyUnresolvedReferences
Martin Bauer's avatar
Martin Bauer committed
28
    from pygments.formatters import HtmlFormatter
Martin Bauer's avatar
Martin Bauer committed
29
    # noinspection PyUnresolvedReferences
Martin Bauer's avatar
Martin Bauer committed
30
31
    from pygments.lexers import CppLexer

32
33
34
35
36
37
38
39
40
    from pystencils.cpu.cpujit import get_highlight_style_config
    config = get_highlight_style_config()

    try:
        css = HtmlFormatter(style=config['light_theme']).get_style_defs('.highlight')
    except Exception:
        css = HtmlFormatter(style='default').get_style_defs('.highlight')
        print(f"Could not find light theme: {config['light_theme']}")

41
    try:
42
43
        dark_css = HtmlFormatter(style=config['dark_theme']).get_style_defs('.highlight')
    except Exception:
44
45
46
47
        dark_css = css

    css_tag = "<style>{css} @media (prefers-color-scheme: dark) {{ {dark_css} }}</style>".format(css=css,
                                                                                                 dark_css=dark_css)
Martin Bauer's avatar
Martin Bauer committed
48
    display(HTML(css_tag))
Martin Bauer's avatar
Martin Bauer committed
49
50
51
    return HTML(highlight(code, CppLexer(), HtmlFormatter()))


52
def get_code_obj(ast: Union[KernelFunction, KernelWrapper], custom_backend=None):
Martin Bauer's avatar
Martin Bauer committed
53
    """Returns an object to display generated code (C/C++ or CUDA)
Martin Bauer's avatar
Martin Bauer committed
54

55
    Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
Martin Bauer's avatar
Martin Bauer committed
56
    """
Martin Bauer's avatar
Martin Bauer committed
57
    from pystencils.backends.cbackend import generate_c
58
59
60
61

    if isinstance(ast, KernelWrapper):
        ast = ast.ast

62
63
64
65
66
67
    if ast.backend == 'gpucuda':
        dialect = 'cuda'
    elif ast.backend == 'opencl':
        dialect = 'opencl'
    else:
        dialect = 'c'
Martin Bauer's avatar
Martin Bauer committed
68
69

    class CodeDisplay:
Martin Bauer's avatar
Martin Bauer committed
70
71
        def __init__(self, ast_input):
            self.ast = ast_input
Martin Bauer's avatar
Martin Bauer committed
72
73

        def _repr_html_(self):
74
            return highlight_cpp(generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)).__html__()
Martin Bauer's avatar
Martin Bauer committed
75
76

        def __str__(self):
77
            return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)
Martin Bauer's avatar
Martin Bauer committed
78
79

        def __repr__(self):
80
            return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)
Martin Bauer's avatar
Martin Bauer committed
81
    return CodeDisplay(ast)
82
83
84


def get_code_str(ast, custom_backend=None):
85
    return str(get_code_obj(ast, custom_backend))
86
87


88
89
90
91
92
93
94
95
96
97
98
99
100
def _isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False


101
def show_code(ast: Union[KernelFunction, KernelWrapper], custom_backend=None):
102
    code = get_code_obj(ast, custom_backend)
103

104
    if _isnotebook():
105
106
        from IPython.display import display
        display(code)
107
108
109
110
111
112
113
114
115
    else:
        try:
            import rich.syntax
            import rich.console
            syntax = rich.syntax.Syntax(str(code), "c++", theme="monokai", line_numbers=True)
            console = rich.console.Console()
            console.print(syntax)
        except ImportError:
            print(code)