diff --git a/README.md b/README.md
index ef4755a02..7f0cb0fd3 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,6 @@
# JaxSim
-JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX.
-
-Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
+**JaxSim** is a **differentiable physics engine** and **multibody dynamics library** built with JAX, tailored for control and robotic learning applications.
@@ -17,41 +15,105 @@ Its design facilitates research and accelerates prototyping in the intersection
## Features
+- Reduced-coordinate physics engine for **fixed-base** and **floating-base** robots.
+- Multibody dynamics library for model-based control algorithms.
+- Fully Python-based, leveraging [jax][jax] following a functional programming paradigm.
+- Seamless execution on CPUs, GPUs, and TPUs.
+- Supports JIT compilation and automatic vectorization for high performance.
+- Compatible with SDF models and URDF (via [sdformat][sdformat] conversion).
+
+## Usage
+
+### Using JaxSim as simulator
+
+
+```python
+import jax
+import jax.numpy as jnp
+import jaxsim.api as js
+import icub_models
+import pathlib
+
+# Load the iCub model
+model_path = icub_models.get_model_file("iCubGazeboV2_5")
+joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
+ 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
+ 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
+ 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
+ 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
+ 'r_ankle_roll')
+ndof = len(joints)
+
+# Build and reduce the model
+model_description = pathlib.Path(model_path)
+full_model = js.model.JaxSimModel.build_from_model_description(
+ model_description=model_description, time_step=0.0001, is_urdf=True
+)
+model = js.model.reduce(model=full_model, considered_joints=joints)
+
+# Initialize data and simulation
+data = js.data.JaxSimModelData.zero(model=model).reset_base_position(
+ base_position=jnp.array([0.0, 0.0, 1.0])
+)
+T = jnp.arange(start=0, stop=1.0, step=model.time_step)
+tau = jnp.zeros(ndof)
+
+# Simulate
+for t in T:
+ data, _ = js.model.step(model=model, data=data, link_forces=None, joint_force_references=tau)
-- Physics engine in reduced coordinates supporting fixed-base and floating-base robots.
-- Multibody dynamics library providing all the necessary components for developing model-based control algorithms.
-- Completely developed in Python with [`google/jax`][jax] following a functional programming paradigm.
-- Transparent support for running on CPUs, GPUs, and TPUs.
-- Full support for JIT compilation for increased performance.
-- Full support for automatic vectorization for massive parallelization of open-loop and closed-loop architectures.
-- Support for SDF models and, upon conversion with [sdformat][sdformat], URDF models.
-- Visualization based on the [passive viewer][passive_viewer_mujoco] of Mujoco.
-
-### JaxSim as a simulator
-
-- Wide range of fixed-step explicit Runge-Kutta integrators.
-- Support for variable-step integrators implemented as embedded Runge-Kutta schemes.
-- Improved stability by optionally integrating the base orientation on the $\text{SO}(3)$ manifold.
-- Soft contacts model supporting full friction cone and sticking-slipping transition.
-- Collision detection between points rigidly attached to bodies and uneven ground surfaces.
-
-### JaxSim as a multibody dynamics library
+```
-- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
-- Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion.
-- Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation].
-- Exposes all the necessary quantities to develop controllers in centroidal coordinates.
-- Supports running open-loop and full closed-loop control architectures on hardware accelerators.
+### Using JaxSim as a multibody dynamics library
+``` python
+import jax.numpy as jnp
+import jaxsim.api as js
+import icub_models
+import pathlib
+
+# Load the iCub model
+model_path = icub_models.get_model_file("iCubGazeboV2_5")
+joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
+ 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
+ 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
+ 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
+ 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
+ 'r_ankle_roll')
+
+# Build and reduce the model
+model_description = pathlib.Path(model_path)
+full_model = js.model.JaxSimModel.build_from_model_description(
+ model_description=model_description, time_step=0.0001, is_urdf=True
+)
+model = js.model.reduce(model=full_model, considered_joints=joints)
+
+# Initialize model data
+data = js.data.JaxSimModelData.zero(model=model).reset_base_position(
+ base_position=jnp.array([0.0, 0.0, 1.0])
+)
+
+# Frame and dynamics computations
+frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot")
+W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index) # Frame transformation
+W_J_F = js.frame.jacobian(model=model, data=data, frame_index=frame_index) # Frame Jacobian
+
+# Dynamics properties
+M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix
+h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces
+g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces
+C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix
+
+# Print dynamics results
+print(f"M: shape={M.shape}, h: shape={h.shape}, g: shape={g.shape}, C: shape={C.shape}")
-### JaxSim for robot learning
+```
+### Additional features
-- Being developed with JAX, all the RBDAs support automatic differentiation both in forward and reverse modes.
+- Full support for automatic differentiation of RBDAs (forward and reverse modes) with JAX.
- Support for automatically differentiating against kinematics and dynamics parameters.
- All fixed-step integrators are forward and reverse differentiable.
- All variable-step integrators are forward differentiable.
-- Ideal for sampling synthetic data for reinforcement learning (RL).
-- Ideal for designing physics-informed neural networks (PINNs) with loss functions requiring model-based quantities.
-- Ideal for combining model-based control with learning-based components.
+- Check the example folder for additional usecase !
[jax]: https://github.com/google/jax/
[sdformat]: https://github.com/gazebosim/sdformat
@@ -65,12 +127,6 @@ Its design facilitates research and accelerates prototyping in the intersection
> JaxSim currently focuses on locomotion applications.
> Only contacts between bodies and smooth ground surfaces are supported.
-## Documentation
-
-The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs].
-
-[readthedocs]: https://jaxsim.readthedocs.io/
-
## Installation
@@ -142,6 +198,13 @@ pip install --no-deps -e .
[venv]: https://docs.python.org/3/tutorial/venv.html
[jax_gpu]: https://github.com/google/jax/#installation
+## Documentation
+
+The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs].
+
+[readthedocs]: https://jaxsim.readthedocs.io/
+
+
## Overview
diff --git a/environment.yml b/environment.yml
index 2603b0c82..6767660fa 100644
--- a/environment.yml
+++ b/environment.yml
@@ -29,6 +29,7 @@ dependencies:
- pytest
- pytest-icdiff
- robot_descriptions
+ - icub-models
# [viz]
- lxml
- mediapy
@@ -55,6 +56,7 @@ dependencies:
- sphinx-multiversion
- sphinx_rtd_theme
- sphinx-toolbox
+ - icub-models
# ========================================
# Other dependencies for GitHub Codespaces
# ========================================
diff --git a/examples/README.md b/examples/README.md
index 533158fba..63043fbbe 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -8,11 +8,13 @@ This folder contains Jupyter notebooks that demonstrate the practical usage of J
| :--- | :---: | :--- |
| [`jaxsim_as_multibody_dynamics_library`](./jaxsim_as_multibody_dynamics_library.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_multibody_dynamics] | An example demonstrating how to use JaxSim as a multibody dynamics library. |
| [`jaxsim_as_physics_engine.ipynb`](./jaxsim_as_physics_engine.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_physics_engine] | An example demonstrating how to simulate vectorized models in parallel. |
+| [`jaxsim_as_physics_engine_advanced.ipynb`](./jaxsim_as_physics_engine_advanced.ipynb) | [![Open In Colab][colab_badge]][jaxsim_as_physics_engine_advanced] | An example showcasing advanced JaxSim usage, such as customizing the integrator, contact model, and more. |
| [`jaxsim_for_robot_controllers.ipynb`](./jaxsim_for_robot_controllers.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_closed_loop] | A basic example showing how to simulate a PD controller with gravity compensation for a 2-DOF cart-pole. |
[colab_badge]: https://colab.research.google.com/assets/colab-badge.svg
[ipynb_jaxsim_closed_loop]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb
[ipynb_jaxsim_as_physics_engine]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb
+[jaxsim_as_physics_engine_advanced]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb
[ipynb_jaxsim_as_multibody_dynamics]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb
## How to run the examples
diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb
index b8ec18cfb..d63f3e307 100644
--- a/examples/jaxsim_as_physics_engine.ipynb
+++ b/examples/jaxsim_as_physics_engine.ipynb
@@ -8,9 +8,7 @@
"source": [
"# JaxSim as a hardware-accelerated parallel physics engine\n",
"\n",
- "JaxSim was originally developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n",
- "\n",
- "In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n",
+ "This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.\n",
"\n",
"\n",
" \n",
@@ -58,15 +56,13 @@
"# Notebook imports\n",
"# ================\n",
"\n",
- "import os\n",
- "\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jaxsim.api as js\n",
- "import jaxsim\n",
- "import rod\n",
"from jaxsim import logging\n",
- "from rod.builder.primitives import SphereBuilder\n",
+ "import pathlib\n",
+ "import tempfile\n",
+ "import urllib.request\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
"print(f\"Running on {jax.devices()}\")"
@@ -75,242 +71,135 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "QtCCUhdpdGFH"
+ "id": "NqjuZKvOaTG_"
},
"source": [
"## Prepare the simulation\n",
"\n",
- "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`ami-iit/rod`][rod] library, which processes these formats.\n",
- "\n",
- "The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n",
+ "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the [ergoCub][ergocub] model urdf.\n",
"\n",
"[sdformat]: http://sdformat.org/\n",
"[urdf]: http://wiki.ros.org/urdf/\n",
- "[rod]: https://github.com/ami-iit/rod"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "0emoMQhCaTG_"
- },
- "outputs": [],
- "source": [
- "# @title Create the model description of a sphere\n",
- "\n",
- "# Create a SDF model.\n",
- "# The builder takes care to compute the right inertia tensor for you.\n",
- "rod_sdf = rod.Sdf(\n",
- " version=\"1.7\",\n",
- " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n",
- " .build_model()\n",
- " .add_link()\n",
- " .add_inertial()\n",
- " .add_visual()\n",
- " .add_collision()\n",
- " .build(),\n",
- ")\n",
+ "[ergocub]: https://ergocub.eu/\n",
+ "[rod]: https://github.com/ami-iit/rod\n",
"\n",
- "# Rod allows to update the frames w.r.t. the poses are expressed.\n",
- "rod_sdf.model.switch_frame_convention(\n",
- " frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n",
- ")\n",
- "\n",
- "# Serialize the model to a SDF string.\n",
- "model_sdf_string = rod_sdf.serialize(pretty=True)\n",
- "print(model_sdf_string)\n",
- "\n",
- "# JaxSim currently only supports collisions between points attached to bodies\n",
- "# and a ground surface modeled as a heightmap sampled from a smooth function.\n",
- "# While this approach is universal as it applies to generic meshes, the number\n",
- "# of considered points greatly affects the performance. Spheres, by default,\n",
- "# are discretized with 250 points. It's too much for this simple example.\n",
- "# This number can be decreased with the following environment variable.\n",
- "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "NqjuZKvOaTG_"
- },
- "source": [
"### Create the model and its data\n",
- "\n",
- "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n",
+ " To define a simulation we need two main objects:\n",
"\n",
"- `model`: an object that defines the dynamics of the system.\n",
"- `data`: an object that contains the state of the system.\n",
- "- `integrator` *(Optional)*: an object that defines the integration method.\n",
- "- `integrator_state` *(Optional)*: an object that contains the state of the integrator.\n",
"\n",
- "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
- "In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "etQ577cFaTHA"
- },
- "outputs": [],
- "source": [
- "# Create the JaxSim model.\n",
- "# This is shared among all the parallel instances.\n",
- "model = js.model.JaxSimModel.build_from_model_description(\n",
- " model_description=model_sdf_string,\n",
- " time_step=0.001,\n",
- " integrator=jaxsim.integrators.fixed_step.Heun2,\n",
- ")\n",
"\n",
- "# Create the data of a single model.\n",
- "# We will create a vectorized instance later.\n",
- "data_single = js.data.JaxSimModelData.zero(model=model)"
+ "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
+ "To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the `model` object and how to change the contact model."
]
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "FJF-HoWaiK9J"
- },
+ "metadata": {},
"source": [
- "### Select the contact model\n",
- "\n",
- "JaxSim offers several contact models, with the default being the non-linear Hunt/Crossley soft contact model. This model supports stick/slip transitions and fully accounts for friction cones.\n",
- "\n",
- "While it is faster than other models, it requires careful parameter tuning and may need a small time step $\\Delta t$, unless a variable-step integrator is used.\n",
- "\n"
+ "### Create the model "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "VAaitHRKjnwc"
+ "id": "etQ577cFaTHA"
},
"outputs": [],
"source": [
- "import jaxsim\n",
- "\n",
- "# Operate on a copy of the model.\n",
- "# When validate=True, this context manager ensures that the PyTree structure\n",
- "# of the object is not altered. This is a nice feature of JaxSim to spot\n",
- "# earlier user logic that might trigger unwanted JIT recompilations.\n",
- "# In this case, we need to disable validation since PyTree structure might\n",
- "# change if you use a contact model different from the default.\n",
- "with model.editable(validate=False) as model:\n",
- "\n",
- " # The SoftContacts class can be replaced with a different contact model.\n",
- " model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()\n",
- "\n",
- "# JaxSim provides the following helper that estimates good contact\n",
- "# parameters. While they might not be optimal, usually are a good\n",
- "# starting point. Users are encouraged to fine-tune them.\n",
- "contacts_params = js.contact.estimate_good_contact_parameters(\n",
- " model=model,\n",
- " number_of_active_collidable_points_steady_state=4,\n",
- " max_penetration=0.001,\n",
+ "# Create the JaxSim model.\n",
+ "url = \"https://raw.githubusercontent.com/icub-tech-iit/ergocub-software/refs/heads/master/urdf/ergoCub/robots/ergoCubSN001/model.urdf\"\n",
+ "\n",
+ "# Retrieve the file\n",
+ "model_path = urllib.request.urlretrieve(url)\n",
+ "\n",
+ "model_description_path = pathlib.Path(model_path)\n",
+ "full_model = js.model.JaxSimModel.build_from_model_description(\n",
+ " model_description=model_description_path,\n",
+ " time_step=0.0001,\n",
+ " is_urdf=True\n",
")\n",
"\n",
- "# Print the contact parameters.\n",
- "# Note that these parameters are the nominal parameters shared among\n",
- "# all parallel instances. If needed, they can be overridden in the\n",
- "# vectorized data object that will be created later.\n",
- "print(contacts_params)\n",
+ "joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',\n",
+ " 'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',\n",
+ " 'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n",
+ " 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))\n",
"\n",
- "# Update the data object with the new contact model parameters.\n",
- "data_single = data_single.replace(contacts_params=contacts_params, validate=False)"
+ "model = js.model.reduce(\n",
+ " model=full_model,\n",
+ " considered_joints=joints_list\n",
+ ")\n"
]
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "6REY2bq3lc_k"
- },
+ "metadata": {},
"source": [
- "### Select the integrator\n",
+ "### Create the data object \n",
"\n",
- "JaxSim offers various integrators, ranging from basic ones like `ForwardEuler` to higher-order methods like `RungeKutta4`. You can explore the available integrators in the following modules:\n",
- "\n",
- "- `jaxsim.integrators.fixed_step`\n",
- "- `jaxsim.integrators.variable_step`\n",
- "\n",
- "The `*SO3` variants update the integration scheme by integrating more accurately the base orientation on the $\\text{SO}(3)$ manifold."
+ "The data object is never changed by reference. Anytime you call a method aimed at modifying data, like `reset_base_position`, a new data object will be returned with the updated attributes while the original data will not be changed."
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "id": "o86Teq5piVGj"
- },
+ "metadata": {},
"outputs": [],
"source": [
- "print(f\"Using integrator: {model.integrator}\")\n",
- "\n",
- "# Initialize the simulated time.\n",
- "T = jnp.arange(start=0, stop=1.0, step=model.time_step)"
+ "# Create the data of a single model.\n",
+ "data_zero = js.data.JaxSimModelData.zero(model=model)\n",
+ "base_position = jnp.array([0.0, 0.0, 1.0])\n",
+ "data = data_zero.reset_base_position(base_position=base_position) # Note that the reset position returns the updated data object"
]
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "V6IeD2B3m4F0"
- },
+ "metadata": {},
"source": [
- "## Sample a batch of trajectories in parallel\n",
- "\n",
- "With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n",
- "\n",
- "In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n",
- "\n",
- "Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions."
+ "### Simulation"
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "id": "vtEn0aIzr_2j"
- },
+ "metadata": {},
"outputs": [],
"source": [
- "# @title Generate batched initial data\n",
- "\n",
"# Create a random JAX key.\n",
+ "\n",
"key = jax.random.PRNGKey(seed=0)\n",
"\n",
- "# Split subkeys for sampling random initial data.\n",
- "batch_size = 10\n",
- "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n",
+ "# Initialize the simulated time.\n",
+ "T = jnp.arange(start=0, stop=0.3, step=model.time_step)\n",
"\n",
- "# Create the batched data by sampling the height from [0.5, 0.6] meters.\n",
- "data_batch_t0 = jax.vmap(\n",
- " lambda key: js.data.random_model_data(\n",
+ "# Simulate\n",
+ "for _t in T:\n",
+ " data, _ = js.model.step(\n",
" model=model,\n",
- " key=key,\n",
- " base_pos_bounds=([0, 0, 0.3], [0, 0, 0.6]),\n",
- " base_vel_lin_bounds=(0, 0),\n",
- " base_vel_ang_bounds=(0, 0),\n",
- " contacts_params=contacts_params,\n",
- " )\n",
- ")(jnp.vstack(subkeys))\n",
+ " data=data,\n",
+ " link_forces=None,\n",
+ " joint_force_references=None,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Vectorized simulation \n",
"\n",
- "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])"
+ "We will now vectorize the simulation on batched data using `jax.vmap`"
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "id": "0tQPfsl6uxHm"
- },
+ "metadata": {},
"outputs": [],
"source": [
- "# @title Create parallel step function\n",
+ "# first we have to vmap the function\n",
"\n",
"import functools\n",
"from typing import Any\n",
@@ -343,106 +232,21 @@
" )\n",
"\n",
"\n",
- "# The first run will be slow since JAX needs to JIT-compile the functions.\n",
- "_ = step_single(model, data_single)\n",
- "_ = step_parallel(model, data_batch_t0)\n",
+ "# Then we have to create the vector of initial state\n",
"\n",
- "# Benchmark the execution of a single step.\n",
- "print(\"\\nSingle simulation step:\")\n",
- "%timeit step_single(model, data_single)\n",
+ "# Split subkeys for sampling random initial data.\n",
+ "batch_size = 5\n",
+ "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n",
"\n",
- "# On hardware accelerators, there's a range of batch_size values where\n",
- "# increasing the number of parallel instances doesn't affect computation time.\n",
- "# This range depends on the GPU/TPU specifications.\n",
- "print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n",
- "%timeit step_parallel(model, data_batch_t0)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "VNwzT2JQ1n15"
- },
- "outputs": [],
- "source": [
- "# @title Run parallel simulation\n",
+ "# Create the batched data.\n",
+ "data_batch_t0 = jax.vmap(\n",
+ " lambda key: data_zero.reset_base_position(base_position=jnp.array([0.0, 0.0, 1.0]))\n",
+ ")(jnp.vstack(subkeys))\n",
"\n",
"data = data_batch_t0\n",
- "data_trajectory_list = []\n",
- "\n",
- "for _ in T:\n",
- "\n",
- " data, _ = step_parallel(model, data)\n",
- " data_trajectory_list.append(data)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Y6n720Cr3G44"
- },
- "source": [
- "## Visualize trajectory"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "BLPODyKr3Lyg"
- },
- "outputs": [],
- "source": [
- "# Convert a list of PyTrees to a batched PyTree.\n",
- "# This operation is called 'tree transpose' in JAX.\n",
- "data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n",
- "\n",
- "print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "-jxJXy5r3RMt"
- },
- "outputs": [],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "\n",
- "\n",
- "plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n",
- "plt.grid(True)\n",
- "plt.xlabel(\"Time [s]\")\n",
- "plt.ylabel(\"Height [m]\")\n",
- "plt.title(\"Height trajectory of the sphere\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "N92-WjPFGuua"
- },
- "source": [
- "# Conclusions\n",
- "\n",
- "This notebook introduced the key APIs of JaxSim as a hardware-accelerated parallel physics engine. Key takeaways:\n",
- "\n",
- "- **Contact models**: trajectories are sensitive to the contact model used. Explore the `jaxsim.rbda.contacts` package to find the best fit, as each model comes with trade-offs.\n",
- "- **Integrator selection**: the choice of integrator affects both accuracy and speed. Experiment with options in the `jaxsim.integrators` package to optimize for your application and hardware accelerator.\n",
- "- **Time step**: the interaction between contact models and integrators depends on the integration step $\\Delta t$. Choose the largest stable time step that guarantees for stable simulations.\n",
- "- **Automatic vectorization**: this notebook demonstrated one way to use `jax.vmap`, but there are many other approaches. As you become more familiar with JAX, you'll discover better methods tailored to your needs.\n",
- "- **Advanced applications**: Combine `jax.jit` and `jax.vmap` with `jax.grad`, `jax.jacfwd`, and `jax.jacrev` for gradient-based learning and other advanced tasks (not covered here).\n",
- "\n",
- "Have fun!"
+ "for _t in T:\n",
+ " data, _ = step_parallel(model, data)"
]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": []
}
],
"metadata": {
@@ -454,7 +258,7 @@
"toc_visible": true
},
"kernelspec": {
- "display_name": "rsl",
+ "display_name": "jaxsim",
"language": "python",
"name": "python3"
},
@@ -468,7 +272,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.7"
+ "version": "3.13.0"
}
},
"nbformat": 4,
diff --git a/examples/jaxsim_as_physics_engine_advanced.ipynb b/examples/jaxsim_as_physics_engine_advanced.ipynb
new file mode 100644
index 000000000..1d76c9817
--- /dev/null
+++ b/examples/jaxsim_as_physics_engine_advanced.ipynb
@@ -0,0 +1,476 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "H-WgcgGQaTG7"
+ },
+ "source": [
+ "# JaxSim as a hardware-accelerated parallel physics engine-advanced usage\n",
+ "\n",
+ "JaxSim is developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n",
+ "\n",
+ "In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n",
+ "\n",
+ "\n",
+ " \n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SgOSnrSscEkt"
+ },
+ "source": [
+ "## Prepare the environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fdqvAqMDaTG9"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Imports and setup\n",
+ "import sys\n",
+ "from IPython.display import clear_output\n",
+ "\n",
+ "IS_COLAB = \"google.colab\" in sys.modules\n",
+ "\n",
+ "# Install JAX and Gazebo\n",
+ "if IS_COLAB:\n",
+ " !{sys.executable} -m pip install --pre -qU jaxsim\n",
+ " !apt install -qq lsb-release wget gnupg\n",
+ " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
+ " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
+ " !apt -qq update\n",
+ " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
+ "\n",
+ " clear_output()\n",
+ "\n",
+ "# Set environment variable to avoid GPU out of memory errors\n",
+ "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
+ "\n",
+ "# ================\n",
+ "# Notebook imports\n",
+ "# ================\n",
+ "\n",
+ "import os\n",
+ "\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import jaxsim.api as js\n",
+ "import jaxsim\n",
+ "import rod\n",
+ "from jaxsim import logging\n",
+ "from rod.builder.primitives import SphereBuilder\n",
+ "\n",
+ "logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
+ "print(f\"Running on {jax.devices()}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QtCCUhdpdGFH"
+ },
+ "source": [
+ "## Prepare the simulation\n",
+ "\n",
+ "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`ami-iit/rod`][rod] library, which processes these formats.\n",
+ "\n",
+ "The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n",
+ "\n",
+ "[sdformat]: http://sdformat.org/\n",
+ "[urdf]: http://wiki.ros.org/urdf/\n",
+ "[rod]: https://github.com/ami-iit/rod"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "0emoMQhCaTG_"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Create the model description of a sphere\n",
+ "\n",
+ "# Create a SDF model.\n",
+ "# The builder takes care to compute the right inertia tensor for you.\n",
+ "rod_sdf = rod.Sdf(\n",
+ " version=\"1.7\",\n",
+ " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n",
+ " .build_model()\n",
+ " .add_link()\n",
+ " .add_inertial()\n",
+ " .add_visual()\n",
+ " .add_collision()\n",
+ " .build(),\n",
+ ")\n",
+ "\n",
+ "# Rod allows to update the frames w.r.t. the poses are expressed.\n",
+ "rod_sdf.model.switch_frame_convention(\n",
+ " frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n",
+ ")\n",
+ "\n",
+ "# Serialize the model to a SDF string.\n",
+ "model_sdf_string = rod_sdf.serialize(pretty=True)\n",
+ "print(model_sdf_string)\n",
+ "\n",
+ "# JaxSim currently only supports collisions between points attached to bodies\n",
+ "# and a ground surface modeled as a heightmap sampled from a smooth function.\n",
+ "# While this approach is universal as it applies to generic meshes, the number\n",
+ "# of considered points greatly affects the performance. Spheres, by default,\n",
+ "# are discretized with 250 points. It's too much for this simple example.\n",
+ "# This number can be decreased with the following environment variable.\n",
+ "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NqjuZKvOaTG_"
+ },
+ "source": [
+ "### Create the model and its data\n",
+ "\n",
+ "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n",
+ "\n",
+ "- `model`: an object that defines the dynamics of the system.\n",
+ "- `data`: an object that contains the state of the system.\n",
+ "- `integrator` *(Optional)*: an object that defines the integration method.\n",
+ "- `integrator_metadata` *(Optional)*: an object that contains the state of the integrator.\n",
+ "\n",
+ "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
+ "In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "etQ577cFaTHA"
+ },
+ "outputs": [],
+ "source": [
+ "# Create the JaxSim model.\n",
+ "# This is shared among all the parallel instances.\n",
+ "model = js.model.JaxSimModel.build_from_model_description(\n",
+ " model_description=model_sdf_string,\n",
+ " time_step=0.001,\n",
+ " integrator=jaxsim.integrators.fixed_step.Heun2,\n",
+ ")\n",
+ "\n",
+ "# Create the data of a single model.\n",
+ "# We will create a vectorized instance later.\n",
+ "data_single = js.data.JaxSimModelData.zero(model=model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FJF-HoWaiK9J"
+ },
+ "source": [
+ "### Select the contact model\n",
+ "\n",
+ "JaxSim offers several contact models, with the default being the non-linear Hunt/Crossley soft contact model. This model supports stick/slip transitions and fully accounts for friction cones.\n",
+ "\n",
+ "While it is faster than other models, it requires careful parameter tuning and may need a small time step $\\Delta t$, unless a variable-step integrator is used.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VAaitHRKjnwc"
+ },
+ "outputs": [],
+ "source": [
+ "import jaxsim\n",
+ "\n",
+ "# Operate on a copy of the model.\n",
+ "# When validate=True, this context manager ensures that the PyTree structure\n",
+ "# of the object is not altered. This is a nice feature of JaxSim to spot\n",
+ "# earlier user logic that might trigger unwanted JIT recompilations.\n",
+ "# In this case, we need to disable validation since PyTree structure might\n",
+ "# change if you use a contact model different from the default.\n",
+ "with model.editable(validate=False) as model:\n",
+ "\n",
+ " # The SoftContacts class can be replaced with a different contact model.\n",
+ " model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()\n",
+ "\n",
+ "# JaxSim provides the following helper that estimates good contact\n",
+ "# parameters. While they might not be optimal, usually are a good\n",
+ "# starting point. Users are encouraged to fine-tune them.\n",
+ "contacts_params = js.contact.estimate_good_contact_parameters(\n",
+ " model=model,\n",
+ " number_of_active_collidable_points_steady_state=4,\n",
+ " max_penetration=0.001,\n",
+ ")\n",
+ "\n",
+ "# Print the contact parameters.\n",
+ "# Note that these parameters are the nominal parameters shared among\n",
+ "# all parallel instances. If needed, they can be overridden in the\n",
+ "# vectorized data object that will be created later.\n",
+ "print(contacts_params)\n",
+ "\n",
+ "# Update the data object with the new contact model parameters.\n",
+ "data_single = data_single.replace(contacts_params=contacts_params, validate=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6REY2bq3lc_k"
+ },
+ "source": [
+ "### Select the integrator\n",
+ "\n",
+ "JaxSim offers various integrators, ranging from basic ones like `ForwardEuler` to higher-order methods like `RungeKutta4`. You can explore the available integrators in the following modules:\n",
+ "\n",
+ "- `jaxsim.integrators.fixed_step`\n",
+ "- `jaxsim.integrators.variable_step`\n",
+ "\n",
+ "The `*SO3` variants update the integration scheme by integrating more accurately the base orientation on the $\\text{SO}(3)$ manifold."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "o86Teq5piVGj"
+ },
+ "outputs": [],
+ "source": [
+ "print(f\"Using integrator: {model.integrator}\")\n",
+ "\n",
+ "# Initialize the simulated time.\n",
+ "T = jnp.arange(start=0, stop=1.0, step=model.time_step)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "V6IeD2B3m4F0"
+ },
+ "source": [
+ "## Sample a batch of trajectories in parallel\n",
+ "\n",
+ "With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n",
+ "\n",
+ "In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n",
+ "\n",
+ "Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "vtEn0aIzr_2j"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Generate batched initial data\n",
+ "\n",
+ "# Create a random JAX key.\n",
+ "key = jax.random.PRNGKey(seed=0)\n",
+ "\n",
+ "# Split subkeys for sampling random initial data.\n",
+ "batch_size = 10\n",
+ "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n",
+ "\n",
+ "# Create the batched data by sampling the height from [0.5, 0.6] meters.\n",
+ "data_batch_t0 = jax.vmap(\n",
+ " lambda key: js.data.random_model_data(\n",
+ " model=model,\n",
+ " key=key,\n",
+ " base_pos_bounds=([0, 0, 0.3], [0, 0, 0.6]),\n",
+ " base_vel_lin_bounds=(0, 0),\n",
+ " base_vel_ang_bounds=(0, 0),\n",
+ " contacts_params=contacts_params,\n",
+ " )\n",
+ ")(jnp.vstack(subkeys))\n",
+ "\n",
+ "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0tQPfsl6uxHm"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Create parallel step function\n",
+ "\n",
+ "import functools\n",
+ "from typing import Any\n",
+ "\n",
+ "\n",
+ "@jax.jit\n",
+ "def step_single(\n",
+ " model: js.model.JaxSimModel,\n",
+ " data: js.data.JaxSimModelData,\n",
+ ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
+ "\n",
+ " # Close step over static arguments.\n",
+ " return js.model.step(\n",
+ " model=model,\n",
+ " data=data,\n",
+ " link_forces=None,\n",
+ " joint_force_references=None,\n",
+ " )\n",
+ "\n",
+ "\n",
+ "@jax.jit\n",
+ "@functools.partial(jax.vmap, in_axes=(None, 0))\n",
+ "def step_parallel(\n",
+ " model: js.model.JaxSimModel,\n",
+ " data: js.data.JaxSimModelData,\n",
+ ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
+ "\n",
+ " return step_single(\n",
+ " model=model, data=data\n",
+ " )\n",
+ "\n",
+ "\n",
+ "# The first run will be slow since JAX needs to JIT-compile the functions.\n",
+ "_ = step_single(model, data_single)\n",
+ "_ = step_parallel(model, data_batch_t0)\n",
+ "\n",
+ "# Benchmark the execution of a single step.\n",
+ "print(\"\\nSingle simulation step:\")\n",
+ "%timeit step_single(model, data_single)\n",
+ "\n",
+ "# On hardware accelerators, there's a range of batch_size values where\n",
+ "# increasing the number of parallel instances doesn't affect computation time.\n",
+ "# This range depends on the GPU/TPU specifications.\n",
+ "print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n",
+ "%timeit step_parallel(model, data_batch_t0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VNwzT2JQ1n15"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Run parallel simulation\n",
+ "\n",
+ "data = data_batch_t0\n",
+ "data_trajectory_list = []\n",
+ "\n",
+ "for _ in T:\n",
+ "\n",
+ " data, _ = step_parallel(model, data)\n",
+ " data_trajectory_list.append(data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Y6n720Cr3G44"
+ },
+ "source": [
+ "## Visualize trajectory"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BLPODyKr3Lyg"
+ },
+ "outputs": [],
+ "source": [
+ "# Convert a list of PyTrees to a batched PyTree.\n",
+ "# This operation is called 'tree transpose' in JAX.\n",
+ "data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n",
+ "\n",
+ "print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-jxJXy5r3RMt"
+ },
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "\n",
+ "plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n",
+ "plt.grid(True)\n",
+ "plt.xlabel(\"Time [s]\")\n",
+ "plt.ylabel(\"Height [m]\")\n",
+ "plt.title(\"Height trajectory of the sphere\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "N92-WjPFGuua"
+ },
+ "source": [
+ "# Conclusions\n",
+ "\n",
+ "This notebook introduced the key APIs of JaxSim as a hardware-accelerated parallel physics engine. Key takeaways:\n",
+ "\n",
+ "- **Contact models**: trajectories are sensitive to the contact model used. Explore the `jaxsim.rbda.contacts` package to find the best fit, as each model comes with trade-offs.\n",
+ "- **Integrator selection**: the choice of integrator affects both accuracy and speed. Experiment with options in the `jaxsim.integrators` package to optimize for your application and hardware accelerator.\n",
+ "- **Time step**: the interaction between contact models and integrators depends on the integration step $\\Delta t$. Choose the largest stable time step that guarantees for stable simulations.\n",
+ "- **Automatic vectorization**: this notebook demonstrated one way to use `jax.vmap`, but there are many other approaches. As you become more familiar with JAX, you'll discover better methods tailored to your needs.\n",
+ "- **Advanced applications**: Combine `jax.jit` and `jax.vmap` with `jax.grad`, `jax.jacfwd`, and `jax.jacrev` for gradient-based learning and other advanced tasks (not covered here).\n",
+ "\n",
+ "Have fun!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuClass": "premium",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "rsl",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/pyproject.toml b/pyproject.toml
index 5fc968178..ad27e2508 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,7 +69,8 @@ testing = [
"pytest >=6.0",
"pytest-benchmark",
"pytest-icdiff",
- "robot-descriptions"
+ "robot-descriptions",
+ "icub-models",
]
viz = [
"lxml",