-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_dataset.py
98 lines (80 loc) · 3.1 KB
/
generate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gym
import numpy as np
import os
import torch
import argparse
from logger import Logger
from utils import stack_frames, add_distractor
parser = argparse.ArgumentParser()
parser.add_argument('--env-name', type=str, default='Pendulum-v1',
help='Environment name.')
parser.add_argument('--observation-dim-w', type=int, default=84,
help='Width of the input measurements (RGB images).')
parser.add_argument('--observation-dim-h', type=int, default=84,
help='Height of the input measurements (RGB images).')
parser.add_argument('--test', default=True,
help='Generate training or testing dataset.')
parser.add_argument('--training-dataset', type=str, default='pendulum-train-moving.pkl',
help='Training dataset.')
parser.add_argument('--testing-dataset', type=str, default='pendulum-test-moving.pkl',
help='Testing dataset.')
parser.add_argument('--random-policy', default=True,
help='Use random action policy.')
parser.add_argument('--render-mode', type=str, default='rgb_array',
help='Render mode (human or rgb_array)')
parser.add_argument('--distractors', type=str, default='moving',
help='Distractors type (none, fixed, moving)')
args = parser.parse_args()
env_name = args.env_name
test = args.test
if test:
seed = 7
num_episodes = 5
data_file_name = args.testing_dataset
else:
num_episodes = 50
seed = 1
data_file_name = args.training_dataset
obs_dim1 = args.observation_dim_w
obs_dim2 = args.observation_dim_h
random_policy = args.random_policy
distractors = args.distractors
env = gym.make(env_name, render_mode=args.render_mode)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
directory = os.path.dirname(os.path.abspath(__file__))
folder = os.path.join(directory + '/data/')
logger = Logger(folder)
# Set seeds
env.action_space.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
max_steps = 200 #default for pendulum-v1
for episode in range(num_episodes):
state, _ = env.reset(seed=seed)
frame = np.array(env.render())
prev_frame = np.array(env.render())
print('Episode: ', episode)
for step in range(max_steps):
if distractors == 'fixed':
frame = add_distractor(frame, is_random=False)
if distractors == 'moving':
frame = add_distractor(frame, is_random=True)
obs = stack_frames(prev_frame, frame, obs_dim1, obs_dim2)
if random_policy:
action = env.action_space.sample()
else:
pass
next_state, reward, done, truncated, info = env.step(action)
next_frame = np.array(env.render())
next_obs = stack_frames(frame, next_frame, obs_dim1, obs_dim2)
if step == max_steps - 1:
done = True
logger.obslog((obs, action, reward, next_obs, done, state))
prev_frame = frame
frame = next_frame
state = next_state
if done:
break
logger.save_obslog(filename=data_file_name)