diff --git a/marl_self_play.py b/marl_self_play.py index af01b4802..ddf28614f 100644 --- a/marl_self_play.py +++ b/marl_self_play.py @@ -37,14 +37,9 @@ class Args: """the wandb's project name""" wandb_entity: str = None """the entity (team) of wandb's project""" - capture_video: bool = False - """whether to capture videos of the agent performances (check out `videos` folder)""" - - async_envs: bool = False - """whether to use Gymnasium's async vector environment or not (disabled by default)""" mt_wd: str = "./" """Directory where the Minetest working directories will be created (defaults to the current one)""" - frameskip: int = 1 + frameskip: int = 4 """Number of frames to skip between observations""" save_agent: bool = False """Save the agent's model (disabled by default)""" @@ -52,6 +47,8 @@ class Args: """Number of times to save the agent's model. By default it is randomly selected in the [49152, 65535] range.""" mt_port: int = np.random.randint(49152, 65535) """TCP port used by Minetest server and client communication. Multiple envs will use successive ports.""" + fps_max: int = 200 + """Target FPS to run the environment""" # Algorithm specific arguments """the id of the environment""" @@ -97,7 +94,7 @@ class Args: """the number of iterations (computed in runtime)""" -def make_env(frameskip, mt_port, mt_wd): +def make_env(fps_max, frameskip, mt_port, mt_wd): def thunk(): env = MarlCraftiumEnv( num_agents=2, @@ -107,16 +104,17 @@ def thunk(): obs_width=64, obs_height=64, frameskip=frameskip, - max_timesteps=2000*frameskip, + max_timesteps=1000*frameskip, rgb_observations=False, init_frames=200, sync_mode=False, + fps_max=fps_max, ) env = DiscreteActionWrapper( env, - actions=["forward", "left", "right", "jump", "dig", "mouse x+", "mouse x-", "mouse y+", "mouse y-"], + # actions=["forward", "left", "right", "jump", "dig", "mouse x+", "mouse x-", "mouse y+", "mouse y-"], + actions=["forward", "left", "right", "jump", "dig", "mouse x+", "mouse x-"], ) - # env = gym.wrappers.RecordEpisodeStatistics(env) env = gym.wrappers.FrameStack(env, 4) return env @@ -159,6 +157,32 @@ def get_action_and_value(self, x, action=None): return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) +@torch.no_grad() +def test_agent(env, agent, device): + # reset the environment + obs, info = env.reset() + obs = torch.Tensor(np.array(obs)).to(device) + + ep_ret, ep_len = 0, 0 + while True: + # get agent's action + agent_input = obs.permute(1, 0, 2, 3) # new size: (num_agents, framestack, H, W) + action, logprob, _, value = agent.get_action_and_value(agent_input) + action = action.cpu().numpy() + # sample the action of the second (random) player + action[1] = env.action_space.sample() + + next_obs, reward, terminations, truncations, info = env.step(action) + obs = torch.Tensor(np.array(next_obs)).to(device) + + ep_ret += reward[0] # ignore the reward of teh second player + ep_len += 1 + if sum(terminations) + sum(truncations) > 0: + break + + return ep_ret, ep_len + + if __name__ == "__main__": args = tyro.cli(Args) args.batch_size = int(args.num_envs * args.num_steps) @@ -201,7 +225,12 @@ def get_action_and_value(self, x, action=None): # env setup args.num_envs = 2 # One for each agent - env = make_env(frameskip=args.frameskip, mt_port=args.mt_port, mt_wd=args.mt_wd)() + env = make_env( + fps_max=args.fps_max, + frameskip=args.frameskip, + mt_port=args.mt_port, + mt_wd=args.mt_wd + )() assert isinstance(env.action_space, gym.spaces.Discrete), "only discrete action space is supported" @@ -227,7 +256,7 @@ def get_action_and_value(self, x, action=None): reset_in_next_step = False - ep_rets, ep_len = np.zeros(args.num_envs), 0 + ep_rets, ep_len, num_ep = np.zeros(args.num_envs), 0, 0 for iteration in range(1, args.num_iterations + 1): # Annealing the rate if instructed to do so. if args.anneal_lr: @@ -260,10 +289,18 @@ def get_action_and_value(self, x, action=None): ep_rets += reward ep_len += 1 if sum(terminations) + sum(truncations) > 0: + num_ep += 1 print(f"global_step={global_step}, episodic_return={ep_rets}") writer.add_scalar("charts/episodic_return_agent0", ep_rets[0], global_step) writer.add_scalar("charts/episodic_return_agent1", ep_rets[1], global_step) writer.add_scalar("charts/episodic_length", ep_len*args.frameskip, global_step) + + if num_ep % 10 == 0: + test_ep_ret, test_ep_len = test_agent(env, agent, device) + print(f"global_step={global_step}, test_ep_ret={test_ep_ret}, test_ep_len={test_ep_len}") + writer.add_scalar("charts/test_ep_ret", test_ep_ret, global_step) + writer.add_scalar("charts/test_ep_len", test_ep_len, global_step) + # reset episode statistic ep_rets, ep_len = np.zeros(args.num_envs), 0