Skip to content

Commit

Permalink
marl_self_play: Eval vs random agent and add fps_max
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikel committed Nov 17, 2024
1 parent 9e01f09 commit d789f23
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions marl_self_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,18 @@ class Args:
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = False
"""whether to capture videos of the agent performances (check out `videos` folder)"""

async_envs: bool = False
"""whether to use Gymnasium's async vector environment or not (disabled by default)"""
mt_wd: str = "./"
"""Directory where the Minetest working directories will be created (defaults to the current one)"""
frameskip: int = 1
frameskip: int = 4
"""Number of frames to skip between observations"""
save_agent: bool = False
"""Save the agent's model (disabled by default)"""
save_num: int = 5
"""Number of times to save the agent's model. By default it is randomly selected in the [49152, 65535] range."""
mt_port: int = np.random.randint(49152, 65535)
"""TCP port used by Minetest server and client communication. Multiple envs will use successive ports."""
fps_max: int = 200
"""Target FPS to run the environment"""

# Algorithm specific arguments
"""the id of the environment"""
Expand Down Expand Up @@ -97,7 +94,7 @@ class Args:
"""the number of iterations (computed in runtime)"""


def make_env(frameskip, mt_port, mt_wd):
def make_env(fps_max, frameskip, mt_port, mt_wd):
def thunk():
env = MarlCraftiumEnv(
num_agents=2,
Expand All @@ -107,16 +104,17 @@ def thunk():
obs_width=64,
obs_height=64,
frameskip=frameskip,
max_timesteps=2000*frameskip,
max_timesteps=1000*frameskip,
rgb_observations=False,
init_frames=200,
sync_mode=False,
fps_max=fps_max,
)
env = DiscreteActionWrapper(
env,
actions=["forward", "left", "right", "jump", "dig", "mouse x+", "mouse x-", "mouse y+", "mouse y-"],
# actions=["forward", "left", "right", "jump", "dig", "mouse x+", "mouse x-", "mouse y+", "mouse y-"],
actions=["forward", "left", "right", "jump", "dig", "mouse x+", "mouse x-"],
)
# env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.FrameStack(env, 4)

return env
Expand Down Expand Up @@ -159,6 +157,32 @@ def get_action_and_value(self, x, action=None):
return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)


@torch.no_grad()
def test_agent(env, agent, device):
# reset the environment
obs, info = env.reset()
obs = torch.Tensor(np.array(obs)).to(device)

ep_ret, ep_len = 0, 0
while True:
# get agent's action
agent_input = obs.permute(1, 0, 2, 3) # new size: (num_agents, framestack, H, W)
action, logprob, _, value = agent.get_action_and_value(agent_input)
action = action.cpu().numpy()
# sample the action of the second (random) player
action[1] = env.action_space.sample()

next_obs, reward, terminations, truncations, info = env.step(action)
obs = torch.Tensor(np.array(next_obs)).to(device)

ep_ret += reward[0] # ignore the reward of teh second player
ep_len += 1
if sum(terminations) + sum(truncations) > 0:
break

return ep_ret, ep_len


if __name__ == "__main__":
args = tyro.cli(Args)
args.batch_size = int(args.num_envs * args.num_steps)
Expand Down Expand Up @@ -201,7 +225,12 @@ def get_action_and_value(self, x, action=None):

# env setup
args.num_envs = 2 # One for each agent
env = make_env(frameskip=args.frameskip, mt_port=args.mt_port, mt_wd=args.mt_wd)()
env = make_env(
fps_max=args.fps_max,
frameskip=args.frameskip,
mt_port=args.mt_port,
mt_wd=args.mt_wd
)()

assert isinstance(env.action_space, gym.spaces.Discrete), "only discrete action space is supported"

Expand All @@ -227,7 +256,7 @@ def get_action_and_value(self, x, action=None):

reset_in_next_step = False

ep_rets, ep_len = np.zeros(args.num_envs), 0
ep_rets, ep_len, num_ep = np.zeros(args.num_envs), 0, 0
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
Expand Down Expand Up @@ -260,10 +289,18 @@ def get_action_and_value(self, x, action=None):
ep_rets += reward
ep_len += 1
if sum(terminations) + sum(truncations) > 0:
num_ep += 1
print(f"global_step={global_step}, episodic_return={ep_rets}")
writer.add_scalar("charts/episodic_return_agent0", ep_rets[0], global_step)
writer.add_scalar("charts/episodic_return_agent1", ep_rets[1], global_step)
writer.add_scalar("charts/episodic_length", ep_len*args.frameskip, global_step)

if num_ep % 10 == 0:
test_ep_ret, test_ep_len = test_agent(env, agent, device)
print(f"global_step={global_step}, test_ep_ret={test_ep_ret}, test_ep_len={test_ep_len}")
writer.add_scalar("charts/test_ep_ret", test_ep_ret, global_step)
writer.add_scalar("charts/test_ep_len", test_ep_len, global_step)

# reset episode statistic
ep_rets, ep_len = np.zeros(args.num_envs), 0

Expand Down

0 comments on commit d789f23

Please sign in to comment.