diff --git a/tests/test_embeds.py b/tests/test_embeds.py index 4ff68a6..ab45589 100644 --- a/tests/test_embeds.py +++ b/tests/test_embeds.py @@ -1,12 +1,11 @@ """ Tests the embed_text module. """ +import numpy as np from torch.utils.data import DataLoader from datasets import load_dataset, Dataset -from transformers import AutoModel, AutoConfig +from transformers import AutoConfig from embed_text_package.embed_text import Embedder -import wandb -import numpy as np # TEST DIFFERENT MODEL SIZES ITERATIVELY (small --> large) # Nested for-loop @@ -41,9 +40,11 @@ def test_workflow(): if ds_name == "proteinea/fluorescence": cols_to_be_embded = cols_to_be_embded_fluor ds_split = ds_split_fluor - if ds_name == "allenai/reward-bench": + elif ds_name == "allenai/reward-bench": cols_to_be_embded = cols_to_be_embded_rwben ds_split = ds_split_rwben + else: + raise ValueError(f"Unknown dataset name: {ds_name}") print(f"dataset_name = {ds_name}\n") dataset = load_dataset(ds_name)[ds_split] @@ -65,7 +66,7 @@ def test_workflow(): assert len(emb) == len(sub_ds) for col in cols_to_be_embded: assert len(emb[col][0]) == config.hidden_size - + if __name__ == "__main__": test_workflow()