Skip to content

Commit 539d30a

Browse files
committed
REINFORCE: Generalize for any environment
- Parameterized by number of observations and actions
1 parent 37a1866 commit 539d30a

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

reinforcement_learning/reinforce.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@
1919
help='render the environment')
2020
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
2121
help='interval between training status logs (default: 10)')
22+
parser.add_argument('--env-id', type=str, default='CartPole-v1')
2223
args = parser.parse_args()
2324

24-
25-
env = gym.make('CartPole-v1')
25+
env = gym.make(args.env_id)
2626
env.reset(seed=args.seed)
2727
torch.manual_seed(args.seed)
2828

2929

3030
class Policy(nn.Module):
31-
def __init__(self):
31+
def __init__(self, n_observation, n_actions):
3232
super(Policy, self).__init__()
33-
self.affine1 = nn.Linear(4, 128)
33+
self.affine1 = nn.Linear(n_observation, 128)
3434
self.dropout = nn.Dropout(p=0.6)
35-
self.affine2 = nn.Linear(128, 2)
35+
self.affine2 = nn.Linear(128, n_actions)
3636

3737
self.saved_log_probs = []
3838
self.rewards = []
@@ -44,8 +44,7 @@ def forward(self, x):
4444
action_scores = self.affine2(x)
4545
return F.softmax(action_scores, dim=1)
4646

47-
48-
policy = Policy()
47+
policy = Policy(env.observation_space.shape[0], env.action_space.n)
4948
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
5049
eps = np.finfo(np.float32).eps.item()
5150

0 commit comments

Comments
 (0)