Skip to content
Snippets Groups Projects
Commit c1bb1ec2 authored by Michael Kuron's avatar Michael Kuron :mortar_board: Committed by Markus Holzer
Browse files

Sympy 1.9 support

parent 059de5fb
1 merge request!239Sympy 1.9 support
...@@ -692,6 +692,9 @@ class ResolvedFieldAccess(sp.Indexed): ...@@ -692,6 +692,9 @@ class ResolvedFieldAccess(sp.Indexed):
def __getnewargs__(self): def __getnewargs__(self):
return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
def __getnewargs_ex__(self):
return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {}
class TemporaryMemoryAllocation(Node): class TemporaryMemoryAllocation(Node):
"""Node for temporary memory buffer allocation. """Node for temporary memory buffer allocation.
...@@ -837,6 +840,9 @@ class ConditionalFieldAccess(sp.Function): ...@@ -837,6 +840,9 @@ class ConditionalFieldAccess(sp.Function):
def __getnewargs__(self): def __getnewargs__(self):
return self.access, self.outofbounds_condition, self.outofbounds_value return self.access, self.outofbounds_condition, self.outofbounds_value
def __getnewargs_ex__(self):
return (self.access, self.outofbounds_condition, self.outofbounds_value), {}
class NontemporalFence(Node): class NontemporalFence(Node):
def __init__(self): def __init__(self):
......
...@@ -219,9 +219,10 @@ class TypedSymbol(sp.Symbol): ...@@ -219,9 +219,10 @@ class TypedSymbol(sp.Symbol):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds) obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj return obj
def __new_stage2__(cls, name, dtype, *args, **kwargs): def __new_stage2__(cls, name, dtype, **kwargs):
assumptions = assumptions_from_dtype(dtype) assumptions = assumptions_from_dtype(dtype)
obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs) assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
try: try:
obj._dtype = create_type(dtype) obj._dtype = create_type(dtype)
except (TypeError, ValueError): except (TypeError, ValueError):
...@@ -242,6 +243,9 @@ class TypedSymbol(sp.Symbol): ...@@ -242,6 +243,9 @@ class TypedSymbol(sp.Symbol):
def __getnewargs__(self): def __getnewargs__(self):
return self.name, self.dtype return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), self.assumptions0
@property @property
def canonical(self): def canonical(self):
return self return self
...@@ -578,6 +582,13 @@ def get_type_of_expression(expr, ...@@ -578,6 +582,13 @@ def get_type_of_expression(expr,
raise NotImplementedError("Could not determine type for", expr, type(expr)) raise NotImplementedError("Could not determine type for", expr, type(expr))
sympy_version = sp.__version__.split('.')
if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
# __setstate__ would bypass the contructor, so we remove it
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
class Type(sp.Atom): class Type(sp.Atom):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls) return sp.Basic.__new__(cls)
...@@ -621,6 +632,9 @@ class BasicType(Type): ...@@ -621,6 +632,9 @@ class BasicType(Type):
def __getnewargs__(self): def __getnewargs__(self):
return self.numpy_dtype, self.const return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property @property
def base_type(self): def base_type(self):
return None return None
...@@ -717,6 +731,9 @@ class VectorType(Type): ...@@ -717,6 +731,9 @@ class VectorType(Type):
def __getnewargs__(self): def __getnewargs__(self):
return self._base_type, self.width return self._base_type, self.width
def __getnewargs_ex__(self):
return (self._base_type, self.width), {}
class PointerType(Type): class PointerType(Type):
def __init__(self, base_type, const=False, restrict=True): def __init__(self, base_type, const=False, restrict=True):
...@@ -727,6 +744,9 @@ class PointerType(Type): ...@@ -727,6 +744,9 @@ class PointerType(Type):
def __getnewargs__(self): def __getnewargs__(self):
return self.base_type, self.const, self.restrict return self.base_type, self.const, self.restrict
def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict), {}
@property @property
def alias(self): def alias(self):
return not self.restrict return not self.restrict
...@@ -768,6 +788,9 @@ class StructType: ...@@ -768,6 +788,9 @@ class StructType:
def __getnewargs__(self): def __getnewargs__(self):
return self.numpy_dtype, self.const return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property @property
def base_type(self): def base_type(self):
return None return None
...@@ -815,13 +838,11 @@ class TypedImaginaryUnit(TypedSymbol): ...@@ -815,13 +838,11 @@ class TypedImaginaryUnit(TypedSymbol):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds) obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj return obj
def __new_stage2__(cls, dtype, *args, **kwargs): def __new_stage2__(cls, dtype):
obj = super(TypedImaginaryUnit, cls).__xnew__(cls, obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
"_i", "_i",
dtype, dtype,
imaginary=True, imaginary=True)
*args,
**kwargs)
return obj return obj
headers = ['"cuda_complex.hpp"'] headers = ['"cuda_complex.hpp"']
...@@ -831,3 +852,6 @@ class TypedImaginaryUnit(TypedSymbol): ...@@ -831,3 +852,6 @@ class TypedImaginaryUnit(TypedSymbol):
def __getnewargs__(self): def __getnewargs__(self):
return (self.dtype,) return (self.dtype,)
def __getnewargs_ex__(self):
return (self.dtype,), {}
...@@ -729,6 +729,9 @@ class Field(AbstractField): ...@@ -729,6 +729,9 @@ class Field(AbstractField):
def __getnewargs__(self): def __getnewargs__(self):
return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype
def __getnewargs_ex__(self):
return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {}
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
......
...@@ -50,6 +50,9 @@ class _InterpolationSymbol(TypedSymbol): ...@@ -50,6 +50,9 @@ class _InterpolationSymbol(TypedSymbol):
def __getnewargs__(self): def __getnewargs__(self):
return self.name, self.field, self.interpolator return self.name, self.field, self.interpolator
def __getnewargs_ex__(self):
return (self.name, self.field, self.interpolator), {}
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
...@@ -167,8 +170,8 @@ class NearestNeightborInterpolator(Interpolator): ...@@ -167,8 +170,8 @@ class NearestNeightborInterpolator(Interpolator):
class InterpolatorAccess(TypedSymbol): class InterpolatorAccess(TypedSymbol):
def __new__(cls, field, *offsets, **kwargs): def __new__(cls, field, *offsets):
obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs) obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets)
return obj return obj
def __new_stage2__(cls, symbol, *offsets): def __new_stage2__(cls, symbol, *offsets):
...@@ -354,13 +357,16 @@ class InterpolatorAccess(TypedSymbol): ...@@ -354,13 +357,16 @@ class InterpolatorAccess(TypedSymbol):
def __getnewargs__(self): def __getnewargs__(self):
return (self.symbol, *self.offsets) return (self.symbol, *self.offsets)
def __getnewargs_ex__(self):
return (self.symbol, *self.offsets), {}
class DiffInterpolatorAccess(InterpolatorAccess): class DiffInterpolatorAccess(InterpolatorAccess):
def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs): def __new__(cls, symbol, diff_coordinate_idx, *offsets):
if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR: if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR:
from pystencils.fd import Diff, Discretization2ndOrder from pystencils.fd import Diff, Discretization2ndOrder
return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx)) return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx))
obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs) obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets)
return obj return obj
def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets): def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets):
...@@ -399,6 +405,9 @@ class DiffInterpolatorAccess(InterpolatorAccess): ...@@ -399,6 +405,9 @@ class DiffInterpolatorAccess(InterpolatorAccess):
def __getnewargs__(self): def __getnewargs__(self):
return (self.symbol, self.diff_coordinate_idx, *self.offsets) return (self.symbol, self.diff_coordinate_idx, *self.offsets)
def __getnewargs_ex__(self):
return (self.symbol, self.diff_coordinate_idx, *self.offsets), {}
########################################################################################## ##########################################################################################
# GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols # # GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols #
......
...@@ -38,6 +38,9 @@ class FieldStrideSymbol(TypedSymbol): ...@@ -38,6 +38,9 @@ class FieldStrideSymbol(TypedSymbol):
def __getnewargs__(self): def __getnewargs__(self):
return self.field_name, self.coordinate return self.field_name, self.coordinate
def __getnewargs_ex__(self):
return (self.field_name, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
...@@ -63,6 +66,9 @@ class FieldShapeSymbol(TypedSymbol): ...@@ -63,6 +66,9 @@ class FieldShapeSymbol(TypedSymbol):
def __getnewargs__(self): def __getnewargs__(self):
return self.field_names, self.coordinate return self.field_names, self.coordinate
def __getnewargs_ex__(self):
return (self.field_names, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
...@@ -86,6 +92,9 @@ class FieldPointerSymbol(TypedSymbol): ...@@ -86,6 +92,9 @@ class FieldPointerSymbol(TypedSymbol):
def __getnewargs__(self): def __getnewargs__(self):
return self.field_name, self.dtype, self.dtype.const return self.field_name, self.dtype, self.dtype.const
def __getnewargs_ex__(self):
return (self.field_name, self.dtype, self.dtype.const), {}
def _hashable_content(self): def _hashable_content(self):
return super()._hashable_content(), self.field_name return super()._hashable_content(), self.field_name
......
...@@ -5,7 +5,7 @@ import pystencils as ps ...@@ -5,7 +5,7 @@ import pystencils as ps
from pystencils import Assignment from pystencils import Assignment
from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.')] sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.') if x.isdigit()]
if len(sympy_numeric_version) < 3: if len(sympy_numeric_version) < 3:
sympy_numeric_version.append(0) sympy_numeric_version.append(0)
sympy_numeric_version.reverse() sympy_numeric_version.reverse()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment