@@ -532,7 +532,7 @@ function avgPool3d_(
532
532
dimRoundingMode ?: 'floor' | 'round' | 'ceil' ,
533
533
dataFormat : 'NDHWC' | 'NCDHW' = 'NDHWC' ) : Tensor5D {
534
534
const $x = convertToTensor ( x , 'x' , 'avgPool3d' , 'float32' ) ;
535
-
535
+
536
536
if ( dilations == null ) {
537
537
dilations = [ 1 , 1 , 1 ] ;
538
538
}
@@ -553,22 +553,22 @@ function avgPool3d_(
553
553
( ) => `Error in avgPool3d: pad must be an integer when using, ` +
554
554
`dimRoundingMode ${ dimRoundingMode } but got pad ${ pad } .` ) ;
555
555
}
556
-
556
+
557
557
const convInfo = conv_util . computePool3DInfo (
558
558
$x . shape , filterSize , strides , dilations , pad , dimRoundingMode ,
559
559
dataFormat ) ;
560
-
560
+
561
561
const grad = ( dy : Tensor5D ) => {
562
562
return {
563
563
x : ( ) => avgPool3dBackprop (
564
564
dy , $x , filterSize , strides , dilations , pad , dimRoundingMode )
565
565
} ;
566
566
} ;
567
-
567
+
568
568
let res = ENGINE . runKernel (
569
569
backend => backend . avgPool3d ( $x , convInfo ) , { x : $x } , grad ) ;
570
570
res = res . cast ( $x . dtype ) ;
571
-
571
+
572
572
return res ;
573
573
}
574
574
@@ -691,7 +691,7 @@ function maxPool3d_(
691
691
dimRoundingMode ?: 'floor' | 'round' | 'ceil' ,
692
692
dataFormat : 'NDHWC' | 'NCDHW' = 'NDHWC' ) : Tensor5D {
693
693
const $x = convertToTensor ( x , 'x' , 'maxPool3d' ) ;
694
-
694
+
695
695
if ( dilations == null ) {
696
696
dilations = [ 1 , 1 , 1 ] ;
697
697
}
@@ -712,11 +712,11 @@ function maxPool3d_(
712
712
( ) => `Error in maxPool3d: pad must be an integer when using, ` +
713
713
`dimRoundingMode ${ dimRoundingMode } but got pad ${ pad } .` ) ;
714
714
}
715
-
715
+
716
716
const convInfo = conv_util . computePool3DInfo (
717
717
$x . shape , filterSize , strides , dilations , pad , dimRoundingMode ,
718
718
dataFormat ) ;
719
-
719
+
720
720
const grad = ( dy : Tensor5D , saved : Tensor [ ] ) => {
721
721
const [ $x , y ] = saved ;
722
722
return {
@@ -725,13 +725,13 @@ function maxPool3d_(
725
725
pad , dimRoundingMode )
726
726
} ;
727
727
} ;
728
-
728
+
729
729
const res = ENGINE . runKernel ( ( backend , save ) => {
730
730
const y = backend . maxPool3d ( $x , convInfo ) ;
731
731
save ( [ $x , y ] ) ;
732
732
return y ;
733
733
} , { x : $x } , grad ) ;
734
-
734
+
735
735
return res ;
736
736
}
737
737
0 commit comments