diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 53731645797afeb62a4a5ff803c2ee7b4e5d8e16..f9044d3cb994ed81fef149073d2a48ef3e66da9f 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -692,6 +692,9 @@ class ResolvedFieldAccess(sp.Indexed):
     def __getnewargs__(self):
         return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
 
+    def __getnewargs_ex__(self):
+        return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {}
+
 
 class TemporaryMemoryAllocation(Node):
     """Node for temporary memory buffer allocation.
@@ -837,6 +840,9 @@ class ConditionalFieldAccess(sp.Function):
     def __getnewargs__(self):
         return self.access, self.outofbounds_condition, self.outofbounds_value
 
+    def __getnewargs_ex__(self):
+        return (self.access, self.outofbounds_condition, self.outofbounds_value), {}
+
 
 class NontemporalFence(Node):
     def __init__(self):
diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 28fc4b32c4084ed319c4788b2e0b4ade0d0ccb7d..46abd84f30e15b8a56144f3f91ce78393304d074 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -219,9 +219,10 @@ class TypedSymbol(sp.Symbol):
         obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, name, dtype, *args, **kwargs):
+    def __new_stage2__(cls, name, dtype, **kwargs):
         assumptions = assumptions_from_dtype(dtype)
-        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
+        assumptions.update(kwargs)
+        obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
         try:
             obj._dtype = create_type(dtype)
         except (TypeError, ValueError):
@@ -242,6 +243,9 @@ class TypedSymbol(sp.Symbol):
     def __getnewargs__(self):
         return self.name, self.dtype
 
+    def __getnewargs_ex__(self):
+        return (self.name, self.dtype), self.assumptions0
+
     @property
     def canonical(self):
         return self
@@ -578,6 +582,13 @@ def get_type_of_expression(expr,
     raise NotImplementedError("Could not determine type for", expr, type(expr))
 
 
+sympy_version = sp.__version__.split('.')
+if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
+    # __setstate__ would bypass the contructor, so we remove it
+    sp.Number.__getstate__ = sp.Basic.__getstate__
+    del sp.Basic.__getstate__
+
+
 class Type(sp.Atom):
     def __new__(cls, *args, **kwargs):
         return sp.Basic.__new__(cls)
@@ -621,6 +632,9 @@ class BasicType(Type):
     def __getnewargs__(self):
         return self.numpy_dtype, self.const
 
+    def __getnewargs_ex__(self):
+        return (self.numpy_dtype, self.const), {}
+
     @property
     def base_type(self):
         return None
@@ -717,6 +731,9 @@ class VectorType(Type):
     def __getnewargs__(self):
         return self._base_type, self.width
 
+    def __getnewargs_ex__(self):
+        return (self._base_type, self.width), {}
+
 
 class PointerType(Type):
     def __init__(self, base_type, const=False, restrict=True):
@@ -727,6 +744,9 @@ class PointerType(Type):
     def __getnewargs__(self):
         return self.base_type, self.const, self.restrict
 
+    def __getnewargs_ex__(self):
+        return (self.base_type, self.const, self.restrict), {}
+
     @property
     def alias(self):
         return not self.restrict
@@ -768,6 +788,9 @@ class StructType:
     def __getnewargs__(self):
         return self.numpy_dtype, self.const
 
+    def __getnewargs_ex__(self):
+        return (self.numpy_dtype, self.const), {}
+
     @property
     def base_type(self):
         return None
@@ -815,13 +838,11 @@ class TypedImaginaryUnit(TypedSymbol):
         obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, dtype, *args, **kwargs):
+    def __new_stage2__(cls, dtype):
         obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
                                                       "_i",
                                                       dtype,
-                                                      imaginary=True,
-                                                      *args,
-                                                      **kwargs)
+                                                      imaginary=True)
         return obj
 
     headers = ['"cuda_complex.hpp"']
@@ -831,3 +852,6 @@ class TypedImaginaryUnit(TypedSymbol):
 
     def __getnewargs__(self):
         return (self.dtype,)
+
+    def __getnewargs_ex__(self):
+        return (self.dtype,), {}
diff --git a/pystencils/field.py b/pystencils/field.py
index fdc587e60523b7547e15859ed4b0ab17643f307e..82882c5487c246bfc1d5863e001c30e6870bf423 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -729,6 +729,9 @@ class Field(AbstractField):
         def __getnewargs__(self):
             return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype
 
+        def __getnewargs_ex__(self):
+            return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {}
+
         # noinspection SpellCheckingInspection
         __xnew__ = staticmethod(__new_stage2__)
         # noinspection SpellCheckingInspection
diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py
index a694cb8eac864fdcea59fced498f72619bbdfec1..07d67d1da9ad6b3eb5a2de20d04cf50a48201af5 100644
--- a/pystencils/interpolation_astnodes.py
+++ b/pystencils/interpolation_astnodes.py
@@ -50,6 +50,9 @@ class _InterpolationSymbol(TypedSymbol):
     def __getnewargs__(self):
         return self.name, self.field, self.interpolator
 
+    def __getnewargs_ex__(self):
+        return (self.name, self.field, self.interpolator), {}
+
     # noinspection SpellCheckingInspection
     __xnew__ = staticmethod(__new_stage2__)
     # noinspection SpellCheckingInspection
@@ -167,8 +170,8 @@ class NearestNeightborInterpolator(Interpolator):
 
 
 class InterpolatorAccess(TypedSymbol):
-    def __new__(cls, field, *offsets, **kwargs):
-        obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
+    def __new__(cls, field, *offsets):
+        obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets)
         return obj
 
     def __new_stage2__(cls, symbol, *offsets):
@@ -354,13 +357,16 @@ class InterpolatorAccess(TypedSymbol):
     def __getnewargs__(self):
         return (self.symbol, *self.offsets)
 
+    def __getnewargs_ex__(self):
+        return (self.symbol, *self.offsets), {}
+
 
 class DiffInterpolatorAccess(InterpolatorAccess):
-    def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs):
+    def __new__(cls, symbol, diff_coordinate_idx, *offsets):
         if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR:
             from pystencils.fd import Diff, Discretization2ndOrder
             return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx))
-        obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs)
+        obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets)
         return obj
 
     def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets):
@@ -399,6 +405,9 @@ class DiffInterpolatorAccess(InterpolatorAccess):
     def __getnewargs__(self):
         return (self.symbol, self.diff_coordinate_idx, *self.offsets)
 
+    def __getnewargs_ex__(self):
+        return (self.symbol, self.diff_coordinate_idx, *self.offsets), {}
+
 
 ##########################################################################################
 # GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols #
diff --git a/pystencils/kernelparameters.py b/pystencils/kernelparameters.py
index 3257522e419bf921b13010215e44a51a5290ce80..934c305cc21e3a5bcad2e9f6076230dd69ec1d40 100644
--- a/pystencils/kernelparameters.py
+++ b/pystencils/kernelparameters.py
@@ -38,6 +38,9 @@ class FieldStrideSymbol(TypedSymbol):
     def __getnewargs__(self):
         return self.field_name, self.coordinate
 
+    def __getnewargs_ex__(self):
+        return (self.field_name, self.coordinate), {}
+
     __xnew__ = staticmethod(__new_stage2__)
     __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
 
@@ -63,6 +66,9 @@ class FieldShapeSymbol(TypedSymbol):
     def __getnewargs__(self):
         return self.field_names, self.coordinate
 
+    def __getnewargs_ex__(self):
+        return (self.field_names, self.coordinate), {}
+
     __xnew__ = staticmethod(__new_stage2__)
     __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
 
@@ -86,6 +92,9 @@ class FieldPointerSymbol(TypedSymbol):
     def __getnewargs__(self):
         return self.field_name, self.dtype, self.dtype.const
 
+    def __getnewargs_ex__(self):
+        return (self.field_name, self.dtype, self.dtype.const), {}
+
     def _hashable_content(self):
         return super()._hashable_content(), self.field_name
 
diff --git a/pystencils_tests/test_astnodes.py b/pystencils_tests/test_astnodes.py
index 1d66eae29f1ec4edd7af48813979abb6a9fd1087..a3edc1c2bfb860b2ff967dc74694a27442addb22 100644
--- a/pystencils_tests/test_astnodes.py
+++ b/pystencils_tests/test_astnodes.py
@@ -5,7 +5,7 @@ import pystencils as ps
 from pystencils import Assignment
 from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
 
-sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.')]
+sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.') if x.isdigit()]
 if len(sympy_numeric_version) < 3:
     sympy_numeric_version.append(0)
 sympy_numeric_version.reverse()