diff --git a/README.md b/README.md index ef4755a02..7f0cb0fd3 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ # JaxSim -JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX. - -Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence. +**JaxSim** is a **differentiable physics engine** and **multibody dynamics library** built with JAX, tailored for control and robotic learning applications.

@@ -17,41 +15,105 @@ Its design facilitates research and accelerates prototyping in the intersection
## Features +- Reduced-coordinate physics engine for **fixed-base** and **floating-base** robots. +- Multibody dynamics library for model-based control algorithms. +- Fully Python-based, leveraging [jax][jax] following a functional programming paradigm. +- Seamless execution on CPUs, GPUs, and TPUs. +- Supports JIT compilation and automatic vectorization for high performance. +- Compatible with SDF models and URDF (via [sdformat][sdformat] conversion). + +## Usage + +### Using JaxSim as simulator + + +```python +import jax +import jax.numpy as jnp +import jaxsim.api as js +import icub_models +import pathlib + +# Load the iCub model +model_path = icub_models.get_model_file("iCubGazeboV2_5") +joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch', + 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch', + 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', + 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', + 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', + 'r_ankle_roll') +ndof = len(joints) + +# Build and reduce the model +model_description = pathlib.Path(model_path) +full_model = js.model.JaxSimModel.build_from_model_description( + model_description=model_description, time_step=0.0001, is_urdf=True +) +model = js.model.reduce(model=full_model, considered_joints=joints) + +# Initialize data and simulation +data = js.data.JaxSimModelData.zero(model=model).reset_base_position( + base_position=jnp.array([0.0, 0.0, 1.0]) +) +T = jnp.arange(start=0, stop=1.0, step=model.time_step) +tau = jnp.zeros(ndof) + +# Simulate +for t in T: + data, _ = js.model.step(model=model, data=data, link_forces=None, joint_force_references=tau) -- Physics engine in reduced coordinates supporting fixed-base and floating-base robots. -- Multibody dynamics library providing all the necessary components for developing model-based control algorithms. -- Completely developed in Python with [`google/jax`][jax] following a functional programming paradigm. -- Transparent support for running on CPUs, GPUs, and TPUs. -- Full support for JIT compilation for increased performance. -- Full support for automatic vectorization for massive parallelization of open-loop and closed-loop architectures. -- Support for SDF models and, upon conversion with [sdformat][sdformat], URDF models. -- Visualization based on the [passive viewer][passive_viewer_mujoco] of Mujoco. - -### JaxSim as a simulator - -- Wide range of fixed-step explicit Runge-Kutta integrators. -- Support for variable-step integrators implemented as embedded Runge-Kutta schemes. -- Improved stability by optionally integrating the base orientation on the $\text{SO}(3)$ manifold. -- Soft contacts model supporting full friction cone and sticking-slipping transition. -- Collision detection between points rigidly attached to bodies and uneven ground surfaces. - -### JaxSim as a multibody dynamics library +``` -- Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians. -- Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion. -- Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation]. -- Exposes all the necessary quantities to develop controllers in centroidal coordinates. -- Supports running open-loop and full closed-loop control architectures on hardware accelerators. +### Using JaxSim as a multibody dynamics library +``` python +import jax.numpy as jnp +import jaxsim.api as js +import icub_models +import pathlib + +# Load the iCub model +model_path = icub_models.get_model_file("iCubGazeboV2_5") +joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch', + 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch', + 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', + 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', + 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', + 'r_ankle_roll') + +# Build and reduce the model +model_description = pathlib.Path(model_path) +full_model = js.model.JaxSimModel.build_from_model_description( + model_description=model_description, time_step=0.0001, is_urdf=True +) +model = js.model.reduce(model=full_model, considered_joints=joints) + +# Initialize model data +data = js.data.JaxSimModelData.zero(model=model).reset_base_position( + base_position=jnp.array([0.0, 0.0, 1.0]) +) + +# Frame and dynamics computations +frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot") +W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index) # Frame transformation +W_J_F = js.frame.jacobian(model=model, data=data, frame_index=frame_index) # Frame Jacobian + +# Dynamics properties +M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix +h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces +g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces +C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix + +# Print dynamics results +print(f"M: shape={M.shape}, h: shape={h.shape}, g: shape={g.shape}, C: shape={C.shape}") -### JaxSim for robot learning +``` +### Additional features -- Being developed with JAX, all the RBDAs support automatic differentiation both in forward and reverse modes. +- Full support for automatic differentiation of RBDAs (forward and reverse modes) with JAX. - Support for automatically differentiating against kinematics and dynamics parameters. - All fixed-step integrators are forward and reverse differentiable. - All variable-step integrators are forward differentiable. -- Ideal for sampling synthetic data for reinforcement learning (RL). -- Ideal for designing physics-informed neural networks (PINNs) with loss functions requiring model-based quantities. -- Ideal for combining model-based control with learning-based components. +- Check the example folder for additional usecase ! [jax]: https://github.com/google/jax/ [sdformat]: https://github.com/gazebosim/sdformat @@ -65,12 +127,6 @@ Its design facilitates research and accelerates prototyping in the intersection > JaxSim currently focuses on locomotion applications. > Only contacts between bodies and smooth ground surfaces are supported. -## Documentation - -The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs]. - -[readthedocs]: https://jaxsim.readthedocs.io/ - ## Installation
@@ -142,6 +198,13 @@ pip install --no-deps -e . [venv]: https://docs.python.org/3/tutorial/venv.html [jax_gpu]: https://github.com/google/jax/#installation +## Documentation + +The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs]. + +[readthedocs]: https://jaxsim.readthedocs.io/ + + ## Overview
diff --git a/environment.yml b/environment.yml index 2603b0c82..6767660fa 100644 --- a/environment.yml +++ b/environment.yml @@ -29,6 +29,7 @@ dependencies: - pytest - pytest-icdiff - robot_descriptions + - icub-models # [viz] - lxml - mediapy @@ -55,6 +56,7 @@ dependencies: - sphinx-multiversion - sphinx_rtd_theme - sphinx-toolbox + - icub-models # ======================================== # Other dependencies for GitHub Codespaces # ======================================== diff --git a/examples/README.md b/examples/README.md index 533158fba..63043fbbe 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,11 +8,13 @@ This folder contains Jupyter notebooks that demonstrate the practical usage of J | :--- | :---: | :--- | | [`jaxsim_as_multibody_dynamics_library`](./jaxsim_as_multibody_dynamics_library.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_multibody_dynamics] | An example demonstrating how to use JaxSim as a multibody dynamics library. | | [`jaxsim_as_physics_engine.ipynb`](./jaxsim_as_physics_engine.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_physics_engine] | An example demonstrating how to simulate vectorized models in parallel. | +| [`jaxsim_as_physics_engine_advanced.ipynb`](./jaxsim_as_physics_engine_advanced.ipynb) | [![Open In Colab][colab_badge]][jaxsim_as_physics_engine_advanced] | An example showcasing advanced JaxSim usage, such as customizing the integrator, contact model, and more. | | [`jaxsim_for_robot_controllers.ipynb`](./jaxsim_for_robot_controllers.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_closed_loop] | A basic example showing how to simulate a PD controller with gravity compensation for a 2-DOF cart-pole. | [colab_badge]: https://colab.research.google.com/assets/colab-badge.svg [ipynb_jaxsim_closed_loop]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb [ipynb_jaxsim_as_physics_engine]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb +[jaxsim_as_physics_engine_advanced]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb [ipynb_jaxsim_as_multibody_dynamics]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb ## How to run the examples diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index b8ec18cfb..d63f3e307 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -8,9 +8,7 @@ "source": [ "# JaxSim as a hardware-accelerated parallel physics engine\n", "\n", - "JaxSim was originally developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n", - "\n", - "In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n", + "This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.\n", "\n", "\n", " \"Open\n", @@ -58,15 +56,13 @@ "# Notebook imports\n", "# ================\n", "\n", - "import os\n", - "\n", "import jax\n", "import jax.numpy as jnp\n", "import jaxsim.api as js\n", - "import jaxsim\n", - "import rod\n", "from jaxsim import logging\n", - "from rod.builder.primitives import SphereBuilder\n", + "import pathlib\n", + "import tempfile\n", + "import urllib.request\n", "\n", "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", "print(f\"Running on {jax.devices()}\")" @@ -75,242 +71,135 @@ { "cell_type": "markdown", "metadata": { - "id": "QtCCUhdpdGFH" + "id": "NqjuZKvOaTG_" }, "source": [ "## Prepare the simulation\n", "\n", - "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`ami-iit/rod`][rod] library, which processes these formats.\n", - "\n", - "The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n", + "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the [ergoCub][ergocub] model urdf.\n", "\n", "[sdformat]: http://sdformat.org/\n", "[urdf]: http://wiki.ros.org/urdf/\n", - "[rod]: https://github.com/ami-iit/rod" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "0emoMQhCaTG_" - }, - "outputs": [], - "source": [ - "# @title Create the model description of a sphere\n", - "\n", - "# Create a SDF model.\n", - "# The builder takes care to compute the right inertia tensor for you.\n", - "rod_sdf = rod.Sdf(\n", - " version=\"1.7\",\n", - " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", - " .build_model()\n", - " .add_link()\n", - " .add_inertial()\n", - " .add_visual()\n", - " .add_collision()\n", - " .build(),\n", - ")\n", + "[ergocub]: https://ergocub.eu/\n", + "[rod]: https://github.com/ami-iit/rod\n", "\n", - "# Rod allows to update the frames w.r.t. the poses are expressed.\n", - "rod_sdf.model.switch_frame_convention(\n", - " frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n", - ")\n", - "\n", - "# Serialize the model to a SDF string.\n", - "model_sdf_string = rod_sdf.serialize(pretty=True)\n", - "print(model_sdf_string)\n", - "\n", - "# JaxSim currently only supports collisions between points attached to bodies\n", - "# and a ground surface modeled as a heightmap sampled from a smooth function.\n", - "# While this approach is universal as it applies to generic meshes, the number\n", - "# of considered points greatly affects the performance. Spheres, by default,\n", - "# are discretized with 250 points. It's too much for this simple example.\n", - "# This number can be decreased with the following environment variable.\n", - "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NqjuZKvOaTG_" - }, - "source": [ "### Create the model and its data\n", - "\n", - "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", + " To define a simulation we need two main objects:\n", "\n", "- `model`: an object that defines the dynamics of the system.\n", "- `data`: an object that contains the state of the system.\n", - "- `integrator` *(Optional)*: an object that defines the integration method.\n", - "- `integrator_state` *(Optional)*: an object that contains the state of the integrator.\n", "\n", - "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n", - "In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "etQ577cFaTHA" - }, - "outputs": [], - "source": [ - "# Create the JaxSim model.\n", - "# This is shared among all the parallel instances.\n", - "model = js.model.JaxSimModel.build_from_model_description(\n", - " model_description=model_sdf_string,\n", - " time_step=0.001,\n", - " integrator=jaxsim.integrators.fixed_step.Heun2,\n", - ")\n", "\n", - "# Create the data of a single model.\n", - "# We will create a vectorized instance later.\n", - "data_single = js.data.JaxSimModelData.zero(model=model)" + "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n", + "To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the `model` object and how to change the contact model." ] }, { "cell_type": "markdown", - "metadata": { - "id": "FJF-HoWaiK9J" - }, + "metadata": {}, "source": [ - "### Select the contact model\n", - "\n", - "JaxSim offers several contact models, with the default being the non-linear Hunt/Crossley soft contact model. This model supports stick/slip transitions and fully accounts for friction cones.\n", - "\n", - "While it is faster than other models, it requires careful parameter tuning and may need a small time step $\\Delta t$, unless a variable-step integrator is used.\n", - "\n" + "### Create the model " ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "id": "VAaitHRKjnwc" + "id": "etQ577cFaTHA" }, "outputs": [], "source": [ - "import jaxsim\n", - "\n", - "# Operate on a copy of the model.\n", - "# When validate=True, this context manager ensures that the PyTree structure\n", - "# of the object is not altered. This is a nice feature of JaxSim to spot\n", - "# earlier user logic that might trigger unwanted JIT recompilations.\n", - "# In this case, we need to disable validation since PyTree structure might\n", - "# change if you use a contact model different from the default.\n", - "with model.editable(validate=False) as model:\n", - "\n", - " # The SoftContacts class can be replaced with a different contact model.\n", - " model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()\n", - "\n", - "# JaxSim provides the following helper that estimates good contact\n", - "# parameters. While they might not be optimal, usually are a good\n", - "# starting point. Users are encouraged to fine-tune them.\n", - "contacts_params = js.contact.estimate_good_contact_parameters(\n", - " model=model,\n", - " number_of_active_collidable_points_steady_state=4,\n", - " max_penetration=0.001,\n", + "# Create the JaxSim model.\n", + "url = \"https://raw.githubusercontent.com/icub-tech-iit/ergocub-software/refs/heads/master/urdf/ergoCub/robots/ergoCubSN001/model.urdf\"\n", + "\n", + "# Retrieve the file\n", + "model_path = urllib.request.urlretrieve(url)\n", + "\n", + "model_description_path = pathlib.Path(model_path)\n", + "full_model = js.model.JaxSimModel.build_from_model_description(\n", + " model_description=model_description_path,\n", + " time_step=0.0001,\n", + " is_urdf=True\n", ")\n", "\n", - "# Print the contact parameters.\n", - "# Note that these parameters are the nominal parameters shared among\n", - "# all parallel instances. If needed, they can be overridden in the\n", - "# vectorized data object that will be created later.\n", - "print(contacts_params)\n", + "joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',\n", + " 'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',\n", + " 'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n", + " 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))\n", "\n", - "# Update the data object with the new contact model parameters.\n", - "data_single = data_single.replace(contacts_params=contacts_params, validate=False)" + "model = js.model.reduce(\n", + " model=full_model,\n", + " considered_joints=joints_list\n", + ")\n" ] }, { "cell_type": "markdown", - "metadata": { - "id": "6REY2bq3lc_k" - }, + "metadata": {}, "source": [ - "### Select the integrator\n", + "### Create the data object \n", "\n", - "JaxSim offers various integrators, ranging from basic ones like `ForwardEuler` to higher-order methods like `RungeKutta4`. You can explore the available integrators in the following modules:\n", - "\n", - "- `jaxsim.integrators.fixed_step`\n", - "- `jaxsim.integrators.variable_step`\n", - "\n", - "The `*SO3` variants update the integration scheme by integrating more accurately the base orientation on the $\\text{SO}(3)$ manifold." + "The data object is never changed by reference. Anytime you call a method aimed at modifying data, like `reset_base_position`, a new data object will be returned with the updated attributes while the original data will not be changed." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "o86Teq5piVGj" - }, + "metadata": {}, "outputs": [], "source": [ - "print(f\"Using integrator: {model.integrator}\")\n", - "\n", - "# Initialize the simulated time.\n", - "T = jnp.arange(start=0, stop=1.0, step=model.time_step)" + "# Create the data of a single model.\n", + "data_zero = js.data.JaxSimModelData.zero(model=model)\n", + "base_position = jnp.array([0.0, 0.0, 1.0])\n", + "data = data_zero.reset_base_position(base_position=base_position) # Note that the reset position returns the updated data object" ] }, { "cell_type": "markdown", - "metadata": { - "id": "V6IeD2B3m4F0" - }, + "metadata": {}, "source": [ - "## Sample a batch of trajectories in parallel\n", - "\n", - "With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n", - "\n", - "In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n", - "\n", - "Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions." + "### Simulation" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "vtEn0aIzr_2j" - }, + "metadata": {}, "outputs": [], "source": [ - "# @title Generate batched initial data\n", - "\n", "# Create a random JAX key.\n", + "\n", "key = jax.random.PRNGKey(seed=0)\n", "\n", - "# Split subkeys for sampling random initial data.\n", - "batch_size = 10\n", - "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n", + "# Initialize the simulated time.\n", + "T = jnp.arange(start=0, stop=0.3, step=model.time_step)\n", "\n", - "# Create the batched data by sampling the height from [0.5, 0.6] meters.\n", - "data_batch_t0 = jax.vmap(\n", - " lambda key: js.data.random_model_data(\n", + "# Simulate\n", + "for _t in T:\n", + " data, _ = js.model.step(\n", " model=model,\n", - " key=key,\n", - " base_pos_bounds=([0, 0, 0.3], [0, 0, 0.6]),\n", - " base_vel_lin_bounds=(0, 0),\n", - " base_vel_ang_bounds=(0, 0),\n", - " contacts_params=contacts_params,\n", - " )\n", - ")(jnp.vstack(subkeys))\n", + " data=data,\n", + " link_forces=None,\n", + " joint_force_references=None,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Vectorized simulation \n", "\n", - "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])" + "We will now vectorize the simulation on batched data using `jax.vmap`" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "0tQPfsl6uxHm" - }, + "metadata": {}, "outputs": [], "source": [ - "# @title Create parallel step function\n", + "# first we have to vmap the function\n", "\n", "import functools\n", "from typing import Any\n", @@ -343,106 +232,21 @@ " )\n", "\n", "\n", - "# The first run will be slow since JAX needs to JIT-compile the functions.\n", - "_ = step_single(model, data_single)\n", - "_ = step_parallel(model, data_batch_t0)\n", + "# Then we have to create the vector of initial state\n", "\n", - "# Benchmark the execution of a single step.\n", - "print(\"\\nSingle simulation step:\")\n", - "%timeit step_single(model, data_single)\n", + "# Split subkeys for sampling random initial data.\n", + "batch_size = 5\n", + "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n", "\n", - "# On hardware accelerators, there's a range of batch_size values where\n", - "# increasing the number of parallel instances doesn't affect computation time.\n", - "# This range depends on the GPU/TPU specifications.\n", - "print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n", - "%timeit step_parallel(model, data_batch_t0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VNwzT2JQ1n15" - }, - "outputs": [], - "source": [ - "# @title Run parallel simulation\n", + "# Create the batched data.\n", + "data_batch_t0 = jax.vmap(\n", + " lambda key: data_zero.reset_base_position(base_position=jnp.array([0.0, 0.0, 1.0]))\n", + ")(jnp.vstack(subkeys))\n", "\n", "data = data_batch_t0\n", - "data_trajectory_list = []\n", - "\n", - "for _ in T:\n", - "\n", - " data, _ = step_parallel(model, data)\n", - " data_trajectory_list.append(data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Y6n720Cr3G44" - }, - "source": [ - "## Visualize trajectory" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BLPODyKr3Lyg" - }, - "outputs": [], - "source": [ - "# Convert a list of PyTrees to a batched PyTree.\n", - "# This operation is called 'tree transpose' in JAX.\n", - "data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n", - "\n", - "print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-jxJXy5r3RMt" - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n", - "plt.grid(True)\n", - "plt.xlabel(\"Time [s]\")\n", - "plt.ylabel(\"Height [m]\")\n", - "plt.title(\"Height trajectory of the sphere\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N92-WjPFGuua" - }, - "source": [ - "# Conclusions\n", - "\n", - "This notebook introduced the key APIs of JaxSim as a hardware-accelerated parallel physics engine. Key takeaways:\n", - "\n", - "- **Contact models**: trajectories are sensitive to the contact model used. Explore the `jaxsim.rbda.contacts` package to find the best fit, as each model comes with trade-offs.\n", - "- **Integrator selection**: the choice of integrator affects both accuracy and speed. Experiment with options in the `jaxsim.integrators` package to optimize for your application and hardware accelerator.\n", - "- **Time step**: the interaction between contact models and integrators depends on the integration step $\\Delta t$. Choose the largest stable time step that guarantees for stable simulations.\n", - "- **Automatic vectorization**: this notebook demonstrated one way to use `jax.vmap`, but there are many other approaches. As you become more familiar with JAX, you'll discover better methods tailored to your needs.\n", - "- **Advanced applications**: Combine `jax.jit` and `jax.vmap` with `jax.grad`, `jax.jacfwd`, and `jax.jacrev` for gradient-based learning and other advanced tasks (not covered here).\n", - "\n", - "Have fun!" + "for _t in T:\n", + " data, _ = step_parallel(model, data)" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { @@ -454,7 +258,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "rsl", + "display_name": "jaxsim", "language": "python", "name": "python3" }, @@ -468,7 +272,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.13.0" } }, "nbformat": 4, diff --git a/examples/jaxsim_as_physics_engine_advanced.ipynb b/examples/jaxsim_as_physics_engine_advanced.ipynb new file mode 100644 index 000000000..1d76c9817 --- /dev/null +++ b/examples/jaxsim_as_physics_engine_advanced.ipynb @@ -0,0 +1,476 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "H-WgcgGQaTG7" + }, + "source": [ + "# JaxSim as a hardware-accelerated parallel physics engine-advanced usage\n", + "\n", + "JaxSim is developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n", + "\n", + "In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n", + "\n", + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SgOSnrSscEkt" + }, + "source": [ + "## Prepare the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fdqvAqMDaTG9" + }, + "outputs": [], + "source": [ + "# @title Imports and setup\n", + "import sys\n", + "from IPython.display import clear_output\n", + "\n", + "IS_COLAB = \"google.colab\" in sys.modules\n", + "\n", + "# Install JAX and Gazebo\n", + "if IS_COLAB:\n", + " !{sys.executable} -m pip install --pre -qU jaxsim\n", + " !apt install -qq lsb-release wget gnupg\n", + " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", + " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", + " !apt -qq update\n", + " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", + "\n", + " clear_output()\n", + "\n", + "# Set environment variable to avoid GPU out of memory errors\n", + "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n", + "\n", + "# ================\n", + "# Notebook imports\n", + "# ================\n", + "\n", + "import os\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jaxsim.api as js\n", + "import jaxsim\n", + "import rod\n", + "from jaxsim import logging\n", + "from rod.builder.primitives import SphereBuilder\n", + "\n", + "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", + "print(f\"Running on {jax.devices()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QtCCUhdpdGFH" + }, + "source": [ + "## Prepare the simulation\n", + "\n", + "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`ami-iit/rod`][rod] library, which processes these formats.\n", + "\n", + "The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n", + "\n", + "[sdformat]: http://sdformat.org/\n", + "[urdf]: http://wiki.ros.org/urdf/\n", + "[rod]: https://github.com/ami-iit/rod" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "0emoMQhCaTG_" + }, + "outputs": [], + "source": [ + "# @title Create the model description of a sphere\n", + "\n", + "# Create a SDF model.\n", + "# The builder takes care to compute the right inertia tensor for you.\n", + "rod_sdf = rod.Sdf(\n", + " version=\"1.7\",\n", + " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", + " .build_model()\n", + " .add_link()\n", + " .add_inertial()\n", + " .add_visual()\n", + " .add_collision()\n", + " .build(),\n", + ")\n", + "\n", + "# Rod allows to update the frames w.r.t. the poses are expressed.\n", + "rod_sdf.model.switch_frame_convention(\n", + " frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n", + ")\n", + "\n", + "# Serialize the model to a SDF string.\n", + "model_sdf_string = rod_sdf.serialize(pretty=True)\n", + "print(model_sdf_string)\n", + "\n", + "# JaxSim currently only supports collisions between points attached to bodies\n", + "# and a ground surface modeled as a heightmap sampled from a smooth function.\n", + "# While this approach is universal as it applies to generic meshes, the number\n", + "# of considered points greatly affects the performance. Spheres, by default,\n", + "# are discretized with 250 points. It's too much for this simple example.\n", + "# This number can be decreased with the following environment variable.\n", + "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NqjuZKvOaTG_" + }, + "source": [ + "### Create the model and its data\n", + "\n", + "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", + "\n", + "- `model`: an object that defines the dynamics of the system.\n", + "- `data`: an object that contains the state of the system.\n", + "- `integrator` *(Optional)*: an object that defines the integration method.\n", + "- `integrator_metadata` *(Optional)*: an object that contains the state of the integrator.\n", + "\n", + "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n", + "In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "etQ577cFaTHA" + }, + "outputs": [], + "source": [ + "# Create the JaxSim model.\n", + "# This is shared among all the parallel instances.\n", + "model = js.model.JaxSimModel.build_from_model_description(\n", + " model_description=model_sdf_string,\n", + " time_step=0.001,\n", + " integrator=jaxsim.integrators.fixed_step.Heun2,\n", + ")\n", + "\n", + "# Create the data of a single model.\n", + "# We will create a vectorized instance later.\n", + "data_single = js.data.JaxSimModelData.zero(model=model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FJF-HoWaiK9J" + }, + "source": [ + "### Select the contact model\n", + "\n", + "JaxSim offers several contact models, with the default being the non-linear Hunt/Crossley soft contact model. This model supports stick/slip transitions and fully accounts for friction cones.\n", + "\n", + "While it is faster than other models, it requires careful parameter tuning and may need a small time step $\\Delta t$, unless a variable-step integrator is used.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VAaitHRKjnwc" + }, + "outputs": [], + "source": [ + "import jaxsim\n", + "\n", + "# Operate on a copy of the model.\n", + "# When validate=True, this context manager ensures that the PyTree structure\n", + "# of the object is not altered. This is a nice feature of JaxSim to spot\n", + "# earlier user logic that might trigger unwanted JIT recompilations.\n", + "# In this case, we need to disable validation since PyTree structure might\n", + "# change if you use a contact model different from the default.\n", + "with model.editable(validate=False) as model:\n", + "\n", + " # The SoftContacts class can be replaced with a different contact model.\n", + " model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()\n", + "\n", + "# JaxSim provides the following helper that estimates good contact\n", + "# parameters. While they might not be optimal, usually are a good\n", + "# starting point. Users are encouraged to fine-tune them.\n", + "contacts_params = js.contact.estimate_good_contact_parameters(\n", + " model=model,\n", + " number_of_active_collidable_points_steady_state=4,\n", + " max_penetration=0.001,\n", + ")\n", + "\n", + "# Print the contact parameters.\n", + "# Note that these parameters are the nominal parameters shared among\n", + "# all parallel instances. If needed, they can be overridden in the\n", + "# vectorized data object that will be created later.\n", + "print(contacts_params)\n", + "\n", + "# Update the data object with the new contact model parameters.\n", + "data_single = data_single.replace(contacts_params=contacts_params, validate=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6REY2bq3lc_k" + }, + "source": [ + "### Select the integrator\n", + "\n", + "JaxSim offers various integrators, ranging from basic ones like `ForwardEuler` to higher-order methods like `RungeKutta4`. You can explore the available integrators in the following modules:\n", + "\n", + "- `jaxsim.integrators.fixed_step`\n", + "- `jaxsim.integrators.variable_step`\n", + "\n", + "The `*SO3` variants update the integration scheme by integrating more accurately the base orientation on the $\\text{SO}(3)$ manifold." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o86Teq5piVGj" + }, + "outputs": [], + "source": [ + "print(f\"Using integrator: {model.integrator}\")\n", + "\n", + "# Initialize the simulated time.\n", + "T = jnp.arange(start=0, stop=1.0, step=model.time_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V6IeD2B3m4F0" + }, + "source": [ + "## Sample a batch of trajectories in parallel\n", + "\n", + "With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n", + "\n", + "In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n", + "\n", + "Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vtEn0aIzr_2j" + }, + "outputs": [], + "source": [ + "# @title Generate batched initial data\n", + "\n", + "# Create a random JAX key.\n", + "key = jax.random.PRNGKey(seed=0)\n", + "\n", + "# Split subkeys for sampling random initial data.\n", + "batch_size = 10\n", + "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n", + "\n", + "# Create the batched data by sampling the height from [0.5, 0.6] meters.\n", + "data_batch_t0 = jax.vmap(\n", + " lambda key: js.data.random_model_data(\n", + " model=model,\n", + " key=key,\n", + " base_pos_bounds=([0, 0, 0.3], [0, 0, 0.6]),\n", + " base_vel_lin_bounds=(0, 0),\n", + " base_vel_ang_bounds=(0, 0),\n", + " contacts_params=contacts_params,\n", + " )\n", + ")(jnp.vstack(subkeys))\n", + "\n", + "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0tQPfsl6uxHm" + }, + "outputs": [], + "source": [ + "# @title Create parallel step function\n", + "\n", + "import functools\n", + "from typing import Any\n", + "\n", + "\n", + "@jax.jit\n", + "def step_single(\n", + " model: js.model.JaxSimModel,\n", + " data: js.data.JaxSimModelData,\n", + ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", + "\n", + " # Close step over static arguments.\n", + " return js.model.step(\n", + " model=model,\n", + " data=data,\n", + " link_forces=None,\n", + " joint_force_references=None,\n", + " )\n", + "\n", + "\n", + "@jax.jit\n", + "@functools.partial(jax.vmap, in_axes=(None, 0))\n", + "def step_parallel(\n", + " model: js.model.JaxSimModel,\n", + " data: js.data.JaxSimModelData,\n", + ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", + "\n", + " return step_single(\n", + " model=model, data=data\n", + " )\n", + "\n", + "\n", + "# The first run will be slow since JAX needs to JIT-compile the functions.\n", + "_ = step_single(model, data_single)\n", + "_ = step_parallel(model, data_batch_t0)\n", + "\n", + "# Benchmark the execution of a single step.\n", + "print(\"\\nSingle simulation step:\")\n", + "%timeit step_single(model, data_single)\n", + "\n", + "# On hardware accelerators, there's a range of batch_size values where\n", + "# increasing the number of parallel instances doesn't affect computation time.\n", + "# This range depends on the GPU/TPU specifications.\n", + "print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n", + "%timeit step_parallel(model, data_batch_t0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VNwzT2JQ1n15" + }, + "outputs": [], + "source": [ + "# @title Run parallel simulation\n", + "\n", + "data = data_batch_t0\n", + "data_trajectory_list = []\n", + "\n", + "for _ in T:\n", + "\n", + " data, _ = step_parallel(model, data)\n", + " data_trajectory_list.append(data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y6n720Cr3G44" + }, + "source": [ + "## Visualize trajectory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BLPODyKr3Lyg" + }, + "outputs": [], + "source": [ + "# Convert a list of PyTrees to a batched PyTree.\n", + "# This operation is called 'tree transpose' in JAX.\n", + "data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n", + "\n", + "print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-jxJXy5r3RMt" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n", + "plt.grid(True)\n", + "plt.xlabel(\"Time [s]\")\n", + "plt.ylabel(\"Height [m]\")\n", + "plt.title(\"Height trajectory of the sphere\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N92-WjPFGuua" + }, + "source": [ + "# Conclusions\n", + "\n", + "This notebook introduced the key APIs of JaxSim as a hardware-accelerated parallel physics engine. Key takeaways:\n", + "\n", + "- **Contact models**: trajectories are sensitive to the contact model used. Explore the `jaxsim.rbda.contacts` package to find the best fit, as each model comes with trade-offs.\n", + "- **Integrator selection**: the choice of integrator affects both accuracy and speed. Experiment with options in the `jaxsim.integrators` package to optimize for your application and hardware accelerator.\n", + "- **Time step**: the interaction between contact models and integrators depends on the integration step $\\Delta t$. Choose the largest stable time step that guarantees for stable simulations.\n", + "- **Automatic vectorization**: this notebook demonstrated one way to use `jax.vmap`, but there are many other approaches. As you become more familiar with JAX, you'll discover better methods tailored to your needs.\n", + "- **Advanced applications**: Combine `jax.jit` and `jax.vmap` with `jax.grad`, `jax.jacfwd`, and `jax.jacrev` for gradient-based learning and other advanced tasks (not covered here).\n", + "\n", + "Have fun!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuClass": "premium", + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "rsl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/pyproject.toml b/pyproject.toml index 5fc968178..ad27e2508 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,8 @@ testing = [ "pytest >=6.0", "pytest-benchmark", "pytest-icdiff", - "robot-descriptions" + "robot-descriptions", + "icub-models", ] viz = [ "lxml",