7
7
import unittest
8
8
9
9
import torch
10
-
10
+ import torch . _dynamo
11
11
12
12
class TestCompat (unittest .TestCase ):
13
13
def test_torchvision (self ):
@@ -22,32 +22,41 @@ def test_pytorch3d(self):
22
22
def test_hf_tokenizers (self ):
23
23
import tokenizers # noqa: F401
24
24
25
- @unittest .skip ("torch.Library is not supported" )
26
25
def test_torchdynamo_eager (self ):
27
- import torch ._dynamo as torchdynamo
28
26
29
- @torchdynamo .optimize ("eager" )
27
+ torch ._dynamo .reset ()
28
+
30
29
def fn (x , y ):
31
30
a = torch .cos (x )
32
31
b = torch .sin (y )
33
32
return a + b
34
33
35
- fn (torch .randn (10 ), torch .randn (10 ))
34
+ c_fn = torch .compile (fn , backend = "eager" )
35
+ c_fn (torch .randn (10 ), torch .randn (10 ))
36
36
37
- @unittest .skip ("torch.Library is not supported" )
38
37
def test_torchdynamo_ofi (self ):
39
- import torch ._dynamo as torchdynamo
40
38
41
- torchdynamo .reset ()
39
+ torch . _dynamo .reset ()
42
40
43
- @torchdynamo .optimize ("ofi" )
44
41
def fn (x , y ):
45
42
a = torch .cos (x )
46
43
b = torch .sin (y )
47
44
return a + b
48
45
49
- fn (torch .randn (10 ), torch .randn (10 ))
46
+ c_fn = torch .compile (fn , backend = "ofi" )
47
+ c_fn (torch .randn (10 ), torch .randn (10 ))
48
+
49
+ def test_torchdynamo_inductor (self ):
50
+
51
+ torch ._dynamo .reset ()
52
+
53
+ def fn (x , y ):
54
+ a = torch .cos (x )
55
+ b = torch .sin (y )
56
+ return a + b
50
57
58
+ c_fn = torch .compile (fn )
59
+ c_fn (torch .randn (10 ), torch .randn (10 ))
51
60
52
61
if __name__ == "__main__" :
53
62
unittest .main ()
0 commit comments