diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 661cabd12456fb11b6bc00f49d7af604930b35c6..38bb9883dbbe5c7d3ca1450840f1f331e4efc4ef 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -1,7 +1,6 @@ import uuid from typing import Any, List, Optional, Sequence, Set, Union -import jinja2 import sympy as sp from pystencils.data_types import TypedSymbol, cast_func, create_type @@ -673,7 +672,7 @@ class DestructuringBindingsForFieldClass(Node): FieldShapeSymbol: "shape[%i]", FieldStrideSymbol: "stride[%i]" } - CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>") + CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>" @property def fields_accessed(self) -> Set['ResolvedFieldAccess']: @@ -703,7 +702,7 @@ class DestructuringBindingsForFieldClass(Node): undefined_field_symbols = self.symbols_defined corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} - return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') + return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') for f in corresponding_field_names} | \ (self.body.undefined_symbols - undefined_field_symbols) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 965b5d84bdf19eddec05bdd2d3b9ac83182eaa58..a73904172916184feafc3643ee7e2d7b1826f813 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -1,7 +1,6 @@ from collections import namedtuple from typing import Set -import jinja2 import sympy as sp from sympy.core import S from sympy.printing.ccode import C89CodePrinter @@ -9,11 +8,12 @@ from sympy.printing.ccode import C89CodePrinter from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import ( - PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func, - vector_memory_access) + PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, + reinterpret_cast_func, vector_memory_access) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.integer_functions import ( - bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) + bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, + int_div, int_power_of_2, modulo_ceil) from pystencils.kernelparameters import FieldPointerSymbol try: @@ -271,18 +271,11 @@ class CBackend: for u in undefined_field_symbols ] destructuring_bindings.sort() # only for code aesthetics - template = jinja2.Template( - """{ - {% for binding in bindings -%} - {{ binding | indent(3) }} - {% endfor -%} - {{ block | indent(3) }} -} - -""") - code = template.render(bindings=destructuring_bindings, - block=self._print(node.body)) - return code + return "{\n" + self._indent + \ + ("\n" + self._indent).join(destructuring_bindings) + \ + "\n" + self._indent + \ + ("\n" + self._indent).join(self._print(node.body).splitlines()) + \ + "\n}" # ------------------------------------------ Helper function & classes ------------------------------------------------- diff --git a/pystencils_tests/test_destructuring_field_class.py b/pystencils_tests/test_destructuring_field_class.py index 6e1eefff9e47cdee1e6c9508796886050bb0505a..fb15068cdac0d41ce7c7a8fc3d32510f7c5da8c6 100644 --- a/pystencils_tests/test_destructuring_field_class.py +++ b/pystencils_tests/test_destructuring_field_class.py @@ -1,11 +1,8 @@ import sympy -import jinja2 - import pystencils from pystencils.astnodes import DestructuringBindingsForFieldClass -from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol - +from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol def test_destructuring_field_class(): @@ -28,12 +25,13 @@ class DestructuringEmojiClass(DestructuringBindingsForFieldClass): FieldShapeSymbol: "😳_%i", FieldStrideSymbol: "🥵_%i" } - CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>") + CLASS_NAME_TEMPLATE = "🤯<{dtype}, {ndim}>" + def __init__(self, node): super().__init__(node) self.headers = [] - - + + def test_destructuring_alternative_field_class(): z, x, y = pystencils.fields("z, y, x: [2d]") @@ -44,6 +42,7 @@ def test_destructuring_alternative_field_class(): ast.body = DestructuringEmojiClass(ast.body) print(pystencils.show_code(ast)) + def main(): test_destructuring_field_class() test_destructuring_alternative_field_class()