Skip to content

Commit

Permalink
add language paraphrase augmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
mees committed May 7, 2024
1 parent 7219f0d commit c33deaa
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
76 changes: 75 additions & 1 deletion octo/data/utils/task_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,85 @@
Contains basic logic for randomly zero-ing out keys in the task specification.
"""

import pickle

import tensorflow as tf

from octo.data.utils.data_utils import to_padding


def delete_and_rephrase(
traj, pickle_file_path: str, rephrase_prob: float, keep_image_prob: float
):
traj = rephrase_instruction(traj, pickle_file_path, rephrase_prob)
traj = delete_task_conditioning(traj, keep_image_prob)
return traj


class Rephraser:
def create_static_hash_table(self, dictionary):
"""Takes a python dictionary with string keys and values and creates a tf static hash table"""
keys = list(dictionary.keys())
values = list(dictionary.values())
initializer = tf.lookup.KeyValueTensorInitializer(
keys, values, key_dtype=tf.string, value_dtype=tf.string
)
hash_table = tf.lookup.StaticHashTable(initializer, default_value="")
return hash_table

def __init__(self, pickle_file_path: str):
if isinstance(pickle_file_path, str):
with tf.io.gfile.GFile(pickle_file_path, "rb") as file:
lang_paraphrases = pickle.load(file)
# Create StaticHashTable
self.rephrase_lookup = self.create_static_hash_table(lang_paraphrases)


def rephrase_instruction(
traj: dict, pickle_file_path: str, rephrase_prob: float
) -> dict:
"""Randomly rephrases language instructions with precomputed paraphrases
Args:
traj: A dictionary containing trajectory data. Should have a "task" key.
pickle_file_path: The path to the pickle file containing the paraphrases.
rephrase_prob: The probability of augmenting the language instruction. The probability of keeping the language
instruction is 1 - rephrase_prob.
"""
rephraser = Rephraser(pickle_file_path)

if "language_instruction" not in traj["task"]:
return traj
original_language = traj["task"]["language_instruction"]
# check the language key is not empty
string_is_not_empty = tf.reduce_all(tf.strings.length(original_language) > 0)
# check dict is not empty
dict_is_not_empty = bool(rephraser.rephrase_lookup)
if dict_is_not_empty and string_is_not_empty:
rephrased_instruction = rephraser.rephrase_lookup.lookup(original_language[0])
rephrased_instruction = tf.where(
tf.strings.length(rephrased_instruction) > 0,
original_language[0] + "." + rephrased_instruction,
original_language[0],
)
split_tensor = tf.strings.split(rephrased_instruction, sep=".")
num_strings = tf.cast(tf.shape(split_tensor)[0], tf.int32)
random_index = tf.random.uniform(
(tf.shape(original_language)[0],),
minval=0,
maxval=num_strings,
dtype=tf.int32,
)
sampled_language = tf.gather(split_tensor, random_index)
rand = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32)
sampled_language = tf.where(
rand < rephrase_prob,
sampled_language,
original_language,
)
traj["task"]["language_instruction"] = sampled_language
return traj


def delete_task_conditioning(
traj: dict,
keep_image_prob: float,
Expand Down Expand Up @@ -57,4 +131,4 @@ def delete_task_conditioning(
traj_len - 1,
)

return traj
return traj
5 changes: 5 additions & 0 deletions scripts/configs/octo_pretrain_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def get_config(config_string=None):
),
traj_transform_kwargs=dict(
future_action_window_size=3,
task_augment_strategy="delete_and_rephrase",
task_augment_kwargs=dict(
pickle_file_path="gs://rail-datasets-europe-west4/oxe/resize_256_256/paraphrases_oxe.pkl",
rephrase_prob=0.5,
),
),
batch_size=128,
shuffle_buffer_size=500000,
Expand Down

0 comments on commit c33deaa

Please sign in to comment.