Skip to content

Commit

Permalink
Merge pull request #79 from andrearosasco/patch-1
Browse files Browse the repository at this point in the history
Added LogSoftmax to discrete loss
  • Loading branch information
dibyaghosh authored Apr 30, 2024
2 parents bd930f9 + 89edda5 commit 7480a2a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions octo/model/components/action_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def discrete_loss(
labels = discrete_tokenizer(ground_truth_value)
labels_one_hot = jax.nn.one_hot(labels, logits.shape[-1])

loss = -jnp.sum(logits * labels_one_hot, axis=-1)
loss = masked_mean(loss, mask)
loss = jnp.sum(jax.nn.log_softmax(logits, axis=-1) * labels_one_hot, axis=-1)
loss = -masked_mean(loss, mask)

# compute accuracy between predicted actions and target actions
pred_label = jnp.argmax(logits, axis=-1)
Expand Down

0 comments on commit 7480a2a

Please sign in to comment.