Skip to content

Commit

Permalink
pylint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LMorlok committed Jul 23, 2024
1 parent 95e8fc3 commit eda906b
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/test_embeds.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""
Tests the embed_text module.
"""
import numpy as np
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from transformers import AutoModel, AutoConfig
from transformers import AutoConfig
from embed_text_package.embed_text import Embedder
import wandb
import numpy as np

# TEST DIFFERENT MODEL SIZES ITERATIVELY (small --> large)
# Nested for-loop
Expand Down Expand Up @@ -41,9 +40,11 @@ def test_workflow():
if ds_name == "proteinea/fluorescence":
cols_to_be_embded = cols_to_be_embded_fluor
ds_split = ds_split_fluor
if ds_name == "allenai/reward-bench":
elif ds_name == "allenai/reward-bench":
cols_to_be_embded = cols_to_be_embded_rwben
ds_split = ds_split_rwben
else:
raise ValueError(f"Unknown dataset name: {ds_name}")
print(f"dataset_name = {ds_name}\n")
dataset = load_dataset(ds_name)[ds_split]

Expand All @@ -65,7 +66,7 @@ def test_workflow():
assert len(emb) == len(sub_ds)
for col in cols_to_be_embded:
assert len(emb[col][0]) == config.hidden_size


if __name__ == "__main__":
test_workflow()

0 comments on commit eda906b

Please sign in to comment.