Skip to content

Commit

Permalink
added readout tokenizers and fixed discrete head
Browse files Browse the repository at this point in the history
  • Loading branch information
andrearosasco committed Apr 22, 2024
1 parent c4c222a commit ca9ab7e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 34 deletions.
2 changes: 2 additions & 0 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def restructure(traj):

# add timestep info
new_obs["timestep"] = tf.range(traj_len)
new_obs['next_action'] = old_obs['next_action']

# extracts `language_key` into the "task" dict
task = {}
Expand Down Expand Up @@ -380,6 +381,7 @@ def restructure(traj):
for filter_fcn_spec in filter_functions:
full_dataset = full_dataset.filter(ModuleSpec.instantiate(filter_fcn_spec))
full_dataset = full_dataset.traj_map(restructure, num_parallel_calls)

# tries to load from cache, otherwise computes on the fly
dataset_statistics = get_dataset_statistics(
full_dataset,
Expand Down
4 changes: 2 additions & 2 deletions octo/model/components/action_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def discrete_loss(
labels = discrete_tokenizer(ground_truth_value)
labels_one_hot = jax.nn.one_hot(labels, logits.shape[-1])

loss = -jnp.sum(logits * labels_one_hot, axis=-1)
loss = masked_mean(loss, mask)
loss = jnp.sum(jax.nn.log_softmax(logits, axis=-1) * labels_one_hot, axis=-1)
loss = -masked_mean(loss, mask)

# compute accuracy between predicted actions and target actions
pred_label = jnp.argmax(logits, axis=-1)
Expand Down
2 changes: 2 additions & 0 deletions octo/model/octo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def load_pretrained(
tf.io.gfile.join(checkpoint_path, "config.json"), "r"
) as f:
config = json.load(f)
if 'readouts' in config['model']:
config['model']['readout_tokenizers'] = config['model'].pop('readouts')

# load example batch
with tf.io.gfile.GFile(
Expand Down
94 changes: 62 additions & 32 deletions octo/model/octo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class OctoTransformer(nn.Module):

observation_tokenizers: Dict[str, nn.Module]
task_tokenizers: Dict[str, nn.Module]
readouts: Dict[str, int]
readout_tokenizers: Dict[str, int | nn.Module]
transformer_kwargs: Dict
token_embedding_size: int
max_horizon: int
Expand All @@ -88,7 +88,7 @@ def __call__(
observations: Data,
tasks: Data,
pad_mask: jax.Array,
readouts: Optional[Sequence[str]] = None,
readout_tokenizers: Optional[Sequence[str]] = None,
train: bool = False,
verbose: bool = False,
) -> Dict[str, TokenGroup]:
Expand All @@ -110,15 +110,15 @@ def __call__(
Note: Horizon can be anything <= max_horizon.
"""
if readouts is None:
readouts = list(self.readouts.keys())
if readout_tokenizers is None:
readout_tokenizers = list(self.readout_tokenizers.keys())

#
# Check that all inputs are valid
#

assert set(readouts).issubset(
set(self.readouts.keys())
assert set(readout_tokenizers).issubset(
set(self.readout_tokenizers.keys())
), "readouts must be specified in the model config"

batch_size, horizon = jax.tree_util.tree_leaves(observations)[0].shape[:2]
Expand Down Expand Up @@ -213,32 +213,58 @@ def __call__(
# Finally, add the readout tokens
#

for readout_name in readouts:
group_name = f"readout_{readout_name}"
# Readouts do not correspond to any inputs, just positional embeddings
n_tokens_for_readout = self.readouts[readout_name]
readout_tokens = jnp.zeros(
(batch_size, horizon, n_tokens_for_readout, self.token_embedding_size)
)

# Add positional embedding
readout_tokens += self._create_positional_embedding(
group_name, readout_tokens
)
readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout))
readout_attention_rules = {
"task_*": AttentionRule.CAUSAL,
"obs_*": AttentionRule.CAUSAL,
group_name: AttentionRule.CAUSAL,
} # Attend to tasks, all previous observations, and *only it's own own readout*
for name, tok in self.readout_tokenizers.items():
group_name = f"readout_{name}"
if isinstance(tok, nn.Module):
tokenizer_output: TokenGroup = tok(observations, tasks, train=train)
if tokenizer_output is None:
logging.warning(f"Skipping observation tokenizer: {group_name}")
continue

obs_tokens = nn.Dense(
self.token_embedding_size, name=f"{group_name}_projection"
)(tokenizer_output.tokens)
# obs_tokens shape is (batch, horizon, n_tokens, token_embedding_size)

# Add positional embedding
obs_tokens += self._create_positional_embedding(group_name, obs_tokens)

# Update mask to account for which timesteps are padding
obs_pad_mask = jnp.logical_and(pad_mask[:, :, None], tokenizer_output.mask)

all_timestep_groups.append(
TimestepGroup(
tokens=obs_tokens,
mask=obs_pad_mask,
name=group_name,
attention_rules=observation_attention_rules,
)
)
elif isinstance(tok, int):
# Readouts do not correspond to any inputs, just positional embeddings
n_tokens_for_readout = self.readout_tokenizers[name]
readout_tokens = jnp.zeros(
(batch_size, horizon, n_tokens_for_readout, self.token_embedding_size)
)

all_timestep_groups.append(
TimestepGroup(
tokens=readout_tokens,
mask=readout_mask,
name=group_name,
attention_rules=readout_attention_rules,
# Add positional embedding
readout_tokens += self._create_positional_embedding(
group_name, readout_tokens
)
readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout))
readout_attention_rules = {
"task_*": AttentionRule.CAUSAL,
"obs_*": AttentionRule.CAUSAL,
group_name: AttentionRule.CAUSAL,
} # Attend to tasks, all previous observations, and *only it's own own readout*

all_timestep_groups.append(
TimestepGroup(
tokens=readout_tokens,
mask=readout_mask,
name=group_name,
attention_rules=readout_attention_rules,
)
)

# Run the transformer!
Expand Down Expand Up @@ -341,7 +367,7 @@ def create(
observation_tokenizers: Dict[str, ModuleSpec],
task_tokenizers: Dict[str, ModuleSpec],
heads: Dict[str, ModuleSpec],
readouts: Dict[str, int],
readout_tokenizers: Dict[str, int | ModuleSpec],
transformer_kwargs: Dict,
token_embedding_size: int,
max_horizon: int,
Expand Down Expand Up @@ -372,13 +398,17 @@ def create(
task_tokenizer_defs = {
k: ModuleSpec.instantiate(spec)() for k, spec in task_tokenizers.items()
}
readout_tokenizer_defs = {
k: ModuleSpec.instantiate(spec)() if isinstance(spec, dict) else spec
for k, spec in readout_tokenizers.items()
}

head_defs = {k: ModuleSpec.instantiate(spec)() for k, spec in heads.items()}

model_def = OctoTransformer(
observation_tokenizers=observation_tokenizer_defs,
task_tokenizers=task_tokenizer_defs,
readouts=readouts,
readout_tokenizers=readout_tokenizer_defs,
token_embedding_size=token_embedding_size,
max_horizon=max_horizon,
transformer_kwargs=transformer_kwargs,
Expand Down

0 comments on commit ca9ab7e

Please sign in to comment.