Skip to content

Commit f6d1b33

Browse files
ferrinericardoV94
authored andcommitted
fix: Replace unsafe under vectorize pt.zeros(value.shape) with zeros_like
1 parent 8fd4f1c commit f6d1b33

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc/logprob/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ def log_jac_det(self, value, *inputs):
961961
N = N.astype(value.dtype)
962962
sum_value = pt.sum(value, -1, keepdims=True)
963963
value_sum_expanded = value + sum_value
964-
value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1)
964+
value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros_like(sum_value)], -1)
965965
logsumexp_value_expanded = pt.logsumexp(value_sum_expanded, -1, keepdims=True)
966966
res = pt.log(N) + (N * sum_value) - (N * logsumexp_value_expanded)
967967
return pt.sum(res, -1)
@@ -977,7 +977,7 @@ def forward(self, value, *inputs):
977977
return pt.as_tensor_variable(value)
978978

979979
def log_jac_det(self, value, *inputs):
980-
return pt.zeros(value.shape)
980+
return pt.zeros_like(value)
981981

982982

983983
class ChainedTransform(Transform):

0 commit comments

Comments
 (0)