Skip to content

Commit

Permalink
Bundled changes from wam experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianVes authored and famura committed Oct 15, 2020
1 parent b1d2300 commit 6637c3b
Show file tree
Hide file tree
Showing 28 changed files with 278 additions and 457 deletions.
39 changes: 31 additions & 8 deletions Pyrado/pyrado/algorithms/bayrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self,
self.acq_restarts = acq_restarts
self.acq_samples = acq_samples
self.acq_param = acq_param
self.policy_param_init = policy_param_init.detach() if policy_param_init is not None else None
self.policy_param_init = policy_param_init
self.valuefcn_param_init = valuefcn_param_init.detach() if valuefcn_param_init is not None else None
self.warmstart = warmstart
self.num_eval_rollouts_real = num_eval_rollouts_real
Expand All @@ -164,6 +164,11 @@ def __init__(self,
self.curr_cand_value = -pyrado.inf # for the stopping criterion
self.uc_normalizer = UnitCubeProjector(bounds[0, :], bounds[1, :])

if self.policy_param_init is not None:
if to.is_tensor(self.policy_param_init):
self.policy_param_init.detach()
else:
self.policy_param_init = to.tensor(self.policy_param_init)
# Set the flag to run the initialization phase. This is overruled if load_snapshot is called.
self.initialized = False
if num_init_cand > 0:
Expand Down Expand Up @@ -269,7 +274,8 @@ def eval_policy(save_dir: [str, None],
policy: Policy,
mc_estimator: bool,
prefix: str,
num_rollouts: int) -> to.Tensor:
num_rollouts: int,
num_parallel_envs: int = 1) -> to.Tensor:
"""
Evaluate a policy on the target system (real-world platform).
This method is static to facilitate evaluation of specific policies in hindsight.
Expand All @@ -281,6 +287,8 @@ def eval_policy(save_dir: [str, None],
bound (`False`) obtained from bootrapping
:param prefix: to control the saving for the evaluation of an initial policy, `None` to deactivate
:param num_rollouts: number of rollouts to collect on the target system
:param prefix: to control the saving for the evaluation of an initial policy, `None` to deactivate
:param num_parallel_envs: number of environments for the parallel sampler (only used for SimEnv)
:return: estimated return in the target domain
"""
if save_dir is not None:
Expand All @@ -290,10 +298,15 @@ def eval_policy(save_dir: [str, None],
if isinstance(inner_env(env), RealEnv):
# Evaluate sequentially when conducting a sim-to-real experiment
for i in range(num_rollouts):
rets_real[i] = rollout(env, policy, eval=True, no_close=False).undiscounted_return()
rets_real[i] = rollout(env, policy, eval=True).undiscounted_return()
# If a reward of -1 is given, skip evaluation ahead and set all returns to zero
if rets_real[i] == -1:
print_cbt('Set all returns for this policy to zero.', color='c')
rets_real = to.zeros(num_rollouts)
break
elif isinstance(inner_env(env), SimEnv):
# Create a parallel sampler when conducting a sim-to-sim experiment
sampler = ParallelRolloutSampler(env, policy, num_workers=1, min_rollouts=num_rollouts)
sampler = ParallelRolloutSampler(env, policy, num_workers=num_parallel_envs, min_rollouts=num_rollouts)
ros = sampler.sample()
for i in range(num_rollouts):
rets_real[i] = ros[i].undiscounted_return()
Expand Down Expand Up @@ -425,8 +438,8 @@ def load_snapshot(self, load_dir: str = None, meta_info: dict = None):
if not len(found_policies) == len(found_cands):
print_cbt(f'Found {len(found_policies)} policies, but {len(found_cands)} candidates!', 'r')
n = len(found_cands) - len(found_policies)
delete = input('Delete the superfluous candidates? [y / any other]').lower() == 'y'
if n > 0 and delete:
delete = input_timeout('Delete the superfluous candidates? [any / n]', default='').lower()
if n > 0 and not delete == 'n':
# Delete the superfluous candidates
print_cbt(f'Candidates before:\n{self.cands.numpy()}', 'w')
self.cands = self.cands[:-n, :]
Expand Down Expand Up @@ -461,7 +474,7 @@ def load_snapshot(self, load_dir: str = None, meta_info: dict = None):
found_evals = natural_sort(found_evals) # the order determines the rows of the tensor

# Reconstruct candidates_values.pt
self.cands_values = to.empty(self.cands.shape[0])
self.cands_values = to.zeros(self.cands.shape[0])
for i, fe in enumerate(found_evals):
# Get the return estimate from the raw evaluations as in eval_policy()
if self.mc_estimator:
Expand All @@ -473,7 +486,7 @@ def load_snapshot(self, load_dir: str = None, meta_info: dict = None):

if len(found_evals) < len(found_cands):
print_cbt(f'Found {len(found_evals)} real-world evaluation files but {len(found_cands)} candidates.'
f' Now evaluation the remaining ones.', 'c', bright=True)
f' Now evaluating the remaining ones.', 'c', bright=True)
for i in range(len(found_cands) - len(found_evals)):
# Evaluate the current policy in the target domain
if len(found_evals) < self.num_init_cand:
Expand All @@ -489,6 +502,10 @@ def load_snapshot(self, load_dir: str = None, meta_info: dict = None):

if len(found_cands) < self.num_init_cand:
print_cbt('Found less candidates than the number of initial candidates.', 'y')
if input('Do you want to skip training and evaluating the remaining ones? [any / n]') == 'y':
self.initialized = True
else:
print_cbt('Redoing all init policies', 'y', bright=True)
else:
self.initialized = True

Expand Down Expand Up @@ -593,6 +610,12 @@ def train_argmax_policy(load_dir: str,
bounds = to.load(osp.join(load_dir, 'bounds.pt'))
uc_normalizer = UnitCubeProjector(bounds[0, :], bounds[1, :])

if cands.shape[0] > cands_values.shape[0]:
print_cbt(
f'There are {cands.shape[0]} candidates but only {cands_values.shape[0]} evaluations. Ignoring the'
f'candidates without evaluation for computing the argmax.', 'y')
cands = cands[:cands_values.shape[0], :]

# Find the maximizer
argmax_cand = BayRn.argmax_posterior_mean(cands, cands_values, uc_normalizer, num_restarts, num_samples)

Expand Down
2 changes: 1 addition & 1 deletion Pyrado/pyrado/algorithms/svpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(self,

# Particle factory
actor = FNNPolicy(spec=env.spec, **particle_hparam['actor'])
value_fcn = FNNPolicy(spec=EnvSpec(env.obs_space, ValueFunctionSpace), **particle_hparam['value_fcn'])
value_fcn = FNNPolicy(spec=EnvSpec(env.obs_space, ValueFunctionSpace), **particle_hparam['valuefcn'])
critic = GAE(value_fcn, **particle_hparam['critic'])
self.register_as_logger_parent(critic)
particle = SVPGParticle(env.spec, actor, critic)
Expand Down
5 changes: 3 additions & 2 deletions Pyrado/pyrado/domain_randomization/default_randomizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,10 @@ def get_default_randomizer_wambic() -> DomainRandomizer:
NormalDomainParam(name='ball_mass', mean=dp_nom['ball_mass'], std=dp_nom['ball_mass']/10, clip_lo=1e-2),
UniformDomainParam(name='joint_damping', mean=dp_nom['joint_damping'], halfspan=dp_nom['joint_damping']/2,
clip_lo=0.),
UniformDomainParam(name='joint_stiction', mean=dp_nom['joint_stiction'], halfspan=dp_nom['joint_stiction'],
UniformDomainParam(name='joint_stiction', mean=dp_nom['joint_stiction'], halfspan=dp_nom['joint_stiction']/2,
clip_lo=0.),
UniformDomainParam(name='rope_damping', mean=dp_nom['rope_damping'], halfspan=5e-4, clip_lo=1e-6),
UniformDomainParam(name='rope_damping', mean=dp_nom['rope_damping'], halfspan=dp_nom['rope_damping']/2,
clip_lo=1e-6),
)


Expand Down
2 changes: 2 additions & 0 deletions Pyrado/pyrado/environments/barrett_wam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@


# 4 DoF arm, 2 DoF actuated
init_pose_des_4dof = np.array([0., 0.6, 0., 1.25])
act_min_wam_4dof = np.array([-1.985, -0.9, -10*np.pi, -10*np.pi])
act_max_wam_4dof = np.array([1.985, np.pi, 10*np.pi, 10*np.pi])
labels_4dof = [r'$q_{1,des}$', r'$q_{3,des}$', r'$\dot{q}_{1,des}$', r'$\dot{q}_{3,des}$']
act_space_wam_4dof = BoxSpace(act_min_wam_4dof, act_max_wam_4dof, labels=labels_4dof)

# 7 DoF arm, 3 DoF actuated
init_pose_des_7dof = np.array([0., 0.5876, 0., 1.36, 0.5, -0.321, -1.57])
act_min_wam_7dof = np.array([-1.985, -0.9, -np.pi/2, -10*np.pi, -10*np.pi, -10*np.pi])
act_max_wam_7dof = np.array([1.985, np.pi, np.pi/2, 10*np.pi, 10*np.pi, 10*np.pi])
labels_7dof = [r'$q_{1,des}$', r'$q_{3,des}$', r'$q_{5,des}$',
Expand Down
26 changes: 17 additions & 9 deletions Pyrado/pyrado/environments/barrett_wam/wam.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from init_args_serializer import Serializable

import pyrado
from pyrado.environments.barrett_wam import act_space_wam_7dof, act_space_wam_4dof
from pyrado.environments.barrett_wam import init_pose_des_4dof, init_pose_des_7dof, act_space_wam_7dof, \
act_space_wam_4dof
from pyrado.environments.real_base import RealEnv
from pyrado.spaces import BoxSpace
from pyrado.spaces.base import Space
Expand Down Expand Up @@ -77,21 +78,25 @@ def __init__(self,
# Call the base class constructor to initialize fundamental members
super().__init__(dt, max_steps)

# Create the robcom client and connect to it
# Create the robcom client and connect to it. Use a Process to timeout if connection cannot be established
self._connected = False
self._client = robcom.Client()
if ip is not None:
with completion_context('Connecting to the Barret WAM client', color='c'):
self._client.start(ip, 2013) # IP address and port
self._dc = None # Goto command
try:
self._client.start(ip, 2013, 1000) # ip address, port, timeout in ms
self._connected = True
print_cbt('Connected to the Barret WAM client.', 'c', bright=True)
except RuntimeError:
print_cbt('Connection to the Barret WAM client failed.', 'r', bright=True)
self._dc = None # direct-control process

# Number of controlled joints (dof)
self.num_dof = num_dof

# Desired joint position for the initial state
if self.num_dof == 4:
self.init_pose_des = np.array([0.0, 0.6, 0.0, 1.25])
self.init_pose_des = init_pose_des_4dof
elif self.num_dof == 7:
self.init_pose_des = np.array([0.0, 0.5876, 0.0, 1.36, 0.0, -0.321, -1.57])
self.init_pose_des = init_pose_des_7dof
else:
raise pyrado.ValueErr(given=self.num_dof, eq_constraint="4 or 7")

Expand Down Expand Up @@ -145,6 +150,10 @@ def _create_spaces(self):
self._obs_space = BoxSpace(np.array([0.]), np.array([1.]), labels=['$t$'])

def reset(self, init_state: np.ndarray = None, domain_param: dict = None) -> np.ndarray:
if not self._connected:
print_cbt('Not connected to Barret WAM client.', 'r', bright=True)
raise pyrado.ValueErr(given=self._connected, eq_constraint=True)

# Create robcom GoTo process
gt = self._client.create(robcom.Goto, 'RIGHT_ARM', '')

Expand Down Expand Up @@ -231,7 +240,6 @@ def _callback(self, jg, eg, data_provider):
:param eg: end-effector group
:param data_provider: additional data stream
"""

# Check if max_steps is reached
if self._curr_step_rr >= self.max_steps:
return True
Expand Down
34 changes: 19 additions & 15 deletions Pyrado/pyrado/environments/mujoco/wam.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
from init_args_serializer import Serializable

import pyrado
from pyrado.environments.barrett_wam import act_space_wam_4dof, act_space_wam_7dof
from pyrado.environments.barrett_wam import init_pose_des_4dof, init_pose_des_7dof, act_space_wam_4dof, \
act_space_wam_7dof
from pyrado.environments.mujoco.base import MujocoSimEnv
from pyrado.spaces.base import Space
from pyrado.spaces.singular import SingularStateSpace
from pyrado.tasks.base import Task
from pyrado.environments.mujoco.base import MujocoSimEnv
from pyrado.spaces.box import BoxSpace
from pyrado.tasks.base import Task
from pyrado.tasks.condition_only import ConditionOnlyTask
from pyrado.tasks.desired_state import DesStateTask
from pyrado.tasks.final_reward import BestStateFinalRewTask, FinalRewTask, FinalRewMode
Expand Down Expand Up @@ -160,10 +161,10 @@ def __init__(self,
self.num_dof = num_dof
if num_dof == 4:
graph_file_name = 'wam_4dof_bic.xml'
self.init_pose_des = np.array([0.0, 0.6, 0.0, 1.25])
self.init_pose_des = init_pose_des_4dof
elif num_dof == 7:
graph_file_name = 'wam_7dof_bic.xml'
self.init_pose_des = np.array([0.0, 0.5876, 0.0, 1.36, 0.0, -0.321, -1.57])
self.init_pose_des = init_pose_des_7dof
else:
raise pyrado.ValueErr(given=num_dof, eq_constraint='4 or 7')

Expand Down Expand Up @@ -259,16 +260,19 @@ def _create_spaces(self):
self._obs_space = BoxSpace(np.array([0.]), np.array([1.]), labels=['$t$'])

def _create_task(self, task_args: dict) -> Task:
# Create two (or three) parallel running task.
# 1.) Main task: Desired state task for the cartesian ball distance
# 2.) Deviation task: Desired state task for the cartesian- and joint deviation from the init position
# 3.) Binary Bonus: Adds a binary bonus when ball is catched [inactive by default]
return ParallelTasks([self._create_main_task(task_args),
self._create_deviation_task(task_args),
self._create_main_task(dict(
sparse_rew_fcn=True,
success_bonus=task_args.get('success_bonus', 0)))
])
if task_args.get('sparse_rew_fcn', False):
# Create a task with binary reward
return self._create_main_task(task_args)
else:
# Create two (or three) parallel running task.
# 1.) Main task: Desired state task for the cartesian ball distance
# 2.) Deviation task: Desired state task for the cartesian- and joint deviation from the init position
# 3.) Binary Bonus: Adds a binary bonus when ball is catched [inactive by default]
return ParallelTasks([
self._create_main_task(task_args),
self._create_deviation_task(task_args),
self._create_main_task(dict(sparse_rew_fcn=True, success_bonus=task_args.get('success_bonus', 0)))
])

def _create_main_task(self, task_args: dict) -> Task:
# Create a DesStateTask that masks everything but the ball position
Expand Down
6 changes: 3 additions & 3 deletions Pyrado/pyrado/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def set_style(style_name: str = 'default'):
ValueError("Unknown style name! Got {}, but expected 'default', 'ggplot', 'dark_background',"
"'seaborn', or 'seaborn-muted'.".format(style_name))

plt.rc('font', size=11)
plt.rc('xtick', labelsize=11)
plt.rc('ytick', labelsize=11)
plt.rc('font', size=10)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)
# plt.rc('savefig', bbox='tight') # 'tight' is incompatible with pipe-based animation backends
plt.rc('savefig', pad_inches=0)

Expand Down
7 changes: 3 additions & 4 deletions Pyrado/pyrado/policies/environment_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
# POSSIBILITY OF SUCH DAMAGE.

import math
from typing import Union

import numpy as np
import time
import torch as to
Expand Down Expand Up @@ -561,7 +559,7 @@ def __init__(self, positive: bool = True, cnt_done: int = 250):
:param positive: direction switch
"""
self.done = False
self.th_lim = 10.
self.th_lim = pyrado.inf
self.sign = 1 if positive else -1
self.u_max = 0.9
self.cnt = 0
Expand All @@ -577,7 +575,8 @@ def __call__(self, meas: to.Tensor) -> to.Tensor:
# Unpack the raw measurement (is not an observation)
th = meas[0].item()

if abs(th - self.th_lim) > 1e-8:
if abs(th - self.th_lim) > 1e-6:
# Recognized significant change in theta
self.cnt = 0
self.th_lim = th
else:
Expand Down
28 changes: 3 additions & 25 deletions Pyrado/pyrado/policies/fnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,10 @@ def __init__(self,
super().__init__() # init nn.Module

# Store settings
self._hidden_nonlin = hidden_nonlin if isinstance(hidden_nonlin, Iterable) else len(hidden_sizes)*[
hidden_nonlin]
self._dropout = dropout
self.output_nonlin = output_nonlin
# TODO One day replace legacy above with below
# self.hidden_nonlin = hidden_nonlin if isinstance(hidden_nonlin, Iterable) else len(hidden_sizes)*[hidden_nonlin]
# self.dropout = dropout
# self.output_nonlin = output_nonlin
self.hidden_nonlin = hidden_nonlin if isinstance(hidden_nonlin, Iterable) else len(hidden_sizes)*[hidden_nonlin]
self.dropout = dropout
self.output_nonlin = output_nonlin

# Create hidden layers (stored in ModuleList so their parameters are tracked)
self.hidden_layers = nn.ModuleList()
Expand All @@ -101,24 +97,6 @@ def device(self) -> str:
""" Get the device (CPU or GPU) on which the FNN is stored. """
return self._device

@property
def dropout(self):
""" For legacy compatibility """
return self._dropout

@property
def hidden_nonlin(self):
""" For legacy compatibility """
if isinstance(self._hidden_nonlin, Iterable):
return self._hidden_nonlin
else:
return len(self.hidden_layers)*[self._hidden_nonlin]

@property
def output_nonlin(self):
""" For legacy compatibility """
return self.output_nonlin

@property
def param_values(self) -> to.Tensor:
"""
Expand Down
3 changes: 3 additions & 0 deletions Pyrado/pyrado/utils/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def get_argparser() -> argparse.ArgumentParser:
parser.add_argument('--load_all', action='store_true', default=False,
help="load all quantities e.g. policies (default: False)")

parser.add_argument('--load_name', type=str, default='policy',
help="name of the policy to load without type extension, (default: 'policy')")

parser.add_argument('-d', '--ex_dir', type=str, nargs='?',
help="path to the experiment directory to load from")

Expand Down
1 change: 0 additions & 1 deletion Pyrado/pyrado/utils/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"""
Only tested for 1-dim inputs, e.g. time series of rewards.
"""

import numpy as np
import torch as to
from typing import Union, Tuple
Expand Down
10 changes: 3 additions & 7 deletions Pyrado/pyrado/utils/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,9 @@ def load_experiment(ex_dir: str, args: Any = None) -> (Union[SimEnv, EnvWrapper]
last_cand = to.load(osp.join(ex_dir, 'candidates.pt'))[-1, :]
env.adapt_randomizer(last_cand.numpy())
print_cbt(f'Loaded the domain randomizer\n{env.randomizer}', 'w')
# Policy
if args.iter == -1:
policy = to.load(osp.join(ex_dir, 'policy.pt'))
print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')
else:
policy = to.load(osp.join(ex_dir, f'iter_{args.iter}_policy.pt'))
print_cbt(f"Loaded {osp.join(ex_dir, f'iter_{args.iter}_policy.pt')}", 'g')
# Policy
policy = to.load(osp.join(ex_dir, f'{args.load_name}.pt'))
print_cbt(f"Loaded {osp.join(ex_dir, f'{args.load_name}.pt')}", 'g')
# Value function (optional)
if any([a.name in hparams.get('subrtn_name', '') for a in [PPO, PPO2, A2C]]):
try:
Expand Down
Loading

0 comments on commit 6637c3b

Please sign in to comment.