Skip to content
Snippets Groups Projects
Commit 8b30de1d authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'avoid-jinja2-dependency' into 'master'

Fix #10: Avoid jinja2 dependency

Closes #10

See merge request !18
parents 6942ed0b 74236fab
No related merge requests found
import uuid import uuid
from typing import Any, List, Optional, Sequence, Set, Union from typing import Any, List, Optional, Sequence, Set, Union
import jinja2
import sympy as sp import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type from pystencils.data_types import TypedSymbol, cast_func, create_type
...@@ -673,7 +672,7 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -673,7 +672,7 @@ class DestructuringBindingsForFieldClass(Node):
FieldShapeSymbol: "shape[%i]", FieldShapeSymbol: "shape[%i]",
FieldStrideSymbol: "stride[%i]" FieldStrideSymbol: "stride[%i]"
} }
CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>") CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>"
@property @property
def fields_accessed(self) -> Set['ResolvedFieldAccess']: def fields_accessed(self) -> Set['ResolvedFieldAccess']:
...@@ -703,7 +702,7 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -703,7 +702,7 @@ class DestructuringBindingsForFieldClass(Node):
undefined_field_symbols = self.symbols_defined undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
for f in corresponding_field_names} | \ for f in corresponding_field_names} | \
(self.body.undefined_symbols - undefined_field_symbols) (self.body.undefined_symbols - undefined_field_symbols)
......
from collections import namedtuple from collections import namedtuple
from typing import Set from typing import Set
import jinja2
import sympy as sp import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
...@@ -9,11 +8,12 @@ from sympy.printing.ccode import C89CodePrinter ...@@ -9,11 +8,12 @@ from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func, PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
vector_memory_access) reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import ( from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
try: try:
...@@ -271,18 +271,11 @@ class CBackend: ...@@ -271,18 +271,11 @@ class CBackend:
for u in undefined_field_symbols for u in undefined_field_symbols
] ]
destructuring_bindings.sort() # only for code aesthetics destructuring_bindings.sort() # only for code aesthetics
template = jinja2.Template( return "{\n" + self._indent + \
"""{ ("\n" + self._indent).join(destructuring_bindings) + \
{% for binding in bindings -%} "\n" + self._indent + \
{{ binding | indent(3) }} ("\n" + self._indent).join(self._print(node.body).splitlines()) + \
{% endfor -%} "\n}"
{{ block | indent(3) }}
}
""")
code = template.render(bindings=destructuring_bindings,
block=self._print(node.body))
return code
# ------------------------------------------ Helper function & classes ------------------------------------------------- # ------------------------------------------ Helper function & classes -------------------------------------------------
......
import sympy import sympy
import jinja2
import pystencils import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass from pystencils.astnodes import DestructuringBindingsForFieldClass
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
def test_destructuring_field_class(): def test_destructuring_field_class():
...@@ -28,12 +25,13 @@ class DestructuringEmojiClass(DestructuringBindingsForFieldClass): ...@@ -28,12 +25,13 @@ class DestructuringEmojiClass(DestructuringBindingsForFieldClass):
FieldShapeSymbol: "😳_%i", FieldShapeSymbol: "😳_%i",
FieldStrideSymbol: "🥵_%i" FieldStrideSymbol: "🥵_%i"
} }
CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>") CLASS_NAME_TEMPLATE = "🤯<{dtype}, {ndim}>"
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.headers = [] self.headers = []
def test_destructuring_alternative_field_class(): def test_destructuring_alternative_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]") z, x, y = pystencils.fields("z, y, x: [2d]")
...@@ -44,6 +42,7 @@ def test_destructuring_alternative_field_class(): ...@@ -44,6 +42,7 @@ def test_destructuring_alternative_field_class():
ast.body = DestructuringEmojiClass(ast.body) ast.body = DestructuringEmojiClass(ast.body)
print(pystencils.show_code(ast)) print(pystencils.show_code(ast))
def main(): def main():
test_destructuring_field_class() test_destructuring_field_class()
test_destructuring_alternative_field_class() test_destructuring_alternative_field_class()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment