Skip to content

Commit

Permalink
Add types and update Ray RLlib support/notebook (#52)
Browse files Browse the repository at this point in the history
* Update ray version in example

Remove num_agents from env, which is now a read-only property.

* Add/fix typing. Try to configure possible agents in multi-agent env

* Update ray rllib notebook + smaller fixes

* Increment version number

* Support Python 3.9-3.12

* Fix linting
  • Loading branch information
stefanbschneider authored Nov 10, 2024
1 parent 2d4d13d commit c62dc02
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 258 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
platform: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.DS_Store
.idea
.vscode
results*
Expand Down
338 changes: 115 additions & 223 deletions examples/rllib.ipynb

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions mobile_env/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@
from mobile_env.core.util import BS_SYMBOL, deep_dict_merge
from mobile_env.core.utilities import BoundedLogUtility
from mobile_env.handlers.central import MComCentralHandler
from mobile_env.handlers.handler import Handler


class MComCore(gymnasium.Env):
NOOP_ACTION = 0
metadata = {"render_modes": ["rgb_array", "human"]}

def __init__(self, stations, users, config={}, render_mode=None):
def __init__(
self,
stations: list[BaseStation],
users: list[UserEquipment],
config={},
render_mode=None,
):
super().__init__()

self.render_mode = render_mode
Expand Down Expand Up @@ -70,7 +77,7 @@ def __init__(self, stations, users, config={}, render_mode=None):

# set object that handles calls to action(), reward() & observation()
# set action & observation space according to handler
self.handler = config["handler"]
self.handler: Handler = config["handler"]
self.action_space = self.handler.action_space(self)
self.observation_space = self.handler.observation_space(self)

Expand Down
8 changes: 4 additions & 4 deletions mobile_env/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(
self.noise = noise
self.height = height

self.x: float = None
self.y: float = None
self.stime: int = None
self.extime: int = None
self.x: float
self.y: float
self.stime: int
self.extime: int

@property
def point(self):
Expand Down
13 changes: 6 additions & 7 deletions mobile_env/handlers/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,42 @@
from gymnasium.spaces.space import Space


class Handler:
class Handler(abc.ABC):
"""Defines Gymnasium interface methods called by core simulation."""

@classmethod
@abc.abstractmethod
def action_space(cls, env) -> Space:
"""Defines action space for passed environment."""
pass

@classmethod
@abc.abstractmethod
def ue_obs_size(cls, env) -> int:
"""Size of the observation space."""

@classmethod
@abc.abstractmethod
def observation_space(cls, env) -> Space:
"""Defines observation space for passed environment."""
pass

@classmethod
@abc.abstractmethod
def action(cls, env, action) -> Dict[int, int]:
"""Transform passed action(s) to dict shape expected by simulation."""
pass

@classmethod
@abc.abstractmethod
def observation(cls, env):
"""Computes observations for agent."""
pass

@classmethod
@abc.abstractmethod
def reward(cls, env):
"""Computes rewards for agent."""
pass

@classmethod
def check(cls, env):
"""Check if handler is applicable to simulation configuration."""
pass

@classmethod
def info(cls, env):
Expand Down
11 changes: 6 additions & 5 deletions mobile_env/handlers/multi_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gymnasium
import numpy as np

from mobile_env.core.base import MComCore
from mobile_env.handlers.handler import Handler


Expand All @@ -20,7 +21,7 @@ def ue_obs_size(cls, env) -> int:
return sum(env.feature_sizes[ftr] for ftr in cls.features)

@classmethod
def action_space(cls, env) -> gymnasium.spaces.Dict:
def action_space(cls, env: MComCore) -> gymnasium.spaces.Dict:
return gymnasium.spaces.Dict(
{
ue.ue_id: gymnasium.spaces.Discrete(env.NUM_STATIONS + 1)
Expand All @@ -29,7 +30,7 @@ def action_space(cls, env) -> gymnasium.spaces.Dict:
)

@classmethod
def observation_space(cls, env) -> gymnasium.spaces.Dict:
def observation_space(cls, env: MComCore) -> gymnasium.spaces.Dict:
size = cls.ue_obs_size(env)
space = {
ue_id: gymnasium.spaces.Box(low=-1, high=1, shape=(size,), dtype=np.float32)
Expand All @@ -39,7 +40,7 @@ def observation_space(cls, env) -> gymnasium.spaces.Dict:
return gymnasium.spaces.Dict(space)

@classmethod
def reward(cls, env):
def reward(cls, env: MComCore):
"""UE's reward is their utility and the avg. utility of nearby BSs."""
# compute average utility of UEs for each BS
# set to lower bound if no UEs are connected
Expand All @@ -64,7 +65,7 @@ def ue_utility(ue):
return rewards

@classmethod
def observation(cls, env) -> Dict[int, np.ndarray]:
def observation(cls, env: MComCore) -> Dict[int, np.ndarray]:
"""Select features for MA setting & flatten each UE's features."""

# get features for currently active UEs
Expand All @@ -85,6 +86,6 @@ def observation(cls, env) -> Dict[int, np.ndarray]:
return obs

@classmethod
def action(cls, env, action: Dict[int, int]):
def action(cls, env: MComCore, action: Dict[int, int]):
"""Base environment by default expects action dictionary."""
return action
19 changes: 11 additions & 8 deletions mobile_env/wrappers/multi_agent.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import Tuple
from typing import Optional, Tuple

import gymnasium
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.typing import MultiAgentDict

from mobile_env.core.base import MComCore


class RLlibMAWrapper(MultiAgentEnv):
def __init__(self, env):
def __init__(self, env: MComCore):
super().__init__()

# class wrapps environment object
self.env = env

# set number of overall controllable actors
self.num_agents = len(self.env.users)
# set max. number of steps for RLlib trainer
self.max_episode_steps = self.env.EP_MAX_TIME

Expand All @@ -27,7 +29,7 @@ def __init__(self, env):

# track UE IDs of last observation's dictionary, i.e.,
# what UEs were active in the previous step
self.prev_step_ues = None
self.prev_step_ues: Optional[set[int]] = None

def reset(self, *, seed=None, options=None) -> MultiAgentDict:
obs, info = self.env.reset(seed=seed, options=options)
Expand All @@ -36,13 +38,14 @@ def reset(self, *, seed=None, options=None) -> MultiAgentDict:

def step(
self, action_dict: MultiAgentDict
) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
obs, rews, terminated, truncated, infos = self.env.step(action_dict)

# UEs that are not active after `step()` are done (here: truncated)
# NOTE: `truncateds` keys are keys of previous observation dictionary
assert self.prev_step_ues is not None
inactive_ues = self.prev_step_ues - set([ue.ue_id for ue in self.env.active])
truncateds = {
truncateds: MultiAgentDict = {
ue_id: True if ue_id in inactive_ues else False
for ue_id in self.prev_step_ues
}
Expand All @@ -51,7 +54,7 @@ def step(
assert (
not terminated
), "There is no natural episode termination. terminated should be False."
terminateds = {ue_id: False for ue_id in self.prev_step_ues}
terminateds: MultiAgentDict = {ue_id: False for ue_id in self.prev_step_ues}
terminateds["__all__"] = False

# update keys of previous observation dictionary
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length = 88
max-line-length = 100
extend-ignore = E203
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@

setup(
name="mobile-env",
version="2.0.1",
version="2.0.2",
author="Stefan Schneider, Stefan Werner",
description="mobile-env: An Open Environment for Autonomous Coordination in "
"Wireless Mobile Networks",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/stefanbschneider/mobile-env",
packages=find_packages(),
python_requires=">=3.8.0",
python_requires=">=3.9.0",
install_requires=requirements,
zip_safe=False,
classifiers=[
Expand Down
11 changes: 6 additions & 5 deletions tests/test_env_stepping.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
"""Simple test of small env similar to test notebook."""

import gymnasium
import pytest

# importing mobile_env automatically registers the predefined scenarios in Gym
import mobile_env # noqa: F401
from mobile_env.scenarios.registry import handlers, scenarios


@pytest.mark.parametrize(
"env_name",
["mobile-small-central-v0", "mobile-medium-central-v0", "mobile-large-central-v0"],
)
def test_env_stepping(env_name):
@pytest.mark.parametrize("scenario", list(scenarios.keys()))
@pytest.mark.parametrize("handler", list(handlers.keys()))
def test_env_stepping(scenario: str, handler: str):
"""Create a mobile-env and run with random actions until done.
Just to ensure that it does not crash.
"""
env_name: str = f"mobile-{scenario}-{handler}-v0"
# create a small mobile environment for a single, centralized control agent
env = gymnasium.make(env_name)
obs, info = env.reset()
Expand Down

0 comments on commit c62dc02

Please sign in to comment.