Commit c5c6019f authored by Stephan Seitz's avatar Stephan Seitz Committed by Markus Holzer
Browse files

Fix deprecation warning for Sympy 1.7

We have to try from newest to oldest import to avoid deprecation
warnings.
parent 9f966136
...@@ -98,6 +98,10 @@ try: ...@@ -98,6 +98,10 @@ try:
except ImportError: except ImportError:
collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils/datahandling/vtk.py")] collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils/datahandling/vtk.py")]
# TODO: Remove if Ubuntu 18.04 is no longer supported
if pytest_version < 50403:
collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils_tests/test_jupyter_extensions.ipynb")]
collect_ignore += [os.path.join(SCRIPT_FOLDER, 'setup.py')] collect_ignore += [os.path.join(SCRIPT_FOLDER, 'setup.py')]
for root, sub_dirs, files in os.walk('.'): for root, sub_dirs, files in os.walk('.'):
......
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.codegen.ast import Assignment
from sympy.printing.latex import LatexPrinter from sympy.printing.latex import LatexPrinter
try:
from sympy.codegen.ast import Assignment
except ImportError:
Assignment = None
__all__ = ['Assignment', 'assignment_from_stencil'] __all__ = ['Assignment', 'assignment_from_stencil']
...@@ -21,43 +17,22 @@ def assignment_str(assignment): ...@@ -21,43 +17,22 @@ def assignment_str(assignment):
return r"{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs) return r"{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs)
if Assignment: _old_new = sp.codegen.ast.Assignment.__new__
_old_new = sp.codegen.ast.Assignment.__new__
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs)
Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
else: def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
# back port for older sympy versions that don't have Assignment yet if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs)
class Assignment(sp.Rel): # pragma: no cover
rel_op = ':=' Assignment.__str__ = assignment_str
__slots__ = [] Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
def __new__(cls, lhs, rhs=0, **assumptions): sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
from sympy.matrices.expressions.matexpr import (
MatrixElement, MatrixSymbol)
lhs = sp.sympify(lhs)
rhs = sp.sympify(rhs)
# Tuple of things that can be on the lhs of an assignment
assignable = (sp.Symbol, MatrixSymbol, MatrixElement, sp.Indexed)
if not isinstance(lhs, assignable):
raise TypeError(f"Cannot assign to lhs of type {type(lhs)}.")
return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
__str__ = assignment_str
_print_Assignment = print_assignment_latex
# Apparently, in SymPy 1.4 Assignment.__hash__ is not implemented. This has been fixed in current master # Apparently, in SymPy 1.4 Assignment.__hash__ is not implemented. This has been fixed in current master
try: try:
......
...@@ -18,12 +18,9 @@ from pystencils.integer_functions import ( ...@@ -18,12 +18,9 @@ from pystencils.integer_functions import (
int_div, int_power_of_2, modulo_ceil) int_div, int_power_of_2, modulo_ceil)
try: try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter from sympy.printing.c import C99CodePrinter as CCodePrinter # for sympy versions > 1.6
except ImportError: except ImportError:
try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
except ImportError:
from sympy.printing.c import C11CodePrinter as CCodePrinter # for sympy versions > 1.6
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
......
...@@ -352,7 +352,7 @@ class InterpolatorAccess(TypedSymbol): ...@@ -352,7 +352,7 @@ class InterpolatorAccess(TypedSymbol):
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self): def __getnewargs__(self):
return tuple(self.symbol, *self.offsets) return (self.symbol, *self.offsets)
class DiffInterpolatorAccess(InterpolatorAccess): class DiffInterpolatorAccess(InterpolatorAccess):
...@@ -397,7 +397,7 @@ class DiffInterpolatorAccess(InterpolatorAccess): ...@@ -397,7 +397,7 @@ class DiffInterpolatorAccess(InterpolatorAccess):
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self): def __getnewargs__(self):
return tuple(self.symbol, self.diff_coordinate_idx, *self.offsets) return (self.symbol, self.diff_coordinate_idx, *self.offsets)
########################################################################################## ##########################################################################################
......
import pytest
import sympy as sp import sympy as sp
import pystencils as ps
import pystencils as ps
from pystencils import Assignment from pystencils import Assignment
from pystencils.astnodes import Block, SkipIteration, LoopOverCoordinate, SympyAssignment from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
from sympy.codegen.rewriting import optims_c99
sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.')]
if len(sympy_numeric_version) < 3:
sympy_numeric_version.append(0)
sympy_numeric_version.reverse()
sympy_version = sum(x * (100 ** i) for i, x in enumerate(sympy_numeric_version))
dst = ps.fields('dst(8): double[2D]') dst = ps.fields('dst(8): double[2D]')
s = sp.symbols('s_:8') s = sp.symbols('s_:8')
...@@ -11,6 +17,8 @@ x = sp.symbols('x') ...@@ -11,6 +17,8 @@ x = sp.symbols('x')
y = sp.symbols('y') y = sp.symbols('y')
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_kernel_function(): def test_kernel_function():
assignments = [ assignments = [
Assignment(dst[0, 0](0), s[0]), Assignment(dst[0, 0](0), s[0]),
...@@ -36,6 +44,8 @@ def test_skip_iteration(): ...@@ -36,6 +44,8 @@ def test_skip_iteration():
assert skipped.undefined_symbols == set() assert skipped.undefined_symbols == set()
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_block(): def test_block():
assignments = [ assignments = [
Assignment(dst[0, 0](0), s[0]), Assignment(dst[0, 0](0), s[0]),
...@@ -83,7 +93,8 @@ def test_loop_over_coordinate(): ...@@ -83,7 +93,8 @@ def test_loop_over_coordinate():
def test_sympy_assignment(): def test_sympy_assignment():
pytest.importorskip('sympy.codegen.rewriting')
from sympy.codegen.rewriting import optims_c99
assignment = SympyAssignment(dst[0, 0](0), sp.log(x + 3) / sp.log(2) + sp.log(x ** 2 + 1)) assignment = SympyAssignment(dst[0, 0](0), sp.log(x + 3) / sp.log(2) + sp.log(x ** 2 + 1))
assignment.optimize(optims_c99) assignment.optimize(optims_c99)
......
...@@ -390,7 +390,7 @@ ...@@ -390,7 +390,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.6.5" "version": "3.8.5"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -369,7 +369,7 @@ ...@@ -369,7 +369,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.6.9" "version": "3.8.5"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment