Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit d6300ce

Browse files
syt123450dsmilkov
authored andcommitted
Add avgPool3d & maxPool3d (#1778)
This PR adds `avgPool3d` op & `maxPool3d` op with CPU and WebGL implementation, supports inference and gradient. The APIs align with TensorFlow’s Python API [tf.nn.avg_pool3d](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/avg_pool3d) and [tf.nn.max_pool3d](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/max_pool3d). The `avgPool3d` & `maxPool3d` ops support `tf.layers.averagePooling3d` & `tf.layers.maxPooling3d` (feature requested in [tensorflow/tfjs#1035](tensorflow/tfjs#1035)). As a checklist, features in this PR: * Add `tf.avgPool3d` to ops * Add `tf.maxPool3d` to ops * Add `avgPool3d` kernel to handle `tf.avgPool3d`’s prediction * Add `maxPool3d` kernel to handle `tf.maxPool3d`’s prediction * Add private helper `pool3d` function in CPU kernel as the implementation of `avgPool3d` kernel and `maxPool3d` kernel in CPU end * Add `Pool3DProgram` in WebGL kernel as the implementation of `avgPool3d` kernel and `maxPool3d` kernel in GPU end * Add `avgPool3dBackprop` kernel to handle `tf.avgPool3d`’s gradient * Add `avgPool3dBackprop` CPU kernel implementation * Add `AvgPool3DBackpropProgram` as the implementation of `avgPool3dBackprop` WebGL kernel * Add `maxPool3dBackprop` kernel to handle `tf.maxPool3d`’s gradient * Add `maxPool3dBackprop` CPU kernel implementation * Add a private helper function `maxPool3dPositions` for maxPool3dBackprop in CPU kernel * Add `MaxPool3DBackpropProgram` as the implementation of `maxPool3dBackprop` WebGL kernel * Integrate WebGL end’s `maxPool3dPositions` helper function into Pool3DProgram * Add a `computePool3DInfo` util function to compute operation information for `avgPool3d` & `maxPool3d` * Make check function `eitherStridesOrDilationsAreOne` support 3D input * Add 14 unit tests for `avgPool3d` (one test case failed in nodejs env as TF Backend doesn’t support `NUMBER` pad mode) * Add 7 unit tests for `avgPool3dBackprop` * Add 13 unit tests for `maxPool3d` (one test case failed in nodejs env as TF Backend doesn’t support `NUMBER` pad mode) * Add 10 unit tests for `maxPool3dBackprop` * Add 8 unit tests for util function `computePool3DInfo` * Export `Tensor5D `type * Add jsdocs and executable examples for website api documentation I built a local website, if everything goes well, the `avgPool3d` and `maxPool3d` APIs would lie in `Operations/Convolution` section and look like the screen shot below: <img width="1184" alt="Screen Shot 2019-06-06 at 1 32 30 AM" src="https://user-images.githubusercontent.com/7977100/59018883-c13c3400-87fb-11e9-8f36-316edab35231.png"> **Relative PRs:** * Make nodejs kernel support avgPool3d & maxPool3d [tensorflow/tfjs-node#256](tensorflow/tfjs-node#256) * Add averagePooling3d layer & maxPooling3d layer [tensorflow/tfjs-layers#555](tensorflow/tfjs-layers#555) * Add avgPool3d & maxPool3d ops for graph model [tensorflow/tfjs-converter#375](tensorflow/tfjs-converter#375) **Reference:** * [tf.nn.avg_pool3d TensorFlow Python API](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/avg_pool3d) * [tf.nn.max_pool3d TensorFlow Python API](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/max_pool3d) --- To see the logs from the Cloud Build CI, please join either our [discussion](https://groups.google.com/a/tensorflow.org/forum/#!forum/tfjs) or [announcement](https://groups.google.com/a/tensorflow.org/forum/#!forum/tfjs-announce) mailing list. FEATURE
1 parent e4d7607 commit d6300ce

10 files changed

+2097
-24
lines changed

src/backends/backend.ts

+13
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,19 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
468468
avgPoolBackprop(dy: Tensor4D, x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
469469
throw new Error('Not yet implemented');
470470
}
471+
avgPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
472+
throw new Error('Not yet implemented');
473+
}
474+
avgPool3dBackprop(dy: Tensor5D, x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
475+
throw new Error('Not yet implemented');
476+
}
477+
maxPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
478+
throw new Error('Not yet implemented');
479+
}
480+
maxPool3dBackprop(
481+
dy: Tensor5D, x: Tensor5D, y: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
482+
throw new Error('Not yet implemented');
483+
}
471484

472485
reshape<T extends Tensor, R extends Rank>(x: T, shape: ShapeMap[R]):
473486
Tensor<R> {

src/backends/cpu/backend_cpu.ts

+350
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,356 @@ export class MathBackendCPU implements KernelBackend {
25212521
return dx.toTensor();
25222522
}
25232523

2524+
private pool3d(x: Tensor5D, convInfo: Conv3DInfo, poolType: 'max'|'avg'):
2525+
Tensor5D {
2526+
this.assertNotComplex(x, 'pool3d');
2527+
2528+
const strideDepth = convInfo.strideDepth;
2529+
const strideHeight = convInfo.strideHeight;
2530+
const strideWidth = convInfo.strideWidth;
2531+
const dilationDepth = convInfo.dilationDepth;
2532+
const dilationHeight = convInfo.dilationHeight;
2533+
const dilationWidth = convInfo.dilationWidth;
2534+
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
2535+
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
2536+
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
2537+
const padFront = convInfo.padInfo.front;
2538+
const padTop = convInfo.padInfo.top;
2539+
const padLeft = convInfo.padInfo.left;
2540+
2541+
const initialValue =
2542+
(poolType === 'max' ? Number.NEGATIVE_INFINITY :
2543+
Number.POSITIVE_INFINITY);
2544+
2545+
const xValues = this.readSync(x.dataId) as TypedArray;
2546+
const output = ops.buffer(convInfo.outShape, x.dtype);
2547+
const outputVals = output.values;
2548+
2549+
const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] *
2550+
convInfo.outShape[3] * convInfo.outShape[4];
2551+
const outputDepthStrides =
2552+
convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
2553+
const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
2554+
const outputColStrides = convInfo.outShape[4];
2555+
2556+
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
2557+
const outputBatchOffset = batch * outputBatchStrides;
2558+
const inputBatchOffset = batch * x.strides[0];
2559+
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
2560+
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
2561+
const xDepthCorner = yDepth * strideDepth - padFront;
2562+
let xDepthMin = xDepthCorner;
2563+
while (xDepthMin < 0) {
2564+
xDepthMin += dilationDepth;
2565+
}
2566+
const xDepthMax =
2567+
Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
2568+
const outputDepthOffset =
2569+
outputBatchOffset + yDepth * outputDepthStrides;
2570+
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
2571+
const xRowCorner = yRow * strideHeight - padTop;
2572+
let xRowMin = xRowCorner;
2573+
while (xRowMin < 0) {
2574+
xRowMin += dilationHeight;
2575+
}
2576+
const xRowMax =
2577+
Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
2578+
const outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
2579+
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
2580+
const xColCorner = yCol * strideWidth - padLeft;
2581+
let xColMin = xColCorner;
2582+
while (xColMin < 0) {
2583+
xColMin += dilationWidth;
2584+
}
2585+
const xColMax =
2586+
Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
2587+
// Shader code begins
2588+
const outputColOffset = outputRowOffset + yCol * outputColStrides;
2589+
let minMaxValue = initialValue;
2590+
let avgValue = 0;
2591+
let count = 0;
2592+
for (let xDepth = xDepthMin; xDepth < xDepthMax;
2593+
xDepth += dilationDepth) {
2594+
const xDepthOffset = inputBatchOffset + xDepth * x.strides[1];
2595+
for (let xRow = xRowMin; xRow < xRowMax;
2596+
xRow += dilationHeight) {
2597+
const xRowOffset = xDepthOffset + xRow * x.strides[2];
2598+
for (let xCol = xColMin; xCol < xColMax;
2599+
xCol += dilationWidth) {
2600+
const xColOffset = xRowOffset + xCol * x.strides[3];
2601+
const pixel = xValues[xColOffset + channel];
2602+
if ((poolType === 'max' && pixel > minMaxValue)) {
2603+
minMaxValue = pixel;
2604+
} else if (poolType === 'avg') {
2605+
avgValue += pixel;
2606+
count++;
2607+
}
2608+
if (isNaN(minMaxValue)) {
2609+
break;
2610+
}
2611+
}
2612+
if (isNaN(minMaxValue)) {
2613+
break;
2614+
}
2615+
}
2616+
if (isNaN(minMaxValue)) {
2617+
break;
2618+
}
2619+
}
2620+
const outputOffset = outputColOffset + channel;
2621+
outputVals[outputOffset] =
2622+
poolType === 'avg' ? avgValue / count : minMaxValue;
2623+
}
2624+
}
2625+
}
2626+
}
2627+
}
2628+
return output.toTensor() as Tensor5D;
2629+
}
2630+
2631+
avgPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
2632+
this.assertNotComplex(x, 'avgPool3d');
2633+
2634+
return this.pool3d(x, convInfo, 'avg').toFloat();
2635+
}
2636+
2637+
avgPool3dBackprop(dy: Tensor5D, x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
2638+
this.assertNotComplex([dy, x], 'avgPool3dBackprop');
2639+
2640+
const strideDepth = convInfo.strideDepth;
2641+
const strideHeight = convInfo.strideHeight;
2642+
const strideWidth = convInfo.strideWidth;
2643+
const filterDepth = convInfo.filterDepth;
2644+
const filterHeight = convInfo.filterHeight;
2645+
const filterWidth = convInfo.filterWidth;
2646+
const dilationDepth = convInfo.dilationDepth;
2647+
const dilationHeight = convInfo.dilationHeight;
2648+
const dilationWidth = convInfo.dilationWidth;
2649+
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
2650+
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
2651+
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
2652+
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
2653+
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
2654+
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
2655+
const dx = ops.buffer<Rank.R5>(x.shape, 'float32');
2656+
2657+
const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
2658+
2659+
const dyBuf = this.bufferSync(dy);
2660+
2661+
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
2662+
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
2663+
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
2664+
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
2665+
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
2666+
// Shader code begins.
2667+
const dyDepthCorner = dxDepth - padFront;
2668+
const dyRowCorner = dxRow - padTop;
2669+
const dyColCorner = dxCol - padLeft;
2670+
let dotProd = 0;
2671+
for (let wDepth = 0; wDepth < effectiveFilterDepth;
2672+
wDepth += dilationDepth) {
2673+
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
2674+
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
2675+
Math.floor(dyDepth) !== dyDepth) {
2676+
continue;
2677+
}
2678+
for (let wRow = 0; wRow < effectiveFilterHeight;
2679+
wRow += dilationHeight) {
2680+
const dyRow = (dyRowCorner + wRow) / strideHeight;
2681+
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
2682+
Math.floor(dyRow) !== dyRow) {
2683+
continue;
2684+
}
2685+
for (let wCol = 0; wCol < effectiveFilterWidth;
2686+
wCol += dilationWidth) {
2687+
const dyCol = (dyColCorner + wCol) / strideWidth;
2688+
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
2689+
Math.floor(dyCol) !== dyCol) {
2690+
continue;
2691+
}
2692+
2693+
const pixel =
2694+
dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
2695+
dotProd += pixel;
2696+
}
2697+
}
2698+
}
2699+
dx.set(
2700+
dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol,
2701+
channel);
2702+
}
2703+
}
2704+
}
2705+
}
2706+
}
2707+
return dx.toTensor() as Tensor5D;
2708+
}
2709+
2710+
maxPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
2711+
this.assertNotComplex(x, 'maxPool3d');
2712+
2713+
return this.pool3d(x, convInfo, 'max').toFloat();
2714+
}
2715+
2716+
private maxPool3dPositions(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
2717+
const maxPositions = ops.buffer(convInfo.outShape, 'int32');
2718+
const strideDepth = convInfo.strideDepth;
2719+
const strideHeight = convInfo.strideHeight;
2720+
const strideWidth = convInfo.strideWidth;
2721+
const dilationDepth = convInfo.dilationDepth;
2722+
const dilationHeight = convInfo.dilationHeight;
2723+
const dilationWidth = convInfo.dilationWidth;
2724+
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
2725+
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
2726+
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
2727+
const padFront = convInfo.padInfo.front;
2728+
const padTop = convInfo.padInfo.top;
2729+
const padLeft = convInfo.padInfo.left;
2730+
2731+
const xBuf = this.bufferSync(x);
2732+
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
2733+
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
2734+
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
2735+
const xDepthCorner = yDepth * strideDepth - padFront;
2736+
let xDepthMin = xDepthCorner;
2737+
while (xDepthMin < 0) {
2738+
xDepthMin += dilationDepth;
2739+
}
2740+
const xDepthMax =
2741+
Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
2742+
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
2743+
const xRowCorner = yRow * strideHeight - padTop;
2744+
let xRowMin = xRowCorner;
2745+
while (xRowMin < 0) {
2746+
xRowMin += dilationHeight;
2747+
}
2748+
const xRowMax =
2749+
Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
2750+
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
2751+
const xColCorner = yCol * strideWidth - padLeft;
2752+
let xColMin = xColCorner;
2753+
while (xColMin < 0) {
2754+
xColMin += dilationWidth;
2755+
}
2756+
const xColMax =
2757+
Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
2758+
2759+
// Shader code begins
2760+
let maxValue = Number.NEGATIVE_INFINITY;
2761+
let maxPosition = -1;
2762+
2763+
for (let xDepth = xDepthMin; xDepth < xDepthMax;
2764+
xDepth += dilationDepth) {
2765+
const wDepth = xDepth - xDepthCorner;
2766+
for (let xRow = xRowMin; xRow < xRowMax;
2767+
xRow += dilationHeight) {
2768+
const wRow = xRow - xRowCorner;
2769+
for (let xCol = xColMin; xCol < xColMax;
2770+
xCol += dilationWidth) {
2771+
const wCol = xCol - xColCorner;
2772+
const pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
2773+
if (pixel >= maxValue) {
2774+
maxValue = pixel;
2775+
maxPosition = wDepth * effectiveFilterHeight *
2776+
effectiveFilterWidth +
2777+
wRow * effectiveFilterHeight + wCol;
2778+
}
2779+
}
2780+
}
2781+
}
2782+
2783+
maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
2784+
}
2785+
}
2786+
}
2787+
}
2788+
}
2789+
return maxPositions.toTensor() as Tensor5D;
2790+
}
2791+
2792+
maxPool3dBackprop(
2793+
dy: Tensor5D, x: Tensor5D, y: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
2794+
this.assertNotComplex([x, y], 'maxPool3dBackprop');
2795+
2796+
const maxPositions = this.maxPool3dPositions(x, convInfo);
2797+
const strideDepth = convInfo.strideDepth;
2798+
const strideHeight = convInfo.strideHeight;
2799+
const strideWidth = convInfo.strideWidth;
2800+
const dilationDepth = convInfo.dilationDepth;
2801+
const dilationHeight = convInfo.dilationHeight;
2802+
const dilationWidth = convInfo.dilationWidth;
2803+
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
2804+
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
2805+
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
2806+
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
2807+
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
2808+
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
2809+
const dx = ops.buffer<Rank.R5>(x.shape, 'float32');
2810+
2811+
const maxPosBuf = this.bufferSync(maxPositions);
2812+
const dyBuf = this.bufferSync(dy);
2813+
2814+
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
2815+
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
2816+
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
2817+
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
2818+
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
2819+
// Shader code begins
2820+
const dyDepthCorner = dxDepth - padFront;
2821+
const dyRowCorner = dxRow - padTop;
2822+
const dyColCorner = dxCol - padLeft;
2823+
let dotProd = 0;
2824+
for (let wDepth = 0; wDepth < effectiveFilterDepth;
2825+
wDepth += dilationDepth) {
2826+
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
2827+
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
2828+
Math.floor(dyDepth) !== dyDepth) {
2829+
continue;
2830+
}
2831+
for (let wRow = 0; wRow < effectiveFilterHeight;
2832+
wRow += dilationHeight) {
2833+
const dyRow = (dyRowCorner + wRow) / strideHeight;
2834+
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
2835+
Math.floor(dyRow) !== dyRow) {
2836+
continue;
2837+
}
2838+
for (let wCol = 0; wCol < effectiveFilterWidth;
2839+
wCol += dilationWidth) {
2840+
const dyCol = (dyColCorner + wCol) / strideWidth;
2841+
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
2842+
Math.floor(dyCol) !== dyCol) {
2843+
continue;
2844+
}
2845+
2846+
const maxPos = effectiveFilterDepth *
2847+
effectiveFilterHeight * effectiveFilterWidth -
2848+
1 -
2849+
maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
2850+
const curPos =
2851+
wDepth * effectiveFilterHeight * effectiveFilterWidth +
2852+
wRow * effectiveFilterWidth + wCol;
2853+
2854+
const mask = maxPos === curPos ? 1 : 0;
2855+
if (mask === 0) {
2856+
continue;
2857+
}
2858+
2859+
const pixel =
2860+
dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
2861+
dotProd += pixel * mask;
2862+
}
2863+
}
2864+
}
2865+
dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
2866+
}
2867+
}
2868+
}
2869+
}
2870+
}
2871+
return dx.toTensor() as Tensor5D;
2872+
}
2873+
25242874
cast<T extends Tensor>(x: T, dtype: DataType): T {
25252875
return backend_util.castTensor(x, dtype, this);
25262876
}

0 commit comments

Comments
 (0)