File tree 1 file changed +3
-3
lines changed
1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -227,7 +227,7 @@ def __init__(self, n):
227
227
228
228
def forward (self , X ):
229
229
# (I + X)(I - X)^{-1}
230
- return torch .solve (self .Id + X , self .Id - X ). solution
230
+ return torch .linalg . solve (self .Id - X , self .Id + X )
231
231
232
232
layer = nn .Linear (3 , 3 )
233
233
parametrize .register_parametrization (layer , "weight" , Skew ())
@@ -301,13 +301,13 @@ def __init__(self, n):
301
301
def forward (self , X ):
302
302
# Assume X skew-symmetric
303
303
# (I + X)(I - X)^{-1}
304
- return torch .solve (self .Id + X , self .Id - X ). solution
304
+ return torch .linalg . solve (self .Id - X , self .Id + X )
305
305
306
306
def right_inverse (self , A ):
307
307
# Assume A orthogonal
308
308
# See https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
309
309
# (X - I)(X + I)^{-1}
310
- return torch .solve (X - self .Id , self .Id + X ). solution
310
+ return torch .linalg . solve (X + self .Id , self .Id - X )
311
311
312
312
layer_orthogonal = nn .Linear (3 , 3 )
313
313
parametrize .register_parametrization (layer_orthogonal , "weight" , Skew ())
You can’t perform that action at this time.
0 commit comments