Skip to content

Commit

Permalink
update production of the data
Browse files Browse the repository at this point in the history
  • Loading branch information
budzianowski committed Nov 18, 2024
1 parent 09677cd commit fa9bce0
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 58 deletions.
126 changes: 126 additions & 0 deletions sim/h5_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
""" Logger for logging data to HDF5 files """
import os
import uuid
from typing import Dict

import h5py
import matplotlib.pyplot as plt
import numpy as np


class HDF5Logger:
def __init__(self, data_name: str, num_actions: int, max_timesteps: int, num_observations: int):
self.data_name = data_name
self.num_actions = num_actions
self.max_timesteps = max_timesteps
self.num_observations = num_observations
self.max_threshold = 1e3 # Adjust this threshold as needed
self.h5_file, self.h5_dict = self._create_h5_file()
self.current_timestep = 0

def _create_h5_file(self):
# Create a unique file ID
idd = str(uuid.uuid4())
h5_file = h5py.File(f"{self.data_name}/{idd}.h5", "w")

# Create datasets for logging actions and observations
dset_actions = h5_file.create_dataset("actions", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_2D_command = h5_file.create_dataset("observations/2D_command", (self.max_timesteps, 2), dtype=np.float32)
dset_3D_command = h5_file.create_dataset("observations/3D_command", (self.max_timesteps, 3), dtype=np.float32)
dset_q = h5_file.create_dataset("observations/q", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_dq = h5_file.create_dataset("observations/dq", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_ang_vel = h5_file.create_dataset("observations/ang_vel", (self.max_timesteps, 3), dtype=np.float32)
dset_euler = h5_file.create_dataset("observations/euler", (self.max_timesteps, 3), dtype=np.float32)
dset_t = h5_file.create_dataset("observations/t", (self.max_timesteps, 1), dtype=np.float32)
dset_buffer = h5_file.create_dataset("observations/buffer", (self.max_timesteps, self.num_observations), dtype=np.float32)

# Map datasets for easy access
h5_dict = {
"actions": dset_actions,
"2D_command": dset_2D_command,
"3D_command": dset_3D_command,
"joint_pos": dset_q,
"joint_vel": dset_dq,
"ang_vel": dset_ang_vel,
"euler_rotation": dset_euler,
"t": dset_t,
"buffer": dset_buffer,
}
return h5_file, h5_dict

def log_data(self, data: Dict[str, np.ndarray]):
if self.current_timestep >= self.max_timesteps:
print(f"Warning: Exceeded maximum timesteps ({self.max_timesteps})")
return

for key, dataset in self.h5_dict.items():
if key in data:
dataset[self.current_timestep] = data[key]

self.current_timestep += 1

def close(self):
for key, dataset in self.h5_dict.items():
max_val = np.max(np.abs(dataset[:]))
if max_val > self.max_threshold:
print(f"Warning: Found very large values in {key}: {max_val}")
print("File will not be saved to prevent corrupted data")
self.h5_file.close()
# Delete the file
os.remove(self.h5_file.filename)
return

self.h5_file.close()

@staticmethod
def visualize_h5(h5_file_path: str):
"""
Visualizes the data from an HDF5 file by plotting each variable one by one.
Args:
h5_file_path (str): Path to the HDF5 file.
"""
try:
# Open the HDF5 file
with h5py.File(h5_file_path, "r") as h5_file:
# Extract all datasets
for key in h5_file.keys():
group = h5_file[key]
if isinstance(group, h5py.Group):
for subkey in group.keys():
dataset = group[subkey][:]
HDF5Logger._plot_dataset(f"{key}/{subkey}", dataset)
else:
dataset = group[:]
HDF5Logger._plot_dataset(key, dataset)

except Exception as e:
print(f"Failed to visualize HDF5 file: {e}")

@staticmethod
def _plot_dataset(name: str, data: np.ndarray):
"""
Helper method to plot a single dataset.
Args:
name (str): Name of the dataset.
data (np.ndarray): Data to be plotted.
"""
plt.figure(figsize=(10, 5))
if data.ndim == 2: # Handle multi-dimensional data
for i in range(data.shape[1]):
plt.plot(data[:, i], label=f"{name}[{i}]")
else:
plt.plot(data, label=name)

plt.title(f"Visualization of {name}")
plt.xlabel("Timesteps")
plt.ylabel("Values")
plt.legend(loc="upper right")
plt.grid(True)
plt.tight_layout()
plt.show()


if __name__ == "__main__":
HDF5Logger.visualize_h5("stompypro/6dc85e02-fc8e-42e1-a396-b0bd578e0816.h5")
8 changes: 4 additions & 4 deletions sim/produce_sim_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Produce simulation data for training."""
import argparse
import multiprocessing as mp
import subprocess
Expand Down Expand Up @@ -46,15 +47,14 @@ def run_parallel_sims(num_threads: int, args: argparse.Namespace) -> None:
parser.add_argument("--num_threads", type=int, default=10, help="Number of parallel simulations to run")
parser.add_argument("--embodiment", default="stompypro", type=str, help="Embodiment name")
parser.add_argument("--load_model", default="examples/walking_pro.onnx", type=str, help="Path to model to load")

parser.add_argument("--num_examples", default=1000, type=int, help="Number of examples to run")
args = parser.parse_args()

# Run 100 examples total, in parallel batches
num_examples = 2000
num_batches = (num_examples + args.num_threads - 1) // args.num_threads
num_batches = (args.num_examples + args.num_threads - 1) // args.num_threads

for batch in range(num_batches):
examples_remaining = num_examples - (batch * args.num_threads)
examples_remaining = args.num_examples - (batch * args.num_threads)
threads_this_batch = min(args.num_threads, examples_remaining)
print(f"\nRunning batch {batch+1}/{num_batches} ({threads_this_batch} simulations)")
run_parallel_sims(threads_this_batch, args)
2 changes: 1 addition & 1 deletion sim/resources/stompypro/robot_fixed.xml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,6 @@
<gyro name="angular-velocity" site="imu" noise="0.005" cutoff="34.9"/>
</sensor>
<keyframe>
<key name="default" qpos="0 0 0.63 0. 0.0 0.0 1.0 -0.23 0.0 0.0 0.441 -0.258 -0.23 0.0 0.0 0.441 -0.258"/>
<key name="default" qpos="0 0 0.63 1. 0.0 0.0 0.0 -0.23 0.0 0.0 0.441 -0.258 -0.23 0.0 0.0 0.441 -0.258"/>
</keyframe>
</mujoco>
71 changes: 18 additions & 53 deletions sim/sim2sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
python sim/sim2sim.py --load_model examples/standing_pro.pt --embodiment stompypro
python sim/sim2sim.py --load_model examples/standing_micro.pt --embodiment stompymicro
"""

import argparse
import math
import os
Expand All @@ -22,47 +21,10 @@
from scipy.spatial.transform import Rotation as R
from tqdm import tqdm

from sim.h5_logger import HDF5Logger
from sim.model_export import ActorCfg, convert_model_to_onnx


def log_hdf5(data_name, num_actions, stop_state_log):
# Create data directory if it doesn't exist
# os.makedirs(data_name, exist_ok=True)
idd = str(uuid.uuid4())
h5_file = h5py.File(f"{data_name}/{idd}.h5", "w")

# Create dataset for actions
max_timesteps = stop_state_log
dset_actions = h5_file.create_dataset("actions", (max_timesteps, num_actions), dtype=np.float32)

# Create dataset of observations
dset_2D_command = h5_file.create_dataset(
"observations/2D_command", (max_timesteps, 2), dtype=np.float32
) # sin and cos commands
dset_3D_command = h5_file.create_dataset(
"observations/3D_command", (max_timesteps, 3), dtype=np.float32
) # x, y, yaw commands
dset_q = h5_file.create_dataset("observations/q", (max_timesteps, num_actions), dtype=np.float32) # joint positions
dset_dq = h5_file.create_dataset(
"observations/dq", (max_timesteps, num_actions), dtype=np.float32
) # joint velocities
dset_ang_vel = h5_file.create_dataset(
"observations/ang_vel", (max_timesteps, 3), dtype=np.float32
) # root angular velocity
dset_euler = h5_file.create_dataset("observations/euler", (max_timesteps, 3), dtype=np.float32) # root orientation

h5_dict = {
"actions": dset_actions,
"2D_command": dset_2D_command,
"3D_command": dset_3D_command,
"joint_pos": dset_q,
"joint_vel": dset_dq,
"ang_vel": dset_ang_vel,
"euler_rotation": dset_euler,
}
return h5_file, h5_dict


@dataclass
class Sim2simCfg:
sim_duration: float = 60.0
Expand Down Expand Up @@ -183,7 +145,8 @@ def run_mujoco(
except:
print("No default position found, using zero initialization")
default = np.zeros(model_info["num_actions"]) # 3 for pos, 4 for quat, cfg.num_actions for joints

default += np.random.uniform(-0.03, 0.03, size=default.shape)
print("Default position:", default)
mujoco.mj_step(model, data)
for ii in range(len(data.ctrl) + 1):
print(data.joint(ii).id, data.joint(ii).name)
Expand Down Expand Up @@ -214,8 +177,8 @@ def run_mujoco(
}

if log_h5:
stop_state_log = int(cfg.sim_duration / cfg.dt)
h5_file, h5_dict = log_hdf5(embodiment, model_info["num_actions"], stop_state_log)
stop_state_log = int(cfg.sim_duration / cfg.dt) / cfg.decimation
logger = HDF5Logger(embodiment, model_info["num_actions"], stop_state_log, model_info["num_observations"])

# Initialize variables for tracking upright steps and average speed
upright_steps = 0
Expand Down Expand Up @@ -270,18 +233,20 @@ def run_mujoco(
input_data["buffer.1"] = hist_obs.astype(np.float32)

positions, actions, hist_obs = policy.run(None, input_data)
# actions = np.zeros_like(actions)
target_q = positions

if args.log_h5:
t += 1
h5_dict["2D_command"][t] = np.array([x_vel_cmd, y_vel_cmd], dtype=np.float32)
h5_dict["3D_command"][t] = np.array([x_vel_cmd, y_vel_cmd, yaw_vel_cmd], dtype=np.float32)
h5_dict["joint_pos"][t] = cur_pos_obs.astype(np.float32)
h5_dict["joint_vel"][t] = cur_vel_obs.astype(np.float32)
h5_dict["actions"][t] = actions.astype(np.float32)
h5_dict["ang_vel"][t] = omega.astype(np.float32)
h5_dict["euler_rotation"][t] = eu_ang.astype(np.float32)
logger.log_data({
"t": np.array([count_lowlevel * cfg.dt], dtype=np.float32),
"2D_command": np.array([x_vel_cmd, y_vel_cmd], dtype=np.float32),
"3D_command": np.array([x_vel_cmd, y_vel_cmd, yaw_vel_cmd], dtype=np.float32),
"joint_pos": cur_pos_obs.astype(np.float32),
"joint_vel": cur_vel_obs.astype(np.float32),
"actions": actions.astype(np.float32),
"ang_vel": omega.astype(np.float32),
"euler_rotation": eu_ang.astype(np.float32),
"buffer": hist_obs.astype(np.float32)
})

# Generate PD control
tau = pd_control(target_q, q, kps, dq, kds, default) # Calc torques
Expand All @@ -308,7 +273,7 @@ def run_mujoco(
print(f"Average speed: {average_speed:.4f} m/s")

if args.log_h5:
h5_file.close()
logger.close()


def parse_modelmeta(
Expand Down Expand Up @@ -360,7 +325,7 @@ def parse_modelmeta(
sim_duration=10.0,
dt=0.001,
decimation=10,
tau_factor=3.0,
tau_factor=4.0,
)
elif args.embodiment == "stompymicro":
policy_cfg.cycle_time = 0.2
Expand Down

0 comments on commit fa9bce0

Please sign in to comment.