-
Notifications
You must be signed in to change notification settings - Fork 0
/
MAML_ML1.py
88 lines (73 loc) · 3.08 KB
/
MAML_ML1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import click
import metaworld
import torch
from garage import wrap_experiment
from garage.envs import MetaWorldSetTaskEnv
from garage.experiment import (MetaEvaluator, MetaWorldTaskSampler,
SetTaskSampler)
from garage.experiment.deterministic import set_seed
from garage.sampler import RaySampler
from garage.torch.algos import MAMLTRPO
from garage.torch.policies import GaussianMLPPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer
# yapf: enable
@click.command()
@click.option('--seed', default=1)
@click.option('--epochs', default=300)
@click.option('--rollouts_per_task', default=10)
@click.option('--meta_batch_size', default=20)
@wrap_experiment(snapshot_mode='all')
def MAML_ML1(ctxt, seed, epochs, rollouts_per_task,
meta_batch_size):
"""Set up environment and algorithm and run the task.
Args:
ctxt (garage.experiment.ExperimentContext): The experiment
configuration used by Trainer to create the snapshotter.
seed (int): Used to seed the random number generator to produce
determinism.
epochs (int): Number of training epochs.
rollouts_per_task (int): Number of rollouts per epoch per task
for training.
meta_batch_size (int): Number of tasks sampled per batch.
"""
set_seed(seed)
ml1 = metaworld.ML1('button-press-v2')
tasks = MetaWorldTaskSampler(ml1, 'train')
env = tasks.sample(1)[0]()
# env.visualize()
test_sampler = SetTaskSampler(MetaWorldSetTaskEnv,
env=MetaWorldSetTaskEnv(ml1, 'test'))
policy = GaussianMLPPolicy(
env_spec=env.spec,
hidden_sizes=(100, 100),
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None,
)
value_function = GaussianMLPValueFunction(env_spec=env.spec,
hidden_sizes=[32, 32],
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)
meta_evaluator = MetaEvaluator(test_task_sampler=test_sampler,
n_test_tasks=1,
n_exploration_eps=rollouts_per_task)
sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
n_workers=meta_batch_size)
trainer = Trainer(ctxt)
algo = MAMLTRPO(env=env,
policy=policy,
sampler=sampler,
task_sampler=tasks,
value_function=value_function,
meta_batch_size=meta_batch_size,
discount=0.99,
gae_lambda=1.,
inner_lr=0.1,
num_grad_updates=1,
meta_evaluator=meta_evaluator)
trainer.setup(algo, env)
trainer.train(n_epochs=epochs,
batch_size=rollouts_per_task * env.spec.max_episode_length)
MAML_ML1()