Skip to content

Commit

Permalink
PPO-fix (#145)
Browse files Browse the repository at this point in the history
* Investigating PPO crashing

* Removing debugging prints

* generalize calc_log_probs and refactor for SIL

* improve reliability of ppo and sil loss calc

* add log_prob nanguard at creation

* improve logger

* add computation logging

* improve debug logging

* use base_case_openai for test

* fix SIL log_probs

* fix singleton cont action separate AC output unit

* fix PPO weight copy

* replace clone with detach properly

* revert detach to clone to fix PPO

* typo

* refactor log_probs to policy_util

* add net arg to calc_pdparam function

* add PPOSIL

* refactor calc_pdparams in policy_util

* fix typo
  • Loading branch information
kengz authored Aug 8, 2018
1 parent 925c1d2 commit 5ec2a0f
Show file tree
Hide file tree
Showing 25 changed files with 721 additions and 142 deletions.
10 changes: 7 additions & 3 deletions slm_lab/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ def post_body_init(self):
@lab_api
def reset(self, state_a):
'''Do agent reset per session, such as memory pointer'''
logger.debug(f'Agent {self.a} reset')
for (e, b), body in util.ndenumerate_nonan(self.body_a):
body.memory.epi_reset(state_a[(e, b)])

@lab_api
def act(self, state_a):
'''Standard act method from algorithm.'''
action_a = self.algorithm.act(state_a)
logger.debug(f'Agent {self.a} act: {action_a}')
return action_a

@lab_api
Expand All @@ -144,6 +146,7 @@ def update(self, action_a, reward_a, state_a, done_a):
body.loss = loss_a[(e, b)]
explore_var_a = self.algorithm.update()
explore_var_a = util.guard_data_a(self, explore_var_a, 'explore_var')
logger.debug(f'Agent {self.a} loss: {loss_a}, explore_var_a {explore_var_a}')
return loss_a, explore_var_a

@lab_api
Expand Down Expand Up @@ -179,12 +182,13 @@ def get(self, a):

@lab_api
def reset(self, state_space):
logger.debug('AgentSpace.reset')
logger.debug3('AgentSpace.reset')
_action_v, _loss_v, _explore_var_v = self.aeb_space.init_data_v(AGENT_DATA_NAMES)
for agent in self.agents:
state_a = state_space.get(a=agent.a)
agent.reset(state_a)
_action_space, _loss_space, _explore_var_space = self.aeb_space.add(AGENT_DATA_NAMES, [_action_v, _loss_v, _explore_var_v])
logger.debug3(f'action_space: {_action_space}')
return _action_space

@lab_api
Expand All @@ -197,7 +201,7 @@ def act(self, state_space):
action_a = agent.act(state_a)
action_v[a, 0:len(action_a)] = action_a
action_space, = self.aeb_space.add(data_names, [action_v])
logger.debug(f'\naction_space: {action_space}')
logger.debug3(f'\naction_space: {action_space}')
return action_space

@lab_api
Expand All @@ -214,7 +218,7 @@ def update(self, action_space, reward_space, state_space, done_space):
loss_v[a, 0:len(loss_a)] = loss_a
explore_var_v[a, 0:len(explore_var_a)] = explore_var_a
loss_space, explore_var_space = self.aeb_space.add(data_names, [loss_v, explore_var_v])
logger.debug(f'\nloss_space: {loss_space}\nexplore_var_space: {explore_var_space}')
logger.debug3(f'\nloss_space: {loss_space}\nexplore_var_space: {explore_var_space}')
return loss_space, explore_var_space

@lab_api
Expand Down
44 changes: 26 additions & 18 deletions slm_lab/agent/algorithm/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def init_nets(self):
assert 'Separate' in net_type
self.share_architecture = False
out_dim = self.body.action_dim * [2]
if len(out_dim) == 1:
out_dim = out_dim[0]
critic_out_dim = 1

self.net_spec['type'] = net_type = net_type.replace('Shared', '').replace('Separate', '')
Expand Down Expand Up @@ -195,37 +197,39 @@ def init_nets(self):
self.post_init_nets()

@lab_api
def calc_pdparam(self, x, evaluate=True):
def calc_pdparam(self, x, evaluate=True, net=None):
'''
The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.
'''
net = self.net if net is None else net
if evaluate:
pdparam = self.net.wrap_eval(x)
pdparam = net.wrap_eval(x)
else:
self.net.train()
pdparam = self.net(x)
net.train()
pdparam = net(x)
if self.share_architecture:
# MLPHeterogenousTails, get front (no critic)
if self.body.is_discrete:
return pdparam[0]
pdparam = pdparam[0]
else:
if len(pdparam) == 2: # only (loc, scale) and (v)
return pdparam[0]
pdparam = pdparam[0]
else:
return pdparam[:-1]
else:
return pdparam
pdparam = pdparam[:-1]
logger.debug(f'pdparam: {pdparam}')
return pdparam

def calc_v(self, x, evaluate=True):
def calc_v(self, x, evaluate=True, net=None):
'''
Forward-pass to calculate the predicted state-value from critic.
'''
net = self.net if net is None else net
if self.share_architecture:
if evaluate:
out = self.net.wrap_eval(x)
out = net.wrap_eval(x)
else:
self.net.train()
out = self.net(x)
net.train()
out = net(x)
# MLPHeterogenousTails, get last
v = out[-1].squeeze_(dim=1)
else:
Expand All @@ -235,6 +239,7 @@ def calc_v(self, x, evaluate=True):
self.critic.train()
out = self.critic(x)
v = out.squeeze_(dim=1)
logger.debug(f'v: {v}')
return v

@lab_api
Expand Down Expand Up @@ -264,7 +269,7 @@ def train_shared(self):
self.to_train = 0
self.body.log_probs = []
self.body.entropies = []
logger.debug(f'Total loss: {loss:.2f}')
logger.debug(f'Total loss: {loss:.4f}')
self.last_loss = loss.item()
return self.last_loss

Expand All @@ -282,7 +287,7 @@ def train_separate(self):
self.to_train = 0
self.body.entropies = []
self.body.log_probs = []
logger.debug(f'Total loss: {loss:.2f}')
logger.debug(f'Total loss: {loss:.4f}')
self.last_loss = loss.item()
return self.last_loss

Expand All @@ -309,7 +314,7 @@ def train_critic(self, batch):

def calc_policy_loss(self, batch, advs):
'''Calculate the actor's policy loss'''
assert len(self.body.log_probs) == len(advs), f'{len(self.body.log_probs)} vs {len(advs)}'
assert len(self.body.log_probs) == len(advs), f'batch_size of log_probs {len(self.body.log_probs)} vs advs: {len(advs)}'
log_probs = torch.stack(self.body.log_probs)
policy_loss = - self.policy_loss_coef * log_probs * advs
if self.add_entropy:
Expand All @@ -318,7 +323,7 @@ def calc_policy_loss(self, batch, advs):
policy_loss = torch.mean(policy_loss)
if torch.cuda.is_available() and self.net.gpu:
policy_loss = policy_loss.cuda()
logger.debug(f'Actor policy loss: {policy_loss:.2f}')
logger.debug(f'Actor policy loss: {policy_loss:.4f}')
return policy_loss

def calc_val_loss(self, batch, v_targets):
Expand All @@ -329,7 +334,7 @@ def calc_val_loss(self, batch, v_targets):
val_loss = self.val_loss_coef * self.net.loss_fn(v_preds, v_targets)
if torch.cuda.is_available() and self.net.gpu:
val_loss = val_loss.cuda()
logger.debug(f'Critic value loss: {val_loss:.2f}')
logger.debug(f'Critic value loss: {val_loss:.4f}')
return val_loss

def calc_gae_advs_v_targets(self, batch):
Expand Down Expand Up @@ -360,6 +365,7 @@ def calc_gae_advs_v_targets(self, batch):
adv_std[adv_std != adv_std] = 0
adv_std += 1e-08
adv_targets = (adv_targets - adv_targets.mean()) / adv_std
logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
return adv_targets, v_targets

def calc_nstep_advs_v_targets(self, batch):
Expand All @@ -375,6 +381,7 @@ def calc_nstep_advs_v_targets(self, batch):
if torch.cuda.is_available() and self.net.gpu:
nstep_advs = nstep_advs.cuda()
adv_targets = v_targets = nstep_advs
logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
return adv_targets, v_targets

def calc_td_advs_v_targets(self, batch):
Expand All @@ -388,6 +395,7 @@ def calc_td_advs_v_targets(self, batch):
td_returns = td_returns.cuda()
v_targets = td_returns
adv_targets = v_targets - v_preds # TD error, but called adv for API consistency
logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
return adv_targets, v_targets

@lab_api
Expand Down
2 changes: 1 addition & 1 deletion slm_lab/agent/algorithm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def post_init_nets(self):
logger.info(f'Initialized algorithm models for lab_mode: {util.get_lab_mode()}')

@lab_api
def calc_pdparam(self, x, evaluate=True):
def calc_pdparam(self, x, evaluate=True, net=None):
'''
To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs.
The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.
Expand Down
14 changes: 10 additions & 4 deletions slm_lab/agent/algorithm/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def calc_q_targets(self, batch):
q_targets = (max_q_targets * batch['actions']) + (q_preds * (1 - batch['actions']))
if torch.cuda.is_available() and self.net.gpu:
q_targets = q_targets.cuda()
logger.debug(f'q_targets: {q_targets}')
return q_targets

@lab_api
Expand Down Expand Up @@ -221,6 +222,7 @@ def calc_q_targets(self, batch):
q_targets = (max_q_targets * batch['actions']) + (q_preds * (1 - batch['actions']))
if torch.cuda.is_available() and self.net.gpu:
q_targets = q_targets.cuda()
logger.debug(f'q_targets: {q_targets}')
return q_targets

def update_nets(self):
Expand Down Expand Up @@ -333,12 +335,13 @@ def init_nets(self):
self.eval_net = self.target_net

@lab_api
def calc_pdparam(self, x, evaluate=True):
def calc_pdparam(self, x, evaluate=True, net=None):
'''
Calculate pdparams for multi-action by chunking the network logits output
'''
pdparam = super(MultitaskDQN, self).calc_pdparam(x, evaluate=evaluate)
pdparam = super(MultitaskDQN, self).calc_pdparam(x, evaluate=evaluate, net=net)
pdparam = torch.cat(torch.split(pdparam, self.action_dims, dim=1))
logger.debug(f'pdparam: {pdparam}')
return pdparam

@lab_api
Expand All @@ -359,6 +362,7 @@ def act(self, state_a):
action_pd = action_pd_a[idx]
body.entropies.append(action_pd.entropy())
body.log_probs.append(action_pd.log_prob(action_a[idx].float()))
assert not torch.isnan(body.log_probs[-1])
return action_a.cpu().numpy()

@lab_api
Expand Down Expand Up @@ -410,6 +414,7 @@ def calc_q_targets(self, batch):
q_targets = torch.cat(multi_q_targets, dim=1)
if torch.cuda.is_available() and self.net.gpu:
q_targets = q_targets.cuda()
logger.debug(f'q_targets: {q_targets}')
return q_targets


Expand All @@ -432,12 +437,12 @@ def init_nets(self):
self.eval_net = self.target_net

@lab_api
def calc_pdparam(self, x, evaluate=True):
def calc_pdparam(self, x, evaluate=True, net=None):
'''
Calculate pdparams for multi-action by chunking the network logits output
'''
x = torch.cat(torch.split(x, self.state_dims, dim=1)).unsqueeze_(dim=1)
pdparam = SARSA.calc_pdparam(self, x, evaluate=evaluate)
pdparam = SARSA.calc_pdparam(self, x, evaluate=evaluate, net=net)
return pdparam

@lab_api
Expand Down Expand Up @@ -479,6 +484,7 @@ def calc_q_targets(self, batch):
multi_q_targets.append(q_targets)
# return as list for compatibility with net output in training_step
q_targets = multi_q_targets
logger.debug(f'q_targets: {q_targets}')
return q_targets

@lab_api
Expand Down
3 changes: 0 additions & 3 deletions slm_lab/agent/algorithm/math_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,3 @@ def calc_gaes(rewards, v_preds, next_v_preds, gamma, lam):
assert not np.isnan(gaes).any(), f'GAE has nan: {gaes}'
gaes = torch.from_numpy(gaes).float()
return gaes


# Q-learning calc
56 changes: 53 additions & 3 deletions slm_lab/agent/algorithm/policy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from slm_lab.lib import logger, util
from torch import distributions
import numpy as np
import pydash as ps
import torch

logger = logger.get_logger(__name__)
Expand Down Expand Up @@ -155,15 +156,17 @@ def sample_action_pd(ActionPD, pdparam, body):
action_pd = ActionPD(logits=pdparam)
else: # continuous outputs a list, loc and scale
assert len(pdparam) == 2, pdparam
# scale (stdev) must be >=0
clamp_pdparam = torch.stack([pdparam[0], torch.clamp(pdparam[1], 1e-8)])
action_pd = ActionPD(*clamp_pdparam)
# scale (stdev) must be >0, use softplus
if pdparam[1] < 5:
pdparam[1] = torch.log(1 + torch.exp(pdparam[1])) + 1e-8
action_pd = ActionPD(*pdparam)
action = action_pd.sample()
return action, action_pd


# interface action sampling methods


def default(state, algorithm, body):
'''Plain policy by direct sampling using outputs of net as logits and constructing ActionPD as appropriate'''
ActionPD, pdparam, body = init_action_pd(state, algorithm, body)
Expand Down Expand Up @@ -341,3 +344,50 @@ def rate_decay(algorithm, body):
def periodic_decay(algorithm, body):
'''Apply _periodic_decay to explore_var'''
return fn_decay_explore_var(algorithm, body, _periodic_decay)


# misc calc methods


def guard_multi_pdparams(pdparams, body):
'''Guard pdparams for multi action'''
action_dim = body.action_dim
is_multi_action = ps.is_iterable(action_dim)
if is_multi_action:
assert ps.is_list(pdparams)
pdparams = [t.clone() for t in pdparams] # clone for grad safety
assert len(pdparams) == len(action_dim), pdparams
# transpose into (batch_size, [action_dims])
pdparams = [list(torch.split(t, action_dim, dim=0)) for t in torch.cat(pdparams, dim=1)]
return pdparams


def calc_log_probs(algorithm, net, body, batch):
'''
Method to calculate log_probs fresh from batch data
Body already stores log_prob from self.net. This is used for PPO where log_probs needs to be recalculated.
'''
states, actions = batch['states'], batch['actions']
action_dim = body.action_dim
is_multi_action = ps.is_iterable(action_dim)
# construct log_probs for each state-action
pdparams = algorithm.calc_pdparam(states, net=net)
pdparams = guard_multi_pdparams(pdparams, body)
assert len(pdparams) == len(states), f'batch_size of pdparams: {len(pdparams)} vs states: {len(states)}'

pdtypes = ACTION_PDS[body.action_type]
ActionPD = getattr(distributions, body.action_pdtype)

log_probs = []
for idx, pdparam in enumerate(pdparams):
if not is_multi_action: # already cloned for multi_action above
pdparam = pdparam.clone() # clone for grad safety
_action, action_pd = sample_action_pd(ActionPD, pdparam, body)
log_probs.append(action_pd.log_prob(actions[idx]))
log_probs = torch.stack(log_probs)
if is_multi_action:
log_probs = log_probs.mean(dim=1)
log_probs = torch.tensor(log_probs, requires_grad=True)
assert not torch.isnan(log_probs).any(), f'log_probs: {log_probs}, \npdparams: {pdparams} \nactions: {actions}'
logger.debug(f'log_probs: {log_probs}')
return log_probs
Loading

0 comments on commit 5ec2a0f

Please sign in to comment.