From c9e5903579b3af59ac65896f434abd96f86da87c Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Sun, 31 Oct 2021 12:25:35 +0100
Subject: [PATCH] Fix deepcopy issue with Sympy 1.9

---
 pystencils/data_types.py               | 22 ++++++++++++++++++++++
 pystencils_tests/test_vectorization.py | 17 +++++++++++++++++
 2 files changed, 39 insertions(+)

diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index b8653539b..295afe089 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -590,6 +590,28 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
     sp.Number.__getstate__ = sp.Basic.__getstate__
     del sp.Basic.__getstate__
 
+    class FunctorWithStoredKwargs:
+        def __init__(self, func, **kwargs):
+            self.func = func
+            self.kwargs = kwargs
+
+        def __call__(self, *args):
+            return self.func(*args, **self.kwargs)
+
+    # __reduce_ex__ would strip kwargs, so we override it
+    def basic_reduce_ex(self, protocol):
+        if hasattr(self, '__getnewargs_ex__'):
+            args, kwargs = self.__getnewargs_ex__()
+        else:
+            args, kwargs = self.__getnewargs__(), {}
+        if hasattr(self, '__getstate__'):
+            state = self.__getstate__()
+        else:
+            state = None
+        return FunctorWithStoredKwargs(type(self), **kwargs), args, state
+    sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__
+    sp.Basic.__reduce_ex__ = basic_reduce_ex
+
 
 class Type(sp.Atom):
     def __new__(cls, *args, **kwargs):
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index 9c34949be..1bbc6a4a4 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -65,6 +65,7 @@ def test_aligned_and_nt_stores(instruction_set=instruction_set, openmp=False):
     dh.run_kernel(kernel)
     np.testing.assert_equal(np.sum(dh.cpu_arrays['f']), np.prod(domain_size))
 
+
 def test_aligned_and_nt_stores_openmp(instruction_set=instruction_set):
     test_aligned_and_nt_stores(instruction_set, True)
 
@@ -278,3 +279,19 @@ def test_vectorised_fast_approximations(instruction_set=instruction_set):
     ast = ps.create_kernel(insert_fast_sqrts(assignment))
     vectorize(ast, instruction_set=instruction_set)
     ast.compile()
+
+
+def test_issue40(*_):
+    """https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/40"""
+    opt = {'instruction_set': "avx512", 'assume_aligned': False,
+           'nontemporal': False, 'assume_inner_stride_one': True}
+
+    src = ps.fields("src(1): double[2D]", layout='fzyx')
+    eq = [ps.Assignment(sp.Symbol('rho'), 1.0),
+          ps.Assignment(src[0, 0](0), sp.Rational(4, 9) * sp.Symbol('rho'))]
+
+    config = ps.CreateKernelConfig(target=Target.CPU, cpu_vectorize_info=opt, data_type='float64')
+    ast = ps.create_kernel(eq, config=config)
+
+    code = ps.get_code_str(ast)
+    assert 'epi32' not in code
-- 
GitLab