diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 3d1940e6971b1fb5b3a074de6abc81685f7f7e31..f399287ed02ec4eb0d3d0e295d72ee1cf5ecc14b 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -673,7 +673,7 @@ class SympyAssignment(Node): return hash((self.lhs, self.rhs)) def __eq__(self, other): - return type(self) == type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs) + return type(self) is type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs) class ResolvedFieldAccess(sp.Indexed): diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 5c8259699247d3c20f893c933c71bf37058010ed..7dbf84d378d768530cfe9c706186f7d2a684581f 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -152,7 +152,7 @@ class CustomCodeNode(Node): return self._symbols_read - self._symbols_defined def __eq__(self, other): - return type(self) == type(other) and self._code == other._code + return type(self) is type(other) and self._code == other._code def __hash__(self): return hash(self._code) diff --git a/pystencils/boundaries/boundaryconditions.py b/pystencils/boundaries/boundaryconditions.py index 65243177dafe6dbce44cfb14bf7f1eb5c53fa39c..5fd8480b65539994bfae3e2c472a81fd4fa86f51 100644 --- a/pystencils/boundaries/boundaryconditions.py +++ b/pystencils/boundaries/boundaryconditions.py @@ -76,7 +76,7 @@ class Neumann(Boundary): return hash("Neumann") def __eq__(self, other): - return type(other) == Neumann + return type(other) is Neumann class Dirichlet(Boundary): diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 0e2b77da646729bc87af4f359c12db82195bbed8..872f0b3c45983a38b5b1be8fd7f425eb422b570d 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -295,7 +295,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): elif isinstance(expr, CastFunc): cast_type = expr.args[1] arg = visit_expr(expr.args[0], default_type, force_vectorize) - assert cast_type in [BasicType('float32'), BasicType('float64')],\ + assert cast_type in [BasicType('float32'), BasicType('float64')], \ f'Vectorization cannot vectorize type {cast_type}' return expr.func(arg, VectorType(cast_type, instruction_set['width'])) elif expr.func is sp.Abs and 'abs' not in instruction_set: diff --git a/pystencils/rng.py b/pystencils/rng.py index 6e9bc95480cf83654ce4b4b0b7d783fbb0c6718b..84155b00c28f685b584ddb42e806e425e21df486 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -65,7 +65,7 @@ class RNGBase(CustomCodeNode): return (self._name, *self.result_symbols, *self.args) def __eq__(self, other): - return type(self) == type(other) and self._hashable_content() == other._hashable_content() + return type(self) is type(other) and self._hashable_content() == other._hashable_content() def __hash__(self): return hash(self._hashable_content()) diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 40be43eaaa2b6b014bf93447afcb8a68ab407eff..680b58670f61a188c0474a1cf9137fa35a552a01 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -356,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order factor_count = 0 if type(product) is Mul: for factor in product.args: - if type(factor) == Pow: + if type(factor) is Pow: if factor.args[0] in symbols: factor_count += factor.args[1] if factor in symbols: @@ -366,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order factor_count += product.args[1] return factor_count - if type(expr) == Mul or type(expr) == Pow: + if type(expr) is Mul or type(expr) is Pow: if velocity_factors_in_product(expr) <= order: return expr else: return Zero() - if type(expr) != Add: + if type(expr) is not Add: return expr for sum_term in expr.args: