Skip to content

Commit 7ca61dd

Browse files
committed
Implemented Repeat and Unique Ops in PyTorch
1 parent a6e79f2 commit 7ca61dd

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

pytensor/link/pytorch/dispatch/extra_ops.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4-
from pytensor.tensor.extra_ops import CumOp
4+
from pytensor.tensor.extra_ops import CumOp, Repeat, Unique
55

66

77
@pytorch_funcify.register(CumOp)
@@ -21,3 +21,44 @@ def cumop(x):
2121
return torch.cumprod(x, dim=dim)
2222

2323
return cumop
24+
25+
26+
@pytorch_funcify.register(Repeat)
27+
def pytorch_funcify_Repeat(op, **kwargs):
28+
axis = op.axis
29+
30+
def repeat(x, repeats, axis=axis):
31+
return x.repeat_interleave(repeats, dim=axis)
32+
33+
return repeat
34+
35+
36+
@pytorch_funcify.register(Unique)
37+
def pytorch_funcify_Unique(op, **kwargs):
38+
return_index = op.return_index
39+
40+
if return_index:
41+
# TODO: evaluate whether is worth implementing this param
42+
# (see https://github.com/pytorch/pytorch/issues/36748)
43+
raise NotImplementedError("return_index is not implemented for pytorch")
44+
45+
axis = op.axis
46+
return_inverse = op.return_inverse
47+
return_counts = op.return_counts
48+
49+
def unique(
50+
x,
51+
return_index=return_index,
52+
return_inverse=return_inverse,
53+
return_counts=return_counts,
54+
axis=axis,
55+
):
56+
return torch.unique(
57+
x,
58+
sorted=True,
59+
return_inverse=return_inverse,
60+
return_counts=return_counts,
61+
dim=axis,
62+
)
63+
64+
return unique

tests/link/pytorch/test_extra_ops.py

+46
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,49 @@ def test_pytorch_CumOp(axis, dtype):
4141
out = pt.cumprod(a, axis=axis)
4242
fgraph = FunctionGraph([a], [out])
4343
compare_pytorch_and_py(fgraph, [test_value])
44+
45+
46+
def test_pytorch_Repeat():
47+
a = pt.matrix("a", dtype="float64")
48+
49+
test_value = np.arange(6, dtype="float64").reshape((3, 2))
50+
51+
# Test along axis 0
52+
out = pt.repeat(a, (1, 2, 3), axis=0)
53+
fgraph = FunctionGraph([a], [out])
54+
compare_pytorch_and_py(fgraph, [test_value])
55+
56+
# Test along axis 1
57+
out = pt.repeat(a, (3, 3), axis=1)
58+
fgraph = FunctionGraph([a], [out])
59+
compare_pytorch_and_py(fgraph, [test_value])
60+
61+
62+
def test_pytorch_Unique():
63+
a = pt.matrix("a", dtype="float64")
64+
65+
test_value = np.array(
66+
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
67+
)
68+
69+
# Test along axis 0
70+
out = pt.unique(a, axis=0)
71+
fgraph = FunctionGraph([a], [out])
72+
compare_pytorch_and_py(fgraph, [test_value])
73+
74+
# Test along axis 1
75+
out = pt.unique(a, axis=1)
76+
fgraph = FunctionGraph([a], [out])
77+
compare_pytorch_and_py(fgraph, [test_value])
78+
79+
# Test with params
80+
out = pt.unique(a, return_inverse=True, return_counts=True, axis=0)
81+
fgraph = FunctionGraph([a], [out[0]])
82+
compare_pytorch_and_py(fgraph, [test_value])
83+
84+
# Test with return_index=True
85+
out = pt.unique(a, return_index=True, axis=0)
86+
fgraph = FunctionGraph([a], [out[0]])
87+
88+
with pytest.raises(NotImplementedError):
89+
compare_pytorch_and_py(fgraph, [test_value])

0 commit comments

Comments
 (0)