Skip to content

Commit

Permalink
Merge pull request #219 from ami-iit/191-impossible-to-run-with-jax_d…
Browse files Browse the repository at this point in the history
…isable_jit-set-to-true-model-with-zero-dof

Allow models with a single link in RBDAs with `JAX_DISABLE_JIT`
  • Loading branch information
flferretti authored Aug 20, 2024
2 parents a0efffd + d87e076 commit 713bcb0
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 28 deletions.
36 changes: 24 additions & 12 deletions src/jaxsim/rbda/crba.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ def propagate_kinematics(

return (i_X_0,), None

(i_X_0,), _ = jax.lax.scan(
f=propagate_kinematics,
init=forward_pass_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
(i_X_0,), _ = (
jax.lax.scan(
f=propagate_kinematics,
init=forward_pass_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(i_X_0,), None]
)

# ===================
Expand Down Expand Up @@ -128,10 +132,14 @@ def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]:
operand=carry,
)

(j, Fi, M), _ = jax.lax.scan(
f=inner_fn,
init=carry_inner_fn,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
(j, Fi, M), _ = (
jax.lax.scan(
f=inner_fn,
init=carry_inner_fn,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
)
if model.number_of_links() > 1
else [(j, Fi, M), None]
)

Fi = i_X_0[j].T @ Fi
Expand All @@ -143,10 +151,14 @@ def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]:

# This scan performs the backward pass to compute Mbj, Mjb and Mjj, that
# also includes a fake while loop implemented with a scan and two cond.
(Mc, M), _ = jax.lax.scan(
f=backward_pass,
init=backward_pass_carry,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
(Mc, M), _ = (
jax.lax.scan(
f=backward_pass,
init=backward_pass_carry,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
)
if model.number_of_links() > 1
else [(Mc, M), None]
)

# Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.
Expand Down
12 changes: 8 additions & 4 deletions src/jaxsim/rbda/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ def propagate_kinematics(

return (W_X_i,), None

(W_X_i,), _ = jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
(W_X_i,), _ = (
jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(W_X_i,), None]
)

return jax.vmap(Adjoint.to_transform)(W_X_i)
Expand Down
36 changes: 24 additions & 12 deletions src/jaxsim/rbda/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,14 @@ def propagate_kinematics(

return (i_X_0,), None

(i_X_0,), _ = jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=np.arange(start=1, stop=model.number_of_links()),
(i_X_0,), _ = (
jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=np.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(i_X_0,), None]
)

# ============================
Expand Down Expand Up @@ -105,10 +109,14 @@ def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix:

return J, None

L_J_WL_B, _ = jax.lax.scan(
f=compute_jacobian,
init=J,
xs=np.arange(start=1, stop=model.number_of_links()),
L_J_WL_B, _ = (
jax.lax.scan(
f=compute_jacobian,
init=J,
xs=np.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [J, None]
)

return L_J_WL_B
Expand Down Expand Up @@ -184,10 +192,14 @@ def compute_full_jacobian(

return (B_X_i, J), None

(B_X_i, J), _ = jax.lax.scan(
f=compute_full_jacobian,
init=compute_full_jacobian_carry,
xs=np.arange(start=1, stop=model.number_of_links()),
(B_X_i, J), _ = (
jax.lax.scan(
f=compute_full_jacobian,
init=compute_full_jacobian_carry,
xs=np.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(B_X_i, J), None]
)

# Convert adjoints to SE(3) transforms.
Expand Down

0 comments on commit 713bcb0

Please sign in to comment.