Skip to content

Commit f312a0f

Browse files
committed
Override __bool__ of TypedListType
1 parent b6ea0b1 commit f312a0f

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

pytensor/typed_list/basic.py

+20
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def __getitem__(self, index):
1919
def __len__(self):
2020
return length(self)
2121

22+
def __bool__(self):
23+
# Truthiness of typedLists cannot depend on length,
24+
# just like truthiness of TensorVariables does not depend on size or contents
25+
return True
26+
2227
def append(self, toAppend):
2328
return append(self, toAppend)
2429

@@ -677,3 +682,18 @@ def perform(self, node, inputs, outputs):
677682
All PyTensor variables must have the same type.
678683
679684
"""
685+
686+
687+
class MakeEmptyList(Op):
688+
__props__ = ()
689+
690+
def make_node(self, ttype):
691+
tl = TypedListType(ttype)()
692+
return Apply(self, [], [tl])
693+
694+
def perform(self, node, inputs, outputs):
695+
(out,) = outputs
696+
out[0] = []
697+
698+
699+
make_empty_list = MakeEmptyList()

tests/typed_list/test_type.py

+4
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,7 @@ def test_variable_is_Typed_List_variable(self):
150150
)()
151151

152152
assert isinstance(mySymbolicVariable, TypedListVariable)
153+
154+
def test_any(self):
155+
tlist = TypedListType(TensorType(dtype="int64", shape=(None,)))()
156+
assert any([tlist])

0 commit comments

Comments
 (0)