Commit 0f0983d6 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'fix-deprecation-warning-1.7' into 'master'

Fix deprecation warning for Sympy 1.7

See merge request pycodegen/pystencils!191
parents 9f966136 c5c6019f
Pipeline #28466 waiting for manual action with stages
in 5 minutes and 38 seconds
......@@ -98,6 +98,10 @@ try:
except ImportError:
collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils/datahandling/vtk.py")]
# TODO: Remove if Ubuntu 18.04 is no longer supported
if pytest_version < 50403:
collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils_tests/test_jupyter_extensions.ipynb")]
collect_ignore += [os.path.join(SCRIPT_FOLDER, 'setup.py')]
for root, sub_dirs, files in os.walk('.'):
......
import numpy as np
import sympy as sp
from sympy.codegen.ast import Assignment
from sympy.printing.latex import LatexPrinter
try:
from sympy.codegen.ast import Assignment
except ImportError:
Assignment = None
__all__ = ['Assignment', 'assignment_from_stencil']
......@@ -21,43 +17,22 @@ def assignment_str(assignment):
return r"{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs)
if Assignment:
_old_new = sp.codegen.ast.Assignment.__new__
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs)
Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
_old_new = sp.codegen.ast.Assignment.__new__
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
else:
# back port for older sympy versions that don't have Assignment yet
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs)
class Assignment(sp.Rel): # pragma: no cover
rel_op = ':='
__slots__ = []
Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
def __new__(cls, lhs, rhs=0, **assumptions):
from sympy.matrices.expressions.matexpr import (
MatrixElement, MatrixSymbol)
lhs = sp.sympify(lhs)
rhs = sp.sympify(rhs)
# Tuple of things that can be on the lhs of an assignment
assignable = (sp.Symbol, MatrixSymbol, MatrixElement, sp.Indexed)
if not isinstance(lhs, assignable):
raise TypeError(f"Cannot assign to lhs of type {type(lhs)}.")
return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
__str__ = assignment_str
_print_Assignment = print_assignment_latex
# Apparently, in SymPy 1.4 Assignment.__hash__ is not implemented. This has been fixed in current master
try:
......
......@@ -18,12 +18,9 @@ from pystencils.integer_functions import (
int_div, int_power_of_2, modulo_ceil)
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
from sympy.printing.c import C99CodePrinter as CCodePrinter # for sympy versions > 1.6
except ImportError:
try:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
except ImportError:
from sympy.printing.c import C11CodePrinter as CCodePrinter # for sympy versions > 1.6
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
......
......@@ -352,7 +352,7 @@ class InterpolatorAccess(TypedSymbol):
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return tuple(self.symbol, *self.offsets)
return (self.symbol, *self.offsets)
class DiffInterpolatorAccess(InterpolatorAccess):
......@@ -397,7 +397,7 @@ class DiffInterpolatorAccess(InterpolatorAccess):
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return tuple(self.symbol, self.diff_coordinate_idx, *self.offsets)
return (self.symbol, self.diff_coordinate_idx, *self.offsets)
##########################################################################################
......
import pytest
import sympy as sp
import pystencils as ps
import pystencils as ps
from pystencils import Assignment
from pystencils.astnodes import Block, SkipIteration, LoopOverCoordinate, SympyAssignment
from sympy.codegen.rewriting import optims_c99
from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.')]
if len(sympy_numeric_version) < 3:
sympy_numeric_version.append(0)
sympy_numeric_version.reverse()
sympy_version = sum(x * (100 ** i) for i, x in enumerate(sympy_numeric_version))
dst = ps.fields('dst(8): double[2D]')
s = sp.symbols('s_:8')
......@@ -11,6 +17,8 @@ x = sp.symbols('x')
y = sp.symbols('y')
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_kernel_function():
assignments = [
Assignment(dst[0, 0](0), s[0]),
......@@ -36,6 +44,8 @@ def test_skip_iteration():
assert skipped.undefined_symbols == set()
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_block():
assignments = [
Assignment(dst[0, 0](0), s[0]),
......@@ -83,7 +93,8 @@ def test_loop_over_coordinate():
def test_sympy_assignment():
pytest.importorskip('sympy.codegen.rewriting')
from sympy.codegen.rewriting import optims_c99
assignment = SympyAssignment(dst[0, 0](0), sp.log(x + 3) / sp.log(2) + sp.log(x ** 2 + 1))
assignment.optimize(optims_c99)
......
......@@ -390,7 +390,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.8.5"
}
},
"nbformat": 4,
......
......@@ -369,7 +369,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.8.5"
}
},
"nbformat": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment