|
4 | 4 | from pytensor.graph.basic import Apply, Constant
|
5 | 5 | from pytensor.graph.op import Op
|
6 | 6 | from pytensor.misc.safe_asarray import _asarray
|
7 |
| -from pytensor.tensor.basic import arange, as_tensor_variable, flatten, switch |
| 7 | +from pytensor.tensor.basic import arange, as_tensor_variable, switch |
8 | 8 | from pytensor.tensor.math import eq, ge, mul
|
9 |
| -from pytensor.tensor.shape import shape |
10 |
| -from pytensor.tensor.subtensor import set_subtensor |
11 |
| -from pytensor.tensor.type import TensorType, integer_dtypes |
| 9 | +from pytensor.tensor.type import TensorType |
12 | 10 |
|
13 | 11 |
|
14 | 12 | def _variable_is_none(var):
|
@@ -304,270 +302,3 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
|
304 | 302 | else:
|
305 | 303 | zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
|
306 | 304 | return zi.astype(idx_dtype)
|
307 |
| - |
308 |
| - |
309 |
| -class TopKOp(Op): |
310 |
| - """Operations related to finding k-largest elements. |
311 |
| -
|
312 |
| - Parameters |
313 |
| - ---------- |
314 |
| - axis: integer |
315 |
| - Defaults to ``-1``. |
316 |
| - The axis to perform the operation. Must be in range ``[-ndim, ndim)``, where |
317 |
| - ``ndim`` is the dimensionality of input tensor. |
318 |
| -
|
319 |
| - idx_dtype: string |
320 |
| - Specify output dtype for indices, defaults to ``int64``, must be integer type. |
321 |
| -
|
322 |
| - sorted: bool |
323 |
| - NOTE: NOT IMPLEMENTED YET |
324 |
| - Defaults to ``True`` |
325 |
| -
|
326 |
| - If True, the result array would be sorted in descending order. |
327 |
| -
|
328 |
| -
|
329 |
| - Notes |
330 |
| - ----- |
331 |
| - - The output order is not guaranteed. On the CPU, we use |
332 |
| - ``np.partition`` and ``np.argpartition`` that only make sure the |
333 |
| - k-th element is the correct one and that the other |
334 |
| - elements are on the correct side. |
335 |
| - - By default, this Op gives two outputs: values and indices. However |
336 |
| - optimizers may remove a certain output if not needed. |
337 |
| - - Computing the gradient requests the computation of the indices in |
338 |
| - forward pass. |
339 |
| - - If the top-k-th value is not unique, we cannot guarantee the |
340 |
| - output indices being deterministically chosen. |
341 |
| -
|
342 |
| - See Also |
343 |
| - -------- |
344 |
| - topk |
345 |
| - argtopk |
346 |
| - argtopk_and_topk |
347 |
| -
|
348 |
| - """ |
349 |
| - |
350 |
| - # TODO more params |
351 |
| - """ |
352 |
| - only_top_kth: bool |
353 |
| - Defaults to ``False`` |
354 |
| -
|
355 |
| - If ``True``, will only find one exact top k-th element on given axis. |
356 |
| -
|
357 |
| - """ |
358 |
| - |
359 |
| - # TODO c_code |
360 |
| - # TODO add opt, if k==1, use max/min reduce |
361 |
| - # also if k is axis size, just copy input tensor |
362 |
| - # TODO add opt, to merge argtopk / topk |
363 |
| - __props__ = ("axis", "sorted", "return_values", "return_indices", "idx_dtype") |
364 |
| - |
365 |
| - def __init__( |
366 |
| - self, |
367 |
| - axis=-1, |
368 |
| - sorted=True, |
369 |
| - idx_dtype="int64", |
370 |
| - return_values=True, |
371 |
| - return_indices=True, |
372 |
| - ): |
373 |
| - # numpy always uses int64 as output dtype for arg*() routines |
374 |
| - # however, we add "idx_dtype" param as memory is more precious on gpu |
375 |
| - if not isinstance(axis, int): |
376 |
| - raise TypeError(f'"axis" parameter must be integer, got "{type(axis)}"') |
377 |
| - if sorted: |
378 |
| - raise NotImplementedError( |
379 |
| - "The sorted parameter is not yet implemented. Use sorted=False for now." |
380 |
| - ) |
381 |
| - if idx_dtype not in integer_dtypes: |
382 |
| - raise TypeError( |
383 |
| - f'"idx_dtype" parameter must be an integer dtype, got "{idx_dtype}"' |
384 |
| - ) |
385 |
| - |
386 |
| - if not (return_indices or return_values): |
387 |
| - raise ValueError( |
388 |
| - "Neither return_values nor return_indices is True, this isn't allowed" |
389 |
| - ) |
390 |
| - |
391 |
| - self.axis = axis |
392 |
| - self.sorted = sorted |
393 |
| - self.return_values = return_values |
394 |
| - self.return_indices = return_indices |
395 |
| - self.idx_dtype = idx_dtype |
396 |
| - |
397 |
| - def __str__(self): |
398 |
| - return "%(op)s{axis=%(axis)d, sorted=%(sorted)s}" % dict( |
399 |
| - op=self.__class__.__name__, axis=self.axis, sorted=self.sorted |
400 |
| - ) |
401 |
| - |
402 |
| - def make_node(self, inp, kth): |
403 |
| - inp = as_tensor_variable(inp) |
404 |
| - ndim = inp.ndim |
405 |
| - if ndim == 0: |
406 |
| - raise ValueError("Cannot take scalar as input") |
407 |
| - if not -ndim <= self.axis < ndim: |
408 |
| - raise IndexError( |
409 |
| - '"axis" parameter out of range,' |
410 |
| - f" expected integer within [{int(-ndim)}, {int(ndim - 1)}]" |
411 |
| - ) |
412 |
| - |
413 |
| - kth = as_tensor_variable(kth) |
414 |
| - _check_tensor_is_scalar(kth) |
415 |
| - outs = [] |
416 |
| - if self.return_values: |
417 |
| - outs.append( |
418 |
| - TensorType(dtype=inp.type.dtype, shape=(None,) * inp.type.ndim)() |
419 |
| - ) |
420 |
| - if self.return_indices: |
421 |
| - outs.append( |
422 |
| - TensorType(dtype=self.idx_dtype, shape=(None,) * inp.type.ndim)() |
423 |
| - ) |
424 |
| - return Apply(self, [inp, kth], outs) |
425 |
| - |
426 |
| - def perform(self, node, inputs, output_storage): |
427 |
| - x, k = inputs |
428 |
| - axis = self.axis |
429 |
| - if not self.return_indices: |
430 |
| - pzv = output_storage[0] |
431 |
| - pzv[0] = _topk_py_impl(self, x, k, axis, None) |
432 |
| - elif self.return_values: |
433 |
| - pzv = output_storage[0] |
434 |
| - pzi = output_storage[1] |
435 |
| - pzv[0], pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[1].dtype) |
436 |
| - else: |
437 |
| - pzi = output_storage[0] |
438 |
| - pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[0].dtype) |
439 |
| - |
440 |
| - def infer_shape(self, fgraph, node, inp_shapes): |
441 |
| - shp = list(inp_shapes[0]) |
442 |
| - shp[self.axis] = np.abs(node.inputs[1]) |
443 |
| - shp = tuple(shp) |
444 |
| - return [shp for i in [self.return_values, self.return_indices] if i] |
445 |
| - |
446 |
| - def L_op(self, inputs, outputs, out_grads): |
447 |
| - x, k = inputs |
448 |
| - k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable") |
449 |
| - |
450 |
| - if not (self.return_indices or self.return_values): |
451 |
| - x_grad = grad_undefined( |
452 |
| - self, |
453 |
| - 0, |
454 |
| - x, |
455 |
| - "topk: cannot get gradient without both indices and values", |
456 |
| - ) |
457 |
| - else: |
458 |
| - x_shp = shape(x) |
459 |
| - z_grad = out_grads[0] |
460 |
| - ndim = x.ndim |
461 |
| - axis = self.axis % ndim |
462 |
| - grad_indices = [ |
463 |
| - arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1)) |
464 |
| - if i != axis |
465 |
| - else outputs[-1] |
466 |
| - for i in range(ndim) |
467 |
| - ] |
468 |
| - x_grad = x.zeros_like(dtype=z_grad.dtype) |
469 |
| - x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad) |
470 |
| - |
471 |
| - return [x_grad, k_grad] |
472 |
| - |
473 |
| - |
474 |
| -def topk(x, kth, axis=-1, sorted=True, idx_dtype="int64"): |
475 |
| - """ |
476 |
| - Returns the k-largest elements along an axis. |
477 |
| -
|
478 |
| - Parameters |
479 |
| - ---------- |
480 |
| -
|
481 |
| - x: tensor instance |
482 |
| -
|
483 |
| - kth: integer constant/variable |
484 |
| - Must not be 0. If negative, gives k-smallest elements instead. |
485 |
| -
|
486 |
| - axis: integer or ``None`` |
487 |
| - Upon which axis shall the operation be performed on. |
488 |
| - If ``None``, works on flattened array. |
489 |
| -
|
490 |
| - sorted: bool |
491 |
| - NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW. |
492 |
| - Defaults to ``True`` |
493 |
| -
|
494 |
| - If True, the result array would be sorted in descending order. |
495 |
| -
|
496 |
| - idx_dtype: string |
497 |
| - Specify output dtype used in indices, defaults to ``int64``, must be integer type. |
498 |
| - This option is here because indices are needed for gradient. |
499 |
| -
|
500 |
| - Returns |
501 |
| - ------- |
502 |
| - Tensor variable with same dtype as `x`. |
503 |
| -
|
504 |
| - Notes |
505 |
| - ----- |
506 |
| - - ``sorted=True`` is not supported yet. |
507 |
| -
|
508 |
| - """ |
509 |
| - if axis is None: |
510 |
| - x = flatten(x) |
511 |
| - axis = 0 |
512 |
| - return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[0] |
513 |
| - |
514 |
| - |
515 |
| -def argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"): |
516 |
| - """ |
517 |
| - Returns the indices of k-largest elements along an axis. |
518 |
| -
|
519 |
| - Parameters |
520 |
| - ---------- |
521 |
| -
|
522 |
| - x: tensor instance |
523 |
| -
|
524 |
| - kth: integer constant/variable |
525 |
| - Must not be 0. If negative, gives k-smallest elements instead. |
526 |
| -
|
527 |
| - sorted: bool |
528 |
| - NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW. |
529 |
| - Defaults to ``True`` |
530 |
| -
|
531 |
| - If True, the result array of corresponding indices would be sorted in descending order. |
532 |
| -
|
533 |
| -
|
534 |
| - axis: integer, tuple/list of integers, or ``None`` |
535 |
| - Upon which axis shall the operation be performed on. |
536 |
| - If ``None``, works on flattened array. |
537 |
| -
|
538 |
| - idx_dtype: string |
539 |
| - Specify output dtype, defaults to ``int64``, must be integer type. |
540 |
| -
|
541 |
| - Returns |
542 |
| - ------- |
543 |
| - Tensor variable with dtype specified in `idx_dtype`. |
544 |
| -
|
545 |
| - Notes |
546 |
| - ----- |
547 |
| - - ``sorted=True`` is not supported yet. |
548 |
| -
|
549 |
| - - If the top-k-th value is not unique, we cannot guarantee the output |
550 |
| - indices are deterministically chosen. |
551 |
| -
|
552 |
| - """ |
553 |
| - if axis is None: |
554 |
| - x = flatten(x) |
555 |
| - axis = 0 |
556 |
| - return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[1] |
557 |
| - |
558 |
| - |
559 |
| -def topk_and_argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"): |
560 |
| - """ |
561 |
| - Returns the results of both topk() and argtopk() in one Op. |
562 |
| -
|
563 |
| - See the respective documentation for details. |
564 |
| -
|
565 |
| - Returns |
566 |
| - ------- |
567 |
| - tuple: (values, indices) |
568 |
| -
|
569 |
| - """ |
570 |
| - if axis is None: |
571 |
| - x = flatten(x) |
572 |
| - axis = 0 |
573 |
| - return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth) |
0 commit comments