diff --git a/pystencils/display_utils.py b/pystencils/display_utils.py index 55a4720c194e250cd549279ae76607c48968c676..8cdaa4820444cedd1c4cbf8f2db7d7391e3e6344 100644 --- a/pystencils/display_utils.py +++ b/pystencils/display_utils.py @@ -1,5 +1,7 @@ -import sympy as sp from typing import Any, Dict, Optional + +import sympy as sp + from pystencils.astnodes import KernelFunction @@ -32,7 +34,7 @@ def highlight_cpp(code: str): return HTML(highlight(code, CppLexer(), HtmlFormatter())) -def show_code(ast: KernelFunction): +def show_code(ast: KernelFunction, custom_backend=None): """Returns an object to display generated code (C/C++ or CUDA) Can either be displayed as HTML in Jupyter notebooks or printed as normal string. @@ -45,11 +47,11 @@ def show_code(ast: KernelFunction): self.ast = ast_input def _repr_html_(self): - return highlight_cpp(generate_c(self.ast, dialect=dialect)).__html__() + return highlight_cpp(generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)).__html__() def __str__(self): - return generate_c(self.ast, dialect=dialect) + return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend) def __repr__(self): - return generate_c(self.ast, dialect=dialect) + return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend) return CodeDisplay(ast)