From df4de72b67694859676e73cc51fcb5e6d28397f3 Mon Sep 17 00:00:00 2001 From: Carlotta Date: Mon, 9 Dec 2024 15:47:59 +0100 Subject: [PATCH 01/13] added advanced and basic example for jaxsim simulator, added usage section in readme --- README.md | 137 +++-- examples/jaxsim_as_physics_engine.ipynb | 390 +++++--------- .../jaxsim_as_physics_engine_advanced.ipynb | 476 ++++++++++++++++++ 3 files changed, 697 insertions(+), 306 deletions(-) create mode 100644 examples/jaxsim_as_physics_engine_advanced.ipynb diff --git a/README.md b/README.md index ef4755a02..d2aca8ecb 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ # JaxSim -JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX. - -Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence. +**JaxSim** is a **differentiable physics engine** and **multibody dynamics library** built with JAX, tailored for control and robotic learning applications.

@@ -17,41 +15,105 @@ Its design facilitates research and accelerates prototyping in the intersection
## Features +- Reduced-coordinate physics engine for **fixed-base** and **floating-base** robots. +- Multibody dynamics library for model-based control algorithms. +- Fully Python-based, leveraging [jax][jax] following a functional programming paradigm. +- Seamless execution on CPUs, GPUs, and TPUs. +- Supports JIT compilation and automatic vectorization for high performance. +- Compatible with SDF models and URDF (via [sdformat][sdformat] conversion). + +## Usage + +### Using Jaxsim as simulator + + +```python +import jax +import jax.numpy as jnp +import jaxsim.api as js +import icub_models +import pathlib + +# Load the iCub model +model_path = icub_models.get_model_file("iCubGazeboV2_5") +joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch', + 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch', + 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', + 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', + 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', + 'r_ankle_roll') +ndof = len(joints) + +# Build and reduce the model +model_description = pathlib.Path(model_path) +full_model = js.model.JaxSimModel.build_from_model_description( + model_description=model_description, time_step=0.0001, is_urdf=True +) +model = js.model.reduce(model=full_model, considered_joints=joints) + +# Initialize data and simulation +data = js.data.JaxSimModelData.zero(model=model).reset_base_position( + base_position=jnp.array([0.0, 0.0, 1.0]) +) +T = jnp.arange(start=0, stop=1.0, step=model.time_step) +tau = jnp.zeros(ndof) + +# Simulate +for t in T: + data, _ = js.model.step(model=model, data=data, link_forces=None, joint_force_references=tau) -- Physics engine in reduced coordinates supporting fixed-base and floating-base robots. -- Multibody dynamics library providing all the necessary components for developing model-based control algorithms. -- Completely developed in Python with [`google/jax`][jax] following a functional programming paradigm. -- Transparent support for running on CPUs, GPUs, and TPUs. -- Full support for JIT compilation for increased performance. -- Full support for automatic vectorization for massive parallelization of open-loop and closed-loop architectures. -- Support for SDF models and, upon conversion with [sdformat][sdformat], URDF models. -- Visualization based on the [passive viewer][passive_viewer_mujoco] of Mujoco. - -### JaxSim as a simulator - -- Wide range of fixed-step explicit Runge-Kutta integrators. -- Support for variable-step integrators implemented as embedded Runge-Kutta schemes. -- Improved stability by optionally integrating the base orientation on the $\text{SO}(3)$ manifold. -- Soft contacts model supporting full friction cone and sticking-slipping transition. -- Collision detection between points rigidly attached to bodies and uneven ground surfaces. - -### JaxSim as a multibody dynamics library +``` -- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians. -- Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion. -- Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation]. -- Exposes all the necessary quantities to develop controllers in centroidal coordinates. -- Supports running open-loop and full closed-loop control architectures on hardware accelerators. +### Using JaxSim as a multibody dynamics library +``` python +import jax.numpy as jnp +import jaxsim.api as js +import icub_models +import pathlib + +# Load the iCub model +model_path = icub_models.get_model_file("iCubGazeboV2_5") +joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch', + 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch', + 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', + 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', + 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', + 'r_ankle_roll') + +# Build and reduce the model +model_description = pathlib.Path(model_path) +full_model = js.model.JaxSimModel.build_from_model_description( + model_description=model_description, time_step=0.0001, is_urdf=True +) +model = js.model.reduce(model=full_model, considered_joints=joints) + +# Initialize model data +data = js.data.JaxSimModelData.zero(model=model).reset_base_position( + base_position=jnp.array([0.0, 0.0, 1.0]) +) + +# Frame and dynamics computations +frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot") +W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index) # Frame transformation +W_J_F = js.frame.jacobian(model=model, data=data, frame_index=frame_index) # Frame Jacobian + +# Dynamics properties +M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix +h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces +g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces +C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix + +# Print dynamics results +print(f"M: shape={M.shape}, h: shape={h.shape}, g: shape={g.shape}, C: shape={C.shape}") -### JaxSim for robot learning +``` +### Additional features -- Being developed with JAX, all the RBDAs support automatic differentiation both in forward and reverse modes. +- Full support for automatic differentiation of RBDAs (forward and reverse modes) with JAX. - Support for automatically differentiating against kinematics and dynamics parameters. - All fixed-step integrators are forward and reverse differentiable. - All variable-step integrators are forward differentiable. -- Ideal for sampling synthetic data for reinforcement learning (RL). -- Ideal for designing physics-informed neural networks (PINNs) with loss functions requiring model-based quantities. -- Ideal for combining model-based control with learning-based components. +- Check the example folder for additional usecase ! [jax]: https://github.com/google/jax/ [sdformat]: https://github.com/gazebosim/sdformat @@ -65,12 +127,6 @@ Its design facilitates research and accelerates prototyping in the intersection > JaxSim currently focuses on locomotion applications. > Only contacts between bodies and smooth ground surfaces are supported. -## Documentation - -The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs]. - -[readthedocs]: https://jaxsim.readthedocs.io/ - ## Installation
@@ -142,6 +198,13 @@ pip install --no-deps -e . [venv]: https://docs.python.org/3/tutorial/venv.html [jax_gpu]: https://github.com/google/jax/#installation +## Documentation + +The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs]. + +[readthedocs]: https://jaxsim.readthedocs.io/ + + ## Overview
diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index b8ec18cfb..ce8eac0c8 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -8,9 +8,7 @@ "source": [ "# JaxSim as a hardware-accelerated parallel physics engine\n", "\n", - "JaxSim was originally developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n", - "\n", - "In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n", + "This notebook show how to use the key APIs to load a robot model and simulate multiple trajectories in parallel on GPUs.\n", "\n", "\n", " \"Open\n", @@ -58,15 +56,13 @@ "# Notebook imports\n", "# ================\n", "\n", - "import os\n", - "\n", "import jax\n", "import jax.numpy as jnp\n", "import jaxsim.api as js\n", - "import jaxsim\n", - "import rod\n", "from jaxsim import logging\n", - "from rod.builder.primitives import SphereBuilder\n", + "import pathlib\n", + "import tempfile\n", + "import urllib.request\n", "\n", "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", "print(f\"Running on {jax.devices()}\")" @@ -75,79 +71,33 @@ { "cell_type": "markdown", "metadata": { - "id": "QtCCUhdpdGFH" + "id": "NqjuZKvOaTG_" }, "source": [ "## Prepare the simulation\n", "\n", - "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`ami-iit/rod`][rod] library, which processes these formats.\n", - "\n", - "The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n", + "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the ergoCub model urdf.\n", "\n", "[sdformat]: http://sdformat.org/\n", "[urdf]: http://wiki.ros.org/urdf/\n", - "[rod]: https://github.com/ami-iit/rod" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "0emoMQhCaTG_" - }, - "outputs": [], - "source": [ - "# @title Create the model description of a sphere\n", - "\n", - "# Create a SDF model.\n", - "# The builder takes care to compute the right inertia tensor for you.\n", - "rod_sdf = rod.Sdf(\n", - " version=\"1.7\",\n", - " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", - " .build_model()\n", - " .add_link()\n", - " .add_inertial()\n", - " .add_visual()\n", - " .add_collision()\n", - " .build(),\n", - ")\n", - "\n", - "# Rod allows to update the frames w.r.t. the poses are expressed.\n", - "rod_sdf.model.switch_frame_convention(\n", - " frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n", - ")\n", + "[rod]: https://github.com/ami-iit/rod\n", "\n", - "# Serialize the model to a SDF string.\n", - "model_sdf_string = rod_sdf.serialize(pretty=True)\n", - "print(model_sdf_string)\n", - "\n", - "# JaxSim currently only supports collisions between points attached to bodies\n", - "# and a ground surface modeled as a heightmap sampled from a smooth function.\n", - "# While this approach is universal as it applies to generic meshes, the number\n", - "# of considered points greatly affects the performance. Spheres, by default,\n", - "# are discretized with 250 points. It's too much for this simple example.\n", - "# This number can be decreased with the following environment variable.\n", - "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NqjuZKvOaTG_" - }, - "source": [ "### Create the model and its data\n", - "\n", - "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", + " To define a simulation we need two main objects:\n", "\n", "- `model`: an object that defines the dynamics of the system.\n", "- `data`: an object that contains the state of the system.\n", - "- `integrator` *(Optional)*: an object that defines the integration method.\n", - "- `integrator_state` *(Optional)*: an object that contains the state of the integrator.\n", + "\n", "\n", "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n", - "In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model." + "To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the `model` object and how to change the contact model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create the model " ] }, { @@ -158,159 +108,146 @@ }, "outputs": [], "source": [ - "# Create the JaxSim model.\n", - "# This is shared among all the parallel instances.\n", - "model = js.model.JaxSimModel.build_from_model_description(\n", - " model_description=model_sdf_string,\n", - " time_step=0.001,\n", - " integrator=jaxsim.integrators.fixed_step.Heun2,\n", + "# Create the JaxSim model.\n", + "url = \"https://raw.githubusercontent.com/icub-tech-iit/ergocub-software/refs/heads/master/urdf/ergoCub/robots/ergoCubSN001/model.urdf\"\n", + "# Create a temporary file\n", + "with tempfile.NamedTemporaryFile(mode=\"w+\", delete=False) as urdf_robot_file:\n", + " # Retrieve the file\n", + " urllib.request.urlretrieve(url, urdf_robot_file.name)\n", + "\n", + "# print(urllib.request.urlretrieve(url, urdf_robot_file.name))\n", + "model_description_path = pathlib.Path(urdf_robot_file.name)\n", + "full_model = js.model.JaxSimModel.build_from_model_description(\n", + " model_description=model_description_path,\n", + " time_step=0.0001,\n", + " is_urdf=True\n", ")\n", "\n", - "# Create the data of a single model.\n", - "# We will create a vectorized instance later.\n", - "data_single = js.data.JaxSimModelData.zero(model=model)" + "joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',\n", + " 'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',\n", + " 'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n", + " 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))\n", + "\n", + "model = js.model.reduce(\n", + " model=full_model,\n", + " considered_joints=joints_list\n", + ")\n" ] }, { "cell_type": "markdown", - "metadata": { - "id": "FJF-HoWaiK9J" - }, + "metadata": {}, "source": [ - "### Select the contact model\n", + "### Create the data object \n", "\n", - "JaxSim offers several contact models, with the default being the non-linear Hunt/Crossley soft contact model. This model supports stick/slip transitions and fully accounts for friction cones.\n", - "\n", - "While it is faster than other models, it requires careful parameter tuning and may need a small time step $\\Delta t$, unless a variable-step integrator is used.\n", - "\n" + "The data object is never changed by reference. Anytime you call a method aimed at modifying data, like `reset_base_position`, a new data object will be returned with the updated attributes while the original data will not be changed." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "VAaitHRKjnwc" - }, + "metadata": {}, "outputs": [], "source": [ - "import jaxsim\n", - "\n", - "# Operate on a copy of the model.\n", - "# When validate=True, this context manager ensures that the PyTree structure\n", - "# of the object is not altered. This is a nice feature of JaxSim to spot\n", - "# earlier user logic that might trigger unwanted JIT recompilations.\n", - "# In this case, we need to disable validation since PyTree structure might\n", - "# change if you use a contact model different from the default.\n", - "with model.editable(validate=False) as model:\n", - "\n", - " # The SoftContacts class can be replaced with a different contact model.\n", - " model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()\n", - "\n", - "# JaxSim provides the following helper that estimates good contact\n", - "# parameters. While they might not be optimal, usually are a good\n", - "# starting point. Users are encouraged to fine-tune them.\n", - "contacts_params = js.contact.estimate_good_contact_parameters(\n", - " model=model,\n", - " number_of_active_collidable_points_steady_state=4,\n", - " max_penetration=0.001,\n", - ")\n", - "\n", - "# Print the contact parameters.\n", - "# Note that these parameters are the nominal parameters shared among\n", - "# all parallel instances. If needed, they can be overridden in the\n", - "# vectorized data object that will be created later.\n", - "print(contacts_params)\n", - "\n", - "# Update the data object with the new contact model parameters.\n", - "data_single = data_single.replace(contacts_params=contacts_params, validate=False)" + "# Create the data of a single model.\n", + "data_zero = js.data.JaxSimModelData.zero(model=model)\n", + "base_position = jnp.array([0.0, 0.0, 1.0])\n", + "data = data_zero.reset_base_position(base_position=base_position) # Note that the reset position returns the updated data object" ] }, { "cell_type": "markdown", - "metadata": { - "id": "6REY2bq3lc_k" - }, + "metadata": {}, "source": [ - "### Select the integrator\n", - "\n", - "JaxSim offers various integrators, ranging from basic ones like `ForwardEuler` to higher-order methods like `RungeKutta4`. You can explore the available integrators in the following modules:\n", - "\n", - "- `jaxsim.integrators.fixed_step`\n", - "- `jaxsim.integrators.variable_step`\n", - "\n", - "The `*SO3` variants update the integration scheme by integrating more accurately the base orientation on the $\\text{SO}(3)$ manifold." + "### Simulation" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "o86Teq5piVGj" - }, + "metadata": {}, "outputs": [], "source": [ - "print(f\"Using integrator: {model.integrator}\")\n", - "\n", - "# Initialize the simulated time.\n", - "T = jnp.arange(start=0, stop=1.0, step=model.time_step)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "V6IeD2B3m4F0" - }, - "source": [ - "## Sample a batch of trajectories in parallel\n", + "# Create a random JAX key.\n", "\n", - "With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n", + "key = jax.random.PRNGKey(seed=0)\n", "\n", - "In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n", + "# Initialize the simulated time.\n", + "T = jnp.arange(start=0, stop=1.0, step=model.time_step)\n", "\n", - "Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions." + "# Simulate\n", + "for _t in T:\n", + " data, _ = js.model.step(\n", + " model=model,\n", + " data=data,\n", + " link_forces=None,\n", + " joint_force_references=None,\n", + " )" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "vtEn0aIzr_2j" - }, + "metadata": {}, "outputs": [], "source": [ - "# @title Generate batched initial data\n", + "import jax.numpy as jnp\n", + "import jaxsim.api as js\n", + "import icub_models\n", + "import pathlib\n", + "\n", + "# Load the iCub model\n", + "model_path = icub_models.get_model_file(\"iCubGazeboV2_5\")\n", + "joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',\n", + " 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',\n", + " 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',\n", + " 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n", + " 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',\n", + " 'r_ankle_roll')\n", + "\n", + "# Build and reduce the model\n", + "model_description = pathlib.Path(model_path)\n", + "full_model = js.model.JaxSimModel.build_from_model_description(\n", + " model_description=model_description, time_step=0.0001, is_urdf=True\n", + ")\n", + "model = js.model.reduce(model=full_model, considered_joints=joints)\n", "\n", - "# Create a random JAX key.\n", - "key = jax.random.PRNGKey(seed=0)\n", + "# Initialize model data\n", + "data = js.data.JaxSimModelData.zero(model=model).reset_base_position(\n", + " base_position=jnp.array([0.0, 0.0, 1.0])\n", + ")\n", "\n", - "# Split subkeys for sampling random initial data.\n", - "batch_size = 10\n", - "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n", + "# Frame and dynamics computations\n", + "frame_index = js.frame.name_to_idx(model=model, frame_name=\"l_foot\")\n", + "W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index) # Frame transformation\n", + "W_J_F = js.frame.jacobian(model=model, data=data, frame_index=frame_index) # Frame Jacobian\n", "\n", - "# Create the batched data by sampling the height from [0.5, 0.6] meters.\n", - "data_batch_t0 = jax.vmap(\n", - " lambda key: js.data.random_model_data(\n", - " model=model,\n", - " key=key,\n", - " base_pos_bounds=([0, 0, 0.3], [0, 0, 0.6]),\n", - " base_vel_lin_bounds=(0, 0),\n", - " base_vel_ang_bounds=(0, 0),\n", - " contacts_params=contacts_params,\n", - " )\n", - ")(jnp.vstack(subkeys))\n", + "# Dynamics properties\n", + "M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix\n", + "h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces\n", + "g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces\n", + "C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix\n", + "\n", + "# Print dynamics results\n", + "print(f\"M: shape={M.shape}, h: shape={h.shape}, g: shape={g.shape}, C: shape={C.shape}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Vectorized simulation \n", "\n", - "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])" + "We will now vectorize the simulation on batched data using `vmap`" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "0tQPfsl6uxHm" - }, + "metadata": {}, "outputs": [], "source": [ - "# @title Create parallel step function\n", + "# first we have to vmap the function\n", "\n", "import functools\n", "from typing import Any\n", @@ -343,106 +280,21 @@ " )\n", "\n", "\n", - "# The first run will be slow since JAX needs to JIT-compile the functions.\n", - "_ = step_single(model, data_single)\n", - "_ = step_parallel(model, data_batch_t0)\n", + "# Then we have to create the vector of initial state\n", "\n", - "# Benchmark the execution of a single step.\n", - "print(\"\\nSingle simulation step:\")\n", - "%timeit step_single(model, data_single)\n", + "# Split subkeys for sampling random initial data.\n", + "batch_size = 10\n", + "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n", "\n", - "# On hardware accelerators, there's a range of batch_size values where\n", - "# increasing the number of parallel instances doesn't affect computation time.\n", - "# This range depends on the GPU/TPU specifications.\n", - "print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n", - "%timeit step_parallel(model, data_batch_t0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VNwzT2JQ1n15" - }, - "outputs": [], - "source": [ - "# @title Run parallel simulation\n", + "# Create the batched data.\n", + "data_batch_t0 = jax.vmap(\n", + " lambda key: data_zero.reset_base_position(base_position=jnp.array([0.0, 0.0, 1.0]))\n", + ")(jnp.vstack(subkeys))\n", "\n", "data = data_batch_t0\n", - "data_trajectory_list = []\n", - "\n", - "for _ in T:\n", - "\n", - " data, _ = step_parallel(model, data)\n", - " data_trajectory_list.append(data)" + "for _t in T:\n", + " data, _ = step_parallel(model, data)" ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Y6n720Cr3G44" - }, - "source": [ - "## Visualize trajectory" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BLPODyKr3Lyg" - }, - "outputs": [], - "source": [ - "# Convert a list of PyTrees to a batched PyTree.\n", - "# This operation is called 'tree transpose' in JAX.\n", - "data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n", - "\n", - "print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-jxJXy5r3RMt" - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n", - "plt.grid(True)\n", - "plt.xlabel(\"Time [s]\")\n", - "plt.ylabel(\"Height [m]\")\n", - "plt.title(\"Height trajectory of the sphere\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N92-WjPFGuua" - }, - "source": [ - "# Conclusions\n", - "\n", - "This notebook introduced the key APIs of JaxSim as a hardware-accelerated parallel physics engine. Key takeaways:\n", - "\n", - "- **Contact models**: trajectories are sensitive to the contact model used. Explore the `jaxsim.rbda.contacts` package to find the best fit, as each model comes with trade-offs.\n", - "- **Integrator selection**: the choice of integrator affects both accuracy and speed. Experiment with options in the `jaxsim.integrators` package to optimize for your application and hardware accelerator.\n", - "- **Time step**: the interaction between contact models and integrators depends on the integration step $\\Delta t$. Choose the largest stable time step that guarantees for stable simulations.\n", - "- **Automatic vectorization**: this notebook demonstrated one way to use `jax.vmap`, but there are many other approaches. As you become more familiar with JAX, you'll discover better methods tailored to your needs.\n", - "- **Advanced applications**: Combine `jax.jit` and `jax.vmap` with `jax.grad`, `jax.jacfwd`, and `jax.jacrev` for gradient-based learning and other advanced tasks (not covered here).\n", - "\n", - "Have fun!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { @@ -454,7 +306,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "rsl", + "display_name": "jaxsim", "language": "python", "name": "python3" }, @@ -468,7 +320,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.13.0" } }, "nbformat": 4, diff --git a/examples/jaxsim_as_physics_engine_advanced.ipynb b/examples/jaxsim_as_physics_engine_advanced.ipynb new file mode 100644 index 000000000..b8ec18cfb --- /dev/null +++ b/examples/jaxsim_as_physics_engine_advanced.ipynb @@ -0,0 +1,476 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "H-WgcgGQaTG7" + }, + "source": [ + "# JaxSim as a hardware-accelerated parallel physics engine\n", + "\n", + "JaxSim was originally developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n", + "\n", + "In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n", + "\n", + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SgOSnrSscEkt" + }, + "source": [ + "## Prepare the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fdqvAqMDaTG9" + }, + "outputs": [], + "source": [ + "# @title Imports and setup\n", + "import sys\n", + "from IPython.display import clear_output\n", + "\n", + "IS_COLAB = \"google.colab\" in sys.modules\n", + "\n", + "# Install JAX and Gazebo\n", + "if IS_COLAB:\n", + " !{sys.executable} -m pip install --pre -qU jaxsim\n", + " !apt install -qq lsb-release wget gnupg\n", + " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", + " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", + " !apt -qq update\n", + " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", + "\n", + " clear_output()\n", + "\n", + "# Set environment variable to avoid GPU out of memory errors\n", + "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n", + "\n", + "# ================\n", + "# Notebook imports\n", + "# ================\n", + "\n", + "import os\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jaxsim.api as js\n", + "import jaxsim\n", + "import rod\n", + "from jaxsim import logging\n", + "from rod.builder.primitives import SphereBuilder\n", + "\n", + "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", + "print(f\"Running on {jax.devices()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QtCCUhdpdGFH" + }, + "source": [ + "## Prepare the simulation\n", + "\n", + "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`ami-iit/rod`][rod] library, which processes these formats.\n", + "\n", + "The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n", + "\n", + "[sdformat]: http://sdformat.org/\n", + "[urdf]: http://wiki.ros.org/urdf/\n", + "[rod]: https://github.com/ami-iit/rod" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "0emoMQhCaTG_" + }, + "outputs": [], + "source": [ + "# @title Create the model description of a sphere\n", + "\n", + "# Create a SDF model.\n", + "# The builder takes care to compute the right inertia tensor for you.\n", + "rod_sdf = rod.Sdf(\n", + " version=\"1.7\",\n", + " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", + " .build_model()\n", + " .add_link()\n", + " .add_inertial()\n", + " .add_visual()\n", + " .add_collision()\n", + " .build(),\n", + ")\n", + "\n", + "# Rod allows to update the frames w.r.t. the poses are expressed.\n", + "rod_sdf.model.switch_frame_convention(\n", + " frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n", + ")\n", + "\n", + "# Serialize the model to a SDF string.\n", + "model_sdf_string = rod_sdf.serialize(pretty=True)\n", + "print(model_sdf_string)\n", + "\n", + "# JaxSim currently only supports collisions between points attached to bodies\n", + "# and a ground surface modeled as a heightmap sampled from a smooth function.\n", + "# While this approach is universal as it applies to generic meshes, the number\n", + "# of considered points greatly affects the performance. Spheres, by default,\n", + "# are discretized with 250 points. It's too much for this simple example.\n", + "# This number can be decreased with the following environment variable.\n", + "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NqjuZKvOaTG_" + }, + "source": [ + "### Create the model and its data\n", + "\n", + "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", + "\n", + "- `model`: an object that defines the dynamics of the system.\n", + "- `data`: an object that contains the state of the system.\n", + "- `integrator` *(Optional)*: an object that defines the integration method.\n", + "- `integrator_state` *(Optional)*: an object that contains the state of the integrator.\n", + "\n", + "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n", + "In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "etQ577cFaTHA" + }, + "outputs": [], + "source": [ + "# Create the JaxSim model.\n", + "# This is shared among all the parallel instances.\n", + "model = js.model.JaxSimModel.build_from_model_description(\n", + " model_description=model_sdf_string,\n", + " time_step=0.001,\n", + " integrator=jaxsim.integrators.fixed_step.Heun2,\n", + ")\n", + "\n", + "# Create the data of a single model.\n", + "# We will create a vectorized instance later.\n", + "data_single = js.data.JaxSimModelData.zero(model=model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FJF-HoWaiK9J" + }, + "source": [ + "### Select the contact model\n", + "\n", + "JaxSim offers several contact models, with the default being the non-linear Hunt/Crossley soft contact model. This model supports stick/slip transitions and fully accounts for friction cones.\n", + "\n", + "While it is faster than other models, it requires careful parameter tuning and may need a small time step $\\Delta t$, unless a variable-step integrator is used.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VAaitHRKjnwc" + }, + "outputs": [], + "source": [ + "import jaxsim\n", + "\n", + "# Operate on a copy of the model.\n", + "# When validate=True, this context manager ensures that the PyTree structure\n", + "# of the object is not altered. This is a nice feature of JaxSim to spot\n", + "# earlier user logic that might trigger unwanted JIT recompilations.\n", + "# In this case, we need to disable validation since PyTree structure might\n", + "# change if you use a contact model different from the default.\n", + "with model.editable(validate=False) as model:\n", + "\n", + " # The SoftContacts class can be replaced with a different contact model.\n", + " model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()\n", + "\n", + "# JaxSim provides the following helper that estimates good contact\n", + "# parameters. While they might not be optimal, usually are a good\n", + "# starting point. Users are encouraged to fine-tune them.\n", + "contacts_params = js.contact.estimate_good_contact_parameters(\n", + " model=model,\n", + " number_of_active_collidable_points_steady_state=4,\n", + " max_penetration=0.001,\n", + ")\n", + "\n", + "# Print the contact parameters.\n", + "# Note that these parameters are the nominal parameters shared among\n", + "# all parallel instances. If needed, they can be overridden in the\n", + "# vectorized data object that will be created later.\n", + "print(contacts_params)\n", + "\n", + "# Update the data object with the new contact model parameters.\n", + "data_single = data_single.replace(contacts_params=contacts_params, validate=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6REY2bq3lc_k" + }, + "source": [ + "### Select the integrator\n", + "\n", + "JaxSim offers various integrators, ranging from basic ones like `ForwardEuler` to higher-order methods like `RungeKutta4`. You can explore the available integrators in the following modules:\n", + "\n", + "- `jaxsim.integrators.fixed_step`\n", + "- `jaxsim.integrators.variable_step`\n", + "\n", + "The `*SO3` variants update the integration scheme by integrating more accurately the base orientation on the $\\text{SO}(3)$ manifold." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o86Teq5piVGj" + }, + "outputs": [], + "source": [ + "print(f\"Using integrator: {model.integrator}\")\n", + "\n", + "# Initialize the simulated time.\n", + "T = jnp.arange(start=0, stop=1.0, step=model.time_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V6IeD2B3m4F0" + }, + "source": [ + "## Sample a batch of trajectories in parallel\n", + "\n", + "With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n", + "\n", + "In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n", + "\n", + "Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vtEn0aIzr_2j" + }, + "outputs": [], + "source": [ + "# @title Generate batched initial data\n", + "\n", + "# Create a random JAX key.\n", + "key = jax.random.PRNGKey(seed=0)\n", + "\n", + "# Split subkeys for sampling random initial data.\n", + "batch_size = 10\n", + "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n", + "\n", + "# Create the batched data by sampling the height from [0.5, 0.6] meters.\n", + "data_batch_t0 = jax.vmap(\n", + " lambda key: js.data.random_model_data(\n", + " model=model,\n", + " key=key,\n", + " base_pos_bounds=([0, 0, 0.3], [0, 0, 0.6]),\n", + " base_vel_lin_bounds=(0, 0),\n", + " base_vel_ang_bounds=(0, 0),\n", + " contacts_params=contacts_params,\n", + " )\n", + ")(jnp.vstack(subkeys))\n", + "\n", + "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0tQPfsl6uxHm" + }, + "outputs": [], + "source": [ + "# @title Create parallel step function\n", + "\n", + "import functools\n", + "from typing import Any\n", + "\n", + "\n", + "@jax.jit\n", + "def step_single(\n", + " model: js.model.JaxSimModel,\n", + " data: js.data.JaxSimModelData,\n", + ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", + "\n", + " # Close step over static arguments.\n", + " return js.model.step(\n", + " model=model,\n", + " data=data,\n", + " link_forces=None,\n", + " joint_force_references=None,\n", + " )\n", + "\n", + "\n", + "@jax.jit\n", + "@functools.partial(jax.vmap, in_axes=(None, 0))\n", + "def step_parallel(\n", + " model: js.model.JaxSimModel,\n", + " data: js.data.JaxSimModelData,\n", + ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", + "\n", + " return step_single(\n", + " model=model, data=data\n", + " )\n", + "\n", + "\n", + "# The first run will be slow since JAX needs to JIT-compile the functions.\n", + "_ = step_single(model, data_single)\n", + "_ = step_parallel(model, data_batch_t0)\n", + "\n", + "# Benchmark the execution of a single step.\n", + "print(\"\\nSingle simulation step:\")\n", + "%timeit step_single(model, data_single)\n", + "\n", + "# On hardware accelerators, there's a range of batch_size values where\n", + "# increasing the number of parallel instances doesn't affect computation time.\n", + "# This range depends on the GPU/TPU specifications.\n", + "print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n", + "%timeit step_parallel(model, data_batch_t0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VNwzT2JQ1n15" + }, + "outputs": [], + "source": [ + "# @title Run parallel simulation\n", + "\n", + "data = data_batch_t0\n", + "data_trajectory_list = []\n", + "\n", + "for _ in T:\n", + "\n", + " data, _ = step_parallel(model, data)\n", + " data_trajectory_list.append(data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y6n720Cr3G44" + }, + "source": [ + "## Visualize trajectory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BLPODyKr3Lyg" + }, + "outputs": [], + "source": [ + "# Convert a list of PyTrees to a batched PyTree.\n", + "# This operation is called 'tree transpose' in JAX.\n", + "data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n", + "\n", + "print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-jxJXy5r3RMt" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n", + "plt.grid(True)\n", + "plt.xlabel(\"Time [s]\")\n", + "plt.ylabel(\"Height [m]\")\n", + "plt.title(\"Height trajectory of the sphere\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N92-WjPFGuua" + }, + "source": [ + "# Conclusions\n", + "\n", + "This notebook introduced the key APIs of JaxSim as a hardware-accelerated parallel physics engine. Key takeaways:\n", + "\n", + "- **Contact models**: trajectories are sensitive to the contact model used. Explore the `jaxsim.rbda.contacts` package to find the best fit, as each model comes with trade-offs.\n", + "- **Integrator selection**: the choice of integrator affects both accuracy and speed. Experiment with options in the `jaxsim.integrators` package to optimize for your application and hardware accelerator.\n", + "- **Time step**: the interaction between contact models and integrators depends on the integration step $\\Delta t$. Choose the largest stable time step that guarantees for stable simulations.\n", + "- **Automatic vectorization**: this notebook demonstrated one way to use `jax.vmap`, but there are many other approaches. As you become more familiar with JAX, you'll discover better methods tailored to your needs.\n", + "- **Advanced applications**: Combine `jax.jit` and `jax.vmap` with `jax.grad`, `jax.jacfwd`, and `jax.jacrev` for gradient-based learning and other advanced tasks (not covered here).\n", + "\n", + "Have fun!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuClass": "premium", + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "rsl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 243a00be222efcabde67db65c6206d991d11ab72 Mon Sep 17 00:00:00 2001 From: Carlotta Date: Mon, 9 Dec 2024 16:04:36 +0100 Subject: [PATCH 02/13] added icub-models as test dependencies --- environment.yml | 2 ++ pyproject.toml | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 2603b0c82..6767660fa 100644 --- a/environment.yml +++ b/environment.yml @@ -29,6 +29,7 @@ dependencies: - pytest - pytest-icdiff - robot_descriptions + - icub-models # [viz] - lxml - mediapy @@ -55,6 +56,7 @@ dependencies: - sphinx-multiversion - sphinx_rtd_theme - sphinx-toolbox + - icub-models # ======================================== # Other dependencies for GitHub Codespaces # ======================================== diff --git a/pyproject.toml b/pyproject.toml index 5fc968178..ad27e2508 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,8 @@ testing = [ "pytest >=6.0", "pytest-benchmark", "pytest-icdiff", - "robot-descriptions" + "robot-descriptions", + "icub-models", ] viz = [ "lxml", From 10be65fab3288f71f35b5096b133c24069e93678 Mon Sep 17 00:00:00 2001 From: Carlotta Date: Mon, 9 Dec 2024 16:24:11 +0100 Subject: [PATCH 03/13] removed useless part in the notebook --- examples/jaxsim_as_physics_engine.ipynb | 47 ------------------------- 1 file changed, 47 deletions(-) diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index ce8eac0c8..eb6c51246 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -185,53 +185,6 @@ " )" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import jax.numpy as jnp\n", - "import jaxsim.api as js\n", - "import icub_models\n", - "import pathlib\n", - "\n", - "# Load the iCub model\n", - "model_path = icub_models.get_model_file(\"iCubGazeboV2_5\")\n", - "joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',\n", - " 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',\n", - " 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',\n", - " 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n", - " 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',\n", - " 'r_ankle_roll')\n", - "\n", - "# Build and reduce the model\n", - "model_description = pathlib.Path(model_path)\n", - "full_model = js.model.JaxSimModel.build_from_model_description(\n", - " model_description=model_description, time_step=0.0001, is_urdf=True\n", - ")\n", - "model = js.model.reduce(model=full_model, considered_joints=joints)\n", - "\n", - "# Initialize model data\n", - "data = js.data.JaxSimModelData.zero(model=model).reset_base_position(\n", - " base_position=jnp.array([0.0, 0.0, 1.0])\n", - ")\n", - "\n", - "# Frame and dynamics computations\n", - "frame_index = js.frame.name_to_idx(model=model, frame_name=\"l_foot\")\n", - "W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index) # Frame transformation\n", - "W_J_F = js.frame.jacobian(model=model, data=data, frame_index=frame_index) # Frame Jacobian\n", - "\n", - "# Dynamics properties\n", - "M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix\n", - "h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces\n", - "g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces\n", - "C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix\n", - "\n", - "# Print dynamics results\n", - "print(f\"M: shape={M.shape}, h: shape={h.shape}, g: shape={g.shape}, C: shape={C.shape}\")\n" - ] - }, { "cell_type": "markdown", "metadata": {}, From 56fcf7b234c45a7d12dafa37bc0fc94ba518a20a Mon Sep 17 00:00:00 2001 From: Carlotta Date: Mon, 9 Dec 2024 16:41:36 +0100 Subject: [PATCH 04/13] decreased batch size to allow execution in notebook --- examples/jaxsim_as_physics_engine.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index eb6c51246..284a935b1 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -173,7 +173,7 @@ "key = jax.random.PRNGKey(seed=0)\n", "\n", "# Initialize the simulated time.\n", - "T = jnp.arange(start=0, stop=1.0, step=model.time_step)\n", + "T = jnp.arange(start=0, stop=0.3, step=model.time_step)\n", "\n", "# Simulate\n", "for _t in T:\n", @@ -236,7 +236,7 @@ "# Then we have to create the vector of initial state\n", "\n", "# Split subkeys for sampling random initial data.\n", - "batch_size = 10\n", + "batch_size = 5\n", "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n", "\n", "# Create the batched data.\n", From 7ba0101d53918a89fb4df0a9630ba3729340fa3c Mon Sep 17 00:00:00 2001 From: Carlotta Sartore <56030908+CarlottaSartore@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:51:04 +0100 Subject: [PATCH 05/13] Update examples/jaxsim_as_physics_engine.ipynb Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- examples/jaxsim_as_physics_engine.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index 284a935b1..90e2eeb76 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -8,7 +8,7 @@ "source": [ "# JaxSim as a hardware-accelerated parallel physics engine\n", "\n", - "This notebook show how to use the key APIs to load a robot model and simulate multiple trajectories in parallel on GPUs.\n", + "This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.\n", "\n", "\n", " \"Open\n", From 55f4a42fb4e8be09dd7eb455ff243167dcf0726b Mon Sep 17 00:00:00 2001 From: Carlotta Sartore <56030908+CarlottaSartore@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:51:15 +0100 Subject: [PATCH 06/13] Update examples/jaxsim_as_physics_engine_advanced.ipynb Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- examples/jaxsim_as_physics_engine_advanced.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/jaxsim_as_physics_engine_advanced.ipynb b/examples/jaxsim_as_physics_engine_advanced.ipynb index b8ec18cfb..b886f43f3 100644 --- a/examples/jaxsim_as_physics_engine_advanced.ipynb +++ b/examples/jaxsim_as_physics_engine_advanced.ipynb @@ -8,7 +8,7 @@ "source": [ "# JaxSim as a hardware-accelerated parallel physics engine\n", "\n", - "JaxSim was originally developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n", + "JaxSim is developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n", "\n", "In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n", "\n", From 34d64c9ffbf7f2e2f9a1ca4fe4bf52c8e74888fb Mon Sep 17 00:00:00 2001 From: Carlotta Sartore <56030908+CarlottaSartore@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:51:37 +0100 Subject: [PATCH 07/13] Update examples/jaxsim_as_physics_engine.ipynb Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- examples/jaxsim_as_physics_engine.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index 90e2eeb76..3d8e944fb 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -76,10 +76,11 @@ "source": [ "## Prepare the simulation\n", "\n", - "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the ergoCub model urdf.\n", + "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the [ergoCub][ergocub] model urdf.\n", "\n", "[sdformat]: http://sdformat.org/\n", "[urdf]: http://wiki.ros.org/urdf/\n", + "[ergocub]: https://ergocub.eu/\n", "[rod]: https://github.com/ami-iit/rod\n", "\n", "### Create the model and its data\n", From 4f0ceb44550e4fa2753d6eb8ed7e9061b2ca4c12 Mon Sep 17 00:00:00 2001 From: Carlotta Sartore <56030908+CarlottaSartore@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:51:57 +0100 Subject: [PATCH 08/13] Update examples/jaxsim_as_physics_engine.ipynb Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- examples/jaxsim_as_physics_engine.ipynb | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index 3d8e944fb..e22b71e01 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -111,13 +111,11 @@ "source": [ "# Create the JaxSim model.\n", "url = \"https://raw.githubusercontent.com/icub-tech-iit/ergocub-software/refs/heads/master/urdf/ergoCub/robots/ergoCubSN001/model.urdf\"\n", - "# Create a temporary file\n", - "with tempfile.NamedTemporaryFile(mode=\"w+\", delete=False) as urdf_robot_file:\n", - " # Retrieve the file\n", - " urllib.request.urlretrieve(url, urdf_robot_file.name)\n", "\n", - "# print(urllib.request.urlretrieve(url, urdf_robot_file.name))\n", - "model_description_path = pathlib.Path(urdf_robot_file.name)\n", + "# Retrieve the file\n", + "model_path = urllib.request.urlretrieve(url)\n", + "\n", + "model_description_path = pathlib.Path(model_path)\n", "full_model = js.model.JaxSimModel.build_from_model_description(\n", " model_description=model_description_path,\n", " time_step=0.0001,\n", From 98c1e0354b9368db0c8fc2ce2fe8d37661f5d3be Mon Sep 17 00:00:00 2001 From: Carlotta Sartore <56030908+CarlottaSartore@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:53:08 +0100 Subject: [PATCH 09/13] Update examples/jaxsim_as_physics_engine.ipynb Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- examples/jaxsim_as_physics_engine.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index e22b71e01..d63f3e307 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -190,7 +190,7 @@ "source": [ "### Vectorized simulation \n", "\n", - "We will now vectorize the simulation on batched data using `vmap`" + "We will now vectorize the simulation on batched data using `jax.vmap`" ] }, { From a153b2e1d442f6363258b84e78542c6ea83e75e2 Mon Sep 17 00:00:00 2001 From: Carlotta Sartore <56030908+CarlottaSartore@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:56:34 +0100 Subject: [PATCH 10/13] Update examples/jaxsim_as_physics_engine_advanced.ipynb Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- examples/jaxsim_as_physics_engine_advanced.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/jaxsim_as_physics_engine_advanced.ipynb b/examples/jaxsim_as_physics_engine_advanced.ipynb index b886f43f3..370d59cb9 100644 --- a/examples/jaxsim_as_physics_engine_advanced.ipynb +++ b/examples/jaxsim_as_physics_engine_advanced.ipynb @@ -144,7 +144,7 @@ "- `model`: an object that defines the dynamics of the system.\n", "- `data`: an object that contains the state of the system.\n", "- `integrator` *(Optional)*: an object that defines the integration method.\n", - "- `integrator_state` *(Optional)*: an object that contains the state of the integrator.\n", + "- `integrator_metadata` *(Optional)*: an object that contains the state of the integrator.\n", "\n", "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n", "In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model." From 0501bfbf467c4d2d766bd85a2e882cedc0877981 Mon Sep 17 00:00:00 2001 From: Carlotta Sartore <56030908+CarlottaSartore@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:35:30 +0100 Subject: [PATCH 11/13] Update README.md Co-authored-by: Alessandro Croci <57228872+xela-95@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d2aca8ecb..7f0cb0fd3 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ ## Usage -### Using Jaxsim as simulator +### Using JaxSim as simulator ```python From 98c4a262ddd1031a907b8aebca24a11f6f4984bb Mon Sep 17 00:00:00 2001 From: Carlotta Sartore <56030908+CarlottaSartore@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:36:56 +0100 Subject: [PATCH 12/13] Update jaxsim_as_physics_engine_advanced.ipynb --- examples/jaxsim_as_physics_engine_advanced.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/jaxsim_as_physics_engine_advanced.ipynb b/examples/jaxsim_as_physics_engine_advanced.ipynb index 370d59cb9..1d76c9817 100644 --- a/examples/jaxsim_as_physics_engine_advanced.ipynb +++ b/examples/jaxsim_as_physics_engine_advanced.ipynb @@ -6,7 +6,7 @@ "id": "H-WgcgGQaTG7" }, "source": [ - "# JaxSim as a hardware-accelerated parallel physics engine\n", + "# JaxSim as a hardware-accelerated parallel physics engine-advanced usage\n", "\n", "JaxSim is developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n", "\n", From f32f91d46f1618bb7e98c9246ad1ae5c463b4f88 Mon Sep 17 00:00:00 2001 From: Carlotta Sartore <56030908+CarlottaSartore@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:40:39 +0100 Subject: [PATCH 13/13] Update example readme --- examples/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/README.md b/examples/README.md index 533158fba..63043fbbe 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,11 +8,13 @@ This folder contains Jupyter notebooks that demonstrate the practical usage of J | :--- | :---: | :--- | | [`jaxsim_as_multibody_dynamics_library`](./jaxsim_as_multibody_dynamics_library.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_multibody_dynamics] | An example demonstrating how to use JaxSim as a multibody dynamics library. | | [`jaxsim_as_physics_engine.ipynb`](./jaxsim_as_physics_engine.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_physics_engine] | An example demonstrating how to simulate vectorized models in parallel. | +| [`jaxsim_as_physics_engine_advanced.ipynb`](./jaxsim_as_physics_engine_advanced.ipynb) | [![Open In Colab][colab_badge]][jaxsim_as_physics_engine_advanced] | An example showcasing advanced JaxSim usage, such as customizing the integrator, contact model, and more. | | [`jaxsim_for_robot_controllers.ipynb`](./jaxsim_for_robot_controllers.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_closed_loop] | A basic example showing how to simulate a PD controller with gravity compensation for a 2-DOF cart-pole. | [colab_badge]: https://colab.research.google.com/assets/colab-badge.svg [ipynb_jaxsim_closed_loop]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb [ipynb_jaxsim_as_physics_engine]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb +[jaxsim_as_physics_engine_advanced]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb [ipynb_jaxsim_as_multibody_dynamics]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb ## How to run the examples