Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit c53fe9a

Browse files
committed
Support 4D
1 parent d63800d commit c53fe9a

File tree

4 files changed

+153
-57
lines changed

4 files changed

+153
-57
lines changed

src/ops/conv_util.ts

+2-4
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,7 @@ export function computeDefaultPad(
362362
(inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
363363
}
364364

365-
function parseTupleParam(
366-
param: number|number[]): [number, number, number] {
365+
function parseTupleParam(param: number|number[]): [number, number, number] {
367366
if (typeof param === 'number') {
368367
return [param, param, param];
369368
}
@@ -526,8 +525,7 @@ function conditionalRound(
526525
}
527526
}
528527

529-
export function tupleValuesAreOne(
530-
param: number|number[]): boolean {
528+
export function tupleValuesAreOne(param: number|number[]): boolean {
531529
const [dimA, dimB, dimC] = parseTupleParam(param);
532530
return dimA === 1 && dimB === 1 && dimC === 1;
533531
}

src/ops/conv_util_test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ describe('conv_util computePool3dInfo', () => {
974974
expect(
975975
() => conv_util.computePool3DInfo(
976976
inShape, filterSize, stride, dilation, 1, 'floor', fakeDataFormat))
977-
.toThrowError();
977+
.toThrowError();
978978
});
979979
});
980980

src/ops/pool.ts

+110-52
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ function withSpaceToBatchBasePaddings(
491491
* result.print();
492492
* ```
493493
*
494-
* @param x The input tensor, of rank 5 of shape
494+
* @param x The input tensor, of rank 5 or rank 4 of shape
495495
* `[batch, depth, height, width, inChannels]`.
496496
* @param filterSize The filter size:
497497
* `[filterDepth, filterHeight, filterWidth]`.
@@ -525,20 +525,30 @@ function withSpaceToBatchBasePaddings(
525525
* If it is greater than 1, then all values of `strides` must be 1.
526526
*/
527527
/** @doc {heading: 'Operations', subheading: 'Convolution'} */
528-
function avgPool3d_(
529-
x: Tensor5D|TensorLike, filterSize: [number, number, number]|number,
530-
strides: [number, number, number]|number, pad: 'valid'|'same'|number,
528+
function avgPool3d_<T extends Tensor4D|Tensor5D>(
529+
x: T|TensorLike,
530+
filterSize: [number, number, number]|number,
531+
strides: [number, number, number]|number,
532+
pad: 'valid'|'same'|number,
531533
dimRoundingMode?: 'floor'|'round'|'ceil',
532534
dataFormat: 'NDHWC'|'NCDHW' = 'NDHWC',
533-
dilations?: [number, number, number]|number,): Tensor5D {
535+
dilations?: [number, number, number]|number,
536+
): T {
534537
const $x = convertToTensor(x, 'x', 'avgPool3d', 'float32');
535538

539+
let x5D = $x as Tensor5D;
540+
let reshapedTo5D = false;
541+
if ($x.rank === 4) {
542+
reshapedTo5D = true;
543+
x5D = $x.as5D(1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]);
544+
}
545+
536546
if (dilations == null) {
537547
dilations = [1, 1, 1];
538548
}
539549
util.assert(
540-
$x.rank === 5,
541-
() => `Error in avgPool3d: x must be rank 5 but got rank ${$x.rank}.`);
550+
x5D.rank === 5,
551+
() => `Error in avgPool3d: x must be rank 5 but got rank ${x5D.rank}.`);
542552
util.assert(
543553
dataFormat === 'NDHWC',
544554
() => `Error in avgPool3d: Only NDHWC is currently supported, ` +
@@ -555,21 +565,25 @@ function avgPool3d_(
555565
}
556566

557567
const convInfo = conv_util.computePool3DInfo(
558-
$x.shape, filterSize, strides, dilations, pad, dimRoundingMode,
568+
x5D.shape, filterSize, strides, dilations, pad, dimRoundingMode,
559569
dataFormat);
560570

561571
const grad = (dy: Tensor5D) => {
562572
return {
563573
x: () => avgPool3dBackprop(
564-
dy, $x, filterSize, strides, dilations, pad, dimRoundingMode)
574+
dy, x5D, filterSize, strides, dilations, pad, dimRoundingMode)
565575
};
566576
};
567577

568578
let res = ENGINE.runKernel(
569-
backend => backend.avgPool3d($x, convInfo), {x: $x}, grad);
570-
res = res.cast($x.dtype);
579+
backend => backend.avgPool3d(x5D, convInfo), {x: x5D}, grad);
580+
res = res.cast(x5D.dtype);
581+
if (reshapedTo5D) {
582+
return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as
583+
T;
584+
}
571585

572-
return res;
586+
return res as T;
573587
}
574588

575589
/**
@@ -578,7 +592,7 @@ function avgPool3d_(
578592
* @param dy The dy error, of rank 5 of shape
579593
* [batchSize, depth, height, width, channels].
580594
* assumed.
581-
* @param input The original input image, of rank 5 of shape
595+
* @param input The original input image, of rank 5 or rank4 of shape
582596
* [batchSize, depth, height, width, channels].
583597
* @param filterSize The filter size:
584598
* `[filterDepth, filterHeight, filterWidth]`.
@@ -601,23 +615,33 @@ function avgPool3d_(
601615
* number. If none is provided, it will not round and error if the output
602616
* is of fractional size.
603617
*/
604-
function avgPool3dBackprop(
605-
dy: Tensor5D|TensorLike, input: Tensor5D|TensorLike,
618+
function avgPool3dBackprop<T extends Tensor4D|Tensor5D>(
619+
dy: T|TensorLike, input: T|TensorLike,
606620
filterSize: [number, number, number]|number,
607621
strides: [number, number, number]|number,
608622
dilations: [number, number, number]|number, pad: 'valid'|'same'|number,
609-
dimRoundingMode?: 'floor'|'round'|'ceil'): Tensor5D {
623+
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
610624
const $dy = convertToTensor(dy, 'dy', 'avgPool3dBackprop');
611625
const $input = convertToTensor(input, 'input', 'avgPool3dBackprop');
612626

627+
let dy5D = $dy as Tensor5D;
628+
let input5D = $input as Tensor5D;
629+
let reshapedTo5D = false;
630+
if ($input.rank === 4) {
631+
reshapedTo5D = true;
632+
dy5D = $dy.as5D(1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]);
633+
input5D = $input.as5D(
634+
1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]);
635+
}
636+
613637
util.assert(
614-
$dy.rank === 5,
638+
dy5D.rank === 5,
615639
() => `Error in avgPool3dBackprop: dy must be rank 5 but got rank ` +
616-
`${$dy.rank}.`);
640+
`${dy5D.rank}.`);
617641
util.assert(
618-
$input.rank === 5,
642+
input5D.rank === 5,
619643
() => `Error in avgPool3dBackprop: input must be rank 5 but got rank ` +
620-
`${$input.rank}.`);
644+
`${input5D.rank}.`);
621645
if (dilations == null) {
622646
dilations = [1, 1, 1];
623647
}
@@ -633,12 +657,16 @@ function avgPool3dBackprop(
633657
}
634658

635659
const convInfo = conv_util.computePool3DInfo(
636-
$input.shape, filterSize, strides, dilations, pad, dimRoundingMode);
660+
input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode);
637661
const res = ENGINE.runKernel(
638-
backend => backend.avgPool3dBackprop($dy, $input, convInfo),
639-
{$dy, $input});
662+
backend => backend.avgPool3dBackprop(dy5D, input5D, convInfo),
663+
{dy5D, input5D});
664+
if (reshapedTo5D) {
665+
return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as
666+
T;
667+
}
640668

641-
return res;
669+
return res as T;
642670
}
643671

644672
/**
@@ -650,7 +678,7 @@ function avgPool3dBackprop(
650678
* result.print();
651679
* ```
652680
*
653-
* @param x The input tensor, of rank 5 of shape
681+
* @param x The input tensor, of rank 5 or rank 4 of shape
654682
* `[batch, depth, height, width, inChannels]`.
655683
* @param filterSize The filter size:
656684
* `[filterDepth, filterHeight, filterWidth]`.
@@ -684,20 +712,27 @@ function avgPool3dBackprop(
684712
* If it is greater than 1, then all values of `strides` must be 1.
685713
*/
686714
/** @doc {heading: 'Operations', subheading: 'Convolution'} */
687-
function maxPool3d_(
688-
x: Tensor5D|TensorLike, filterSize: [number, number, number]|number,
715+
function maxPool3d_<T extends Tensor4D|Tensor5D>(
716+
x: T|TensorLike, filterSize: [number, number, number]|number,
689717
strides: [number, number, number]|number, pad: 'valid'|'same'|number,
690718
dimRoundingMode?: 'floor'|'round'|'ceil',
691719
dataFormat: 'NDHWC'|'NCDHW' = 'NDHWC',
692-
dilations?: [number, number, number]|number): Tensor5D {
720+
dilations?: [number, number, number]|number): T {
693721
const $x = convertToTensor(x, 'x', 'maxPool3d');
694722

723+
let x5D = $x as Tensor5D;
724+
let reshapedTo5D = false;
725+
if ($x.rank === 4) {
726+
reshapedTo5D = true;
727+
x5D = $x.as5D(1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]);
728+
}
729+
695730
if (dilations == null) {
696731
dilations = [1, 1, 1];
697732
}
698733
util.assert(
699-
$x.rank === 5,
700-
() => `Error in maxPool3d: x must be rank 5 but got rank ${$x.rank}.`);
734+
x5D.rank === 5,
735+
() => `Error in maxPool3d: x must be rank 5 but got rank ${x5D.rank}.`);
701736
util.assert(
702737
dataFormat === 'NDHWC',
703738
() => `Error in maxPool3d: Only NDHWC is currently supported, ` +
@@ -714,25 +749,29 @@ function maxPool3d_(
714749
}
715750

716751
const convInfo = conv_util.computePool3DInfo(
717-
$x.shape, filterSize, strides, dilations, pad, dimRoundingMode,
752+
x5D.shape, filterSize, strides, dilations, pad, dimRoundingMode,
718753
dataFormat);
719754

720755
const grad = (dy: Tensor5D, saved: Tensor[]) => {
721-
const [$x, y] = saved;
756+
const [x5D, y] = saved;
722757
return {
723758
x: () => maxPool3dBackprop(
724-
dy, $x as Tensor5D, y as Tensor5D, filterSize, strides, dilations,
759+
dy, x5D as Tensor5D, y as Tensor5D, filterSize, strides, dilations,
725760
pad, dimRoundingMode)
726761
};
727762
};
728763

729764
const res = ENGINE.runKernel((backend, save) => {
730-
const y = backend.maxPool3d($x, convInfo);
731-
save([$x, y]);
765+
const y = backend.maxPool3d(x5D, convInfo);
766+
save([x5D, y]);
732767
return y;
733-
}, {x: $x}, grad);
768+
}, {x: x5D}, grad);
769+
if (reshapedTo5D) {
770+
return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as
771+
T;
772+
}
734773

735-
return res;
774+
return res as T;
736775
}
737776

738777
/**
@@ -741,7 +780,7 @@ function maxPool3d_(
741780
* @param dy The dy error, of rank 5 of shape
742781
* [batchSize, depth, height, width, channels].
743782
* assumed.
744-
* @param input The original input image, of rank 5 of shape
783+
* @param input The original input image, of rank 5 or rank 4 of shape
745784
* [batchSize, depth, height, width, channels].
746785
* @param output The original output image, of rank 5 of shape
747786
* [batchSize, outDepth, outHeight, outWidth, channels].
@@ -766,28 +805,42 @@ function maxPool3d_(
766805
* number. If none is provided, it will not round and error if the output
767806
* is of fractional size.
768807
*/
769-
function maxPool3dBackprop(
770-
dy: Tensor5D|TensorLike, input: Tensor5D|TensorLike,
771-
output: Tensor5D|TensorLike, filterSize: [number, number, number]|number,
808+
function maxPool3dBackprop<T extends Tensor4D|Tensor5D>(
809+
dy: T|TensorLike, input: T|TensorLike, output: T|TensorLike,
810+
filterSize: [number, number, number]|number,
772811
strides: [number, number, number]|number,
773812
dilations: [number, number, number]|number, pad: 'valid'|'same'|number,
774-
dimRoundingMode?: 'floor'|'round'|'ceil'): Tensor5D {
813+
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
775814
const $dy = convertToTensor(dy, 'dy', 'maxPool3dBackprop');
776815
const $input = convertToTensor(input, 'input', 'maxPool3dBackprop');
777816
const $output = convertToTensor(output, 'output', 'maxPool3dBackprop');
778817

818+
let dy5D = $dy as Tensor5D;
819+
let input5D = $input as Tensor5D;
820+
let output5D = $output as Tensor5D;
821+
let reshapedTo5D = false;
822+
if ($input.rank === 4) {
823+
reshapedTo5D = true;
824+
dy5D = $dy.as5D(1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]);
825+
input5D = $input.as5D(
826+
1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]);
827+
output5D = $output.as5D(
828+
1, $output.shape[0], $output.shape[1], $output.shape[2],
829+
$output.shape[3]);
830+
}
831+
779832
util.assert(
780-
$dy.rank === 5,
833+
dy5D.rank === 5,
781834
() => `Error in maxPool3dBackprop: dy must be rank 5 but got rank ` +
782-
`${$dy.rank}.`);
835+
`${dy5D.rank}.`);
783836
util.assert(
784-
$input.rank === 5,
837+
input5D.rank === 5,
785838
() => `Error in maxPool3dBackprop: input must be rank 5 but got rank ` +
786-
`${$input.rank}.`);
839+
`${input5D.rank}.`);
787840
util.assert(
788-
$output.rank === 5,
841+
output5D.rank === 5,
789842
() => `Error in maxPool3dBackprop: output must be rank 5 but got rank ` +
790-
`${$output.rank}.`);
843+
`${output5D.rank}.`);
791844
if (dilations == null) {
792845
dilations = [1, 1, 1];
793846
}
@@ -803,11 +856,16 @@ function maxPool3dBackprop(
803856
}
804857

805858
const convInfo = conv_util.computePool3DInfo(
806-
$input.shape, filterSize, strides, dilations, pad, dimRoundingMode);
859+
input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode);
807860
const res = ENGINE.runKernel(
808-
backend => backend.maxPool3dBackprop($dy, $input, $output, convInfo),
809-
{$dy, $input});
810-
return res;
861+
backend => backend.maxPool3dBackprop(dy5D, input5D, output5D, convInfo),
862+
{dy5D, input5D});
863+
if (reshapedTo5D) {
864+
return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as
865+
T;
866+
}
867+
868+
return res as T;
811869
}
812870

813871
export const maxPool = op({maxPool_});

0 commit comments

Comments
 (0)