Skip to content

Commit e8b04b7

Browse files
committed
Fix E721: do not compare types, for exact checks use is / is not
1 parent 4c417cc commit e8b04b7

File tree

27 files changed

+43
-43
lines changed

27 files changed

+43
-43
lines changed

pytensor/compile/debugmode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def _lessbroken_deepcopy(a):
689689
else:
690690
rval = copy.deepcopy(a)
691691

692-
assert type(rval) == type(a), (type(rval), type(a))
692+
assert type(rval) is type(a), (type(rval), type(a))
693693

694694
if isinstance(rval, np.ndarray):
695695
assert rval.dtype == a.dtype
@@ -1156,7 +1156,7 @@ def __str__(self):
11561156
return str(self.__dict__)
11571157

11581158
def __eq__(self, other):
1159-
rval = type(self) == type(other)
1159+
rval = type(self) is type(other)
11601160
if rval:
11611161
# nodes are not compared because this comparison is
11621162
# supposed to be true for corresponding events that happen

pytensor/compile/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(self, fn, itypes, otypes, infer_shape):
246246
self.infer_shape = self._infer_shape
247247

248248
def __eq__(self, other):
249-
return type(self) == type(other) and self.__fn == other.__fn
249+
return type(self) is type(other) and self.__fn == other.__fn
250250

251251
def __hash__(self):
252252
return hash(type(self)) ^ hash(self.__fn)

pytensor/graph/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def __eq__(self, other):
718718
return True
719719

720720
return (
721-
type(self) == type(other)
721+
type(self) is type(other)
722722
and self.id == other.id
723723
and self.type == other.type
724724
)

pytensor/graph/null_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def values_eq(self, a, b, force_same_dtype=True):
3333
raise ValueError("NullType has no values to compare")
3434

3535
def __eq__(self, other):
36-
return type(self) == type(other)
36+
return type(self) is type(other)
3737

3838
def __hash__(self):
3939
return hash(type(self))

pytensor/graph/rewriting/unify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def __new__(cls, constraint, token=None, prefix=""):
5858
return obj
5959

6060
def __eq__(self, other):
61-
if type(self) == type(other):
62-
return self.token == other.token and self.constraint == other.constraint
61+
if type(self) is type(other):
62+
return self.token is other.token and self.constraint == other.constraint
6363
return NotImplemented
6464

6565
def __hash__(self):

pytensor/graph/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def __hash__(self):
229229
if "__eq__" not in dct:
230230

231231
def __eq__(self, other):
232-
return type(self) == type(other) and tuple(
232+
return type(self) is type(other) and tuple(
233233
getattr(self, a) for a in props
234234
) == tuple(getattr(other, a) for a in props)
235235

pytensor/ifelse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, n_outs, as_view=False, name=None):
7878
self.name = name
7979

8080
def __eq__(self, other):
81-
if type(self) != type(other):
81+
if type(self) is not type(other):
8282
return False
8383
if self.as_view != other.as_view:
8484
return False

pytensor/link/c/params_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def __hash__(self):
297297

298298
def __eq__(self, other):
299299
return (
300-
type(self) == type(other)
300+
type(self) is type(other)
301301
and self.__params_type__ == other.__params_type__
302302
and all(
303303
# NB: Params object should have been already filtered.
@@ -432,7 +432,7 @@ def __repr__(self):
432432

433433
def __eq__(self, other):
434434
return (
435-
type(self) == type(other)
435+
type(self) is type(other)
436436
and self.fields == other.fields
437437
and self.types == other.types
438438
)

pytensor/link/c/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def __hash__(self):
519519

520520
def __eq__(self, other):
521521
return (
522-
type(self) == type(other)
522+
type(self) is type(other)
523523
and self.ctype == other.ctype
524524
and len(self) == len(other)
525525
and len(self.aliases) == len(other.aliases)

pytensor/raise_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class ExceptionType(Generic):
1818
def __eq__(self, other):
19-
return type(self) == type(other)
19+
return type(self) is type(other)
2020

2121
def __hash__(self):
2222
return hash(type(self))
@@ -51,7 +51,7 @@ def __str__(self):
5151
return f"CheckAndRaise{{{self.exc_type}({self.msg})}}"
5252

5353
def __eq__(self, other):
54-
if type(self) != type(other):
54+
if type(self) is not type(other):
5555
return False
5656

5757
if self.msg == other.msg and self.exc_type == other.exc_type:

pytensor/scalar/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ def __call__(self, *types):
10741074
return [rval]
10751075

10761076
def __eq__(self, other):
1077-
return type(self) == type(other) and self.tbl == other.tbl
1077+
return type(self) is type(other) and self.tbl == other.tbl
10781078

10791079
def __hash__(self):
10801080
return hash(type(self)) # ignore hash of table
@@ -1160,7 +1160,7 @@ def L_op(self, inputs, outputs, output_gradients):
11601160
return self.grad(inputs, output_gradients)
11611161

11621162
def __eq__(self, other):
1163-
test = type(self) == type(other) and getattr(
1163+
test = type(self) is type(other) and getattr(
11641164
self, "output_types_preference", None
11651165
) == getattr(other, "output_types_preference", None)
11661166
return test
@@ -4132,7 +4132,7 @@ def __eq__(self, other):
41324132
if self is other:
41334133
return True
41344134
if (
4135-
type(self) != type(other)
4135+
type(self) is not type(other)
41364136
or self.nin != other.nin
41374137
or self.nout != other.nout
41384138
):

pytensor/scalar/math.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def c_code(self, node, name, inp, out, sub):
626626
raise NotImplementedError("only floatingpoint is implemented")
627627

628628
def __eq__(self, other):
629-
return type(self) == type(other)
629+
return type(self) is type(other)
630630

631631
def __hash__(self):
632632
return hash(type(self))
@@ -675,7 +675,7 @@ def c_code(self, node, name, inp, out, sub):
675675
raise NotImplementedError("only floatingpoint is implemented")
676676

677677
def __eq__(self, other):
678-
return type(self) == type(other)
678+
return type(self) is type(other)
679679

680680
def __hash__(self):
681681
return hash(type(self))
@@ -724,7 +724,7 @@ def c_code(self, node, name, inp, out, sub):
724724
raise NotImplementedError("only floatingpoint is implemented")
725725

726726
def __eq__(self, other):
727-
return type(self) == type(other)
727+
return type(self) is type(other)
728728

729729
def __hash__(self):
730730
return hash(type(self))
@@ -1033,7 +1033,7 @@ def c_code(self, node, name, inp, out, sub):
10331033
raise NotImplementedError("only floatingpoint is implemented")
10341034

10351035
def __eq__(self, other):
1036-
return type(self) == type(other)
1036+
return type(self) is type(other)
10371037

10381038
def __hash__(self):
10391039
return hash(type(self))
@@ -1074,7 +1074,7 @@ def c_code(self, node, name, inp, out, sub):
10741074
raise NotImplementedError("only floatingpoint is implemented")
10751075

10761076
def __eq__(self, other):
1077-
return type(self) == type(other)
1077+
return type(self) is type(other)
10781078

10791079
def __hash__(self):
10801080
return hash(type(self))

pytensor/scan/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1244,7 +1244,7 @@ def is_cpu_vector(s):
12441244
return apply_node
12451245

12461246
def __eq__(self, other):
1247-
if type(self) != type(other):
1247+
if type(self) is not type(other):
12481248
return False
12491249

12501250
if self.info != other.info:

pytensor/sparse/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def __eq__(self, other):
451451
return (
452452
a == x
453453
and (b.dtype == y.dtype)
454-
and (type(b) == type(y))
454+
and (type(b) is type(y))
455455
and (b.shape == y.shape)
456456
and (abs(b - y).sum() < 1e-6 * b.nnz)
457457
)

pytensor/tensor/random/type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _eq(sa, sb):
102102
return _eq(sa, sb)
103103

104104
def __eq__(self, other):
105-
return type(self) == type(other)
105+
return type(self) is type(other)
106106

107107
def __hash__(self):
108108
return hash(type(self))
@@ -198,7 +198,7 @@ def _eq(sa, sb):
198198
return _eq(sa, sb)
199199

200200
def __eq__(self, other):
201-
return type(self) == type(other)
201+
return type(self) is type(other)
202202

203203
def __hash__(self):
204204
return hash(type(self))

pytensor/tensor/rewriting/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1770,7 +1770,7 @@ def local_reduce_broadcastable(fgraph, node):
17701770
ii += 1
17711771
new_reduced = reduced.dimshuffle(*pattern)
17721772
if new_axis:
1773-
if type(node.op) == CAReduce:
1773+
if type(node.op) is CAReduce:
17741774
# This case handles `CAReduce` instances
17751775
# (e.g. generated by `scalar_elemwise`), and not the
17761776
# scalar `Op`-specific subclasses

pytensor/tensor/type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def values_eq_approx(
370370
return values_eq_approx(a, b, allow_remove_inf, allow_remove_nan, rtol, atol)
371371

372372
def __eq__(self, other):
373-
if type(self) != type(other):
373+
if type(self) is not type(other):
374374
return NotImplemented
375375

376376
return other.dtype == self.dtype and other.shape == self.shape
@@ -639,7 +639,7 @@ def c_code_cache_version(self):
639639

640640
class DenseTypeMeta(MetaType):
641641
def __instancecheck__(self, o):
642-
if type(o) == TensorType or isinstance(o, DenseTypeMeta):
642+
if type(o) is TensorType or isinstance(o, DenseTypeMeta):
643643
return True
644644
return False
645645

pytensor/tensor/type_other.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __str__(self):
6464
return "slice"
6565

6666
def __eq__(self, other):
67-
return type(self) == type(other)
67+
return type(self) is type(other)
6868

6969
def __hash__(self):
7070
return hash(type(self))

pytensor/tensor/variable.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ class TensorConstantSignature(tuple):
931931
"""
932932

933933
def __eq__(self, other):
934-
if type(self) != type(other):
934+
if type(self) is not type(other):
935935
return False
936936
try:
937937
(t0, d0), (t1, d1) = self, other
@@ -1091,7 +1091,7 @@ def __deepcopy__(self, memo):
10911091

10921092
class DenseVariableMeta(MetaType):
10931093
def __instancecheck__(self, o):
1094-
if type(o) == TensorVariable or isinstance(o, DenseVariableMeta):
1094+
if type(o) is TensorVariable or isinstance(o, DenseVariableMeta):
10951095
return True
10961096
return False
10971097

@@ -1106,7 +1106,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta):
11061106

11071107
class DenseConstantMeta(MetaType):
11081108
def __instancecheck__(self, o):
1109-
if type(o) == TensorConstant or isinstance(o, DenseConstantMeta):
1109+
if type(o) is TensorConstant or isinstance(o, DenseConstantMeta):
11101110
return True
11111111
return False
11121112

pytensor/typed_list/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __eq__(self, other):
5555
Two lists are equal if they contain the same type.
5656
5757
"""
58-
return type(self) == type(other) and self.ttype == other.ttype
58+
return type(self) is type(other) and self.ttype == other.ttype
5959

6060
def __hash__(self):
6161
return hash((type(self), self.ttype))

tests/graph/rewriting/test_unify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def perform(self, node, inputs, outputs):
4242

4343
class CustomOpNoProps(CustomOpNoPropsNoEq):
4444
def __eq__(self, other):
45-
return type(self) == type(other) and self.a == other.a
45+
return type(self) is type(other) and self.a == other.a
4646

4747
def __hash__(self):
4848
return hash((type(self), self.a))

tests/graph/test_fg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ def test_pickle(self):
3030
s = pickle.dumps(func)
3131
new_func = pickle.loads(s)
3232

33-
assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs))
34-
assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs))
33+
assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs))
34+
assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs))
3535
assert all(
3636
type(a.op) is type(b.op) # noqa: E721
3737
for a, b in zip(func.apply_nodes, new_func.apply_nodes)
3838
)
39-
assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
39+
assert all(a.type is b.type for a, b in zip(func.variables, new_func.variables))
4040

4141
def test_validate_inputs(self):
4242
var1 = op1()

tests/graph/test_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, thingy):
2525
self.thingy = thingy
2626

2727
def __eq__(self, other):
28-
return type(other) == type(self) and other.thingy == self.thingy
28+
return type(other) is type(self) and other.thingy == self.thingy
2929

3030
def __str__(self):
3131
return str(self.thingy)

tests/link/c/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def c_code_cache_version(self):
7373
return (1,)
7474

7575
def __eq__(self, other):
76-
return type(self) == type(other)
76+
return type(self) is type(other)
7777

7878
def __hash__(self):
7979
return hash(type(self))

tests/sparse/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def __init__(self, structured):
348348
self.structured = structured
349349

350350
def __eq__(self, other):
351-
return (type(self) == type(other)) and self.structured == other.structured
351+
return (type(self) is type(other)) and self.structured == other.structured
352352

353353
def __hash__(self):
354354
return hash(type(self)) ^ hash(self.structured)

tests/tensor/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3156,7 +3156,7 @@ def test_stack():
31563156
sx, sy = dscalar(), dscalar()
31573157

31583158
rval = inplace_func([sx, sy], stack([sx, sy]))(-4.0, -2.0)
3159-
assert type(rval) == np.ndarray
3159+
assert type(rval) is np.ndarray
31603160
assert [-4, -2] == list(rval)
31613161

31623162

tests/tensor/test_subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def test_ok_list(self):
819819
assert np.allclose(val, good), (val, good)
820820

821821
# Test reuse of output memory
822-
if type(AdvancedSubtensor1) == AdvancedSubtensor1:
822+
if type(AdvancedSubtensor1) is AdvancedSubtensor1:
823823
op = AdvancedSubtensor1()
824824
# When idx is a TensorConstant.
825825
if hasattr(idx, "data"):

0 commit comments

Comments
 (0)