diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index 3c9ce80..fb2f2db 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -15,6 +15,7 @@ from baselines.common import set_global_seeds from baselines.common.atari_wrappers import make_atari, wrap_deepmind from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv +from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv from baselines.common.vec_env.dummy_vec_env import DummyVecEnv from baselines.common import retro_wrappers from baselines.common.wrappers import ClipActionsWrapper @@ -54,7 +55,7 @@ def make_thunk(rank, initializer=None): set_global_seeds(seed) if not force_dummy and num_env > 1: - return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)]) + return ShmemVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)]) else: return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)]) diff --git a/baselines/ppo2/defaults.py b/baselines/ppo2/defaults.py index d9931b5..cc7d4a2 100644 --- a/baselines/ppo2/defaults.py +++ b/baselines/ppo2/defaults.py @@ -80,8 +80,8 @@ def mara_lstm(): total_timesteps = 1e8, save_interval = 10, env_name = 'MARARandomTarget-v0', - num_envs = 2, + num_envs = 8, transfer_path = None, # transfer_path = '/tmp/ros2learn/MARACollisionOrientRandomTarget-v0/ppo2_lstm/checkpoints/00090', - trained_path = '/home/rkojcev/MARA_NN/LSTM_no_pr/checkpoints/best' + trained_path = '/home/rkojcev/MARA_NN/lstm_server/best' )