Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Add avgPool3d & maxPool3d #1778

Merged
merged 11 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,19 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
avgPoolBackprop(dy: Tensor4D, x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
throw new Error('Not yet implemented');
}
avgPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
throw new Error('Not yet implemented');
}
avgPool3dBackprop(dy: Tensor5D, x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
throw new Error('Not yet implemented');
}
maxPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
throw new Error('Not yet implemented');
}
maxPool3dBackprop(
dy: Tensor5D, x: Tensor5D, y: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
throw new Error('Not yet implemented');
}

reshape<T extends Tensor, R extends Rank>(x: T, shape: ShapeMap[R]):
Tensor<R> {
Expand Down
350 changes: 350 additions & 0 deletions src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2442,6 +2442,356 @@ export class MathBackendCPU implements KernelBackend {
return dx.toTensor();
}

private pool3d(x: Tensor5D, convInfo: Conv3DInfo, poolType: 'max'|'avg'):
Tensor5D {
this.assertNotComplex(x, 'pool3d');

const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = convInfo.padInfo.front;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;

const initialValue =
(poolType === 'max' ? Number.NEGATIVE_INFINITY :
Number.POSITIVE_INFINITY);

const xValues = this.readSync(x.dataId) as TypedArray;
const output = ops.buffer(convInfo.outShape, x.dtype);
const outputVals = output.values;

const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] *
convInfo.outShape[3] * convInfo.outShape[4];
const outputDepthStrides =
convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
const outputColStrides = convInfo.outShape[4];

for (let batch = 0; batch < convInfo.batchSize; ++batch) {
const outputBatchOffset = batch * outputBatchStrides;
const inputBatchOffset = batch * x.strides[0];
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
const xDepthCorner = yDepth * strideDepth - padFront;
let xDepthMin = xDepthCorner;
while (xDepthMin < 0) {
xDepthMin += dilationDepth;
}
const xDepthMax =
Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
const outputDepthOffset =
outputBatchOffset + yDepth * outputDepthStrides;
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
const xRowCorner = yRow * strideHeight - padTop;
let xRowMin = xRowCorner;
while (xRowMin < 0) {
xRowMin += dilationHeight;
}
const xRowMax =
Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
const outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
const xColCorner = yCol * strideWidth - padLeft;
let xColMin = xColCorner;
while (xColMin < 0) {
xColMin += dilationWidth;
}
const xColMax =
Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
// Shader code begins
const outputColOffset = outputRowOffset + yCol * outputColStrides;
let minMaxValue = initialValue;
let avgValue = 0;
let count = 0;
for (let xDepth = xDepthMin; xDepth < xDepthMax;
xDepth += dilationDepth) {
const xDepthOffset = inputBatchOffset + xDepth * x.strides[1];
for (let xRow = xRowMin; xRow < xRowMax;
xRow += dilationHeight) {
const xRowOffset = xDepthOffset + xRow * x.strides[2];
for (let xCol = xColMin; xCol < xColMax;
xCol += dilationWidth) {
const xColOffset = xRowOffset + xCol * x.strides[3];
const pixel = xValues[xColOffset + channel];
if ((poolType === 'max' && pixel > minMaxValue)) {
minMaxValue = pixel;
} else if (poolType === 'avg') {
avgValue += pixel;
count++;
}
if (isNaN(minMaxValue)) {
break;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
const outputOffset = outputColOffset + channel;
outputVals[outputOffset] =
poolType === 'avg' ? avgValue / count : minMaxValue;
}
}
}
}
}
return output.toTensor() as Tensor5D;
}

avgPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
this.assertNotComplex(x, 'avgPool3d');

return this.pool3d(x, convInfo, 'avg').toFloat();
}

avgPool3dBackprop(dy: Tensor5D, x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
this.assertNotComplex([dy, x], 'avgPool3dBackprop');

const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const filterDepth = convInfo.filterDepth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const dx = ops.buffer<Rank.R5>(x.shape, 'float32');

const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);

const dyBuf = this.bufferSync(dy);

for (let batch = 0; batch < convInfo.batchSize; ++batch) {
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
// Shader code begins.
const dyDepthCorner = dxDepth - padFront;
const dyRowCorner = dxRow - padTop;
const dyColCorner = dxCol - padLeft;
let dotProd = 0;
for (let wDepth = 0; wDepth < effectiveFilterDepth;
wDepth += dilationDepth) {
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
Math.floor(dyDepth) !== dyDepth) {
continue;
}
for (let wRow = 0; wRow < effectiveFilterHeight;
wRow += dilationHeight) {
const dyRow = (dyRowCorner + wRow) / strideHeight;
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
Math.floor(dyRow) !== dyRow) {
continue;
}
for (let wCol = 0; wCol < effectiveFilterWidth;
wCol += dilationWidth) {
const dyCol = (dyColCorner + wCol) / strideWidth;
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
Math.floor(dyCol) !== dyCol) {
continue;
}

const pixel =
dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
dotProd += pixel;
}
}
}
dx.set(
dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol,
channel);
}
}
}
}
}
return dx.toTensor() as Tensor5D;
}

maxPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
this.assertNotComplex(x, 'maxPool3d');

return this.pool3d(x, convInfo, 'max').toFloat();
}

private maxPool3dPositions(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
const maxPositions = ops.buffer(convInfo.outShape, 'int32');
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = convInfo.padInfo.front;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;

const xBuf = this.bufferSync(x);
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
const xDepthCorner = yDepth * strideDepth - padFront;
let xDepthMin = xDepthCorner;
while (xDepthMin < 0) {
xDepthMin += dilationDepth;
}
const xDepthMax =
Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
const xRowCorner = yRow * strideHeight - padTop;
let xRowMin = xRowCorner;
while (xRowMin < 0) {
xRowMin += dilationHeight;
}
const xRowMax =
Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
const xColCorner = yCol * strideWidth - padLeft;
let xColMin = xColCorner;
while (xColMin < 0) {
xColMin += dilationWidth;
}
const xColMax =
Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);

// Shader code begins
let maxValue = Number.NEGATIVE_INFINITY;
let maxPosition = -1;

for (let xDepth = xDepthMin; xDepth < xDepthMax;
xDepth += dilationDepth) {
const wDepth = xDepth - xDepthCorner;
for (let xRow = xRowMin; xRow < xRowMax;
xRow += dilationHeight) {
const wRow = xRow - xRowCorner;
for (let xCol = xColMin; xCol < xColMax;
xCol += dilationWidth) {
const wCol = xCol - xColCorner;
const pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
if (pixel >= maxValue) {
maxValue = pixel;
maxPosition = wDepth * effectiveFilterHeight *
effectiveFilterWidth +
wRow * effectiveFilterHeight + wCol;
}
}
}
}

maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
}
}
}
}
}
return maxPositions.toTensor() as Tensor5D;
}

maxPool3dBackprop(
dy: Tensor5D, x: Tensor5D, y: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
this.assertNotComplex([x, y], 'maxPool3dBackprop');

const maxPositions = this.maxPool3dPositions(x, convInfo);
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const dx = ops.buffer<Rank.R5>(x.shape, 'float32');

const maxPosBuf = this.bufferSync(maxPositions);
const dyBuf = this.bufferSync(dy);

for (let batch = 0; batch < convInfo.batchSize; ++batch) {
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
// Shader code begins
const dyDepthCorner = dxDepth - padFront;
const dyRowCorner = dxRow - padTop;
const dyColCorner = dxCol - padLeft;
let dotProd = 0;
for (let wDepth = 0; wDepth < effectiveFilterDepth;
wDepth += dilationDepth) {
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
Math.floor(dyDepth) !== dyDepth) {
continue;
}
for (let wRow = 0; wRow < effectiveFilterHeight;
wRow += dilationHeight) {
const dyRow = (dyRowCorner + wRow) / strideHeight;
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
Math.floor(dyRow) !== dyRow) {
continue;
}
for (let wCol = 0; wCol < effectiveFilterWidth;
wCol += dilationWidth) {
const dyCol = (dyColCorner + wCol) / strideWidth;
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
Math.floor(dyCol) !== dyCol) {
continue;
}

const maxPos = effectiveFilterDepth *
effectiveFilterHeight * effectiveFilterWidth -
1 -
maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
const curPos =
wDepth * effectiveFilterHeight * effectiveFilterWidth +
wRow * effectiveFilterWidth + wCol;

const mask = maxPos === curPos ? 1 : 0;
if (mask === 0) {
continue;
}

const pixel =
dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
dotProd += pixel * mask;
}
}
}
dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
}
}
}
}
}
return dx.toTensor() as Tensor5D;
}

cast<T extends Tensor>(x: T, dtype: DataType): T {
return backend_util.castTensor(x, dtype, this);
}
Expand Down
Loading