Skip to content

Commit c6865c4

Browse files
Fix Theano type determination for generator observations
1 parent 32bffec commit c6865c4

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

pymc3/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,11 @@ def init_value(self):
16751675

16761676

16771677
def pandas_to_array(data):
1678+
"""Convert a Pandas object to a NumPy array.
1679+
1680+
XXX: When `data` is a generator, this will return a Theano tensor!
1681+
1682+
"""
16781683
if hasattr(data, "values"): # pandas
16791684
if data.isnull().any().any(): # missing values
16801685
ret = np.ma.MaskedArray(data.values, data.isnull().values)
@@ -1776,7 +1781,10 @@ def __init__(
17761781

17771782
if type is None:
17781783
data = pandas_to_array(data)
1779-
type = TensorType(distribution.dtype, data.shape)
1784+
if isinstance(data, theano.gof.graph.Variable):
1785+
type = data.type
1786+
else:
1787+
type = TensorType(distribution.dtype, data.shape)
17801788

17811789
self.observations = data
17821790

0 commit comments

Comments
 (0)