Skip to content
This repository was archived by the owner on Oct 17, 2021. It is now read-only.

Add maxPooling3d layer & averagePooling3d layer #555

Merged
merged 8 commits into from
Aug 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion src/exports_layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {Add, Average, Concatenate, ConcatenateLayerArgs, Dot, DotLayerArgs, Maxi
import {AlphaDropout, AlphaDropoutArgs, GaussianDropout, GaussianDropoutArgs, GaussianNoise, GaussianNoiseArgs} from './layers/noise';
import {BatchNormalization, BatchNormalizationLayerArgs} from './layers/normalization';
import {ZeroPadding2D, ZeroPadding2DLayerArgs} from './layers/padding';
import {AveragePooling1D, AveragePooling2D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, Pooling1DLayerArgs, Pooling2DLayerArgs} from './layers/pooling';
import {AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, MaxPooling3D, Pooling1DLayerArgs, Pooling2DLayerArgs, Pooling3DLayerArgs} from './layers/pooling';
import {GRU, GRUCell, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, RNNLayerArgs, SimpleRNN, SimpleRNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs, StackedRNNCells, StackedRNNCellsArgs} from './layers/recurrent';
import {Bidirectional, BidirectionalLayerArgs, TimeDistributed, WrapperLayerArgs} from './layers/wrappers';

Expand Down Expand Up @@ -918,6 +918,38 @@ export function avgPooling2d(args: Pooling2DLayerArgs): Layer {
return averagePooling2d(args);
}

/**
* Average pooling operation for 3D data.
*
* Input shape
* - If `dataFormat === channelsLast`:
* 5D tensor with shape:
* `[batchSize, depths, rows, cols, channels]`
* - If `dataFormat === channelsFirst`:
* 4D tensor with shape:
* `[batchSize, channels, depths, rows, cols]`
*
* Output shape
* - If `dataFormat=channelsLast`:
* 5D tensor with shape:
* `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
* - If `dataFormat=channelsFirst`:
* 5D tensor with shape:
* `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
*/
/** @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */
export function averagePooling3d(args: Pooling3DLayerArgs): Layer {
return new AveragePooling3D(args);
}
export function avgPool3d(args: Pooling3DLayerArgs): Layer {
return averagePooling3d(args);
}
// For backwards compatibility.
// See https://github.com/tensorflow/tfjs/issues/152
export function avgPooling3d(args: Pooling3DLayerArgs): Layer {
return averagePooling3d(args);
}

/**
* Global average pooling operation for temporal data.
*
Expand Down Expand Up @@ -1012,6 +1044,30 @@ export function maxPooling2d(args: Pooling2DLayerArgs): Layer {
return new MaxPooling2D(args);
}

/**
* Max pooling operation for 3D data.
*
* Input shape
* - If `dataFormat === channelsLast`:
* 5D tensor with shape:
* `[batchSize, depths, rows, cols, channels]`
* - If `dataFormat === channelsFirst`:
* 5D tensor with shape:
* `[batchSize, channels, depths, rows, cols]`
*
* Output shape
* - If `dataFormat=channelsLast`:
* 5D tensor with shape:
* `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
* - If `dataFormat=channelsFirst`:
* 5D tensor with shape:
* `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
*/
/** @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */
export function maxPooling3d(args: Pooling3DLayerArgs): Layer {
return new MaxPooling3D(args);
}

// Recurrent Layers.

/**
Expand Down
204 changes: 202 additions & 2 deletions src/layers/pooling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/

import * as tfc from '@tensorflow/tfjs-core';
import {serialization, Tensor, Tensor3D, Tensor4D, tidy} from '@tensorflow/tfjs-core';
import {serialization, Tensor, Tensor3D, Tensor4D, Tensor5D, tidy} from '@tensorflow/tfjs-core';

import {imageDataFormat} from '../backend/common';
import * as K from '../backend/tfjs_backend';
Expand All @@ -27,7 +27,7 @@ import {convOutputLength} from '../utils/conv_utils';
import {assertPositiveInteger} from '../utils/generic_utils';
import {getExactlyOneShape, getExactlyOneTensor} from '../utils/types_utils';

import {preprocessConv2DInput} from './convolutional';
import {preprocessConv2DInput, preprocessConv3DInput} from './convolutional';

/**
* 2D pooling.
Expand Down Expand Up @@ -82,6 +82,52 @@ export function pool2d(
});
}

/**
* 3D pooling.
* @param x
* @param poolSize. Default to [1, 1, 1].
* @param strides strides. Defaults to [1, 1, 1].
* @param padding padding. Defaults to 'valid'.
* @param dataFormat data format. Defaults to 'channelsLast'.
* @param poolMode Mode of pooling. Defaults to 'max'.
* @returns Result of the 3D pooling.
*/
export function pool3d(
x: Tensor5D, poolSize: [number, number, number],
strides?: [number, number, number], padding?: PaddingMode,
dataFormat?: DataFormat, poolMode?: PoolMode): Tensor {
return tidy(() => {
checkDataFormat(dataFormat);
checkPoolMode(poolMode);
checkPaddingMode(padding);
if (strides == null) {
strides = [1, 1, 1];
}
if (padding == null) {
padding = 'valid';
}
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
if (poolMode == null) {
poolMode = 'max';
}

// x is NDHWC after preprocessing.
x = preprocessConv3DInput(x as Tensor, dataFormat) as Tensor5D;
let y: Tensor;
const paddingString = (padding === 'same') ? 'same' : 'valid';
if (poolMode === 'max') {
y = tfc.maxPool3d(x, poolSize, strides, paddingString);
} else { // 'avg'
y = tfc.avgPool3d(x, poolSize, strides, paddingString);
}
if (dataFormat === 'channelsFirst') {
y = tfc.transpose(y, [0, 4, 1, 2, 3]); // NDHWC -> NCDHW.
}
return y;
});
}

export declare interface Pooling1DLayerArgs extends LayerArgs {
/**
Expand Down Expand Up @@ -370,6 +416,160 @@ export class AveragePooling2D extends Pooling2D {
}
serialization.registerClass(AveragePooling2D);

export declare interface Pooling3DLayerArgs extends LayerArgs {
/**
* Factors by which to downscale in each dimension [depth, height, width].
* Expects an integer or an array of 3 integers.
*
* For example, `[2, 2, 2]` will halve the input in three dimensions.
* If only one integer is specified, the same window length
* will be used for all dimensions.
*/
poolSize?: number|[number, number, number];

/**
* The size of the stride in each dimension of the pooling window. Expects
* an integer or an array of 3 integers. Integer, tuple of 3 integers, or
* None.
*
* If `null`, defaults to `poolSize`.
*/
strides?: number|[number, number, number];

/** The padding type to use for the pooling layer. */
padding?: PaddingMode;
/** The data format to use for the pooling layer. */
dataFormat?: DataFormat;
}

/**
* Abstract class for different pooling 3D layers.
*/
export abstract class Pooling3D extends Layer {
protected readonly poolSize: [number, number, number];
protected readonly strides: [number, number, number];
protected readonly padding: PaddingMode;
protected readonly dataFormat: DataFormat;

constructor(args: Pooling3DLayerArgs) {
if (args.poolSize == null) {
args.poolSize = [2, 2, 2];
}
super(args);
this.poolSize = Array.isArray(args.poolSize) ?
args.poolSize :
[args.poolSize, args.poolSize, args.poolSize];
if (args.strides == null) {
this.strides = this.poolSize;
} else if (Array.isArray(args.strides)) {
if (args.strides.length !== 3) {
throw new ValueError(
`If the strides property of a 3D pooling layer is an Array, ` +
`it is expected to have a length of 3, but received length ` +
`${args.strides.length}.`);
}
this.strides = args.strides;
} else {
// `config.strides` is a number.
this.strides = [args.strides, args.strides, args.strides];
}
assertPositiveInteger(this.poolSize, 'poolSize');
assertPositiveInteger(this.strides, 'strides');
this.padding = args.padding == null ? 'valid' : args.padding;
this.dataFormat =
args.dataFormat == null ? 'channelsLast' : args.dataFormat;
checkDataFormat(this.dataFormat);
checkPaddingMode(this.padding);

this.inputSpec = [new InputSpec({ndim: 5})];
}

computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {
inputShape = getExactlyOneShape(inputShape);
let depths =
this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
let rows =
this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
let cols =
this.dataFormat === 'channelsFirst' ? inputShape[4] : inputShape[3];
depths = convOutputLength(
depths, this.poolSize[0], this.padding, this.strides[0]);
rows =
convOutputLength(rows, this.poolSize[1], this.padding, this.strides[1]);
cols =
convOutputLength(cols, this.poolSize[2], this.padding, this.strides[2]);
if (this.dataFormat === 'channelsFirst') {
return [inputShape[0], inputShape[1], depths, rows, cols];
} else {
return [inputShape[0], depths, rows, cols, inputShape[4]];
}
}

protected abstract poolingFunction(
inputs: Tensor, poolSize: [number, number, number],
strides: [number, number, number], padding: PaddingMode,
dataFormat: DataFormat): Tensor;

call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {
return tidy(() => {
this.invokeCallHook(inputs, kwargs);
return this.poolingFunction(
getExactlyOneTensor(inputs), this.poolSize, this.strides,
this.padding, this.dataFormat);
});
}

getConfig(): serialization.ConfigDict {
const config = {
poolSize: this.poolSize,
padding: this.padding,
strides: this.strides,
dataFormat: this.dataFormat
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}

export class MaxPooling3D extends Pooling3D {
/** @nocollapse */
static className = 'MaxPooling3D';
constructor(args: Pooling3DLayerArgs) {
super(args);
}

protected poolingFunction(
inputs: Tensor, poolSize: [number, number, number],
strides: [number, number, number], padding: PaddingMode,
dataFormat: DataFormat): Tensor {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool3d(
inputs as Tensor5D, poolSize, strides, padding, dataFormat, 'max');
}
}
serialization.registerClass(MaxPooling3D);

export class AveragePooling3D extends Pooling3D {
/** @nocollapse */
static className = 'AveragePooling3D';
constructor(args: Pooling3DLayerArgs) {
super(args);
}

protected poolingFunction(
inputs: Tensor, poolSize: [number, number, number],
strides: [number, number, number], padding: PaddingMode,
dataFormat: DataFormat): Tensor {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool3d(
inputs as Tensor5D, poolSize, strides, padding, dataFormat, 'avg');
}
}
serialization.registerClass(AveragePooling3D);

/**
* Abstract class for different global pooling 1D layers.
*/
Expand Down
Loading