Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train speech language ID classification head #450

Open
wants to merge 91 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
738c135
classification head class
am831 May 6, 2024
feac449
finetune script progress
am831 May 7, 2024
dc6cbdb
add comment
am831 May 7, 2024
350f6e2
add layers
am831 May 7, 2024
3c81c28
Model freeze und save classification head
zrthxn May 7, 2024
74f1f2d
Implement train loop
zrthxn May 7, 2024
3280b64
Implement train loop
zrthxn May 7, 2024
ca42666
calc loss, class head params
am831 May 8, 2024
13901cb
Refactor
zrthxn May 8, 2024
db3df0c
fix errors
am831 May 8, 2024
1b84a5b
get vector dimensions within classification head
am831 May 8, 2024
c343cd4
hidden_dim
am831 May 8, 2024
bc7a37f
log and capture interrupts
am831 May 8, 2024
d2a65f8
dataset prep and plotting loss
am831 May 8, 2024
4c65310
Merge branch 'facebookresearch:main' into language_id
am831 May 10, 2024
24781c9
Language ID Dataloader (#2)
zrthxn May 10, 2024
ce04f48
Model Fixes (#3)
zrthxn May 11, 2024
6584115
Merge branch 'facebookresearch:main' into language_id
am831 May 12, 2024
2567ebe
Merge branch 'language_id' of https://github.com/am831/seamless_commu…
am831 May 13, 2024
5efb06c
classification head class
am831 May 6, 2024
d0e8efc
finetune script progress
am831 May 7, 2024
6d76952
add comment
am831 May 7, 2024
64ca2be
add layers
am831 May 7, 2024
373312c
Model freeze und save classification head
zrthxn May 7, 2024
9ee465f
Implement train loop
zrthxn May 7, 2024
e1ab896
Implement train loop
zrthxn May 7, 2024
67dd5fb
calc loss, class head params
am831 May 8, 2024
be33445
Refactor
zrthxn May 8, 2024
70ef27c
fix errors
am831 May 8, 2024
4be035e
get vector dimensions within classification head
am831 May 8, 2024
22786a0
hidden_dim
am831 May 8, 2024
70f93da
log and capture interrupts
am831 May 8, 2024
8eb5195
dataset prep and plotting loss
am831 May 8, 2024
0794e02
Language ID Dataloader (#2)
zrthxn May 10, 2024
c485442
Model Fixes (#3)
zrthxn May 11, 2024
7d2c589
get embed_dim dynamically (#4)
am831 May 13, 2024
d9d6dc0
Model Fixes
zrthxn May 15, 2024
fc7984c
Merge branch 'language_id' of https://github.com/am831/seamless_commu…
am831 May 16, 2024
3de26f4
save plot as pkl
am831 May 16, 2024
caf9a08
address some feedback
zrthxn May 17, 2024
a928a51
Merge pull request #6 from am831/changes_lanID
am831 May 17, 2024
1347b46
switch model to train mode
am831 May 17, 2024
9184204
Code cleanup
zrthxn May 18, 2024
842e1c8
BCE loss
zrthxn May 18, 2024
19653d7
Remove Label smoothing
zrthxn May 18, 2024
588b0a8
change model to increase train loss
am831 May 19, 2024
480d318
classification head class
am831 May 6, 2024
fcb9e90
finetune script progress
am831 May 7, 2024
3581ff9
add comment
am831 May 7, 2024
3c5adc2
add layers
am831 May 7, 2024
2c47cde
Model freeze und save classification head
zrthxn May 7, 2024
c63c427
Implement train loop
zrthxn May 7, 2024
2748412
Implement train loop
zrthxn May 7, 2024
efc93e8
calc loss, class head params
am831 May 8, 2024
0bed503
Refactor
zrthxn May 8, 2024
e1f75fd
fix errors
am831 May 8, 2024
432e692
get vector dimensions within classification head
am831 May 8, 2024
56f72de
hidden_dim
am831 May 8, 2024
eb65fa9
log and capture interrupts
am831 May 8, 2024
04eea4f
dataset prep and plotting loss
am831 May 8, 2024
f41814f
Language ID Dataloader (#2)
zrthxn May 10, 2024
3839e24
Model Fixes (#3)
zrthxn May 11, 2024
4d2f435
classification head class
am831 May 6, 2024
d1ab8ad
finetune script progress
am831 May 7, 2024
631305f
add comment
am831 May 7, 2024
6b81e8e
add layers
am831 May 7, 2024
b266e77
Model freeze und save classification head
zrthxn May 7, 2024
2f4b117
Implement train loop
zrthxn May 7, 2024
2ce0eca
Implement train loop
zrthxn May 7, 2024
228bf5f
calc loss, class head params
am831 May 8, 2024
b33a1d1
Refactor
zrthxn May 8, 2024
2eb1732
fix errors
am831 May 8, 2024
cbaef16
get vector dimensions within classification head
am831 May 8, 2024
b5358de
hidden_dim
am831 May 8, 2024
cc85576
log and capture interrupts
am831 May 8, 2024
2271edc
dataset prep and plotting loss
am831 May 8, 2024
e7b2195
Language ID Dataloader (#2)
zrthxn May 10, 2024
c2488e9
Model Fixes (#3)
zrthxn May 11, 2024
9eb5876
get embed_dim dynamically (#4)
am831 May 13, 2024
7244db5
Model Fixes
zrthxn May 15, 2024
d2810d8
address some feedback
zrthxn May 17, 2024
622675c
save plot as pkl
am831 May 16, 2024
8980655
switch model to train mode
am831 May 17, 2024
13bcdea
Code cleanup
zrthxn May 18, 2024
f86a201
BCE loss
zrthxn May 18, 2024
e0bffc1
Remove Label smoothing
zrthxn May 18, 2024
e269cef
change model to increase train loss
am831 May 19, 2024
27c8c5d
lid classification head training script
mavlyutovr May 22, 2024
671746e
fixes
mavlyutovr May 22, 2024
71d248c
Merge pull request #8 from zrthxn/lid/ruslan-fixes
am831 May 25, 2024
6485277
Merge branch 'language_id' of https://github.com/am831/seamless_commu…
am831 Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
197 changes: 197 additions & 0 deletions src/seamless_communication/cli/m4t/classification_head/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.


import json
import logging
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import torch
import torchaudio
from datasets import Dataset
from datasets.distributed import split_dataset_by_node
from fairseq2.data.text import TextTokenEncoder
from fairseq2.models.nllb import NllbTokenizer
from fairseq2.data.audio import WaveformToFbankConverter
from torch import Tensor
from torch.nn.functional import pad as pad_tensor
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder

from seamless_communication.datasets.datatypes import LangPairSample
from seamless_communication.models.unity.unit_tokenizer import (
UnitTokenEncoder,
UnitTokenizer,
)

logger = logging.getLogger(__name__)


@dataclass
class SeqsBatch:
src_tokens: Optional[Tensor]
src_lengths: Optional[Tensor]

def __del__(self) -> None:
"""Explicitly delete tensors
to force GPU memory cleanup"""
for tensor in [self.src_tokens, self.src_lengths]:
if tensor is not None:
del tensor


@dataclass
class BatchingConfig:
fbank_feats_pad_idx: int = 0
"""The pad index to use in fbanks batching."""

batch_size: int = 5
"""Fixed batch size to use"""

max_audio_length_sec: float = 15.0
""" Drop samples with source audio sample length above the threshold."""

rank: int = 0
"""The rank of this worker in the process group."""

world_size: int = 1
"""The world size of the process group."""

num_workers: int = 2
"""Parallelism in dataset preparation."""

float_dtype: torch.dtype = torch.float16
"""Select between fp16/fp32 for float tensors """

langs: Tuple[str] = ("eng", "fra", "deu", "rus", "spa")
"""Class labels"""


def worker_init_fn(worker_id: int) -> None:
np.random.seed(np.random.get_state()[1][0] + worker_id) # type: ignore


class UnitYLanguageIDDataLoader:
SAMPLE_RATE = 16_000

def __init__(
self,
num_languages: int,
text_tokenizer: NllbTokenizer,
unit_tokenizer: UnitTokenizer,
dataset_manifest_path: str,
batching_config: BatchingConfig,
):
self.num_languages = num_languages
self.text_tokenizer = text_tokenizer
self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
self.unit_tokenizer = unit_tokenizer
self.unit_encoders_per_lang: Dict[str, UnitTokenEncoder] = {}
self.batching_config = batching_config
self._fbank_extract_params = {
"num_mel_bins": 80,
"waveform_scale": 32768,
"channel_last": True,
"standardize": True,
"device": torch.device("cpu"),
"dtype": self.batching_config.float_dtype,
}
self.dataset = self._load_manifest(dataset_manifest_path)

def get_dataloader(self) -> DataLoader[SeqsBatch]:
subset = split_dataset_by_node(
self.dataset,
rank=self.batching_config.rank,
world_size=self.batching_config.world_size,
)
data_loader = DataLoader(
dataset=subset,
batch_size=self.batching_config.batch_size,
shuffle=True,
num_workers=self.batching_config.num_workers,
collate_fn=self._collate,
worker_init_fn=worker_init_fn,
)
return data_loader

def __iter__(self) -> Iterable[Any]:
return self.get_dataloader().__iter__()

def _get_source_fbank(self, sample: LangPairSample) -> Tensor:
wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
assert (
int(sample_rate) == self.SAMPLE_RATE
), f"sample != {self.SAMPLE_RATE}, please resample"
assert len(wav.shape) in (1, 2)
if len(wav.shape) == 1:
wav = wav.unsqueeze(-1)
elif wav.shape[0] <= 2: # channel is first, should be second
wav = wav.transpose(0, 1)
return WaveformToFbankConverter(**self._fbank_extract_params)( # type: ignore
{
"waveform": wav,
"sample_rate": self.SAMPLE_RATE,
}
)["fbank"]

def _batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor:
padding_size = max(tensor.shape[0] for tensor in tensors)
dims = len(tensors[0].shape)
padded_tensors = []
for tensor in tensors:
padding = [0] * 2 * dims
padding[-1] = padding_size - tensor.shape[0]
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
return torch.stack([tensor for tensor in padded_tensors], dim=0)

def _is_long_src_audio(self, sample: LangPairSample) -> bool:
# HACK:: causes errored audios to be excluded but this is difficult to follow
try:
wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
length_s: float = max(wav.shape) / sample_rate
return length_s > self.batching_config.max_audio_length_sec
except Exception:
logger.exception(
f"Failed to load sample path: {sample.source.audio_local_path}"
)
return True

def _collate(self, raw_samples: List[Dict[str, Any]]) -> Tuple[SeqsBatch, torch.Tensor]:
samples = [LangPairSample.from_json(sample) for sample in raw_samples]

# Input Speech

# 1 - filter long audio samples
filtered_samples = [
sample for sample in samples if not self._is_long_src_audio(sample)
]
samples = (
filtered_samples if filtered_samples else [samples[0]]
) # keep at least one sample
src_tokens_list = [self._get_source_fbank(sample) for sample in samples]

# 2 - filter NaNs in fbanks´´
with_nans = [fbank.isnan().any().item() for fbank in src_tokens_list]
samples = [sample for sample, skip in zip(samples, with_nans) if not skip]
assert len(samples) > 0
src_tokens_list = [
tok for tok, skip in zip(src_tokens_list, with_nans) if not skip
]
src_tokens = self._batch_tensors(
src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
).to(self.batching_config.float_dtype)
src_lengths = torch.LongTensor([tok.shape[0] for tok in src_tokens_list])
source_lang_ids = torch.LongTensor([self.batching_config.langs.index(sample.source.lang) for sample in samples])
# logger.info(f"Batch size {source_lang_ids.shape}, lengths: {src_lengths}, labels: {source_lang_ids}")

return SeqsBatch(src_tokens=src_tokens, src_lengths=src_lengths), source_lang_ids

def _load_manifest(self, dataset_manifest_path: str) -> Dataset:
with open(dataset_manifest_path) as fp_in:
dataset = [json.loads(line) for line in fp_in]
return Dataset.from_list(dataset)
Loading