-
Notifications
You must be signed in to change notification settings - Fork 923
/
svc_train_retrieval.py
114 lines (92 loc) · 3.98 KB
/
svc_train_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import logging
import multiprocessing
from functools import partial
from pathlib import Path
import faiss
from feature_retrieval import (
train_index,
FaissIVFFlatTrainableFeatureIndexBuilder,
OnConditionFeatureTransform,
MinibatchKmeansFeatureTransform,
DummyFeatureTransform,
)
logger = logging.getLogger(__name__)
def get_speaker_list(base_path: Path):
speakers_path = base_path / "waves-16k"
if not speakers_path.exists():
raise FileNotFoundError(f"path {speakers_path} does not exists")
return [speaker_dir.name for speaker_dir in speakers_path.iterdir() if speaker_dir.is_dir()]
def create_indexes_path(base_path: Path) -> Path:
indexes_path = base_path / "indexes"
logger.info("create indexes folder %s", indexes_path)
indexes_path.mkdir(exist_ok=True)
return indexes_path
def create_index(
feature_name: str,
prefix: str,
speaker: str,
base_path: Path,
indexes_path: Path,
compress_features_after: int,
n_clusters: int,
n_parallel: int,
train_batch_size: int = 8192,
) -> None:
features_path = base_path / feature_name / speaker
if not features_path.exists():
raise ValueError(f'features not found by path {features_path}')
index_path = indexes_path / speaker
index_path.mkdir(exist_ok=True)
index_filename = f"{prefix}{feature_name}.index"
index_filepath = index_path / index_filename
logger.debug('index will be save to %s', index_filepath)
builder = FaissIVFFlatTrainableFeatureIndexBuilder(train_batch_size, distance=faiss.METRIC_L2)
transform = OnConditionFeatureTransform(
condition=lambda matrix: matrix.shape[0] > compress_features_after,
on_condition=MinibatchKmeansFeatureTransform(n_clusters, n_parallel),
otherwise=DummyFeatureTransform()
)
train_index(features_path, index_filepath, builder, transform)
def main() -> None:
arg_parser = argparse.ArgumentParser("crate faiss indexes for feature retrieval")
arg_parser.add_argument("--debug", action="store_true")
arg_parser.add_argument("--prefix", default='', help="add prefix to index filename")
arg_parser.add_argument('--speakers', nargs="+",
help="speaker names to create an index. By default all speakers are from data_svc")
arg_parser.add_argument("--compress-features-after", type=int, default=200_000,
help="If the number of features is greater than the value compress "
"feature vectors using MiniBatchKMeans.")
arg_parser.add_argument("--n-clusters", type=int, default=10_000,
help="Number of centroids to which features will be compressed")
arg_parser.add_argument("--n-parallel", type=int, default=multiprocessing.cpu_count()-1,
help="Nuber of parallel job of MinibatchKmeans. Default is cpus-1")
args = arg_parser.parse_args()
if args.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
base_path = Path(".").absolute() / "data_svc"
if args.speakers:
speakers = args.speakers
else:
speakers = get_speaker_list(base_path)
logger.info("got %s speakers: %s", len(speakers), speakers)
indexes_path = create_indexes_path(base_path)
create_index_func = partial(
create_index,
prefix=args.prefix,
base_path=base_path,
indexes_path=indexes_path,
compress_features_after=args.compress_features_after,
n_clusters=args.n_clusters,
n_parallel=args.n_parallel,
)
for speaker in speakers:
logger.info("create hubert index for speaker %s", speaker)
create_index_func(feature_name="hubert", speaker=speaker)
logger.info("create whisper index for speaker %s", speaker)
create_index_func(feature_name="whisper", speaker=speaker)
logger.info("done!")
if __name__ == '__main__':
main()