-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LabelScorer base class #80
base: collapsed-vector
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please write const references as std::vector<int> const&
instead of const std::vector<int>&
for better consistency with recently written RASR code.
// Return value of scoring function | ||
struct ScoreWithTime { | ||
Score score; | ||
Speech::TimeframeIndex timestep; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type is called TimeframeIndex, so I would suggest to call this member timeframe instead of timestep.
// Return value of batched scoring function | ||
struct ScoresWithTimes { | ||
std::vector<Score> scores; | ||
Core::CollapsedVector<Speech::TimeframeIndex> timesteps; // Timesteps vector is internally collapsed if all timesteps are the same (e.g. time-sync decoding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above
// Return score and timeframe index of the corresponding output | ||
// May not return a value if the LabelScorer is not ready to score the request yet | ||
// (e.g. not enough features received) | ||
virtual std::optional<ScoreWithTime> getScoreWithTime(const Request request) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the request is passed by value it does not need to be const. But I guess you want to pass by reference here. Also: what do you think about calling this function computeScoreWithTime to better indicate that this is a more costly call?
|
||
namespace Nn { | ||
|
||
typedef Search::Score Score; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to put these into the Nn namespace? Wouldn't it be sufficient to put them in the LabelScorer class?
#include <Search/Types.hh> | ||
#include <Speech/Feature.hh> | ||
#include <Speech/Types.hh> | ||
#include <optional> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std lib imports should be in a separate block before the RASR imports. Also: newline before importing from the same module (i.e. between Speech/Types.hh and ScoringContext.hh.
|
||
/* | ||
* ============================= | ||
* === ScoringContext ============ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
too many = chars in this line.
|
||
namespace Search { | ||
|
||
enum TransitionType { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with Tina that this might logically be related to Search, but by putting into the Search namespace we add a dependency for Search in the Nn module. I would like to avoid that as we otherwise might end up having to compile libRasrSearch..a whenever we want to use libRasrNn..a . So I would prefer to keep it with the LabelScorer in the Nn moudle.
Abstract base class for scoring tokens within an ASR search algorithm.
This class provides an interface for different types of label scorers in an ASR system. Label Scorers compute the scores of tokens based on input features and a scoring context. Children of this base class should represent various ASR model architectures and cover a wide range of possibilities such as CTC, transducer, AED or other models.
The usage is intended as follows:
getInitialScoringContext
should be called and used for the first hypothesesgetScoreWithTime
. This also returns the timestamp of the successor.getScoresWithTimes
which can handle an entire batch of requests at once and might be implemented more efficiently (e.g. using batched model forwarding).signalNoMoreFeatures
function is called to inform the label scorer that it doesn't need to wait for more features and can score as much as possible. This is especially important when the label scorer internally uses an encoder or window with right context.reset
function is called to clean up any internal data (e.g. feature buffer) or reset flags of the LabelScorer. Afterwards it is ready to receive features for the next segment.Each concrete subclass internally implements a concrete type of scoring context which the outside search algorithm is agnostic to. Depending on the model, this scoring context can consist of things like the current timestep, a label history, a hidden state or other values.
This PR is dependent on #78.