Skip to content

Commit

Permalink
add value residual learning, given people are interested in this agai…
Browse files Browse the repository at this point in the history
…n due to notebooklm
  • Loading branch information
lucidrains committed Oct 31, 2024
1 parent 602b616 commit a25e67e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 8 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,12 @@ generated_speech = model.generate(
primaryClass = {cs.CL}
}
```

```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.11',
version = '0.5.0',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
38 changes: 31 additions & 7 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,22 +307,32 @@ def forward(
context = None,
mask = None,
rotary_emb = None,
attn_bias = None
attn_bias = None,
return_values = False,
value_residual = None
):
n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context)
context = default(context, x)

q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

if exists(value_residual):
v = 0.5 * (v + value_residual)

if exists(rotary_emb):
q = apply_rotary_pos_emb(rotary_emb, q)
k = apply_rotary_pos_emb(rotary_emb, k)

out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
out = self.to_out(out)

if not return_values:
return out

return out, v

class FeedForward(Module):
def __init__(
Expand Down Expand Up @@ -418,18 +428,26 @@ def forward(
x,
mask = None,
rotary_emb = None,
attn_bias = None
attn_bias = None,
attn_value_residual = None,
return_values = False
):
x = self.ff1(x) + x

if exists(self.gateloop):
x = self.gateloop(x) + x

x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x
attn_out, attn_values = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias, value_residual = attn_value_residual, return_values = True)
x = attn_out + x

x = self.conv(x, mask = mask) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x

if not return_values:
return x

return x, attn_values

# Conformer

Expand Down Expand Up @@ -484,14 +502,20 @@ def forward(self, x, mask = None):
rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None
attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None

attn_value_residual = None

for block in self.layers:
x = block(
x, attn_values = block(
x,
mask = mask,
rotary_emb = rotary_emb,
attn_bias = attn_bias
attn_bias = attn_bias,
attn_value_residual = attn_value_residual,
return_values = True
)

attn_value_residual = default(attn_value_residual, attn_values)

return x

# conformer with sum reduction across quantized tokens at the beginning, along with heads
Expand Down

0 comments on commit a25e67e

Please sign in to comment.