Skip to content

Commit 4c78408

Browse files
authored
Implemented Repeat and Unique Ops in PyTorch (#890)
1 parent a6b9585 commit 4c78408

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

pytensor/link/pytorch/dispatch/extra_ops.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4-
from pytensor.tensor.extra_ops import CumOp
4+
from pytensor.tensor.extra_ops import CumOp, Repeat, Unique
55

66

77
@pytorch_funcify.register(CumOp)
@@ -21,3 +21,38 @@ def cumop(x):
2121
return torch.cumprod(x, dim=dim)
2222

2323
return cumop
24+
25+
26+
@pytorch_funcify.register(Repeat)
27+
def pytorch_funcify_Repeat(op, **kwargs):
28+
axis = op.axis
29+
30+
def repeat(x, repeats):
31+
return x.repeat_interleave(repeats, dim=axis)
32+
33+
return repeat
34+
35+
36+
@pytorch_funcify.register(Unique)
37+
def pytorch_funcify_Unique(op, **kwargs):
38+
return_index = op.return_index
39+
40+
if return_index:
41+
# TODO: evaluate whether is worth implementing this param
42+
# (see https://github.com/pytorch/pytorch/issues/36748)
43+
raise NotImplementedError("return_index is not implemented for pytorch")
44+
45+
axis = op.axis
46+
return_inverse = op.return_inverse
47+
return_counts = op.return_counts
48+
49+
def unique(x):
50+
return torch.unique(
51+
x,
52+
sorted=True,
53+
return_inverse=return_inverse,
54+
return_counts=return_counts,
55+
dim=axis,
56+
)
57+
58+
return unique

tests/link/pytorch/test_extra_ops.py

+58
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,61 @@ def test_pytorch_CumOp(axis, dtype):
4141
out = pt.cumprod(a, axis=axis)
4242
fgraph = FunctionGraph([a], [out])
4343
compare_pytorch_and_py(fgraph, [test_value])
44+
45+
46+
@pytest.mark.parametrize(
47+
"axis, repeats",
48+
[
49+
(0, (1, 2, 3)),
50+
(1, (3, 3)),
51+
pytest.param(
52+
None,
53+
3,
54+
marks=pytest.mark.xfail(reason="Reshape not implemented"),
55+
),
56+
],
57+
)
58+
def test_pytorch_Repeat(axis, repeats):
59+
a = pt.matrix("a", dtype="float64")
60+
61+
test_value = np.arange(6, dtype="float64").reshape((3, 2))
62+
63+
out = pt.repeat(a, repeats, axis=axis)
64+
fgraph = FunctionGraph([a], [out])
65+
compare_pytorch_and_py(fgraph, [test_value])
66+
67+
68+
@pytest.mark.parametrize("axis", [None, 0, 1])
69+
def test_pytorch_Unique_axis(axis):
70+
a = pt.matrix("a", dtype="float64")
71+
72+
test_value = np.array(
73+
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
74+
)
75+
76+
out = pt.unique(a, axis=axis)
77+
fgraph = FunctionGraph([a], [out])
78+
compare_pytorch_and_py(fgraph, [test_value])
79+
80+
81+
@pytest.mark.parametrize("return_inverse", [False, True])
82+
@pytest.mark.parametrize("return_counts", [False, True])
83+
@pytest.mark.parametrize(
84+
"return_index",
85+
(False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))),
86+
)
87+
def test_pytorch_Unique_params(return_index, return_inverse, return_counts):
88+
a = pt.matrix("a", dtype="float64")
89+
test_value = np.array(
90+
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
91+
)
92+
93+
out = pt.unique(
94+
a,
95+
return_index=return_index,
96+
return_inverse=return_inverse,
97+
return_counts=return_counts,
98+
axis=0,
99+
)
100+
fgraph = FunctionGraph([a], [out[0] if isinstance(out, list) else out])
101+
compare_pytorch_and_py(fgraph, [test_value])

0 commit comments

Comments
 (0)