Skip to content

Commit f7b0a7a

Browse files
mory91ricardoV94
authored andcommitted
Remove TopkOp
1 parent ef22377 commit f7b0a7a

File tree

4 files changed

+4
-590
lines changed

4 files changed

+4
-590
lines changed

pytensor/tensor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
142142

143143
# We import as `_shared` instead of `shared` to avoid confusion between
144144
# `pytensor.shared` and `tensor._shared`.
145-
from pytensor.tensor.sort import argsort, argtopk, sort, topk, topk_and_argtopk
145+
from pytensor.tensor.sort import argsort, sort
146146
from pytensor.tensor.subtensor import *
147147
from pytensor.tensor.type import *
148148
from pytensor.tensor.type_other import *

pytensor/tensor/rewriting/basic.py

-32
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
from pytensor.tensor.extra_ops import broadcast_arrays
6969
from pytensor.tensor.math import Sum, add, eq
7070
from pytensor.tensor.shape import Shape_i, shape_padleft
71-
from pytensor.tensor.sort import TopKOp
7271
from pytensor.tensor.type import DenseTensorType, TensorType
7372
from pytensor.tensor.variable import TensorConstant, TensorVariable
7473
from pytensor.utils import NoDuplicateOptWarningFilter
@@ -1224,35 +1223,4 @@ def local_merge_alloc(fgraph, node):
12241223
return [alloc(inputs_inner[0], *dims_outer)]
12251224

12261225

1227-
@register_useless("fast_compile")
1228-
@node_rewriter([TopKOp])
1229-
def local_useless_topk(fgraph, node):
1230-
"""Remove unused `TopKOp` outputs."""
1231-
op = node.op
1232-
if not isinstance(op, TopKOp):
1233-
return
1234-
if not (op.return_values and op.return_indices):
1235-
return False
1236-
1237-
x, k = node.inputs
1238-
ret_val = bool(fgraph.clients[node.outputs[0]])
1239-
ret_idx = bool(fgraph.clients[node.outputs[1]])
1240-
1241-
if not (ret_val ^ ret_idx):
1242-
# both true -> nothing to remove
1243-
# both false -> let pruner handle
1244-
return False
1245-
1246-
old_output = node.outputs[ret_idx]
1247-
new_output = TopKOp(
1248-
axis=op.axis,
1249-
sorted=op.sorted,
1250-
idx_dtype=op.idx_dtype,
1251-
return_values=ret_val,
1252-
return_indices=ret_idx,
1253-
)(x, k)
1254-
copy_stack_trace(node.outputs[0], new_output)
1255-
return {old_output: new_output}
1256-
1257-
12581226
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")

pytensor/tensor/sort.py

+2-271
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
from pytensor.graph.basic import Apply, Constant
55
from pytensor.graph.op import Op
66
from pytensor.misc.safe_asarray import _asarray
7-
from pytensor.tensor.basic import arange, as_tensor_variable, flatten, switch
7+
from pytensor.tensor.basic import arange, as_tensor_variable, switch
88
from pytensor.tensor.math import eq, ge, mul
9-
from pytensor.tensor.shape import shape
10-
from pytensor.tensor.subtensor import set_subtensor
11-
from pytensor.tensor.type import TensorType, integer_dtypes
9+
from pytensor.tensor.type import TensorType
1210

1311

1412
def _variable_is_none(var):
@@ -304,270 +302,3 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
304302
else:
305303
zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
306304
return zi.astype(idx_dtype)
307-
308-
309-
class TopKOp(Op):
310-
"""Operations related to finding k-largest elements.
311-
312-
Parameters
313-
----------
314-
axis: integer
315-
Defaults to ``-1``.
316-
The axis to perform the operation. Must be in range ``[-ndim, ndim)``, where
317-
``ndim`` is the dimensionality of input tensor.
318-
319-
idx_dtype: string
320-
Specify output dtype for indices, defaults to ``int64``, must be integer type.
321-
322-
sorted: bool
323-
NOTE: NOT IMPLEMENTED YET
324-
Defaults to ``True``
325-
326-
If True, the result array would be sorted in descending order.
327-
328-
329-
Notes
330-
-----
331-
- The output order is not guaranteed. On the CPU, we use
332-
``np.partition`` and ``np.argpartition`` that only make sure the
333-
k-th element is the correct one and that the other
334-
elements are on the correct side.
335-
- By default, this Op gives two outputs: values and indices. However
336-
optimizers may remove a certain output if not needed.
337-
- Computing the gradient requests the computation of the indices in
338-
forward pass.
339-
- If the top-k-th value is not unique, we cannot guarantee the
340-
output indices being deterministically chosen.
341-
342-
See Also
343-
--------
344-
topk
345-
argtopk
346-
argtopk_and_topk
347-
348-
"""
349-
350-
# TODO more params
351-
"""
352-
only_top_kth: bool
353-
Defaults to ``False``
354-
355-
If ``True``, will only find one exact top k-th element on given axis.
356-
357-
"""
358-
359-
# TODO c_code
360-
# TODO add opt, if k==1, use max/min reduce
361-
# also if k is axis size, just copy input tensor
362-
# TODO add opt, to merge argtopk / topk
363-
__props__ = ("axis", "sorted", "return_values", "return_indices", "idx_dtype")
364-
365-
def __init__(
366-
self,
367-
axis=-1,
368-
sorted=True,
369-
idx_dtype="int64",
370-
return_values=True,
371-
return_indices=True,
372-
):
373-
# numpy always uses int64 as output dtype for arg*() routines
374-
# however, we add "idx_dtype" param as memory is more precious on gpu
375-
if not isinstance(axis, int):
376-
raise TypeError(f'"axis" parameter must be integer, got "{type(axis)}"')
377-
if sorted:
378-
raise NotImplementedError(
379-
"The sorted parameter is not yet implemented. Use sorted=False for now."
380-
)
381-
if idx_dtype not in integer_dtypes:
382-
raise TypeError(
383-
f'"idx_dtype" parameter must be an integer dtype, got "{idx_dtype}"'
384-
)
385-
386-
if not (return_indices or return_values):
387-
raise ValueError(
388-
"Neither return_values nor return_indices is True, this isn't allowed"
389-
)
390-
391-
self.axis = axis
392-
self.sorted = sorted
393-
self.return_values = return_values
394-
self.return_indices = return_indices
395-
self.idx_dtype = idx_dtype
396-
397-
def __str__(self):
398-
return "%(op)s{axis=%(axis)d, sorted=%(sorted)s}" % dict(
399-
op=self.__class__.__name__, axis=self.axis, sorted=self.sorted
400-
)
401-
402-
def make_node(self, inp, kth):
403-
inp = as_tensor_variable(inp)
404-
ndim = inp.ndim
405-
if ndim == 0:
406-
raise ValueError("Cannot take scalar as input")
407-
if not -ndim <= self.axis < ndim:
408-
raise IndexError(
409-
'"axis" parameter out of range,'
410-
f" expected integer within [{int(-ndim)}, {int(ndim - 1)}]"
411-
)
412-
413-
kth = as_tensor_variable(kth)
414-
_check_tensor_is_scalar(kth)
415-
outs = []
416-
if self.return_values:
417-
outs.append(
418-
TensorType(dtype=inp.type.dtype, shape=(None,) * inp.type.ndim)()
419-
)
420-
if self.return_indices:
421-
outs.append(
422-
TensorType(dtype=self.idx_dtype, shape=(None,) * inp.type.ndim)()
423-
)
424-
return Apply(self, [inp, kth], outs)
425-
426-
def perform(self, node, inputs, output_storage):
427-
x, k = inputs
428-
axis = self.axis
429-
if not self.return_indices:
430-
pzv = output_storage[0]
431-
pzv[0] = _topk_py_impl(self, x, k, axis, None)
432-
elif self.return_values:
433-
pzv = output_storage[0]
434-
pzi = output_storage[1]
435-
pzv[0], pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[1].dtype)
436-
else:
437-
pzi = output_storage[0]
438-
pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[0].dtype)
439-
440-
def infer_shape(self, fgraph, node, inp_shapes):
441-
shp = list(inp_shapes[0])
442-
shp[self.axis] = np.abs(node.inputs[1])
443-
shp = tuple(shp)
444-
return [shp for i in [self.return_values, self.return_indices] if i]
445-
446-
def L_op(self, inputs, outputs, out_grads):
447-
x, k = inputs
448-
k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable")
449-
450-
if not (self.return_indices or self.return_values):
451-
x_grad = grad_undefined(
452-
self,
453-
0,
454-
x,
455-
"topk: cannot get gradient without both indices and values",
456-
)
457-
else:
458-
x_shp = shape(x)
459-
z_grad = out_grads[0]
460-
ndim = x.ndim
461-
axis = self.axis % ndim
462-
grad_indices = [
463-
arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1))
464-
if i != axis
465-
else outputs[-1]
466-
for i in range(ndim)
467-
]
468-
x_grad = x.zeros_like(dtype=z_grad.dtype)
469-
x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad)
470-
471-
return [x_grad, k_grad]
472-
473-
474-
def topk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):
475-
"""
476-
Returns the k-largest elements along an axis.
477-
478-
Parameters
479-
----------
480-
481-
x: tensor instance
482-
483-
kth: integer constant/variable
484-
Must not be 0. If negative, gives k-smallest elements instead.
485-
486-
axis: integer or ``None``
487-
Upon which axis shall the operation be performed on.
488-
If ``None``, works on flattened array.
489-
490-
sorted: bool
491-
NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
492-
Defaults to ``True``
493-
494-
If True, the result array would be sorted in descending order.
495-
496-
idx_dtype: string
497-
Specify output dtype used in indices, defaults to ``int64``, must be integer type.
498-
This option is here because indices are needed for gradient.
499-
500-
Returns
501-
-------
502-
Tensor variable with same dtype as `x`.
503-
504-
Notes
505-
-----
506-
- ``sorted=True`` is not supported yet.
507-
508-
"""
509-
if axis is None:
510-
x = flatten(x)
511-
axis = 0
512-
return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[0]
513-
514-
515-
def argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):
516-
"""
517-
Returns the indices of k-largest elements along an axis.
518-
519-
Parameters
520-
----------
521-
522-
x: tensor instance
523-
524-
kth: integer constant/variable
525-
Must not be 0. If negative, gives k-smallest elements instead.
526-
527-
sorted: bool
528-
NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
529-
Defaults to ``True``
530-
531-
If True, the result array of corresponding indices would be sorted in descending order.
532-
533-
534-
axis: integer, tuple/list of integers, or ``None``
535-
Upon which axis shall the operation be performed on.
536-
If ``None``, works on flattened array.
537-
538-
idx_dtype: string
539-
Specify output dtype, defaults to ``int64``, must be integer type.
540-
541-
Returns
542-
-------
543-
Tensor variable with dtype specified in `idx_dtype`.
544-
545-
Notes
546-
-----
547-
- ``sorted=True`` is not supported yet.
548-
549-
- If the top-k-th value is not unique, we cannot guarantee the output
550-
indices are deterministically chosen.
551-
552-
"""
553-
if axis is None:
554-
x = flatten(x)
555-
axis = 0
556-
return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[1]
557-
558-
559-
def topk_and_argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):
560-
"""
561-
Returns the results of both topk() and argtopk() in one Op.
562-
563-
See the respective documentation for details.
564-
565-
Returns
566-
-------
567-
tuple: (values, indices)
568-
569-
"""
570-
if axis is None:
571-
x = flatten(x)
572-
axis = 0
573-
return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)

0 commit comments

Comments
 (0)