Commit 17ea6bd2 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'pickle' into 'master'

Sympy 1.9 support

Closes #35

See merge request pycodegen/pystencils!239
parents 059de5fb c1bb1ec2
Pipeline #31755 passed with stage
in 11 minutes and 30 seconds
......@@ -692,6 +692,9 @@ class ResolvedFieldAccess(sp.Indexed):
def __getnewargs__(self):
return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
def __getnewargs_ex__(self):
return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {}
class TemporaryMemoryAllocation(Node):
"""Node for temporary memory buffer allocation.
......@@ -837,6 +840,9 @@ class ConditionalFieldAccess(sp.Function):
def __getnewargs__(self):
return self.access, self.outofbounds_condition, self.outofbounds_value
def __getnewargs_ex__(self):
return (self.access, self.outofbounds_condition, self.outofbounds_value), {}
class NontemporalFence(Node):
def __init__(self):
......
......@@ -219,9 +219,10 @@ class TypedSymbol(sp.Symbol):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype, *args, **kwargs):
def __new_stage2__(cls, name, dtype, **kwargs):
assumptions = assumptions_from_dtype(dtype)
obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
try:
obj._dtype = create_type(dtype)
except (TypeError, ValueError):
......@@ -242,6 +243,9 @@ class TypedSymbol(sp.Symbol):
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), self.assumptions0
@property
def canonical(self):
return self
......@@ -578,6 +582,13 @@ def get_type_of_expression(expr,
raise NotImplementedError("Could not determine type for", expr, type(expr))
sympy_version = sp.__version__.split('.')
if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
# __setstate__ would bypass the contructor, so we remove it
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
class Type(sp.Atom):
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
......@@ -621,6 +632,9 @@ class BasicType(Type):
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
......@@ -717,6 +731,9 @@ class VectorType(Type):
def __getnewargs__(self):
return self._base_type, self.width
def __getnewargs_ex__(self):
return (self._base_type, self.width), {}
class PointerType(Type):
def __init__(self, base_type, const=False, restrict=True):
......@@ -727,6 +744,9 @@ class PointerType(Type):
def __getnewargs__(self):
return self.base_type, self.const, self.restrict
def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict), {}
@property
def alias(self):
return not self.restrict
......@@ -768,6 +788,9 @@ class StructType:
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
......@@ -815,13 +838,11 @@ class TypedImaginaryUnit(TypedSymbol):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, dtype, *args, **kwargs):
def __new_stage2__(cls, dtype):
obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
"_i",
dtype,
imaginary=True,
*args,
**kwargs)
imaginary=True)
return obj
headers = ['"cuda_complex.hpp"']
......@@ -831,3 +852,6 @@ class TypedImaginaryUnit(TypedSymbol):
def __getnewargs__(self):
return (self.dtype,)
def __getnewargs_ex__(self):
return (self.dtype,), {}
......@@ -729,6 +729,9 @@ class Field(AbstractField):
def __getnewargs__(self):
return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype
def __getnewargs_ex__(self):
return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {}
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
# noinspection SpellCheckingInspection
......
......@@ -50,6 +50,9 @@ class _InterpolationSymbol(TypedSymbol):
def __getnewargs__(self):
return self.name, self.field, self.interpolator
def __getnewargs_ex__(self):
return (self.name, self.field, self.interpolator), {}
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
# noinspection SpellCheckingInspection
......@@ -167,8 +170,8 @@ class NearestNeightborInterpolator(Interpolator):
class InterpolatorAccess(TypedSymbol):
def __new__(cls, field, *offsets, **kwargs):
obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
def __new__(cls, field, *offsets):
obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets)
return obj
def __new_stage2__(cls, symbol, *offsets):
......@@ -354,13 +357,16 @@ class InterpolatorAccess(TypedSymbol):
def __getnewargs__(self):
return (self.symbol, *self.offsets)
def __getnewargs_ex__(self):
return (self.symbol, *self.offsets), {}
class DiffInterpolatorAccess(InterpolatorAccess):
def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs):
def __new__(cls, symbol, diff_coordinate_idx, *offsets):
if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR:
from pystencils.fd import Diff, Discretization2ndOrder
return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx))
obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs)
obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets)
return obj
def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets):
......@@ -399,6 +405,9 @@ class DiffInterpolatorAccess(InterpolatorAccess):
def __getnewargs__(self):
return (self.symbol, self.diff_coordinate_idx, *self.offsets)
def __getnewargs_ex__(self):
return (self.symbol, self.diff_coordinate_idx, *self.offsets), {}
##########################################################################################
# GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols #
......
......@@ -38,6 +38,9 @@ class FieldStrideSymbol(TypedSymbol):
def __getnewargs__(self):
return self.field_name, self.coordinate
def __getnewargs_ex__(self):
return (self.field_name, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
......@@ -63,6 +66,9 @@ class FieldShapeSymbol(TypedSymbol):
def __getnewargs__(self):
return self.field_names, self.coordinate
def __getnewargs_ex__(self):
return (self.field_names, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
......@@ -86,6 +92,9 @@ class FieldPointerSymbol(TypedSymbol):
def __getnewargs__(self):
return self.field_name, self.dtype, self.dtype.const
def __getnewargs_ex__(self):
return (self.field_name, self.dtype, self.dtype.const), {}
def _hashable_content(self):
return super()._hashable_content(), self.field_name
......
......@@ -5,7 +5,7 @@ import pystencils as ps
from pystencils import Assignment
from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.')]
sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.') if x.isdigit()]
if len(sympy_numeric_version) < 3:
sympy_numeric_version.append(0)
sympy_numeric_version.reverse()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment