diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 53731645797afeb62a4a5ff803c2ee7b4e5d8e16..f9044d3cb994ed81fef149073d2a48ef3e66da9f 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -692,6 +692,9 @@ class ResolvedFieldAccess(sp.Indexed): def __getnewargs__(self): return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values + def __getnewargs_ex__(self): + return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {} + class TemporaryMemoryAllocation(Node): """Node for temporary memory buffer allocation. @@ -837,6 +840,9 @@ class ConditionalFieldAccess(sp.Function): def __getnewargs__(self): return self.access, self.outofbounds_condition, self.outofbounds_value + def __getnewargs_ex__(self): + return (self.access, self.outofbounds_condition, self.outofbounds_value), {} + class NontemporalFence(Node): def __init__(self): diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 28fc4b32c4084ed319c4788b2e0b4ade0d0ccb7d..46abd84f30e15b8a56144f3f91ce78393304d074 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -219,9 +219,10 @@ class TypedSymbol(sp.Symbol): obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, name, dtype, *args, **kwargs): + def __new_stage2__(cls, name, dtype, **kwargs): assumptions = assumptions_from_dtype(dtype) - obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs) + assumptions.update(kwargs) + obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions) try: obj._dtype = create_type(dtype) except (TypeError, ValueError): @@ -242,6 +243,9 @@ class TypedSymbol(sp.Symbol): def __getnewargs__(self): return self.name, self.dtype + def __getnewargs_ex__(self): + return (self.name, self.dtype), self.assumptions0 + @property def canonical(self): return self @@ -578,6 +582,13 @@ def get_type_of_expression(expr, raise NotImplementedError("Could not determine type for", expr, type(expr)) +sympy_version = sp.__version__.split('.') +if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: + # __setstate__ would bypass the contructor, so we remove it + sp.Number.__getstate__ = sp.Basic.__getstate__ + del sp.Basic.__getstate__ + + class Type(sp.Atom): def __new__(cls, *args, **kwargs): return sp.Basic.__new__(cls) @@ -621,6 +632,9 @@ class BasicType(Type): def __getnewargs__(self): return self.numpy_dtype, self.const + def __getnewargs_ex__(self): + return (self.numpy_dtype, self.const), {} + @property def base_type(self): return None @@ -717,6 +731,9 @@ class VectorType(Type): def __getnewargs__(self): return self._base_type, self.width + def __getnewargs_ex__(self): + return (self._base_type, self.width), {} + class PointerType(Type): def __init__(self, base_type, const=False, restrict=True): @@ -727,6 +744,9 @@ class PointerType(Type): def __getnewargs__(self): return self.base_type, self.const, self.restrict + def __getnewargs_ex__(self): + return (self.base_type, self.const, self.restrict), {} + @property def alias(self): return not self.restrict @@ -768,6 +788,9 @@ class StructType: def __getnewargs__(self): return self.numpy_dtype, self.const + def __getnewargs_ex__(self): + return (self.numpy_dtype, self.const), {} + @property def base_type(self): return None @@ -815,13 +838,11 @@ class TypedImaginaryUnit(TypedSymbol): obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, dtype, *args, **kwargs): + def __new_stage2__(cls, dtype): obj = super(TypedImaginaryUnit, cls).__xnew__(cls, "_i", dtype, - imaginary=True, - *args, - **kwargs) + imaginary=True) return obj headers = ['"cuda_complex.hpp"'] @@ -831,3 +852,6 @@ class TypedImaginaryUnit(TypedSymbol): def __getnewargs__(self): return (self.dtype,) + + def __getnewargs_ex__(self): + return (self.dtype,), {} diff --git a/pystencils/field.py b/pystencils/field.py index fdc587e60523b7547e15859ed4b0ab17643f307e..82882c5487c246bfc1d5863e001c30e6870bf423 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -729,6 +729,9 @@ class Field(AbstractField): def __getnewargs__(self): return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype + def __getnewargs_ex__(self): + return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {} + # noinspection SpellCheckingInspection __xnew__ = staticmethod(__new_stage2__) # noinspection SpellCheckingInspection diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py index a694cb8eac864fdcea59fced498f72619bbdfec1..07d67d1da9ad6b3eb5a2de20d04cf50a48201af5 100644 --- a/pystencils/interpolation_astnodes.py +++ b/pystencils/interpolation_astnodes.py @@ -50,6 +50,9 @@ class _InterpolationSymbol(TypedSymbol): def __getnewargs__(self): return self.name, self.field, self.interpolator + def __getnewargs_ex__(self): + return (self.name, self.field, self.interpolator), {} + # noinspection SpellCheckingInspection __xnew__ = staticmethod(__new_stage2__) # noinspection SpellCheckingInspection @@ -167,8 +170,8 @@ class NearestNeightborInterpolator(Interpolator): class InterpolatorAccess(TypedSymbol): - def __new__(cls, field, *offsets, **kwargs): - obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs) + def __new__(cls, field, *offsets): + obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets) return obj def __new_stage2__(cls, symbol, *offsets): @@ -354,13 +357,16 @@ class InterpolatorAccess(TypedSymbol): def __getnewargs__(self): return (self.symbol, *self.offsets) + def __getnewargs_ex__(self): + return (self.symbol, *self.offsets), {} + class DiffInterpolatorAccess(InterpolatorAccess): - def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs): + def __new__(cls, symbol, diff_coordinate_idx, *offsets): if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR: from pystencils.fd import Diff, Discretization2ndOrder return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx)) - obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs) + obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets) return obj def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets): @@ -399,6 +405,9 @@ class DiffInterpolatorAccess(InterpolatorAccess): def __getnewargs__(self): return (self.symbol, self.diff_coordinate_idx, *self.offsets) + def __getnewargs_ex__(self): + return (self.symbol, self.diff_coordinate_idx, *self.offsets), {} + ########################################################################################## # GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols # diff --git a/pystencils/kernelparameters.py b/pystencils/kernelparameters.py index 3257522e419bf921b13010215e44a51a5290ce80..934c305cc21e3a5bcad2e9f6076230dd69ec1d40 100644 --- a/pystencils/kernelparameters.py +++ b/pystencils/kernelparameters.py @@ -38,6 +38,9 @@ class FieldStrideSymbol(TypedSymbol): def __getnewargs__(self): return self.field_name, self.coordinate + def __getnewargs_ex__(self): + return (self.field_name, self.coordinate), {} + __xnew__ = staticmethod(__new_stage2__) __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) @@ -63,6 +66,9 @@ class FieldShapeSymbol(TypedSymbol): def __getnewargs__(self): return self.field_names, self.coordinate + def __getnewargs_ex__(self): + return (self.field_names, self.coordinate), {} + __xnew__ = staticmethod(__new_stage2__) __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) @@ -86,6 +92,9 @@ class FieldPointerSymbol(TypedSymbol): def __getnewargs__(self): return self.field_name, self.dtype, self.dtype.const + def __getnewargs_ex__(self): + return (self.field_name, self.dtype, self.dtype.const), {} + def _hashable_content(self): return super()._hashable_content(), self.field_name diff --git a/pystencils_tests/test_astnodes.py b/pystencils_tests/test_astnodes.py index 1d66eae29f1ec4edd7af48813979abb6a9fd1087..a3edc1c2bfb860b2ff967dc74694a27442addb22 100644 --- a/pystencils_tests/test_astnodes.py +++ b/pystencils_tests/test_astnodes.py @@ -5,7 +5,7 @@ import pystencils as ps from pystencils import Assignment from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment -sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.')] +sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.') if x.isdigit()] if len(sympy_numeric_version) < 3: sympy_numeric_version.append(0) sympy_numeric_version.reverse()