Skip to content

Commit

Permalink
add additional audio attack effects and update conf
Browse files Browse the repository at this point in the history
  • Loading branch information
hastagAB committed Nov 26, 2024
1 parent f593185 commit 14d0d72
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 0 deletions.
205 changes: 205 additions & 0 deletions audiocraft/utils/audio_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,208 @@ def aac_compression(
tensor, get_aac, sr=sample_rate, bitrate=bitrate, lowpass_freq=lowpass_freq
)
return audio_effect_return(tensor=out, mask=mask)

@staticmethod
def pitch_shift(
tensor: torch.Tensor,
n_steps: float = 2.0,
sample_rate: int = 16000,
mask: tp.Optional[torch.Tensor] = None,
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Change the pitch of the audio signal by a given number of steps.
Parameters:
- tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time).
- n_steps (float): Number of pitch steps to shift (positive for higher pitch, negative for lower pitch).
- sample_rate (int): Sample rate of the audio signal.
- mask (torch.Tensor): Optional mask tensor.
Returns:
- torch.Tensor: Pitch-shifted audio tensor.
"""
shifted_tensor = torchaudio.transforms.PitchShift(sample_rate, n_steps=n_steps)(tensor)
return audio_effect_return(tensor=shifted_tensor, mask=mask)

@staticmethod
def reverse(
tensor: torch.Tensor,
mask: tp.Optional[torch.Tensor] = None,
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Reverse the audio signal.
Parameters:
- tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time).
- mask (torch.Tensor): Optional mask tensor.
Returns:
- torch.Tensor: Reversed audio tensor.
"""
reversed_tensor = torch.flip(tensor, dims=[-1])
return audio_effect_return(tensor=reversed_tensor, mask=mask)

@staticmethod
def clipping(
tensor: torch.Tensor,
clip_value: float = 0.5,
mask: tp.Optional[torch.Tensor] = None,
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Clip the audio signal to a specific threshold value, distorting the signal.
Parameters:
- tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time).
- clip_value (float): Threshold for clipping the audio signal.
- mask (torch.Tensor): Optional mask tensor.
Returns:
- torch.Tensor: Clipped audio tensor.
"""
clipped_tensor = torch.clamp(tensor, min=-clip_value, max=clip_value)
return audio_effect_return(tensor=clipped_tensor, mask=mask)

@staticmethod
def time_stretch(
tensor: torch.Tensor,
stretch_factor: float = 1.2,
mask: tp.Optional[torch.Tensor] = None,
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Stretch the audio signal in time without changing its pitch.
Parameters:
- tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time).
- stretch_factor (float): Factor by which to stretch the audio (greater than 1 for slower, less than 1 for faster).
- mask (torch.Tensor): Optional mask tensor.
Returns:
- torch.Tensor: Time-stretched audio tensor.
"""
stretched_tensor = julius.time_stretch(tensor, stretch_factor)
return audio_effect_return(tensor=stretched_tensor, mask=mask)

@staticmethod
def tremolo(
tensor: torch.Tensor,
frequency: float = 5.0,
depth: float = 0.5,
sample_rate: int = 16000,
mask: tp.Optional[torch.Tensor] = None,
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Apply a tremolo effect to the audio signal by modulating its amplitude.
Parameters:
- tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time).
- frequency (float): Frequency of the tremolo effect in Hz.
- depth (float): Depth of modulation (between 0 and 1).
- sample_rate (int): Sample rate of the audio signal.
- mask (torch.Tensor): Optional mask tensor.
Returns:
- torch.Tensor: Audio tensor with tremolo effect applied.
"""
time = torch.arange(tensor.shape[-1], device=tensor.device) / sample_rate
modulation = (1.0 + depth * torch.sin(2 * torch.pi * frequency * time)) / 2.0
tremolo_tensor = tensor * modulation.unsqueeze(0).unsqueeze(0)
return audio_effect_return(tensor=tremolo_tensor, mask=mask)

@staticmethod
def flanger(
tensor: torch.Tensor,
delay: float = 0.002,
depth: float = 0.002,
rate: float = 0.25,
sample_rate: int = 16000,
mask: tp.Optional[torch.Tensor] = None,
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Apply a flanger effect to the audio signal by mixing a delayed version of the signal with itself.
Parameters:
- tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time).
- delay (float): Base delay time in seconds.
- depth (float): Depth of the delay modulation.
- rate (float): Rate of modulation in Hz.
- sample_rate (int): Sample rate of the audio signal.
- mask (torch.Tensor): Optional mask tensor.
Returns:
- torch.Tensor: Audio tensor with flanger effect applied.
"""
time = torch.arange(tensor.shape[-1], device=tensor.device) / sample_rate
lfo = torch.sin(2 * torch.pi * rate * time) * depth + delay
lfo_samples = (lfo * sample_rate).long().clamp(0, tensor.shape[-1] - 1)
delayed_signal = tensor[..., lfo_samples]
flanger_tensor = tensor + delayed_signal
return audio_effect_return(tensor=flanger_tensor, mask=mask)

@staticmethod
def bit_crusher(
tensor: torch.Tensor,
bit_depth: int = 8,
mask: tp.Optional[torch.Tensor] = None,
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Apply a bit crusher effect by reducing the bit depth of the audio signal.
Parameters:
- tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time).
- bit_depth (int): Bit depth to reduce to (e.g., 8 bits).
- mask (torch.Tensor): Optional mask tensor.
Returns:
- torch.Tensor: Audio tensor with reduced bit depth.
"""
scale = 2 ** bit_depth
crushed_tensor = torch.round(tensor * scale) / scale
return audio_effect_return(tensor=crushed_tensor, mask=mask)

@staticmethod
def ring_modulation(
tensor: torch.Tensor,
modulation_frequency: float = 30.0,
sample_rate: int = 16000,
mask: tp.Optional[torch.Tensor] = None,
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Apply a ring modulation effect to the audio signal, creating a metallic sound.
Parameters:
- tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time).
- modulation_frequency (float): Frequency of the modulation in Hz.
- sample_rate (int): Sample rate of the audio signal.
- mask (torch.Tensor): Optional mask tensor.
Returns:
- torch.Tensor: Ring-modulated audio tensor.
"""
time = torch.arange(tensor.shape[-1], device=tensor.device) / sample_rate
modulation = torch.sin(2 * torch.pi * modulation_frequency * time)
ring_modulated_tensor = tensor * modulation.unsqueeze(0).unsqueeze(0)
return audio_effect_return(tensor=ring_modulated_tensor, mask=mask)

@staticmethod
def granulate(
tensor: torch.Tensor,
grain_size: int = 512,
overlap: float = 0.5,
mask: tp.Optional[torch.Tensor] = None,
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Apply a granulation effect by breaking the audio into small overlapping grains.
Parameters:
- tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time).
- grain_size (int): Size of each grain in samples.
- overlap (float): Overlap ratio between grains (0 to 1).
- mask (torch.Tensor): Optional mask tensor.
Returns:
- torch.Tensor: Granulated audio tensor.
"""
step_size = int(grain_size * (1 - overlap))
grains = [tensor[..., i:i+grain_size] for i in range(0, tensor.shape[-1] - grain_size, step_size)]
granulated_tensor = torch.cat(grains, dim=-1)
return audio_effect_return(tensor=granulated_tensor, mask=mask)
34 changes: 34 additions & 0 deletions config/augmentations/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@ audio_effects:
encodec:
ckpt: "//pretrained/facebook/encodec_24khz"
n_qs: [4, 8, 16]
pitch_shift:
sample_rate: ${sample_rate}
n_steps: 2.0
clipping:
clip_value: 0.5
time_stretch:
sample_rate: ${sample_rate}
stretch_factor: 1.2
tremolo:
frequency: 5.0
depth: 0.5
sample_rate: ${sample_rate}
flanger:
delay: 0.002
depth: 0.002
rate: 0.25
sample_rate: ${sample_rate}
bit_crusher:
bit_depth: 8
ring_modulation:
modulation_frequency: 30.0
sample_rate: ${sample_rate}
granulate:
grain_size: 512
overlap: 0.5

select_aug_mode:
"use_eval" # other are 'all' and 'use_eval_acc', used to sample augmentations, `fixed` uses the prob from aug_weights, `all` uses all agmentations every step
Expand All @@ -61,5 +86,14 @@ aug_weights:
aac_compression: 0.1 # eval only never use in training even if eval_acc low
encodec: 0.1
identity: 1 # no augmentation
pitch_shift: 0.1
reverse: 0.1
clipping: 0.1
time_stretch: 0.1
tremolo: 0.1
flanger: 0.1
bit_crusher: 0.1
ring_modulation: 0.1
granulate: 0.1

n_max_aug: null

0 comments on commit 14d0d72

Please sign in to comment.