Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difference between SentenceLabelDataset and GroupByLabelBatchSampler? #2920

Open
vibhas-singh opened this issue Aug 31, 2024 · 1 comment
Open

Comments

@vibhas-singh
Copy link

vibhas-singh commented Aug 31, 2024

Hi @tomaarsen
First of all - kudos to you for maintaining such an awesome and pragmatic library.

I am facing some difficulty on using GROUP_BY_LABEL batch sampler in v3.0 and want to highlight the issues to check if there is any way to mitigate those.

I went through the issues and found this: #2698 (comment)
You have mentioned here that the idea is to replace SentenceLabelDataset by GroupByLabelBatchSampler but I think there is very drastic differences between the two and we haven't retained the same functionality of SentenceLabelDataset while opting for GroupByLabelBatchSampler as a replacement.

I am taking an detailed example to explain the differences:

Let's take a simple example with a list of integers representing classes, and we'll use it to illustrate how the two approaches handle homogeneity in batch construction.

Example Data:

  • Classes: [1, 1, 1, 1, 2, 2, 2, 3, 3, 4]
  • Number of Samples per Label:
    • Class 1: 4 samples
    • Class 2: 3 samples
    • Class 3: 2 samples
    • Class 4: 1 sample
  • Batch Size: 4
  • Samples per Label in SentenceLabelDataset: 2

GroupByLabelBatchSampler Behavior:

Step 1: Initialization

  • Grouping by Labels:

    • Class 1: [0, 1, 2, 3]
    • Class 2: [4, 5, 6]
    • Class 3: [7, 8]
    • Class 4: [9]
  • Truncate to Even Number:

    • Class 1: [0, 1, 2, 3] (all 4 samples)
    • Class 2: [4, 5] (only 2 samples, 1 sample is discarded)
    • Class 3: [7, 8] (both samples)
    • Class 4: [9] (removed because it doesn't have at least 2 samples)
  • Resulting Groups:

    • Class 1: [0, 1, 2, 3]
    • Class 2: [4, 5]
    • Class 3: [7, 8]

Step 2: Batch Construction

  • Batch 1:

    • Randomly selects Class 1 and takes all samples: [0, 1, 2, 3]
    • Result: [0, 1, 2, 3] (4 samples from Class 1, fully homogeneous)
  • Batch 2:

    • Randomly selects Class 2 and takes both samples: [4, 5]
    • Randomly selects Class 3 and takes both samples: [7, 8]
    • Result: [4, 5, 7, 8] (2 samples from Class 2, 2 samples from Class 3)

SentenceLabelDataset Behavior:

Step 1: Initialization

  • Label Range: [1, 2, 3, 4]
  • Already Seen Dictionary: Empty

Step 2: Batch Construction

  • Batch 1:

    • Starts with Class 1, takes 2 samples (e.g., [0, 1])
    • Result: [0, 1] (2 samples from Class 1)
    • Moves to Class 2, takes 2 samples (e.g., [4, 5])
    • Result: [0, 1, 4, 5] (2 samples from Class 1, 2 samples from Class 2)
  • Batch 2:

    • Continues with Class 1, takes remaining 2 samples (e.g., [2, 3])
    • Result: [2, 3] (2 samples from Class 1)
    • Continues with Class 2 but only 1 sample is left, so skips it and moves to Class 3
    • Takes both samples from Class 3 (e.g., [7, 8])
    • Result: [2, 3, 7, 8] (2 samples from Class 1, 2 samples from Class 3)
  • Batch 3:

    • If continuing, it would reset, shuffle, and start over, potentially including Class 4 depending on whether with_replacement is True or False.

Comparison of Homogeneity:

  • GroupByLabelBatchSampler:

    • Batch 1: Fully homogeneous batch from Class 1: [0, 1, 2, 3].
    • Batch 2: Mixes samples from two classes (Class 2 and Class 3) because neither can fully fill a batch on their own: [4, 5, 7, 8].
    • Overall: Prioritizes keeping batches homogeneous when possible, especially when a class has enough samples to fill an entire batch. It avoids mixing classes unless necessary.
  • SentenceLabelDataset:

    • Batch 1: Mixed from the start, taking 2 samples from Class 1 and 2 from Class 2: [0, 1, 4, 5].
    • Batch 2: Also mixed, with samples from Class 1 and Class 3: [2, 3, 7, 8].
    • Overall: More likely to mix classes in every batch. It enforces a fixed number of samples per label and sequentially processes labels, leading to less homogeneous batches overall.

TL;DR:

I am trying to fine-tune sentence transformers models using the dataset with this label distribution:

Class 1: 5000 Samples
Class 2: 5000 Samples
Class 3: 3000 Samples
Class 5 to 50: Less than 50 samples each

In the new GroupByLabelBatchSampler the batching logic is yielding most of the batches as homogeneous and there is not much improvement observed after fine-tuning.
IMO this type of data could have been easily used with SentenceLabelDataset as it ensures there is at max N samples from each label in a batch. Intuitively, ST models should benefit from having in-batch negatives and more heterogeneous batches.

Can you help me in veryfying if my understanding is correct and if yes, is there any way to opt for the older logic?

@vibhas-singh vibhas-singh changed the title Difference between SentenceLabelDataset and GroupByLabelBatchSampler Difference between SentenceLabelDataset and GroupByLabelBatchSampler? Aug 31, 2024
@tomaarsen
Copy link
Collaborator

Hello!

You're very right in your analysis: GroupByLabelBatchSampler was designed to replace SentenceLabelDataset, and the former is homogeneous whereas the latter is not so much. For reference, here is the docstring for the new GroupByLabelBatchSampler:

This sampler groups samples by their labels and aims to create batches such that
each batch contains samples where the labels are as homogeneous as possible.
This sampler is meant to be used alongside the `Batch...TripletLoss` classes, which
require that each batch contains at least 2 examples per label class.

This sampler is meant for the Batch...TripletLoss classes, which require that each batch contains at least 2 examples per label class. These losses compare across all samples with the same label within the same batch, benefiting from 1) larger batches and 2) more samples with the same label in the each batch. At least, that is my understanding.
As a result, in theory a more homogeneous batch should result in a better training signal for these losses. However, I admit that I haven't tested it out in practice, and I may be wrong.

Is there any way to opt for the older logic?

Yes, and no. You can override the Trainer's get_batch_sampler:

def get_batch_sampler(
self,
dataset: Dataset,
batch_size: int,
drop_last: bool,
valid_label_columns: list[str] | None = None,
generator: torch.Generator | None = None,
) -> BatchSampler:

And replace it with a function that immediately returns a custom Batch Sampler which has your desired behaviour. So yes: you can use the older logic, but no: you'd have to write it yourself.

Hope this helps a bit.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants