Skip to content

Commit e7563f6

Browse files
ahoblitzmalfetsvekars
authored
Update parametrizations.py (#2642)
* Update parametrizations.py For PyTorch 2, torch.solve => torch.linalg.solve * Apply suggestions from code review Co-authored-by: Svetlana Karslioglu <[email protected]> --------- Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 3ac15b1 commit e7563f6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

intermediate_source/parametrizations.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def __init__(self, n):
227227

228228
def forward(self, X):
229229
# (I + X)(I - X)^{-1}
230-
return torch.solve(self.Id + X, self.Id - X).solution
230+
return torch.linalg.solve(self.Id - X, self.Id + X)
231231

232232
layer = nn.Linear(3, 3)
233233
parametrize.register_parametrization(layer, "weight", Skew())
@@ -301,13 +301,13 @@ def __init__(self, n):
301301
def forward(self, X):
302302
# Assume X skew-symmetric
303303
# (I + X)(I - X)^{-1}
304-
return torch.solve(self.Id + X, self.Id - X).solution
304+
return torch.linalg.solve(self.Id - X, self.Id + X)
305305

306306
def right_inverse(self, A):
307307
# Assume A orthogonal
308308
# See https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
309309
# (X - I)(X + I)^{-1}
310-
return torch.solve(X - self.Id, self.Id + X).solution
310+
return torch.linalg.solve(X + self.Id, self.Id - X)
311311

312312
layer_orthogonal = nn.Linear(3, 3)
313313
parametrize.register_parametrization(layer_orthogonal, "weight", Skew())

0 commit comments

Comments
 (0)