Commit 43bb69a0 authored by Markus Holzer's avatar Markus Holzer Committed by Jan Hönig
Browse files

Introduce default assignment simplifications

parent 3961a021
...@@ -142,17 +142,13 @@ ubuntu: ...@@ -142,17 +142,13 @@ ubuntu:
- sed -i 's/--doctest-modules //g' pytest.ini - sed -i 's/--doctest-modules //g' pytest.ini
- env - env
- pip3 list - pip3 list
- pytest-3 -v -n $NUM_CORES --cov-report html --cov-report term --cov=. --junitxml=report.xml pystencils_tests/test_*vec*.py pystencils_tests/test_random.py - pytest-3 -v -n $NUM_CORES --junitxml=report.xml pystencils_tests/test_*vec*.py pystencils_tests/test_random.py
- python3 -m coverage xml
tags: tags:
- docker - docker
- AVX - AVX
artifacts: artifacts:
when: always when: always
paths:
- coverage_report
reports: reports:
cobertura: coverage.xml
junit: report.xml junit: report.xml
arm64v8: arm64v8:
...@@ -252,6 +248,7 @@ pycodegen-integration: ...@@ -252,6 +248,7 @@ pycodegen-integration:
- pip install -e pystencils/ - pip install -e pystencils/
- pip install -e lbmpy/ - pip install -e lbmpy/
- pip install -e pygrandchem/ - pip install -e pygrandchem/
- cmake --version
- ./install_walberla.sh - ./install_walberla.sh
- export NUM_CORES=$(nproc --all) - export NUM_CORES=$(nproc --all)
- mkdir -p ~/.config/matplotlib - mkdir -p ~/.config/matplotlib
......
...@@ -14,5 +14,6 @@ pystencils can help you to generate blazingly fast code for image processing, nu ...@@ -14,5 +14,6 @@ pystencils can help you to generate blazingly fast code for image processing, nu
.. image:: /img/pystencils_arch_block_diagram.svg .. image:: /img/pystencils_arch_block_diagram.svg
:height: 450px :height: 450px
:align: center
...@@ -10,13 +10,27 @@ AssignmentCollection ...@@ -10,13 +10,27 @@ AssignmentCollection
:members: :members:
SimplificationStrategy
======================
.. autoclass:: pystencils.simp.SimplificationStrategy
:members:
Simplifications Simplifications
=============== ===============
.. automodule:: pystencils.simp .. automodule:: pystencils.simp.simplifications
:members: :members:
Subexpression insertion
=======================
The subexpression insertions have the goal to insert subexpressions which will not reduce the number of FLOPs.
For example a constant value kept as subexpression will lead to a new variable in the code which will occupy
a register slot. On the other side a single variable could just be inserted in all assignments.
.. automodule:: pystencils.simp.subexpression_insertion
:members:
......
...@@ -554,7 +554,7 @@ class ExtensionModuleCode: ...@@ -554,7 +554,7 @@ class ExtensionModuleCode:
if os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]))] if os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]))]
header_hash = b''.join([hashlib.sha256(open(h, 'rb').read()).digest() for h in ps_headers]) header_hash = b''.join([hashlib.sha256(open(h, 'rb').read()).digest() for h in ps_headers])
includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list]) includes = "\n".join([f"#include {include_file}" for include_file in header_list])
self._code_string += includes self._code_string += includes
self._code_string += "\n" self._code_string += "\n"
self._code_string += f"#define RESTRICT {restrict_qualifier} \n" self._code_string += f"#define RESTRICT {restrict_qualifier} \n"
...@@ -563,7 +563,7 @@ class ExtensionModuleCode: ...@@ -563,7 +563,7 @@ class ExtensionModuleCode:
for ast, name in zip(self._ast_nodes, self._function_names): for ast, name in zip(self._ast_nodes, self._function_names):
old_name = ast.function_name old_name = ast.function_name
ast.function_name = "kernel_" + name ast.function_name = f"kernel_{name}"
self._code_string += generate_c(ast, custom_backend=self._custom_backend) self._code_string += generate_c(ast, custom_backend=self._custom_backend)
self._code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast) self._code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast)
ast.function_name = old_name ast.function_name = old_name
......
...@@ -205,5 +205,5 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, ass ...@@ -205,5 +205,5 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, ass
prefix = f"#pragma omp for schedule({schedule})" prefix = f"#pragma omp for schedule({schedule})"
if collapse: if collapse:
prefix += " collapse(%d)" % (collapse, ) prefix += f" collapse({collapse})"
loop_to_parallelize.prefix_lines.append(prefix) loop_to_parallelize.prefix_lines.append(prefix)
...@@ -14,6 +14,8 @@ from pystencils.enums import Target, Backend ...@@ -14,6 +14,8 @@ from pystencils.enums import Target, Backend
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
from pystencils.gpucuda.indexing import indexing_creator_from_params from pystencils.gpucuda.indexing import indexing_creator_from_params
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.simp.simplifications import apply_sympy_optimisations
from pystencils.simplificationfactory import create_simplification_strategy
from pystencils.stencil import direction_string_to_offset, inverse_direction_string from pystencils.stencil import direction_string_to_offset, inverse_direction_string
from pystencils.transformations import ( from pystencils.transformations import (
loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel) loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
...@@ -83,6 +85,18 @@ class CreateKernelConfig: ...@@ -83,6 +85,18 @@ class CreateKernelConfig:
Dict with indexing parameters (constructor parameters of indexing class) Dict with indexing parameters (constructor parameters of indexing class)
e.g. for 'block' one can specify '{'block_size': (20, 20, 10) }'. e.g. for 'block' one can specify '{'block_size': (20, 20, 10) }'.
""" """
default_assignment_simplifications: bool = False
"""
If `True` default simplifications are first performed on the Assignments. If problems occur during the
simplification a warning will be thrown.
Furthermore, it is essential to know that this is a two-stage process. The first stage of the process acts
on the level of the `AssignmentCollection`. In this part, `create_simplification_strategy`
from pystencils.simplificationfactory will be used to apply optimisations like insertion of constants to
remove pressure from the registers. Thus the first part of the optimisations can only be executed if
an `AssignmentCollection` is passed. The second part of the optimisation acts on the level of each Assignment
individually. In this stage, all optimisations from `sympy.codegen.rewriting.optims_c99` are applied
to each Assignment. Thus this stage can also be applied if a list of Assignments is passed.
"""
cpu_prepend_optimizations: List[Callable] = field(default_factory=list) cpu_prepend_optimizations: List[Callable] = field(default_factory=list)
""" """
List of extra optimizations to perform first on the AST. List of extra optimizations to perform first on the AST.
...@@ -195,8 +209,8 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC ...@@ -195,8 +209,8 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
>>> import numpy as np >>> import numpy as np
>>> s, d = ps.fields('s, d: [2D]') >>> s, d = ps.fields('s, d: [2D]')
>>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0]) >>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0])
>>> config = ps.CreateKernelConfig(cpu_openmp=True) >>> kernel_config = ps.CreateKernelConfig(cpu_openmp=True)
>>> kernel_ast = ps.kernelcreation.create_domain_kernel([assignment], config=config) >>> kernel_ast = ps.kernelcreation.create_domain_kernel([assignment], config=kernel_config)
>>> kernel = kernel_ast.compile() >>> kernel = kernel_ast.compile()
>>> d_arr = np.zeros([5, 5]) >>> d_arr = np.zeros([5, 5])
>>> kernel(d=d_arr, s=np.ones([5, 5])) >>> kernel(d=d_arr, s=np.ones([5, 5]))
...@@ -207,6 +221,15 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC ...@@ -207,6 +221,15 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
[0., 4., 4., 4., 0.], [0., 4., 4., 4., 0.],
[0., 0., 0., 0., 0.]]) [0., 0., 0., 0., 0.]])
""" """
# --- applying first default simplifications
try:
if config.default_assignment_simplifications and isinstance(assignments, AssignmentCollection):
simplification = create_simplification_strategy()
assignments = simplification(assignments)
except Exception as e:
warnings.warn(f"It was not possible to apply the default pystencils optimisations to the "
f"AssignmentCollection due to the following problem :{e}")
# ---- Normalizing parameters # ---- Normalizing parameters
split_groups = () split_groups = ()
if isinstance(assignments, AssignmentCollection): if isinstance(assignments, AssignmentCollection):
...@@ -214,6 +237,13 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC ...@@ -214,6 +237,13 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
split_groups = assignments.simplification_hints['split_groups'] split_groups = assignments.simplification_hints['split_groups']
assignments = assignments.all_assignments assignments = assignments.all_assignments
try:
if config.default_assignment_simplifications:
assignments = apply_sympy_optimisations(assignments)
except Exception as e:
warnings.warn(f"It was not possible to apply the default SymPy optimisations to the "
f"Assignments due to the following problem :{e}")
# ---- Creating ast # ---- Creating ast
ast = None ast = None
if config.target == Target.CPU: if config.target == Target.CPU:
...@@ -304,9 +334,9 @@ def create_indexed_kernel(assignments: List[Assignment], *, config: CreateKernel ...@@ -304,9 +334,9 @@ def create_indexed_kernel(assignments: List[Assignment], *, config: CreateKernel
>>> >>>
>>> # Additional values stored in index field can be accessed in the kernel as well >>> # Additional values stored in index field can be accessed in the kernel as well
>>> s, d = ps.fields('s, d: [2D]') >>> s, d = ps.fields('s, d: [2D]')
>>> assignment = ps.Assignment(d[0,0], 2 * s[0, 1] + 2 * s[1, 0] + idx_field('val')) >>> assignment = ps.Assignment(d[0, 0], 2 * s[0, 1] + 2 * s[1, 0] + idx_field('val'))
>>> config = ps.CreateKernelConfig(index_fields=[idx_field], coordinate_names=('x', 'y')) >>> kernel_config = ps.CreateKernelConfig(index_fields=[idx_field], coordinate_names=('x', 'y'))
>>> kernel_ast = ps.create_indexed_kernel([assignment], config=config) >>> kernel_ast = ps.create_indexed_kernel([assignment], config=kernel_config)
>>> kernel = kernel_ast.compile() >>> kernel = kernel_ast.compile()
>>> d_arr = np.zeros([5, 5]) >>> d_arr = np.zeros([5, 5])
>>> kernel(s=np.ones([5, 5]), d=d_arr, idx=index_arr) >>> kernel(s=np.ones([5, 5]), d=d_arr, idx=index_arr)
......
...@@ -19,13 +19,14 @@ class AssignmentCollection: ...@@ -19,13 +19,14 @@ class AssignmentCollection:
Additionally a dictionary of simplification hints is stored, which are set by the functions that create Additionally a dictionary of simplification hints is stored, which are set by the functions that create
assignment collections to transport information to the simplification system. assignment collections to transport information to the simplification system.
Attributes: Args:
main_assignments: list of assignments main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
subexpressions: list of assignments defining subexpressions used in main equations assignment is a field access. Thus the generated equations write on arrays.
simplification_hints: dict that is used to annotate the assignment collection with hints that are subexpressions: List of assignments defining subexpressions used in main equations
simplification_hints: Dict that is used to annotate the assignment collection with hints that are
used by the simplification system. See documentation of the simplification rules for used by the simplification system. See documentation of the simplification rules for
potentially required hints and their meaning. potentially required hints and their meaning.
subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
used to get new symbols that are unique for this AssignmentCollection used to get new symbols that are unique for this AssignmentCollection
""" """
...@@ -33,9 +34,13 @@ class AssignmentCollection: ...@@ -33,9 +34,13 @@ class AssignmentCollection:
# ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = {}, subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
simplification_hints: Optional[Dict[str, Any]] = None, simplification_hints: Optional[Dict[str, Any]] = None,
subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
if subexpressions is None:
subexpressions = {}
if isinstance(main_assignments, Dict): if isinstance(main_assignments, Dict):
main_assignments = [Assignment(k, v) main_assignments = [Assignment(k, v)
for k, v in main_assignments.items()] for k, v in main_assignments.items()]
......
...@@ -3,9 +3,11 @@ from typing import Callable, List, Sequence, Union ...@@ -3,9 +3,11 @@ from typing import Callable, List, Sequence, Union
from collections import defaultdict from collections import defaultdict
import sympy as sp import sympy as sp
from sympy.codegen.rewriting import optims_c99, optimize
from sympy.codegen.rewriting import ReplaceOptim
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import Node from pystencils.astnodes import Node, SympyAssignment
from pystencils.field import AbstractField, Field from pystencils.field import AbstractField, Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
...@@ -223,3 +225,24 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]): ...@@ -223,3 +225,24 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
f.__name__ = operation.__name__ f.__name__ = operation.__name__
return f return f
def apply_sympy_optimisations(assignments):
""" Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
and applies the default sympy optimisations. See sympy.codegen.rewriting
"""
# Evaluates all constant terms
evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf())
sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
if hasattr(a, 'lhs')
else a for a in assignments]
assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
for a in chain.from_iterable(assignments_nodes):
a.optimize(sympy_optimisations)
return assignments
...@@ -4,7 +4,7 @@ from pystencils.sympyextensions import is_constant ...@@ -4,7 +4,7 @@ from pystencils.sympyextensions import is_constant
# Subexpression Insertion # Subexpression Insertion
def insert_subexpressions(ac, selection_callback, skip=set()): def insert_subexpressions(ac, selection_callback, skip=None):
""" """
Removes a number of subexpressions from an assignment collection by Removes a number of subexpressions from an assignment collection by
inserting their right-hand side wherever they occur. inserting their right-hand side wherever they occur.
...@@ -16,6 +16,8 @@ def insert_subexpressions(ac, selection_callback, skip=set()): ...@@ -16,6 +16,8 @@ def insert_subexpressions(ac, selection_callback, skip=set()):
- skip: Set of symbols (left-hand sides of subexpressions) that should be - skip: Set of symbols (left-hand sides of subexpressions) that should be
ignored even if qualified by the callback. ignored even if qualified by the callback.
""" """
if skip is None:
skip = set()
i = 0 i = 0
while i < len(ac.subexpressions): while i < len(ac.subexpressions):
exp = ac.subexpressions[i] exp = ac.subexpressions[i]
......
from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one,
insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros)
def create_simplification_strategy():
"""
Creates a default simplification `ps.simp.SimplificationStrategy`. The idea behind the default simplification
strategy is to reduce the number of subexpressions by inserting single constants and to evaluate constant
terms beforehand.
"""
s = SimplificationStrategy()
s.add(insert_symbol_times_minus_one)
s.add(insert_constant_multiples)
s.add(insert_constant_additions)
s.add(insert_squares)
s.add(insert_zeros)
s.add(insert_constants)
s.add(lambda ac: ac.new_without_unused_subexpressions())
...@@ -453,7 +453,7 @@ def recursive_collect(expr, symbols, order_by_occurences=False): ...@@ -453,7 +453,7 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
return rec_sum return rec_sum
def count_operations(term: Union[sp.Expr, List[sp.Expr]], def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
only_type: Optional[str] = 'real') -> Dict[str, int]: only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division. """Counts the number of additions, multiplications and division.
......
import sympy import pytest
import sympy as sp
import numpy import numpy
import pystencils import pystencils
from pystencils.datahandling import create_data_handling from pystencils.datahandling import create_data_handling
def test_max(): @pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
def test_max(sympy_function):
dh = create_data_handling(domain_size=(10, 10), periodicity=True) dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1) x = dh.add_array('x', values_per_cell=1)
...@@ -15,56 +18,28 @@ def test_max(): ...@@ -15,56 +18,28 @@ def test_max():
dh.fill("z", 2.0, ghost_layers=True) dh.fill("z", 2.0, ghost_layers=True)
# test sp.Max with one argument # test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy.Max(y.center + 3.3)) assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3.3))
ast_1 = pystencils.create_kernel(assignment_1) ast_1 = pystencils.create_kernel(assignment_1)
kernel_1 = ast_1.compile() kernel_1 = ast_1.compile()
# test sp.Max with two arguments # test sp.Max with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy.Max(0.5, y.center - 1.5)) assignment_2 = pystencils.Assignment(x.center, sympy_function(0.5, y.center - 1.5))
ast_2 = pystencils.create_kernel(assignment_2) ast_2 = pystencils.create_kernel(assignment_2)
kernel_2 = ast_2.compile() kernel_2 = ast_2.compile()
# test sp.Max with many arguments # test sp.Max with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy.Max(z.center, 4.5, y.center - 1.5, y.center + z.center)) assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4.5, y.center - 1.5, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3) ast_3 = pystencils.create_kernel(assignment_3)
kernel_3 = ast_3.compile() kernel_3 = ast_3.compile()
dh.run_kernel(kernel_1) if sympy_function is sp.Max:
assert numpy.all(dh.cpu_arrays["x"] == 4.3) results = [4.3, 0.5, 4.5]
dh.run_kernel(kernel_2) else:
assert numpy.all(dh.cpu_arrays["x"] == 0.5) results = [4.3, -0.5, -0.5]
dh.run_kernel(kernel_3)
assert numpy.all(dh.cpu_arrays["x"] == 4.5)
def test_min():
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1)
dh.fill("x", 0.0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1)
dh.fill("y", 1.0, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1)
dh.fill("z", 2.0, ghost_layers=True)
# test sp.Min with one argument
assignment_1 = pystencils.Assignment(x.center, sympy.Min(y.center + 3.3))
ast_1 = pystencils.create_kernel(assignment_1)
kernel_1 = ast_1.compile()
# test sp.Min with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy.Min(0.5, y.center - 1.5))
ast_2 = pystencils.create_kernel(assignment_2)
kernel_2 = ast_2.compile()
# test sp.Min with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy.Min(z.center, 4.5, y.center - 1.5, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3)
kernel_3 = ast_3.compile()
dh.run_kernel(kernel_1) dh.run_kernel(kernel_1)
assert numpy.all(dh.cpu_arrays["x"] == 4.3) assert numpy.all(dh.gather_array('x') == results[0])
dh.run_kernel(kernel_2) dh.run_kernel(kernel_2)
assert numpy.all(dh.cpu_arrays["x"] == - 0.5) assert numpy.all(dh.gather_array('x') == results[1])
dh.run_kernel(kernel_3) dh.run_kernel(kernel_3)
assert numpy.all(dh.cpu_arrays["x"] == - 0.5) assert numpy.all(dh.gather_array('x') == results[2])
import pytest import pytest
import sys
import sympy as sp import sympy as sp
import pystencils as ps import pystencils as ps
from pystencils import Assignment from pystencils import Assignment
from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.') if x.isdigit()]
if len(sympy_numeric_version) < 3:
sympy_numeric_version.append(0)
sympy_numeric_version.reverse()
sympy_version = sum(x * (100 ** i) for i, x in enumerate(sympy_numeric_version))
dst = ps.fields('dst(8): double[2D]') dst = ps.fields('dst(8): double[2D]')
s = sp.symbols('s_:8') s = sp.symbols('s_:8')
x = sp.symbols('x') x = sp.symbols('x')
y = sp.symbols('y') y = sp.symbols('y')
python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_kernel_function(): def test_kernel_function():
assignments = [ assignments = [
Assignment(dst[0, 0](0), s[0]), Assignment(dst[0, 0](0), s[0]),
...@@ -44,8 +39,6 @@ def test_skip_iteration(): ...@@ -44,8 +39,6 @@ def test_skip_iteration():
assert skipped.undefined_symbols == set() assert skipped.undefined_symbols == set()
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_block(): def test_block():
assignments = [ assignments = [
Assignment(dst[0, 0](0), s[0]), Assignment(dst[0, 0](0), s[0]),
...@@ -92,17 +85,23 @@ def test_loop_over_coordinate(): ...@@ -92,17 +85,23 @@ def test_loop_over_coordinate():
assert loop.step == 2 assert loop.step == 2
def test_sympy_assignment(): @pytest.mark.parametrize('default_assignment_simplifications', [False, True])
pytest.importorskip('sympy.codegen.rewriting') @pytest.mark.skipif(python_version == '3.8.2', reason="For this python version a strange bug in mpmath occurs")
from sympy.codegen.rewriting import optims_c99 def test_sympy_assignment(default_assignment_simplifications):
assignment = SympyAssignment(dst[0, 0](0), sp.log(x + 3) / sp.log(2) + sp.log(x ** 2 + 1)) assignment = SympyAssignment(dst[0, 0](0), sp.log(x + 3) / sp.log(2) + sp.log(x ** 2 + 1))
assignment.optimize(optims_c99)
ast = ps.create_kernel([assignment]) config = ps.CreateKernelConfig(default_assignment_simplifications=default_assignment_simplifications)
ast = ps.create_kernel([assignment], config=config)
code = ps.get_code_str(ast) code = ps.get_code_str(ast)
assert 'log1p' in code if default_assignment_simplifications:
assert 'log2' in code assert 'log1p' in code
# constant term is directly evaluated
assert 'log2' not in code
else:
# no optimisations will be applied so the optimised version of log will not be in the code
assert 'log1p' not in code
assert 'log2' not in code
assignment.replace(assignment.lhs, dst[0, 0](1)) assignment.replace(assignment.lhs, dst[0, 0](1))
assignment.replace(assignment.rhs, sp.log(2)) assignment.replace(assignment.rhs, sp.log(2))
......
from sys import version_info as vs
import pytest import pytest
import sympy as sp import sympy as sp
import pystencils as ps