@@ -491,7 +491,7 @@ function withSpaceToBatchBasePaddings(
491
491
* result.print();
492
492
* ```
493
493
*
494
- * @param x The input tensor, of rank 5 of shape
494
+ * @param x The input tensor, of rank 5 or rank 4 of shape
495
495
* `[batch, depth, height, width, inChannels]`.
496
496
* @param filterSize The filter size:
497
497
* `[filterDepth, filterHeight, filterWidth]`.
@@ -525,20 +525,30 @@ function withSpaceToBatchBasePaddings(
525
525
* If it is greater than 1, then all values of `strides` must be 1.
526
526
*/
527
527
/** @doc {heading: 'Operations', subheading: 'Convolution'} */
528
- function avgPool3d_ (
529
- x : Tensor5D | TensorLike , filterSize : [ number , number , number ] | number ,
530
- strides : [ number , number , number ] | number , pad : 'valid' | 'same' | number ,
528
+ function avgPool3d_ < T extends Tensor4D | Tensor5D > (
529
+ x : T | TensorLike ,
530
+ filterSize : [ number , number , number ] | number ,
531
+ strides : [ number , number , number ] | number ,
532
+ pad : 'valid' | 'same' | number ,
531
533
dimRoundingMode ?: 'floor' | 'round' | 'ceil' ,
532
534
dataFormat : 'NDHWC' | 'NCDHW' = 'NDHWC' ,
533
- dilations ?: [ number , number , number ] | number , ) : Tensor5D {
535
+ dilations ?: [ number , number , number ] | number ,
536
+ ) : T {
534
537
const $x = convertToTensor ( x , 'x' , 'avgPool3d' , 'float32' ) ;
535
538
539
+ let x5D = $x as Tensor5D ;
540
+ let reshapedTo5D = false ;
541
+ if ( $x . rank === 4 ) {
542
+ reshapedTo5D = true ;
543
+ x5D = $x . as5D ( 1 , $x . shape [ 0 ] , $x . shape [ 1 ] , $x . shape [ 2 ] , $x . shape [ 3 ] ) ;
544
+ }
545
+
536
546
if ( dilations == null ) {
537
547
dilations = [ 1 , 1 , 1 ] ;
538
548
}
539
549
util . assert (
540
- $x . rank === 5 ,
541
- ( ) => `Error in avgPool3d: x must be rank 5 but got rank ${ $x . rank } .` ) ;
550
+ x5D . rank === 5 ,
551
+ ( ) => `Error in avgPool3d: x must be rank 5 but got rank ${ x5D . rank } .` ) ;
542
552
util . assert (
543
553
dataFormat === 'NDHWC' ,
544
554
( ) => `Error in avgPool3d: Only NDHWC is currently supported, ` +
@@ -555,21 +565,25 @@ function avgPool3d_(
555
565
}
556
566
557
567
const convInfo = conv_util . computePool3DInfo (
558
- $x . shape , filterSize , strides , dilations , pad , dimRoundingMode ,
568
+ x5D . shape , filterSize , strides , dilations , pad , dimRoundingMode ,
559
569
dataFormat ) ;
560
570
561
571
const grad = ( dy : Tensor5D ) => {
562
572
return {
563
573
x : ( ) => avgPool3dBackprop (
564
- dy , $x , filterSize , strides , dilations , pad , dimRoundingMode )
574
+ dy , x5D , filterSize , strides , dilations , pad , dimRoundingMode )
565
575
} ;
566
576
} ;
567
577
568
578
let res = ENGINE . runKernel (
569
- backend => backend . avgPool3d ( $x , convInfo ) , { x : $x } , grad ) ;
570
- res = res . cast ( $x . dtype ) ;
579
+ backend => backend . avgPool3d ( x5D , convInfo ) , { x : x5D } , grad ) ;
580
+ res = res . cast ( x5D . dtype ) ;
581
+ if ( reshapedTo5D ) {
582
+ return res . as4D ( res . shape [ 1 ] , res . shape [ 2 ] , res . shape [ 3 ] , res . shape [ 4 ] ) as
583
+ T ;
584
+ }
571
585
572
- return res ;
586
+ return res as T ;
573
587
}
574
588
575
589
/**
@@ -578,7 +592,7 @@ function avgPool3d_(
578
592
* @param dy The dy error, of rank 5 of shape
579
593
* [batchSize, depth, height, width, channels].
580
594
* assumed.
581
- * @param input The original input image, of rank 5 of shape
595
+ * @param input The original input image, of rank 5 or rank4 of shape
582
596
* [batchSize, depth, height, width, channels].
583
597
* @param filterSize The filter size:
584
598
* `[filterDepth, filterHeight, filterWidth]`.
@@ -601,23 +615,33 @@ function avgPool3d_(
601
615
* number. If none is provided, it will not round and error if the output
602
616
* is of fractional size.
603
617
*/
604
- function avgPool3dBackprop (
605
- dy : Tensor5D | TensorLike , input : Tensor5D | TensorLike ,
618
+ function avgPool3dBackprop < T extends Tensor4D | Tensor5D > (
619
+ dy : T | TensorLike , input : T | TensorLike ,
606
620
filterSize : [ number , number , number ] | number ,
607
621
strides : [ number , number , number ] | number ,
608
622
dilations : [ number , number , number ] | number , pad : 'valid' | 'same' | number ,
609
- dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : Tensor5D {
623
+ dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
610
624
const $dy = convertToTensor ( dy , 'dy' , 'avgPool3dBackprop' ) ;
611
625
const $input = convertToTensor ( input , 'input' , 'avgPool3dBackprop' ) ;
612
626
627
+ let dy5D = $dy as Tensor5D ;
628
+ let input5D = $input as Tensor5D ;
629
+ let reshapedTo5D = false ;
630
+ if ( $input . rank === 4 ) {
631
+ reshapedTo5D = true ;
632
+ dy5D = $dy . as5D ( 1 , $dy . shape [ 0 ] , $dy . shape [ 1 ] , $dy . shape [ 2 ] , $dy . shape [ 3 ] ) ;
633
+ input5D = $input . as5D (
634
+ 1 , $input . shape [ 0 ] , $input . shape [ 1 ] , $input . shape [ 2 ] , $input . shape [ 3 ] ) ;
635
+ }
636
+
613
637
util . assert (
614
- $dy . rank === 5 ,
638
+ dy5D . rank === 5 ,
615
639
( ) => `Error in avgPool3dBackprop: dy must be rank 5 but got rank ` +
616
- `${ $dy . rank } .` ) ;
640
+ `${ dy5D . rank } .` ) ;
617
641
util . assert (
618
- $input . rank === 5 ,
642
+ input5D . rank === 5 ,
619
643
( ) => `Error in avgPool3dBackprop: input must be rank 5 but got rank ` +
620
- `${ $input . rank } .` ) ;
644
+ `${ input5D . rank } .` ) ;
621
645
if ( dilations == null ) {
622
646
dilations = [ 1 , 1 , 1 ] ;
623
647
}
@@ -633,12 +657,16 @@ function avgPool3dBackprop(
633
657
}
634
658
635
659
const convInfo = conv_util . computePool3DInfo (
636
- $input . shape , filterSize , strides , dilations , pad , dimRoundingMode ) ;
660
+ input5D . shape , filterSize , strides , dilations , pad , dimRoundingMode ) ;
637
661
const res = ENGINE . runKernel (
638
- backend => backend . avgPool3dBackprop ( $dy , $input , convInfo ) ,
639
- { $dy, $input} ) ;
662
+ backend => backend . avgPool3dBackprop ( dy5D , input5D , convInfo ) ,
663
+ { dy5D, input5D} ) ;
664
+ if ( reshapedTo5D ) {
665
+ return res . as4D ( res . shape [ 1 ] , res . shape [ 2 ] , res . shape [ 3 ] , res . shape [ 4 ] ) as
666
+ T ;
667
+ }
640
668
641
- return res ;
669
+ return res as T ;
642
670
}
643
671
644
672
/**
@@ -650,7 +678,7 @@ function avgPool3dBackprop(
650
678
* result.print();
651
679
* ```
652
680
*
653
- * @param x The input tensor, of rank 5 of shape
681
+ * @param x The input tensor, of rank 5 or rank 4 of shape
654
682
* `[batch, depth, height, width, inChannels]`.
655
683
* @param filterSize The filter size:
656
684
* `[filterDepth, filterHeight, filterWidth]`.
@@ -684,20 +712,27 @@ function avgPool3dBackprop(
684
712
* If it is greater than 1, then all values of `strides` must be 1.
685
713
*/
686
714
/** @doc {heading: 'Operations', subheading: 'Convolution'} */
687
- function maxPool3d_ (
688
- x : Tensor5D | TensorLike , filterSize : [ number , number , number ] | number ,
715
+ function maxPool3d_ < T extends Tensor4D | Tensor5D > (
716
+ x : T | TensorLike , filterSize : [ number , number , number ] | number ,
689
717
strides : [ number , number , number ] | number , pad : 'valid' | 'same' | number ,
690
718
dimRoundingMode ?: 'floor' | 'round' | 'ceil' ,
691
719
dataFormat : 'NDHWC' | 'NCDHW' = 'NDHWC' ,
692
- dilations ?: [ number , number , number ] | number ) : Tensor5D {
720
+ dilations ?: [ number , number , number ] | number ) : T {
693
721
const $x = convertToTensor ( x , 'x' , 'maxPool3d' ) ;
694
722
723
+ let x5D = $x as Tensor5D ;
724
+ let reshapedTo5D = false ;
725
+ if ( $x . rank === 4 ) {
726
+ reshapedTo5D = true ;
727
+ x5D = $x . as5D ( 1 , $x . shape [ 0 ] , $x . shape [ 1 ] , $x . shape [ 2 ] , $x . shape [ 3 ] ) ;
728
+ }
729
+
695
730
if ( dilations == null ) {
696
731
dilations = [ 1 , 1 , 1 ] ;
697
732
}
698
733
util . assert (
699
- $x . rank === 5 ,
700
- ( ) => `Error in maxPool3d: x must be rank 5 but got rank ${ $x . rank } .` ) ;
734
+ x5D . rank === 5 ,
735
+ ( ) => `Error in maxPool3d: x must be rank 5 but got rank ${ x5D . rank } .` ) ;
701
736
util . assert (
702
737
dataFormat === 'NDHWC' ,
703
738
( ) => `Error in maxPool3d: Only NDHWC is currently supported, ` +
@@ -714,25 +749,29 @@ function maxPool3d_(
714
749
}
715
750
716
751
const convInfo = conv_util . computePool3DInfo (
717
- $x . shape , filterSize , strides , dilations , pad , dimRoundingMode ,
752
+ x5D . shape , filterSize , strides , dilations , pad , dimRoundingMode ,
718
753
dataFormat ) ;
719
754
720
755
const grad = ( dy : Tensor5D , saved : Tensor [ ] ) => {
721
- const [ $x , y ] = saved ;
756
+ const [ x5D , y ] = saved ;
722
757
return {
723
758
x : ( ) => maxPool3dBackprop (
724
- dy , $x as Tensor5D , y as Tensor5D , filterSize , strides , dilations ,
759
+ dy , x5D as Tensor5D , y as Tensor5D , filterSize , strides , dilations ,
725
760
pad , dimRoundingMode )
726
761
} ;
727
762
} ;
728
763
729
764
const res = ENGINE . runKernel ( ( backend , save ) => {
730
- const y = backend . maxPool3d ( $x , convInfo ) ;
731
- save ( [ $x , y ] ) ;
765
+ const y = backend . maxPool3d ( x5D , convInfo ) ;
766
+ save ( [ x5D , y ] ) ;
732
767
return y ;
733
- } , { x : $x } , grad ) ;
768
+ } , { x : x5D } , grad ) ;
769
+ if ( reshapedTo5D ) {
770
+ return res . as4D ( res . shape [ 1 ] , res . shape [ 2 ] , res . shape [ 3 ] , res . shape [ 4 ] ) as
771
+ T ;
772
+ }
734
773
735
- return res ;
774
+ return res as T ;
736
775
}
737
776
738
777
/**
@@ -741,7 +780,7 @@ function maxPool3d_(
741
780
* @param dy The dy error, of rank 5 of shape
742
781
* [batchSize, depth, height, width, channels].
743
782
* assumed.
744
- * @param input The original input image, of rank 5 of shape
783
+ * @param input The original input image, of rank 5 or rank 4 of shape
745
784
* [batchSize, depth, height, width, channels].
746
785
* @param output The original output image, of rank 5 of shape
747
786
* [batchSize, outDepth, outHeight, outWidth, channels].
@@ -766,28 +805,42 @@ function maxPool3d_(
766
805
* number. If none is provided, it will not round and error if the output
767
806
* is of fractional size.
768
807
*/
769
- function maxPool3dBackprop (
770
- dy : Tensor5D | TensorLike , input : Tensor5D | TensorLike ,
771
- output : Tensor5D | TensorLike , filterSize : [ number , number , number ] | number ,
808
+ function maxPool3dBackprop < T extends Tensor4D | Tensor5D > (
809
+ dy : T | TensorLike , input : T | TensorLike , output : T | TensorLike ,
810
+ filterSize : [ number , number , number ] | number ,
772
811
strides : [ number , number , number ] | number ,
773
812
dilations : [ number , number , number ] | number , pad : 'valid' | 'same' | number ,
774
- dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : Tensor5D {
813
+ dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
775
814
const $dy = convertToTensor ( dy , 'dy' , 'maxPool3dBackprop' ) ;
776
815
const $input = convertToTensor ( input , 'input' , 'maxPool3dBackprop' ) ;
777
816
const $output = convertToTensor ( output , 'output' , 'maxPool3dBackprop' ) ;
778
817
818
+ let dy5D = $dy as Tensor5D ;
819
+ let input5D = $input as Tensor5D ;
820
+ let output5D = $output as Tensor5D ;
821
+ let reshapedTo5D = false ;
822
+ if ( $input . rank === 4 ) {
823
+ reshapedTo5D = true ;
824
+ dy5D = $dy . as5D ( 1 , $dy . shape [ 0 ] , $dy . shape [ 1 ] , $dy . shape [ 2 ] , $dy . shape [ 3 ] ) ;
825
+ input5D = $input . as5D (
826
+ 1 , $input . shape [ 0 ] , $input . shape [ 1 ] , $input . shape [ 2 ] , $input . shape [ 3 ] ) ;
827
+ output5D = $output . as5D (
828
+ 1 , $output . shape [ 0 ] , $output . shape [ 1 ] , $output . shape [ 2 ] ,
829
+ $output . shape [ 3 ] ) ;
830
+ }
831
+
779
832
util . assert (
780
- $dy . rank === 5 ,
833
+ dy5D . rank === 5 ,
781
834
( ) => `Error in maxPool3dBackprop: dy must be rank 5 but got rank ` +
782
- `${ $dy . rank } .` ) ;
835
+ `${ dy5D . rank } .` ) ;
783
836
util . assert (
784
- $input . rank === 5 ,
837
+ input5D . rank === 5 ,
785
838
( ) => `Error in maxPool3dBackprop: input must be rank 5 but got rank ` +
786
- `${ $input . rank } .` ) ;
839
+ `${ input5D . rank } .` ) ;
787
840
util . assert (
788
- $output . rank === 5 ,
841
+ output5D . rank === 5 ,
789
842
( ) => `Error in maxPool3dBackprop: output must be rank 5 but got rank ` +
790
- `${ $output . rank } .` ) ;
843
+ `${ output5D . rank } .` ) ;
791
844
if ( dilations == null ) {
792
845
dilations = [ 1 , 1 , 1 ] ;
793
846
}
@@ -803,11 +856,16 @@ function maxPool3dBackprop(
803
856
}
804
857
805
858
const convInfo = conv_util . computePool3DInfo (
806
- $input . shape , filterSize , strides , dilations , pad , dimRoundingMode ) ;
859
+ input5D . shape , filterSize , strides , dilations , pad , dimRoundingMode ) ;
807
860
const res = ENGINE . runKernel (
808
- backend => backend . maxPool3dBackprop ( $dy , $input , $output , convInfo ) ,
809
- { $dy, $input} ) ;
810
- return res ;
861
+ backend => backend . maxPool3dBackprop ( dy5D , input5D , output5D , convInfo ) ,
862
+ { dy5D, input5D} ) ;
863
+ if ( reshapedTo5D ) {
864
+ return res . as4D ( res . shape [ 1 ] , res . shape [ 2 ] , res . shape [ 3 ] , res . shape [ 4 ] ) as
865
+ T ;
866
+ }
867
+
868
+ return res as T ;
811
869
}
812
870
813
871
export const maxPool = op ( { maxPool_} ) ;
0 commit comments