File tree 1 file changed +20
-9
lines changed
1 file changed +20
-9
lines changed Original file line number Diff line number Diff line change 7
7
import unittest
8
8
9
9
import torch
10
+ import torch ._dynamo
10
11
11
12
12
13
class TestCompat (unittest .TestCase ):
@@ -22,31 +23,41 @@ def test_pytorch3d(self):
22
23
def test_hf_tokenizers (self ):
23
24
import tokenizers # noqa: F401
24
25
25
- @unittest .skip ("torch.Library is not supported" )
26
26
def test_torchdynamo_eager (self ):
27
- import torch ._dynamo as torchdynamo
28
27
29
- @torchdynamo .optimize ("eager" )
28
+ torch ._dynamo .reset ()
29
+
30
30
def fn (x , y ):
31
31
a = torch .cos (x )
32
32
b = torch .sin (y )
33
33
return a + b
34
34
35
- fn (torch .randn (10 ), torch .randn (10 ))
35
+ c_fn = torch .compile (fn , backend = "eager" )
36
+ c_fn (torch .randn (10 ), torch .randn (10 ))
36
37
37
- @unittest .skip ("torch.Library is not supported" )
38
38
def test_torchdynamo_ofi (self ):
39
- import torch ._dynamo as torchdynamo
40
39
41
- torchdynamo .reset ()
40
+ torch ._dynamo .reset ()
41
+
42
+ def fn (x , y ):
43
+ a = torch .cos (x )
44
+ b = torch .sin (y )
45
+ return a + b
46
+
47
+ c_fn = torch .compile (fn , backend = "ofi" )
48
+ c_fn (torch .randn (10 ), torch .randn (10 ))
49
+
50
+ def test_torchdynamo_inductor (self ):
51
+
52
+ torch ._dynamo .reset ()
42
53
43
- @torchdynamo .optimize ("ofi" )
44
54
def fn (x , y ):
45
55
a = torch .cos (x )
46
56
b = torch .sin (y )
47
57
return a + b
48
58
49
- fn (torch .randn (10 ), torch .randn (10 ))
59
+ c_fn = torch .compile (fn )
60
+ c_fn (torch .randn (10 ), torch .randn (10 ))
50
61
51
62
52
63
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments