forked from WilsonWangTHU/POPLIN
-
Notifications
You must be signed in to change notification settings - Fork 1
/
mbexp.py
56 lines (43 loc) · 2.04 KB
/
mbexp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import argparse
import pprint
import copy
from dotmap import DotMap
from dmbrl.misc.MBExp import MBExperiment
from dmbrl.controllers.MPC import MPC
from dmbrl.config import create_config
from dmbrl.misc import logger
def main(env, ctrl_type, ctrl_args, overrides, logdir, args):
ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})
cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir)
logger.info('\n' + pprint.pformat(cfg))
# add the part of popsize
if ctrl_type == "MPC":
cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)
cfg.exp_cfg.misc = copy.copy(cfg)
exp = MBExperiment(cfg.exp_cfg)
if not os.path.exists(exp.logdir):
os.makedirs(exp.logdir)
with open(os.path.join(exp.logdir, "config.txt"), "w") as f:
f.write(pprint.pformat(cfg.toDict()))
exp.run_experiment()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-env', type=str, required=True,
help='Environment name: select from [cartpole, reacher, pusher, halfcheetah]')
parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[],
help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments')
parser.add_argument('-o', '--override', action='append', nargs=2, default=[],
help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides')
parser.add_argument('-logdir', type=str, default='log',
help='Directory to which results will be logged (default: ./log)')
parser.add_argument('-e_popsize', type=int, default=500,
help='different popsize to use')
args = parser.parse_args()
main(args.env, "MPC", args.ctrl_arg, args.override, args.logdir, args)
# import mbbl test
# from mbbl.env.gym_env import acrobot
# print("import successfully")