diff --git a/octo/data/dataset.py b/octo/data/dataset.py index a549c2c0..a026c31a 100644 --- a/octo/data/dataset.py +++ b/octo/data/dataset.py @@ -336,6 +336,7 @@ def restructure(traj): # add timestep info new_obs["timestep"] = tf.range(traj_len) + new_obs['next_action'] = old_obs['next_action'] # extracts `language_key` into the "task" dict task = {} @@ -380,6 +381,7 @@ def restructure(traj): for filter_fcn_spec in filter_functions: full_dataset = full_dataset.filter(ModuleSpec.instantiate(filter_fcn_spec)) full_dataset = full_dataset.traj_map(restructure, num_parallel_calls) + # tries to load from cache, otherwise computes on the fly dataset_statistics = get_dataset_statistics( full_dataset, diff --git a/octo/model/components/action_heads.py b/octo/model/components/action_heads.py index 454c2bb2..1f077c95 100644 --- a/octo/model/components/action_heads.py +++ b/octo/model/components/action_heads.py @@ -140,8 +140,8 @@ def discrete_loss( labels = discrete_tokenizer(ground_truth_value) labels_one_hot = jax.nn.one_hot(labels, logits.shape[-1]) - loss = -jnp.sum(logits * labels_one_hot, axis=-1) - loss = masked_mean(loss, mask) + loss = jnp.sum(jax.nn.log_softmax(logits, axis=-1) * labels_one_hot, axis=-1) + loss = -masked_mean(loss, mask) # compute accuracy between predicted actions and target actions pred_label = jnp.argmax(logits, axis=-1) diff --git a/octo/model/octo_model.py b/octo/model/octo_model.py index 7fa443d0..26e86d8c 100644 --- a/octo/model/octo_model.py +++ b/octo/model/octo_model.py @@ -225,6 +225,8 @@ def load_pretrained( tf.io.gfile.join(checkpoint_path, "config.json"), "r" ) as f: config = json.load(f) + if 'readouts' in config['model']: + config['model']['readout_tokenizers'] = config['model'].pop('readouts') # load example batch with tf.io.gfile.GFile( diff --git a/octo/model/octo_module.py b/octo/model/octo_module.py index 7e44a215..a0e34b52 100644 --- a/octo/model/octo_module.py +++ b/octo/model/octo_module.py @@ -77,7 +77,7 @@ class OctoTransformer(nn.Module): observation_tokenizers: Dict[str, nn.Module] task_tokenizers: Dict[str, nn.Module] - readouts: Dict[str, int] + readout_tokenizers: Dict[str, int | nn.Module] transformer_kwargs: Dict token_embedding_size: int max_horizon: int @@ -88,7 +88,7 @@ def __call__( observations: Data, tasks: Data, pad_mask: jax.Array, - readouts: Optional[Sequence[str]] = None, + readout_tokenizers: Optional[Sequence[str]] = None, train: bool = False, verbose: bool = False, ) -> Dict[str, TokenGroup]: @@ -110,15 +110,15 @@ def __call__( Note: Horizon can be anything <= max_horizon. """ - if readouts is None: - readouts = list(self.readouts.keys()) + if readout_tokenizers is None: + readout_tokenizers = list(self.readout_tokenizers.keys()) # # Check that all inputs are valid # - assert set(readouts).issubset( - set(self.readouts.keys()) + assert set(readout_tokenizers).issubset( + set(self.readout_tokenizers.keys()) ), "readouts must be specified in the model config" batch_size, horizon = jax.tree_util.tree_leaves(observations)[0].shape[:2] @@ -213,32 +213,58 @@ def __call__( # Finally, add the readout tokens # - for readout_name in readouts: - group_name = f"readout_{readout_name}" - # Readouts do not correspond to any inputs, just positional embeddings - n_tokens_for_readout = self.readouts[readout_name] - readout_tokens = jnp.zeros( - (batch_size, horizon, n_tokens_for_readout, self.token_embedding_size) - ) - - # Add positional embedding - readout_tokens += self._create_positional_embedding( - group_name, readout_tokens - ) - readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout)) - readout_attention_rules = { - "task_*": AttentionRule.CAUSAL, - "obs_*": AttentionRule.CAUSAL, - group_name: AttentionRule.CAUSAL, - } # Attend to tasks, all previous observations, and *only it's own own readout* + for name, tok in self.readout_tokenizers.items(): + group_name = f"readout_{name}" + if isinstance(tok, nn.Module): + tokenizer_output: TokenGroup = tok(observations, tasks, train=train) + if tokenizer_output is None: + logging.warning(f"Skipping observation tokenizer: {group_name}") + continue + + obs_tokens = nn.Dense( + self.token_embedding_size, name=f"{group_name}_projection" + )(tokenizer_output.tokens) + # obs_tokens shape is (batch, horizon, n_tokens, token_embedding_size) + + # Add positional embedding + obs_tokens += self._create_positional_embedding(group_name, obs_tokens) + + # Update mask to account for which timesteps are padding + obs_pad_mask = jnp.logical_and(pad_mask[:, :, None], tokenizer_output.mask) + + all_timestep_groups.append( + TimestepGroup( + tokens=obs_tokens, + mask=obs_pad_mask, + name=group_name, + attention_rules=observation_attention_rules, + ) + ) + elif isinstance(tok, int): + # Readouts do not correspond to any inputs, just positional embeddings + n_tokens_for_readout = self.readout_tokenizers[name] + readout_tokens = jnp.zeros( + (batch_size, horizon, n_tokens_for_readout, self.token_embedding_size) + ) - all_timestep_groups.append( - TimestepGroup( - tokens=readout_tokens, - mask=readout_mask, - name=group_name, - attention_rules=readout_attention_rules, + # Add positional embedding + readout_tokens += self._create_positional_embedding( + group_name, readout_tokens ) + readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout)) + readout_attention_rules = { + "task_*": AttentionRule.CAUSAL, + "obs_*": AttentionRule.CAUSAL, + group_name: AttentionRule.CAUSAL, + } # Attend to tasks, all previous observations, and *only it's own own readout* + + all_timestep_groups.append( + TimestepGroup( + tokens=readout_tokens, + mask=readout_mask, + name=group_name, + attention_rules=readout_attention_rules, + ) ) # Run the transformer! @@ -341,7 +367,7 @@ def create( observation_tokenizers: Dict[str, ModuleSpec], task_tokenizers: Dict[str, ModuleSpec], heads: Dict[str, ModuleSpec], - readouts: Dict[str, int], + readout_tokenizers: Dict[str, int | ModuleSpec], transformer_kwargs: Dict, token_embedding_size: int, max_horizon: int, @@ -372,13 +398,17 @@ def create( task_tokenizer_defs = { k: ModuleSpec.instantiate(spec)() for k, spec in task_tokenizers.items() } + readout_tokenizer_defs = { + k: ModuleSpec.instantiate(spec)() if isinstance(spec, dict) else spec + for k, spec in readout_tokenizers.items() + } head_defs = {k: ModuleSpec.instantiate(spec)() for k, spec in heads.items()} model_def = OctoTransformer( observation_tokenizers=observation_tokenizer_defs, task_tokenizers=task_tokenizer_defs, - readouts=readouts, + readout_tokenizers=readout_tokenizer_defs, token_embedding_size=token_embedding_size, max_horizon=max_horizon, transformer_kwargs=transformer_kwargs,