diff --git a/sim/model_export.py b/sim/model_export.py index 9f89379..51a812c 100644 --- a/sim/model_export.py +++ b/sim/model_export.py @@ -205,6 +205,52 @@ def forward( return actions_scaled, actions, x +def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, Tuple[Tensor, ...]]: + all_weights = torch.load(model_path, map_location="cpu", weights_only=True) + weights = all_weights["model_state_dict"] + num_actor_obs = weights["actor.0.weight"].shape[1] + num_critic_obs = weights["critic.0.weight"].shape[1] + num_actions = weights["std"].shape[0] + actor_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"actor\.\d+\.weight", k)] + critic_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"critic\.\d+\.weight", k)] + actor_hidden_dims = actor_hidden_dims[:-1] + critic_hidden_dims = critic_hidden_dims[:-1] + + ac_model = ActorCritic(num_actor_obs, num_critic_obs, num_actions, actor_hidden_dims, critic_hidden_dims) + ac_model.load_state_dict(weights) + + a_model = Actor(ac_model.actor, cfg) + + # Gets the model input tensors. + x_vel = torch.randn(1) + y_vel = torch.randn(1) + rot = torch.randn(1) + t = torch.randn(1) + dof_pos = torch.randn(a_model.num_actions) + dof_vel = torch.randn(a_model.num_actions) + prev_actions = torch.randn(a_model.num_actions) + imu_ang_vel = torch.randn(3) + imu_euler_xyz = torch.randn(3) + buffer = a_model.get_init_buffer() + input_tensors = (x_vel, y_vel, rot, t, dof_pos, dof_vel, prev_actions, imu_ang_vel, imu_euler_xyz, buffer) + + jit_model = torch.jit.script(a_model) + + # Add sim2sim metadata + robot_effort = list(a_model.robot.effort().values()) + robot_stiffness = list(a_model.robot.stiffness().values()) + robot_damping = list(a_model.robot.damping().values()) + num_actions = a_model.num_actions + num_observations = a_model.num_observations + + return a_model, { + "robot_effort": robot_effort, + "robot_stiffness": robot_stiffness, + "robot_damping": robot_damping, + "num_actions": num_actions, + "num_observations": num_observations, + }, input_tensors + def convert_model_to_onnx(model_path: str, cfg: ActorCfg, save_path: Optional[str] = None) -> ort.InferenceSession: """Converts a PyTorch model to a ONNX format. diff --git a/sim/sim2sim.py b/sim/sim2sim.py index 292462b..829e10c 100755 --- a/sim/sim2sim.py +++ b/sim/sim2sim.py @@ -17,10 +17,12 @@ import onnxruntime as ort import pygame from scipy.spatial.transform import Rotation as R +import torch from tqdm import tqdm from sim.h5_logger import HDF5Logger -from sim.model_export import ActorCfg, convert_model_to_onnx +from model_export import ActorCfg, get_actor_policy, convert_model_to_onnx +from kinfer.export.pytorch import export_to_onnx @dataclass @@ -316,6 +318,10 @@ def parse_modelmeta( return parsed_meta +def new_func(args, policy_cfg): + actor_model = get_actor_jit(args.load_model, policy_cfg) + return actor_model + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Deployment script.") parser.add_argument("--embodiment", type=str, required=True, help="Embodiment name.") @@ -357,9 +363,20 @@ def parse_modelmeta( if args.load_model.endswith(".onnx"): policy = ort.InferenceSession(args.load_model) else: - policy = convert_model_to_onnx( - args.load_model, policy_cfg, save_path="policy.onnx" + # Export function is able to infer input shapes + # actor_model = new_func(args, policy_cfg) + # actor_model = torch.jit.load(args.load_model) + actor_model, sim2sim_info, input_tensors = get_actor_policy(args.load_model, policy_cfg) + # Merge policy_cfg and sim2sim_info into a single config object + export_config = {**vars(policy_cfg), **sim2sim_info} + print(export_config) + policy = export_to_onnx( + actor_model, + input_tensors=input_tensors, + config=export_config, + save_path="kinfer_test.onnx" ) + # policy = convert_model_to_onnx(args.load_model, policy_cfg, save_path="policy.onnx") model_info = parse_modelmeta( policy.get_modelmeta().custom_metadata_map.items(),