19
19
help = 'render the environment' )
20
20
parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
21
21
help = 'interval between training status logs (default: 10)' )
22
+ parser .add_argument ('--env-id' , type = str , default = 'CartPole-v1' )
22
23
args = parser .parse_args ()
23
24
24
-
25
- env = gym .make ('CartPole-v1' )
25
+ env = gym .make (args .env_id )
26
26
env .reset (seed = args .seed )
27
27
torch .manual_seed (args .seed )
28
28
29
29
30
30
class Policy (nn .Module ):
31
- def __init__ (self ):
31
+ def __init__ (self , n_observation , n_actions ):
32
32
super (Policy , self ).__init__ ()
33
- self .affine1 = nn .Linear (4 , 128 )
33
+ self .affine1 = nn .Linear (n_observation , 128 )
34
34
self .dropout = nn .Dropout (p = 0.6 )
35
- self .affine2 = nn .Linear (128 , 2 )
35
+ self .affine2 = nn .Linear (128 , n_actions )
36
36
37
37
self .saved_log_probs = []
38
38
self .rewards = []
@@ -44,8 +44,7 @@ def forward(self, x):
44
44
action_scores = self .affine2 (x )
45
45
return F .softmax (action_scores , dim = 1 )
46
46
47
-
48
- policy = Policy ()
47
+ policy = Policy (env .observation_space .shape [0 ], env .action_space .n )
49
48
optimizer = optim .Adam (policy .parameters (), lr = 1e-2 )
50
49
eps = np .finfo (np .float32 ).eps .item ()
51
50
0 commit comments