Skip to content

Commit

Permalink
add more logging (#114)
Browse files Browse the repository at this point in the history
* update logging

* save previous action

* rename actions to comply
  • Loading branch information
budzianowski authored Nov 19, 2024
1 parent fa9bce0 commit 616e5fd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
6 changes: 3 additions & 3 deletions sim/h5_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _create_h5_file(self):
h5_file = h5py.File(f"{self.data_name}/{idd}.h5", "w")

# Create datasets for logging actions and observations
dset_actions = h5_file.create_dataset("actions", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_actions = h5_file.create_dataset("prev_actions", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_2D_command = h5_file.create_dataset("observations/2D_command", (self.max_timesteps, 2), dtype=np.float32)
dset_3D_command = h5_file.create_dataset("observations/3D_command", (self.max_timesteps, 3), dtype=np.float32)
dset_q = h5_file.create_dataset("observations/q", (self.max_timesteps, self.num_actions), dtype=np.float32)
Expand All @@ -36,7 +36,7 @@ def _create_h5_file(self):

# Map datasets for easy access
h5_dict = {
"actions": dset_actions,
"prev_actions": dset_actions,
"2D_command": dset_2D_command,
"3D_command": dset_3D_command,
"joint_pos": dset_q,
Expand Down Expand Up @@ -123,4 +123,4 @@ def _plot_dataset(name: str, data: np.ndarray):


if __name__ == "__main__":
HDF5Logger.visualize_h5("stompypro/6dc85e02-fc8e-42e1-a396-b0bd578e0816.h5")
HDF5Logger.visualize_h5("stompypro/5a7dc371-445c-4f56-be05-4e65c5cc38bc.h5")
4 changes: 2 additions & 2 deletions sim/produce_sim_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Produce simulation data for training."""
""" Produce simulation data for training."""
import argparse
import multiprocessing as mp
import subprocess
Expand Down Expand Up @@ -47,7 +47,7 @@ def run_parallel_sims(num_threads: int, args: argparse.Namespace) -> None:
parser.add_argument("--num_threads", type=int, default=10, help="Number of parallel simulations to run")
parser.add_argument("--embodiment", default="stompypro", type=str, help="Embodiment name")
parser.add_argument("--load_model", default="examples/walking_pro.onnx", type=str, help="Path to model to load")
parser.add_argument("--num_examples", default=1000, type=int, help="Number of examples to run")
parser.add_argument("--num_examples", default=10000, type=int, help="Number of examples to run")
args = parser.parse_args()

# Run 100 examples total, in parallel batches
Expand Down
37 changes: 27 additions & 10 deletions sim/sim2sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
import argparse
import math
import os
import uuid
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union

import h5py
import mujoco
import mujoco_viewer
import numpy as np
Expand All @@ -31,6 +29,8 @@ class Sim2simCfg:
dt: float = 0.001
decimation: int = 10
tau_factor: float = 3
cycle_time: float = 0.4



def handle_keyboard_input() -> None:
Expand Down Expand Up @@ -185,7 +185,6 @@ def run_mujoco(
total_speed = 0.0
step_count = 0

t = 0
for _ in tqdm(range(int(cfg.sim_duration / cfg.dt)), desc="Simulating..."):
if keyboard_use:
handle_keyboard_input()
Expand Down Expand Up @@ -232,22 +231,28 @@ def run_mujoco(

input_data["buffer.1"] = hist_obs.astype(np.float32)

positions, actions, hist_obs = policy.run(None, input_data)
target_q = positions

if args.log_h5:
logger.log_data({
"t": np.array([count_lowlevel * cfg.dt], dtype=np.float32),
"2D_command": np.array([x_vel_cmd, y_vel_cmd], dtype=np.float32),
"2D_command": np.array(
[
np.sin(2 * math.pi * count_lowlevel * cfg.dt / cfg.cycle_time),
np.cos(2 * math.pi * count_lowlevel * cfg.dt / cfg.cycle_time),
],
dtype=np.float32,
),
"3D_command": np.array([x_vel_cmd, y_vel_cmd, yaw_vel_cmd], dtype=np.float32),
"joint_pos": cur_pos_obs.astype(np.float32),
"joint_vel": cur_vel_obs.astype(np.float32),
"actions": actions.astype(np.float32),
"prev_actions": actions.astype(np.float32),
"ang_vel": omega.astype(np.float32),
"euler_rotation": eu_ang.astype(np.float32),
"buffer": hist_obs.astype(np.float32)
})

positions, actions, hist_obs = policy.run(None, input_data)
target_q = positions

# Generate PD control
tau = pd_control(target_q, q, kps, dq, kds, default) # Calc torques
tau = np.clip(tau, -tau_limit, tau_limit) # Clamp torques
Expand Down Expand Up @@ -326,6 +331,7 @@ def parse_modelmeta(
dt=0.001,
decimation=10,
tau_factor=4.0,
cycle_time=policy_cfg.cycle_time,
)
elif args.embodiment == "stompymicro":
policy_cfg.cycle_time = 0.2
Expand All @@ -334,16 +340,27 @@ def parse_modelmeta(
dt=0.001,
decimation=10,
tau_factor=2,
cycle_time=policy_cfg.cycle_time,
)

if args.load_model.endswith(".onnx"):
policy = ort.InferenceSession(args.load_model)
else:
policy = convert_model_to_onnx(args.load_model, policy_cfg, save_path="policy.onnx")
policy = convert_model_to_onnx(
args.load_model, policy_cfg, save_path="policy.onnx"
)

model_info = parse_modelmeta(
policy.get_modelmeta().custom_metadata_map.items(),
verbose=True,
)

run_mujoco(args.embodiment, policy, cfg, model_info, args.keyboard_use, args.log_h5, args.render)
run_mujoco(
args.embodiment,
policy,
cfg,
model_info,
args.keyboard_use,
args.log_h5,
args.render,
)

0 comments on commit 616e5fd

Please sign in to comment.