Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit b484b28

Browse files
author
Nikhil Thorat
authored
Cleanup intopk to remove it as a kernel. (#1873)
- Removes the kernel intopk because we don't support async kernels - Renames inTopK to inTopKAsync - Changes unit tests to support the async error throwing expectations. FEATURE
1 parent d6300ce commit b484b28

File tree

8 files changed

+107
-119
lines changed

8 files changed

+107
-119
lines changed

src/backends/backend.ts

-5
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,6 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
241241
throw new Error('Not yet implemented');
242242
}
243243

244-
inTopK<T extends Tensor, U extends Tensor>(
245-
predictions: T, targets: U, k: number): U {
246-
throw new Error('Not yet implemented');
247-
}
248-
249244
min(x: Tensor, axes: number[]): Tensor {
250245
throw new Error('Not yet implemented');
251246
}

src/backends/cpu/backend_cpu.ts

-13
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';
4141
import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend';
4242
import * as backend_util from '../backend_util';
4343
import * as complex_util from '../complex_util';
44-
import {inTopKImpl} from '../inTopK_impl';
4544
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
4645
import {split} from '../split_shared';
4746
import {tile} from '../tile_impl';
@@ -859,18 +858,6 @@ export class MathBackendCPU implements KernelBackend {
859858
return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);
860859
}
861860

862-
inTopK<T extends Tensor, U extends Tensor>(
863-
predictions: T, targets: U, k: number): U {
864-
this.assertNotComplex([predictions, targets], 'inTopK');
865-
866-
const predictionsVals = this.readSync(predictions.dataId) as TypedArray;
867-
const targetsVals = this.readSync(targets.dataId) as TypedArray;
868-
869-
return inTopKImpl(
870-
predictionsVals, predictions.shape, targetsVals, targets.shape,
871-
k) as U;
872-
}
873-
874861
min(x: Tensor, axes: number[]): Tensor {
875862
this.assertNotComplex(x, 'min');
876863

src/backends/inTopK_impl.ts

-55
This file was deleted.

src/backends/webgl/backend_webgl.ts

-10
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} fr
4444
import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend';
4545
import * as backend_util from '../backend_util';
4646
import {mergeRealAndImagArrays} from '../complex_util';
47-
import {inTopKImpl} from '../inTopK_impl';
4847
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
4948
import {split} from '../split_shared';
5049
import {tile} from '../tile_impl';
@@ -1360,15 +1359,6 @@ export class MathBackendWebGL implements KernelBackend {
13601359
return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);
13611360
}
13621361

1363-
inTopK<T extends Tensor, U extends Tensor>(
1364-
predictions: T, targets: U, k: number): U {
1365-
const predictionsVals = predictions.dataSync();
1366-
const targetsVals = targets.dataSync();
1367-
return inTopKImpl(
1368-
predictionsVals, predictions.shape, targetsVals, targets.shape,
1369-
k) as U;
1370-
}
1371-
13721362
min(x: Tensor, axes: number[]): Tensor {
13731363
axis_util.assertAxesAreInnerMostDims('min', axes, x.rank);
13741364
const [outShape, reduceShape] =
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2018 Google LLC. All Rights Reserved.
3+
* Copyright 2019 Google LLC. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
@@ -15,21 +15,19 @@
1515
* =============================================================================
1616
*/
1717

18-
import {ENGINE} from '../engine';
19-
import {NumericTensor, Tensor} from '../tensor';
18+
import {Tensor} from '../tensor';
2019
import {convertToTensor} from '../tensor_util_env';
2120
import {TensorLike} from '../types';
22-
import {assert, assertShapesMatch} from '../util';
23-
24-
import {op} from './operation';
21+
import {assert, assertShapesMatch, getTypedArrayFromDType} from '../util';
22+
import {tensor} from './tensor_ops';
2523

2624
/**
27-
* Says whether the targets are in the top K predictions.
25+
* Returns whether the targets are in the top K predictions.
2826
*
2927
* ```js
3028
* const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
3129
* const targets = tf.tensor1d([2, 0]);
32-
* const precision = tf.inTopK(predictions, targets);
30+
* const precision = await tf.inTopKAsync(predictions, targets);
3331
* precision.print();
3432
* ```
3533
* @param predictions 2-D or higher `tf.Tensor` with last dimension being
@@ -39,8 +37,8 @@ import {op} from './operation';
3937
* default to 1.
4038
*/
4139
/** @doc {heading: 'Operations', subheading: 'Evaluation'} */
42-
function inTopK_<T extends Tensor, U extends Tensor>(
43-
predictions: T|TensorLike, targets: U|TensorLike, k = 1): U {
40+
async function inTopKAsync_<T extends Tensor, U extends Tensor>(
41+
predictions: T|TensorLike, targets: U|TensorLike, k = 1): Promise<U> {
4442
const $predictions = convertToTensor(predictions, 'predictions', 'inTopK');
4543
const $targets = convertToTensor(targets, 'targets', 'inTopK');
4644

@@ -50,9 +48,9 @@ function inTopK_<T extends Tensor, U extends Tensor>(
5048
`but got ${$predictions.rank}`);
5149
assert(
5250
$predictions.rank - 1 === $targets.rank,
53-
() => `predictions' rank should be 1 larger than ` +
54-
`targets' rank, but got predictions' rank ` +
55-
`${$predictions.rank} and targets' rank ${$targets.rank}`);
51+
() => `predictions rank should be 1 larger than ` +
52+
`targets rank, but got predictions rank ` +
53+
`${$predictions.rank} and targets rank ${$targets.rank}`);
5654
assertShapesMatch(
5755
$predictions.shape.slice(0, $predictions.shape.length - 1),
5856
$targets.shape,
@@ -61,15 +59,44 @@ function inTopK_<T extends Tensor, U extends Tensor>(
6159
const lastDim = $predictions.shape[$predictions.shape.length - 1];
6260
assert(
6361
k > 0 && k <= lastDim,
64-
() => `'k' passed to inTopK() must be > 0 && <= the predictions' last ` +
62+
() => `'k' passed to inTopK() must be > 0 && <= the predictions last ` +
6563
`dimension (${lastDim}), but got ${k}`);
6664

67-
const precision = ENGINE.runKernel(
68-
b =>
69-
b.inTopK($predictions as NumericTensor, $targets as NumericTensor, k),
70-
{$predictions, $targets});
65+
const predictionsVals = await $predictions.data();
66+
const targetsVals = await $targets.data();
67+
68+
// Reshape predictionsVals into a 2d tensor [batch, lastDim]
69+
// and look up topK along lastDim.
70+
const [batch, size] = [predictionsVals.length / lastDim, lastDim];
71+
const precision = getTypedArrayFromDType('bool', batch);
72+
73+
for (let b = 0; b < batch; b++) {
74+
const offset = b * size;
75+
const vals = predictionsVals.subarray(offset, offset + size);
76+
const valAndInd: Array<{value: number, index: number}> = [];
77+
for (let i = 0; i < vals.length; i++) {
78+
valAndInd.push({value: vals[i], index: i});
79+
}
80+
valAndInd.sort((a, b) => b.value - a.value);
81+
82+
precision[b] = 0;
83+
for (let i = 0; i < k; i++) {
84+
if (valAndInd[i].index === targetsVals[b]) {
85+
precision[b] = 1;
86+
break;
87+
}
88+
}
89+
}
90+
91+
if (predictions !== $predictions) {
92+
$predictions.dispose();
93+
}
94+
if (targets !== $targets) {
95+
$targets.dispose();
96+
}
7197

72-
return precision as U;
98+
// Output precision has the same shape as targets.
99+
return tensor(precision, $targets.shape, 'bool') as U;
73100
}
74101

75-
export const inTopK = op({inTopK_});
102+
export const inTopKAsync = inTopKAsync_;

src/ops/inTopK_test.ts renamed to src/ops/in_top_k_test.ts

+58-14
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ import {expectArraysClose} from '../test_util';
2121

2222
import {tensor1d, tensor2d, tensor3d} from './tensor_ops';
2323

24-
describeWithFlags('inTopK', ALL_ENVS, async () => {
24+
describeWithFlags('inTopKAsync', ALL_ENVS, async () => {
2525
it('predictions 2d array, targets 1d array, with default k', async () => {
2626
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
2727
const targets = tensor1d([2, 0]);
28-
const precision = tf.inTopK(predictions, targets);
28+
const precision = await tf.inTopKAsync(predictions, targets);
2929
expect(precision.shape).toEqual([2]);
3030
expect(precision.dtype).toBe('bool');
3131
expectArraysClose(await precision.data(), [1, 0]);
@@ -35,7 +35,7 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
3535
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
3636
const targets = tensor1d([2, 0]);
3737
const k = 2;
38-
const precision = tf.inTopK(predictions, targets, k);
38+
const precision = await tf.inTopKAsync(predictions, targets, k);
3939
expect(precision.shape).toEqual([2]);
4040
expect(precision.dtype).toBe('bool');
4141
expectArraysClose(await precision.data(), [1, 1]);
@@ -45,7 +45,7 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
4545
const predictions =
4646
tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]);
4747
const targets = tensor2d([[1, 2], [0, 1]]);
48-
const precision = tf.inTopK(predictions, targets);
48+
const precision = await tf.inTopKAsync(predictions, targets);
4949
expect(precision.shape).toEqual([2, 2]);
5050
expect(precision.dtype).toBe('bool');
5151
expectArraysClose(await precision.data(), [1, 1, 1, 0]);
@@ -56,7 +56,7 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
5656
tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]);
5757
const targets = tensor2d([[1, 2], [0, 1]]);
5858
const k = 2;
59-
const precision = tf.inTopK(predictions, targets, k);
59+
const precision = await tf.inTopKAsync(predictions, targets, k);
6060
expect(precision.shape).toEqual([2, 2]);
6161
expect(precision.dtype).toBe('bool');
6262
expectArraysClose(await precision.data(), [1, 1, 1, 1]);
@@ -66,13 +66,13 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
6666
const predictions = tensor2d([[1, 2, 2, 1]]);
6767

6868
const targets1 = tensor1d([1]);
69-
const precision1 = tf.inTopK(predictions, targets1);
69+
const precision1 = await tf.inTopKAsync(predictions, targets1);
7070
expect(precision1.shape).toEqual([1]);
7171
expect(precision1.dtype).toBe('bool');
7272
expectArraysClose(await precision1.data(), [1]);
7373

7474
const targets2 = tensor1d([2]);
75-
const precision2 = tf.inTopK(predictions, targets2);
75+
const precision2 = await tf.inTopKAsync(predictions, targets2);
7676
expect(precision2.shape).toEqual([1]);
7777
expect(precision2.dtype).toBe('bool');
7878
expectArraysClose(await precision2.data(), [0]);
@@ -81,28 +81,72 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
8181
it('accept tensor-like object, with default k', async () => {
8282
const predictions = [[20, 10, 40, 30], [30, 50, -20, 10]];
8383
const targets = [2, 0];
84-
const precision = tf.inTopK(predictions, targets);
84+
const precision = await tf.inTopKAsync(predictions, targets);
8585
expect(precision.shape).toEqual([2]);
8686
expect(precision.dtype).toBe('bool');
8787
expectArraysClose(await precision.data(), [1, 0]);
8888
});
8989

90-
it('throws when predictions_rank <2', () => {
90+
it('doesnt leak tensors with tensor-like objects', async () => {
91+
const numTensors = tf.memory().numTensors;
92+
93+
const predictions = [[20, 10, 40, 30], [30, 50, -20, 10]];
94+
const targets = [2, 0];
95+
const precision = await tf.inTopKAsync(predictions, targets);
96+
precision.dispose();
97+
98+
expect(tf.memory().numTensors).toBe(numTensors);
99+
});
100+
101+
it('throws when predictions_rank <2', async () => {
91102
const predictions = tensor1d([20, 10, 40, 30]);
92103
const targets = [2];
93-
expect(() => tf.inTopK(predictions, targets)).toThrowError();
104+
105+
// expect(...).toThrowError() does not support async functions.
106+
// See https://github.com/jasmine/jasmine/issues/1410
107+
try {
108+
await tf.inTopKAsync(predictions, targets);
109+
throw new Error('The line above should have thrown an error');
110+
} catch (ex) {
111+
expect(ex.message)
112+
.toEqual(
113+
'inTopK() expects the predictions to ' +
114+
'be of rank 2 or higher, but got 1');
115+
}
94116
});
95117

96-
it('throws when prediction_rank != targets_rank + 1', () => {
118+
it('throws when prediction.rank != targets.rank + 1', async () => {
97119
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
98120
const targets = tensor2d([[0], [0]]);
99-
expect(() => tf.inTopK(predictions, targets)).toThrowError();
121+
122+
// expect(...).toThrowError() does not support async functions.
123+
// See https://github.com/jasmine/jasmine/issues/1410
124+
try {
125+
await tf.inTopKAsync(predictions, targets);
126+
throw new Error('The line above should have thrown an error');
127+
} catch (ex) {
128+
expect(ex.message)
129+
.toEqual(
130+
'predictions rank should be 1 larger than targets rank,' +
131+
' but got predictions rank 2 and targets rank 2');
132+
}
100133
});
101134

102-
it('throws when k > size of last dimension of predictions', () => {
135+
it('throws when k > size of last dimension of predictions', async () => {
103136
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
104137
const targets = tensor1d([2, 0]);
105138
const k = 5;
106-
expect(() => tf.inTopK(predictions, targets, k)).toThrowError();
139+
140+
// expect(...).toThrowError() does not support async functions.
141+
// See https://github.com/jasmine/jasmine/issues/1410
142+
try {
143+
await tf.inTopKAsync(predictions, targets, k);
144+
throw new Error('The line above should have thrown an error');
145+
} catch (ex) {
146+
expect(ex.message)
147+
.toEqual(
148+
'\'k\' passed to inTopK() must be > 0 && <= the predictions ' +
149+
'last dimension (4), but got 5');
150+
}
107151
});
108152
});

src/ops/ops.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ export * from './gather_nd';
4848
export * from './diag';
4949
export * from './dropout';
5050
export * from './signal_ops';
51-
export * from './inTopK';
51+
export * from './in_top_k';
5252

5353
export {op} from './operation';
5454

src/tests.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ import './ops/dropout_test';
6161
import './ops/fused_test';
6262
import './ops/gather_nd_test';
6363
import './ops/image_ops_test';
64-
import './ops/inTopK_test';
64+
import './ops/in_top_k_test';
6565
import './ops/linalg_ops_test';
6666
import './ops/logical_ops_test';
6767
import './ops/loss_ops_test';

0 commit comments

Comments
 (0)