Skip to content

Commit

Permalink
flake8 repairs
Browse files Browse the repository at this point in the history
  • Loading branch information
LMorlok committed Jul 22, 2024
1 parent 832f6bf commit c7ded6a
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions tests/test_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from embed_text_package.embed_text import Embedder
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModel

# Load dataset:
Expand All @@ -15,6 +14,7 @@
embdr = Embedder()
embdr.load(model_name)


def test_workflow():

# create batch-structure:
Expand All @@ -28,19 +28,17 @@ def test_workflow():
# get embeddings
emb = embdr.get_embeddings(batches_sentences, model_name)


# Check dimension:
assert(len(emb[0]) == batch_size)
assert (len(emb[0]) == batch_size)
# load model to check dimensions
model = AutoModel.from_pretrained(model_name)
assert(len(emb[0][0]) == model.config.hidden_size)
assert (len(emb[0][0]) == model.config.hidden_size)
# Time-reasons: only do first and last batch for now
#if len(dataset)%batch_size != 0:
# assert(len(dataset)//batch_size +1 == len(emb))
#else:
# assert(len(dataset)//batch_size == len(emb))

# if len(dataset)%batch_size != 0:
# assert (len(dataset)//batch_size +1 == len(emb))
# else:
# assert (len(dataset)//batch_size == len(emb))


if __name__ == "__main__":
test_workflow()
test_workflow()

0 comments on commit c7ded6a

Please sign in to comment.