diff --git a/pystencils/display_utils.py b/pystencils/display_utils.py index 8cdaa4820444cedd1c4cbf8f2db7d7391e3e6344..638d1290acbbfc4d86bec12028dc59b37e2f98ea 100644 --- a/pystencils/display_utils.py +++ b/pystencils/display_utils.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional import sympy as sp from pystencils.astnodes import KernelFunction +from pystencils.kernel_wrapper import KernelWrapper def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True): @@ -40,6 +41,10 @@ def show_code(ast: KernelFunction, custom_backend=None): Can either be displayed as HTML in Jupyter notebooks or printed as normal string. """ from pystencils.backends.cbackend import generate_c + + if isinstance(ast, KernelWrapper): + ast = ast.ast + dialect = 'cuda' if ast.backend == 'gpucuda' else 'c' class CodeDisplay: