This project aims to fine-tune a Wav2Vec2Bert audio encoder model to generate audio embeddings that can be used in place of traditional CLIP text encoder. By integrating audio embeddings, we can leverage the unique properties of audio data to unlock new possibilities for Stable Diffusion models.
🤗 Model on Hugging Face | 📝 Training Notebook
Audio contains a wealth of information that often goes untapped, extending far beyond what we typically imagine. With the rise of Latent Diffusion models and their impressive generative capabilities, there is a growing interest in exploring diverse conditioning techniques. The most common approach involves using CLIP (Contrastive Language-Image Pre-Training) text encoders to condition the model. However, audio data offers a rich and multifaceted source of information that can significantly enhance the conditioning process.
The core idea behind our training process is to achieve cross-modal alignment between audio and text embeddings using a two-stream architecture. This involves leveraging the powerful CLIPTextModel to generate text embeddings that serve as true labels for the audio embeddings produced by our Wav2Vec2Bert model. Here’s a detailed explanation:
-
Two-Stream Architecture:
- Text Stream: We use the CLIPTextModel to generate text embeddings for given text inputs. These embeddings capture rich semantic information and serve as the ground truth labels.
- Audio Stream: Our Wav2Vec2Bert model use convolutional feature encoder followed by a transformer network to processes audio inputs and generate corresponding audio embeddings.
-
Cross-Modality Alignment:
- Objectif: The primary goal of the training is to align the audio embeddings with the text embeddings in a shared embedding space. This ensures that semantically similar audio and text inputs are mapped close to each other.
- Loss Function: We achieve this alignment using Contrastive loss which encourages the model to bring embeddings of matching audio-text pairs closer while pushing apart embeddings of non-matching pairs.
This is similar to how the original CLIP model was trained to align image-text pairs. The difference is that in the OpenAI CLIP model, contrastive loss was computed using the [CLS] token, while we will apply contrastive loss at the sequence level.
This image can explain the logic behind this loss:
Transformers library was used for the training. "facebook/w2v-bert-2.0" checkpoint was loaded as initial pretrained model. Data preparation, training details, and the hyperparameters used can be found in the train_me.ipynb notebook:
- Dataset: We used the nateraw/fsd50k (Freesound Database 50K) dataset, which can be found on Hugging Face, consisting of sound events and their corresponding descriptions.
- Adapter: A convolutional adapter was added on top of the transformer architecture to downsample the dimensionality and match the CLIP text embedding size.
model = Wav2Vec2BertModel.from_pretrained(
"facebook/w2v-bert-2.0",
add_adapter=True,
adapter_kernel_size=3,
adapter_stride=2,
num_adapter_layers=2,
layerdrop=0.0,
)
- Contrastive loss implementation: We can implement the loss function with PyTorch by subclassing the Trainer class and overriding the default loss.
def Contrastive_loss(embeddings1, embeddings2, temperature=0.15):
cos_sim = torch.cosine_similarity(embeddings1.unsqueeze(1), embeddings2.unsqueeze(0), dim=-1)
cos_sim = cos_sim / temperature
labels = torch.arange(embeddings1.size(0)).unsqueeze(1).repeat(1, embeddings1.size(1)).to(embeddings1.device)
loss = F.cross_entropy(cos_sim, labels)
return loss
class TrainBert(Trainer):
def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("text_embeddings")
outputs = model(**inputs)
outputs=outputs.last_hidden_state
loss = Contrastive_loss(outputs, labels)
outputs = (loss, outputs)
return outputs if return_outputs else loss
- Metrics: We compute Euclidean Distance, Cosine Similarity, and Mean Squared Error, and use them as metrics, they can give us a view of how well the model is evolving to achieve alignment.
NB: Batch size is a crucial hyperparameter to tune, as it defines how many negative samples are passed to the model.
- install dependencies:
pip install -r requirements.txt
- download pretrained model:
ckpt 1728 or ckpt 2016
Wav2Vec2BertModel.from_pretrained('youzarsif/wav2vec2bert_2_diffusion')
or
Wav2Vec2BertModel.from_pretrained('youzarsif/wav2vec2bert_2_diffusion_ckpt_1728')
- Stable diffusion:
Feel free to use any variation of Stable Diffusion, ControlNet, or similar models, as long as they utilize the same CLIP encoder.
StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
- Gradio interface:
python3 app.py
Directly build and run Docker image:
docker build -t app.py .
docker run -p 7860:7860 app.py
During training, it was observed that the dataset used was poorly annotated and relied on generic labels. Utilizing a more diverse and well-elaborated dataset will enhance the model's performance.
Additionally, due to resource limitations, a small convolution adapter was used. Using a bigger adapter to match the CLIP max sequence length can indeed improve model performance, as it allows the model to capture more information .
https://github.com/Stability-AI/stablediffusion/tree/main
https://github.com/openai/CLIP/tree/main
https://huggingface.co/docs/transformers/en/model_doc/wav2vec2-bert