Skip to content

Commit 6474f9f

Browse files
committed
replace pytensorf identity with pytensor identity
1 parent 44a75b0 commit 6474f9f

File tree

2 files changed

+1
-32
lines changed

2 files changed

+1
-32
lines changed

pymc/pytensorf.py

+1-23
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import pytensor.tensor as pt
2323
import scipy.sparse as sps
2424

25-
from pytensor import scalar
2625
from pytensor.compile import Function, Mode, get_mode
2726
from pytensor.compile.builders import OpFromGraph
2827
from pytensor.gradient import grad
@@ -38,6 +37,7 @@
3837
from pytensor.graph.fg import FunctionGraph
3938
from pytensor.graph.op import Op
4039
from pytensor.scalar.basic import Cast
40+
from pytensor.scalar.basic import identity as scalar_identity
4141
from pytensor.scan.op import Scan
4242
from pytensor.tensor.basic import _as_tensor_variable
4343
from pytensor.tensor.elemwise import Elemwise
@@ -412,28 +412,6 @@ def hessian_diag(f, vars=None, negate_output=True):
412412
return empty_gradient
413413

414414

415-
class IdentityOp(scalar.UnaryScalarOp):
416-
@staticmethod
417-
def st_impl(x):
418-
return x
419-
420-
def impl(self, x):
421-
return x
422-
423-
def grad(self, inp, grads):
424-
return grads
425-
426-
def c_code(self, node, name, inp, out, sub):
427-
return f"{out[0]} = {inp[0]};"
428-
429-
def __eq__(self, other):
430-
return isinstance(self, type(other))
431-
432-
def __hash__(self):
433-
return hash(type(self))
434-
435-
436-
scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity")
437415
identity = Elemwise(scalar_identity, name="identity")
438416

439417

pymc/sampling/jax.py

-9
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from pymc.distributions.multivariate import PosDefMatrix
4747
from pymc.initial_point import StartDict
4848
from pymc.logprob.utils import CheckParameterValue
49-
from pymc.pytensorf import IdentityOp
5049
from pymc.sampling.mcmc import _init_jitter
5150
from pymc.stats.convergence import log_warnings, run_convergence_checks
5251
from pymc.util import (
@@ -70,14 +69,6 @@
7069
)
7170

7271

73-
@jax_funcify.register(IdentityOp)
74-
def jax_funcify_Identity(op, **kwargs):
75-
def identity_fn(value):
76-
return value
77-
78-
return identity_fn
79-
80-
8172
@jax_funcify.register(Assert)
8273
@jax_funcify.register(CheckParameterValue)
8374
def jax_funcify_Assert(op, **kwargs):

0 commit comments

Comments
 (0)