-
Notifications
You must be signed in to change notification settings - Fork 15
/
a2c.py
311 lines (250 loc) · 14.2 KB
/
a2c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""
This project was developed by Rocky Duan, Peter Chen, Pieter Abbeel for the Berkeley Deep RL Bootcamp, August 2017. Bootcamp website with slides and lecture videos: https://sites.google.com/view/deep-rl-bootcamp/.
Copyright 2017 Deep RL Bootcamp Organizers.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from alg_utils import *
from models import WeightSharingValueFunction
from simplepg.simple_utils import test_once
import tests.a2c_tests
def samples_generator(env_pool, policy, vf, k):
def compute_dists_values(obs):
# Special handling when value function and policy share weights. We try to avoid an additional forward pass
if isinstance(vf, WeightSharingValueFunction) and vf.feature_model is policy:
feats = policy.compute_features(obs)
dists = policy.compute_dists(obs=None, feats=feats)
values = vf.compute_state_values(obs=None, feats=feats)
else:
dists = policy.compute_dists(obs)
values = vf.compute_state_values(obs)
return dists, values
obs = Variable(np.asarray(env_pool.reset()))
dists, values = compute_dists_values(obs)
while True:
all_actions = []
all_rewards = []
all_dones = []
all_dists = []
all_values = []
for _ in range(k):
# To reuse computation, we retain the computation graph for actions and distributions, so that we can
# backprop later without an additional forward pass
# Make sure we don't accidentally differentiate through the actions
with chainer.no_backprop_mode():
actions = dists.sample()
actions_val = actions.data
next_obs, rews, dones, _ = env_pool.step(actions_val)
all_actions.append(actions)
all_rewards.append(rews)
all_dones.append(dones)
all_dists.append(dists.as_dict())
all_values.append(values)
obs = Variable(np.asarray(next_obs))
dists, values = compute_dists_values(obs)
yield all_actions, all_rewards, all_dones, all_dists, all_values, chainer.Variable(values.data)
def compute_returns_advantages(rewards, dones, values, next_values, discount):
"""
Compute returns and advantages given rewards, terminal indicators, values at each state, values at the future
states, and the discount factor.
:param rewards: A matrix of shape T * N, where T is the number of time steps and N is the number of environments.
Each entry is the reward value.
:param dones: A matrix of shape T * N, where each entry is a binary flag indicating whether
an episode has finished after this time step.
:param values: A matrix of shape T * N, where each entry is the estimated value V(s_t), t = 0, ..., T-1.
:param next_values: A vector of length N, where each entry is the estimated value at the future state V(s_T).
:param discount: The discount factor.
:return: A tuple (returns, advantages), each of which should be a matrix of shape T * N
"""
"*** YOUR CODE HERE ***"
T, N = rewards.shape
Rs = np.zeros((T+1, N))
As = np.zeros_like(rewards)
Rs[T] = next_values
for t in range(T-1, -1, -1):
Rs[t] = rewards[t] + (1-dones[t]) * discount * Rs[t+1]
As[t] = Rs[t] - values[t]
return Rs[:T], As
def a2c(env, env_maker, policy, vf, joint_model=None, k=20, n_envs=16, discount=0.99,
optimizer=chainer.optimizers.RMSprop(lr=1e-3), max_grad_norm=1.0, vf_loss_coeff=0.5,
ent_coeff=0.01, last_epoch=-1, epoch_length=10000, n_epochs=8000, snapshot_saver=None):
"""
This method implements (Synchronous) Advantage Actor-Critic algorithm. Rather than having asynchronous workers,
which can be more efficient due to less coordination but also less stable and harder to extend / debug, we use a
pool of environment workers performing simulation, while computing actions and performing gradient updates
centrally. This also makes it easier to utilize GPUs for neural network computation.
:param env: An environment instance, which should have the same class as what env_maker.make() returns.
:param env_maker: An object such that calling env_maker.make() will generate a new environment.
:param policy: A stochastic policy which we will be optimizing.
:param vf: A value function which estimates future returns given a state. It can potentially share weights with the
policy by calling policy.create_vf().
:param joint_model: The joint model of policy and value function. This is usually automatically computed.
:param k: Number of simulation steps per environment for each gradient update.
:param n_envs: Number of environments running simultaneously.
:param discount: Discount factor.
:param optimizer: A chainer optimizer instance. By default we use the RMSProp algorithm.
:param max_grad_norm: If provided, apply gradient clipping with the specified maximum L2 norm.
:param vf_loss_coeff: Coefficient for the value function loss.
:param ent_coeff: Coefficient for the entropy loss (the negative bonus).
:param last_epoch: The index of the last epoch. This is normally -1 when starting afresh, but may be different when
loaded from a snapshot.
:param epoch: The starting epoch. This is normally 0, but may be different when loaded from a snapshot. Since A2C
is an online algorithm, an epoch is just an artificial boundary so that we record logs after each epoch.
:param epoch_length: Number of total environment steps per epoch.
:param n_epochs: Total number of epochs to run the algorithm.
:param snapshot_saver: An object for saving snapshots.
"""
# ensures that shared parameters are only counted once
if joint_model is None:
joint_model = UniqueChainList(policy, vf)
if getattr(optimizer, 'target', None) is not joint_model:
optimizer.setup(joint_model)
try:
# remove existing hook if necessary (this should only be needed when restarting experiments)
optimizer.remove_hook('gradient_clipping')
except KeyError:
pass
if max_grad_norm is not None:
# Clip L2 norm of gradient, to improve stability
optimizer.add_hook(chainer.optimizer.GradientClipping(
threshold=max_grad_norm), 'gradient_clipping')
epoch = last_epoch + 1
global_t = epoch * epoch_length
loggings = defaultdict(list)
logger.info("Starting env pool")
with EnvPool(env_maker, n_envs=n_envs) as env_pool:
gen = samples_generator(env_pool, policy, vf, k)
logger.info("Starting epoch {}".format(epoch))
if logger.get_level() <= logger.INFO:
progbar = tqdm(total=epoch_length)
else:
progbar = None
while global_t < epoch_length * n_epochs:
# Run k steps in the environment
# Note:
# - all_actions, all_values, all_dists, and next_values are chainer variables
# - all_rewards, all_dones are lists numpy arrays
# The first dimension of each variable is time, and the second dimension is the index of the environment
all_actions, all_rewards, all_dones, all_dists, all_values, next_values = next(
gen)
global_t += n_envs * k
# Compute returns and advantages
# Size: (k, n_envs)
all_values = F.stack(all_values)
all_rewards = np.asarray(all_rewards, dtype=np.float32)
all_dones = np.asarray(all_dones, dtype=np.float32)
all_values_data = all_values.data
next_values_data = next_values.data
test_once(compute_returns_advantages)
all_returns, all_advs = compute_returns_advantages(
all_rewards,
all_dones,
all_values_data,
next_values_data,
discount
)
all_returns = chainer.Variable(all_returns.astype(np.float32))
all_advs = chainer.Variable(all_advs.astype(np.float32))
# Concatenate data
# Size: (k*n_envs,) + action_shape
all_flat_actions = F.concat(all_actions, axis=0)
# Size: key -> (k*n_envs,) + dist_shape
all_flat_dists = {k: F.concat(
[d[k] for d in all_dists], axis=0) for k in all_dists[0].keys()}
all_flat_dists = policy.distribution.from_dict(all_flat_dists)
# Prepare variables needed for gradient computation
logli = all_flat_dists.logli(all_flat_actions)
ent = all_flat_dists.entropy()
# Flatten advantages
all_advs = F.concat(all_advs, axis=0)
# Form the loss - you should only need to use the variables provided as input arguments below
def compute_total_loss(logli, all_advs, ent_coeff, ent, vf_loss_coeff, all_returns, all_values):
"""
:param logli: A chainer variable, which should be a vector of size N
:param all_advs: A chainer variable, which should be a vector of size N
:param ent_coeff: A scalar
:param ent: A chainer variable, which should be a vector of size N
:param vf_loss_coeff: A scalar
:param all_returns: A chainer variable, which should be a vector of size N
:param all_values: A chainer variable, which should be a vector of size N
:return: A tuple of (policy_loss, vf_loss, total_loss)
policy_loss should be the weighted sum of the surrogate loss and the average entropy loss
vf_loss should be the (unweighted) squared loss of value function prediction.
total_loss should be the weighted sum of policy_loss and vf_loss
"""
"*** YOUR CODE HERE ***"
policy_loss = -F.mean(logli * all_advs) - ent_coeff * F.mean(ent)
vf_loss = F.mean_squared_error(all_returns, all_values)
total_loss = policy_loss + vf_loss_coeff * vf_loss
return policy_loss, vf_loss, total_loss
test_once(compute_total_loss)
policy_loss, vf_loss, total_loss = compute_total_loss(
logli=logli, all_advs=all_advs, ent_coeff=ent_coeff,
ent=ent, vf_loss_coeff=vf_loss_coeff,
all_returns=all_returns, all_values=all_values
)
joint_model.cleargrads()
total_loss.backward()
optimizer.update()
vf_loss_data = vf_loss.data
all_returns_data = all_returns.data
all_flat_dists_data = {
k: v.data
for k, v in all_flat_dists.as_dict().items()
}
loggings["vf_loss"].append(vf_loss_data)
loggings["vf_preds"].append(all_values_data)
loggings["vf_targets"].append(all_returns_data)
loggings["dists"].append(all_flat_dists_data)
if progbar is not None:
progbar.update(k * n_envs)
# An entire epoch has passed
if global_t // epoch_length > epoch:
logger.logkv('Epoch', epoch)
log_reward_statistics(env)
all_dists = {
k: Variable(np.concatenate([d[k] for d in loggings["dists"]], axis=0))
for k in loggings["dists"][0].keys()
}
log_action_distribution_statistics(
policy.distribution.from_dict(all_dists))
logger.logkv('|VfPred|', np.mean(np.abs(loggings["vf_preds"])))
logger.logkv('|VfTarget|', np.mean(
np.abs(loggings["vf_targets"])))
logger.logkv('VfLoss', np.mean(loggings["vf_loss"]))
logger.dumpkvs()
if snapshot_saver is not None:
logger.info("Saving snapshot")
snapshot_saver.save_state(
epoch,
dict(
alg=a2c,
alg_state=dict(
env_maker=env_maker,
policy=policy,
vf=vf,
joint_model=joint_model,
k=k,
n_envs=n_envs,
discount=discount,
last_epoch=epoch,
n_epochs=n_epochs,
epoch_length=epoch_length,
optimizer=optimizer,
vf_loss_coeff=vf_loss_coeff,
ent_coeff=ent_coeff,
max_grad_norm=max_grad_norm,
)
)
)
# Reset stored logging information
loggings = defaultdict(list)
if progbar is not None:
progbar.close()
epoch = global_t // epoch_length
logger.info("Starting epoch {}".format(epoch))
if progbar is not None:
progbar = tqdm(total=epoch_length)
if progbar is not None:
progbar.close()