40
40
ger ,
41
41
ger_destructive ,
42
42
)
43
- from pytensor .tensor .math import Dot , dot , mean , mul , outer , sigmoid
43
+ from pytensor .tensor .math import Dot , dot , mean , mul , sigmoid
44
44
from pytensor .tensor .rewriting .blas import local_dot22_to_dot22scalar , local_gemm_to_ger
45
45
from pytensor .tensor .type import (
46
46
cmatrix ,
@@ -1721,9 +1721,12 @@ def clone(self, op):
1721
1721
class TestGer (unittest_tools .OptimizationTestMixin ):
1722
1722
shared = staticmethod (shared )
1723
1723
1724
+ def outer_via_dot (self , x , y ):
1725
+ return pt .dot (x [:, None ], y [None , :])
1726
+
1724
1727
def setup_method (self ):
1725
1728
self .mode = pytensor .compile .get_default_mode ().including ("fast_run" )
1726
- self .mode = self .mode .excluding ("c_blas" , "scipy_blas" )
1729
+ self .mode = self .mode .excluding ("c_blas" , "scipy_blas" , "local_dot_to_mul" )
1727
1730
dtype = self .dtype = "float64" # optimization isn't dtype-dependent
1728
1731
self .A = tensor (dtype = dtype , shape = (None , None ))
1729
1732
self .a = tensor (dtype = dtype , shape = ())
@@ -1795,7 +1798,7 @@ def test_b_nonconst_does_not_triggers_ger(self):
1795
1798
1796
1799
def test_outer (self ):
1797
1800
rng = np .random .default_rng (unittest_tools .fetch_seed ())
1798
- f = self .function ([self .x , self .y ], outer (self .x , self .y ))
1801
+ f = self .function ([self .x , self .y ], self . outer_via_dot (self .x , self .y ))
1799
1802
self .assertFunctionContains (f , self .ger_destructive )
1800
1803
f (
1801
1804
rng .random (5 ).astype (self .dtype ),
@@ -1804,7 +1807,9 @@ def test_outer(self):
1804
1807
1805
1808
def test_A_plus_outer (self ):
1806
1809
rng = np .random .default_rng (unittest_tools .fetch_seed ())
1807
- f = self .function ([self .A , self .x , self .y ], self .A + outer (self .x , self .y ))
1810
+ f = self .function (
1811
+ [self .A , self .x , self .y ], self .A + self .outer_via_dot (self .x , self .y )
1812
+ )
1808
1813
self .assertFunctionContains (f , self .ger )
1809
1814
f (
1810
1815
rng .random ((5 , 4 )).astype (self .dtype ),
@@ -1820,7 +1825,7 @@ def test_A_plus_outer(self):
1820
1825
def test_A_plus_scaled_outer (self ):
1821
1826
rng = np .random .default_rng (unittest_tools .fetch_seed ())
1822
1827
f = self .function (
1823
- [self .A , self .x , self .y ], self .A + 0.1 * outer (self .x , self .y )
1828
+ [self .A , self .x , self .y ], self .A + 0.1 * self . outer_via_dot (self .x , self .y )
1824
1829
)
1825
1830
self .assertFunctionContains (f , self .ger )
1826
1831
f (
@@ -1839,7 +1844,7 @@ def test_scaled_A_plus_scaled_outer(self):
1839
1844
f = self .function (
1840
1845
[self .A , self .x , self .y ],
1841
1846
np .asarray (0.2 , self .dtype ) * self .A
1842
- + np .asarray (0.1 , self .dtype ) * outer (self .x , self .y ),
1847
+ + np .asarray (0.1 , self .dtype ) * self . outer_via_dot (self .x , self .y ),
1843
1848
)
1844
1849
# Why gemm? This make the graph simpler did we test that it
1845
1850
# make it faster?
@@ -1863,7 +1868,7 @@ def given_dtype(self, dtype, M, N, *, destructive=True):
1863
1868
x = tensor (dtype = dtype , shape = (None ,))
1864
1869
y = tensor (dtype = dtype , shape = (None ,))
1865
1870
1866
- f = self .function ([A , x , y ], A + 0.1 * outer (x , y ))
1871
+ f = self .function ([A , x , y ], A + 0.1 * self . outer_via_dot (x , y ))
1867
1872
self .assertFunctionContains (
1868
1873
f , self .ger_destructive if destructive else self .ger
1869
1874
)
@@ -1923,7 +1928,12 @@ def test_inplace(self):
1923
1928
[self .x , self .y ],
1924
1929
[],
1925
1930
updates = [
1926
- (A , A + pt .constant (0.1 , dtype = self .dtype ) * outer (self .x , self .y ))
1931
+ (
1932
+ A ,
1933
+ A
1934
+ + pt .constant (0.1 , dtype = self .dtype )
1935
+ * self .outer_via_dot (self .x , self .y ),
1936
+ )
1927
1937
],
1928
1938
)
1929
1939
self .assertFunctionContains (f , self .ger_destructive )
@@ -2264,10 +2274,15 @@ def cmp_ger(self, a_shp, b_shp, c_shp, rng):
2264
2274
b_dev = b .get_value (borrow = False , return_internal_type = True )
2265
2275
c_dev = c .get_value (borrow = False , return_internal_type = True )
2266
2276
2267
- f_n = function ([], [], updates = [(a , (a + l * outer (b , c )))], mode = self .mode )
2277
+ f_n = function (
2278
+ [], [], updates = [(a , (a + l * self .outer_via_dot (b , c )))], mode = self .mode
2279
+ )
2268
2280
2269
2281
f_t = function (
2270
- [], [], updates = [(a_t , (a_t + l * outer (b , c ).T ))], mode = self .mode
2282
+ [],
2283
+ [],
2284
+ updates = [(a_t , (a_t + l * self .outer_via_dot (b , c ).T ))],
2285
+ mode = self .mode ,
2271
2286
)
2272
2287
2273
2288
# Try with all stride patterns, and all transposed patterns
0 commit comments