Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LabelScorer base class #80

Open
wants to merge 4 commits into
base: collapsed-vector
Choose a base branch
from
Open

Conversation

SimBe195
Copy link
Collaborator

@SimBe195 SimBe195 commented Oct 31, 2024

Abstract base class for scoring tokens within an ASR search algorithm.

This class provides an interface for different types of label scorers in an ASR system. Label Scorers compute the scores of tokens based on input features and a scoring context. Children of this base class should represent various ASR model architectures and cover a wide range of possibilities such as CTC, transducer, AED or other models.

The usage is intended as follows:

  • Before or during the search, features can be added
  • At the beginning of search, getInitialScoringContext should be called and used for the first hypotheses
  • For a given hypothesis in search, its search context together with a successor token and transition type are packed into a request and scored via getScoreWithTime. This also returns the timestamp of the successor.
    • Note: The scoring function may return no value, in this case it is not ready yet and needs more input features.
    • Note: There is also the function getScoresWithTimes which can handle an entire batch of requests at once and might be implemented more efficiently (e.g. using batched model forwarding).
  • For all hypotheses that survive pruning, the LabelScorer can compute a new scoring context that extends the previous scoring context of that hypothesis with a given successor token. This new scoring context can then be used as context in subsequent search steps.
  • After all features have been passed, the signalNoMoreFeatures function is called to inform the label scorer that it doesn't need to wait for more features and can score as much as possible. This is especially important when the label scorer internally uses an encoder or window with right context.
  • When all necessary scores for the current segment have been computed, the reset function is called to clean up any internal data (e.g. feature buffer) or reset flags of the LabelScorer. Afterwards it is ready to receive features for the next segment.

Each concrete subclass internally implements a concrete type of scoring context which the outside search algorithm is agnostic to. Depending on the model, this scoring context can consist of things like the current timestep, a label history, a hidden state or other values.

This PR is dependent on #78.

src/Nn/LabelScorer/LabelScorer.hh Outdated Show resolved Hide resolved
src/Nn/LabelScorer/LabelScorer.hh Show resolved Hide resolved
Copy link
Contributor

@curufinwe curufinwe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write const references as std::vector<int> const& instead of const std::vector<int>& for better consistency with recently written RASR code.

// Return value of scoring function
struct ScoreWithTime {
Score score;
Speech::TimeframeIndex timestep;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is called TimeframeIndex, so I would suggest to call this member timeframe instead of timestep.

// Return value of batched scoring function
struct ScoresWithTimes {
std::vector<Score> scores;
Core::CollapsedVector<Speech::TimeframeIndex> timesteps; // Timesteps vector is internally collapsed if all timesteps are the same (e.g. time-sync decoding)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

// Return score and timeframe index of the corresponding output
// May not return a value if the LabelScorer is not ready to score the request yet
// (e.g. not enough features received)
virtual std::optional<ScoreWithTime> getScoreWithTime(const Request request) = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the request is passed by value it does not need to be const. But I guess you want to pass by reference here. Also: what do you think about calling this function computeScoreWithTime to better indicate that this is a more costly call?


namespace Nn {

typedef Search::Score Score;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to put these into the Nn namespace? Wouldn't it be sufficient to put them in the LabelScorer class?

#include <Search/Types.hh>
#include <Speech/Feature.hh>
#include <Speech/Types.hh>
#include <optional>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std lib imports should be in a separate block before the RASR imports. Also: newline before importing from the same module (i.e. between Speech/Types.hh and ScoringContext.hh.


/*
* =============================
* === ScoringContext ============
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too many = chars in this line.


namespace Search {

enum TransitionType {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Tina that this might logically be related to Search, but by putting into the Search namespace we add a dependency for Search in the Nn module. I would like to avoid that as we otherwise might end up having to compile libRasrSearch..a whenever we want to use libRasrNn..a . So I would prefer to keep it with the LabelScorer in the Nn moudle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants