Skip to content

Commit

Permalink
Addressing issue 145 (#160)
Browse files Browse the repository at this point in the history
* Addressing issue 145

* Fixed setup script env_params to reflect new changes

* Fixed pep8 issues

* Added new green_wave_env requirements to rllib/green_wave.py

* Added new env params to the benchmark experiments as well

* Added env params to baseline as well. nose2 not passing for some figure_eight scenarios
  • Loading branch information
kjang96 authored and eugenevinitsky committed Sep 21, 2018
1 parent 4eb9b80 commit 560b8b4
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 16 deletions.
6 changes: 4 additions & 2 deletions examples/rllab/green_wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def run_task(*_):

additional_env_params = {
"target_velocity": 50,
"num_steps": 500,
"switch_time": 3.0
"switch_time": 3.0,
"num_observed": 2,
"discrete": False,
"tl_type": "controlled"
}
env_params = EnvParams(additional_params=additional_env_params)

Expand Down
8 changes: 7 additions & 1 deletion examples/rllib/green_wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ def get_non_flow_params(enter_speed, additional_net_params):
"rl_veh": rl_veh
}

additional_env_params = {"target_velocity": 50, "switch_time": 3.0}
additional_env_params = {
"target_velocity": 50,
"switch_time": 3.0,
"num_observed": 2,
"discrete": False,
"tl_type": "controlled"
}

additional_net_params = {
"speed_limit": 35,
Expand Down
6 changes: 4 additions & 2 deletions flow/benchmarks/baselines/grid0.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ class needed to run simulations
evaluate=True, # Set to True to evaluate traffic metrics
horizon=HORIZON,
additional_params={
"switch_time": 2.0,
"target_velocity": 50,
"switch_time": 2,
"num_observed": 2,
"tl_type": "actuated",
"discrete": False,
"tl_type": "controlled"
},
)

Expand Down
6 changes: 4 additions & 2 deletions flow/benchmarks/baselines/grid1.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ class needed to run simulations
evaluate=True, # Set to True to evaluate traffic metrics
horizon=HORIZON,
additional_params={
"switch_time": 2.0,
"target_velocity": 50,
"switch_time": 2,
"num_observed": 2,
"tl_type": "actuated",
"discrete": False,
"tl_type": "controlled"
},
)

Expand Down
5 changes: 4 additions & 1 deletion flow/benchmarks/grid0.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@
env=EnvParams(
horizon=HORIZON,
additional_params={
"switch_time": 2.0,
"target_velocity": 50,
"switch_time": 2,
"num_observed": 2,
"discrete": False,
"tl_type": "controlled"
},
),

Expand Down
5 changes: 4 additions & 1 deletion flow/benchmarks/grid1.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@
env=EnvParams(
horizon=HORIZON,
additional_params={
"switch_time": 2.0,
"target_velocity": 50,
"switch_time": 2,
"num_observed": 2,
"discrete": False,
"tl_type": "controlled"
},
),

Expand Down
23 changes: 18 additions & 5 deletions flow/envs/green_wave_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
"discrete": False,
}

ADDITIONAL_PO_ENV_PARAMS = {
# num of vehicles the agent can observe on each incoming edge
"num_observed": 2,
# velocity to use in reward functions
"target_velocity": 30,
}


class TrafficLightGridEnv(Env):
"""Environment used to train traffic lights to regulate traffic flow
Expand Down Expand Up @@ -57,6 +64,12 @@ class TrafficLightGridEnv(Env):
"""

def __init__(self, env_params, sumo_params, scenario):

for p in ADDITIONAL_ENV_PARAMS.keys():
if p not in env_params.additional_params:
raise KeyError(
'Environment parameter "{}" not supplied'.format(p))

self.grid_array = scenario.net_params.additional_params["grid_array"]
self.rows = self.grid_array["row_num"]
self.cols = self.grid_array["col_num"]
Expand Down Expand Up @@ -458,15 +471,15 @@ class PO_TrafficLightGridEnv(TrafficLightGridEnv):
def __init__(self, env_params, sumo_params, scenario):
super().__init__(env_params, sumo_params, scenario)

for p in ADDITIONAL_PO_ENV_PARAMS.keys():
if p not in env_params.additional_params:
raise KeyError(
'Environment parameter "{}" not supplied'.format(p))

# number of vehicles nearest each intersection that is observed in the
# state space; defaults to 2
self.num_observed = env_params.additional_params.get("num_observed", 2)

# used while computing the reward
self.env_params.additional_params["target_velocity"] = \
max(self.scenario.speed_limit(edge)
for edge in self.scenario.get_edge_list())

# used during visualization
self.observed_ids = []

Expand Down
5 changes: 3 additions & 2 deletions tests/setup_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,9 @@ def grid_mxn_exp_setup(row_num=1,
# set default env_params configuration
additional_env_params = {
"target_velocity": 50,
"num_steps": 100,
"switch_time": 3.0
"switch_time": 3.0,
"tl_type": "controlled",
"discrete": False
}

env_params = EnvParams(
Expand Down

0 comments on commit 560b8b4

Please sign in to comment.