Skip to content

Commit

Permalink
Adding state function to multiwalker (#1149)
Browse files Browse the repository at this point in the history
  • Loading branch information
ffelten authored Dec 23, 2023
1 parent f7a9ef1 commit 41d47d7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pettingzoo/sisl/multiwalker/multiwalker.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(self, *args, **kwargs):
# spaces
self.action_spaces = dict(zip(self.agents, self.env.action_space))
self.observation_spaces = dict(zip(self.agents, self.env.observation_space))
self.state_space = self.env.state_space
self.steps = 0

def observation_space(self, agent):
Expand Down Expand Up @@ -191,6 +192,9 @@ def close(self):
def render(self):
return self.env.render()

def state(self):
return self.env.state()

def observe(self, agent):
return self.env.observe(self.agent_name_mapping[agent])

Expand Down
22 changes: 22 additions & 0 deletions pettingzoo/sisl/multiwalker/multiwalker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,14 @@ def setup(self):
]
self.observation_space = [agent.observation_space for agent in self.walkers]
self.action_space = [agent.action_space for agent in self.walkers]
self.state_space = spaces.Box(
low=-np.float32(np.inf),
high=+np.float32(np.inf),
shape=(
self.n_walkers * 24 + 3,
), # 24 is the observation space of each walker, 3 is the package observation space
dtype=np.float32,
)

self.package_scale = self.n_walkers / 1.75
self.package_length = PACKAGE_LENGTH / SCALE * self.package_scale
Expand Down Expand Up @@ -545,6 +553,20 @@ def observe(self, agent):
o = np.array(o, dtype=np.float32)
return o

def state(self):
all_walker_obs = self.get_last_obs()
all_walker_obs = np.array(list(all_walker_obs.values())).flatten()
package_obs = np.array(
[
self.package.position.x,
self.package.position.y,
self.package.angle,
]
)
global_state = np.concatenate((all_walker_obs, package_obs)).astype(np.float32)

return global_state

def render(self, close=False):
if close:
self.close()
Expand Down

0 comments on commit 41d47d7

Please sign in to comment.