Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Multi query attention #3

Open
wants to merge 9 commits into
base: load-iter
Choose a base branch
from
7 changes: 7 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ def _add_network_size_args(parser):
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.')
group.add_argument('--attention-head-type', type=str, default='multihead',
choices=['multihead', 'multiquery'],
help='Type of attention heads. `multihead` is the standard multi-head attention.'
'`multiquery` shares the values and keys across attention heads')
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
Expand Down Expand Up @@ -477,6 +481,9 @@ def _add_logging_args(parser):
help="Name of wandb entity for reporting")
group.add_argument('--wandb-project-name', type=str, default=None,
help="Name of wandb project")
group.add_argument('--transformer-timers', action='store_true',
help="If set, activate the timers within the transformer layers."
"Only for debugging, as this slows down the model.")

return parser

Expand Down
199 changes: 189 additions & 10 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_linear_layer


from .glu_activations import GLU_ACTIVATIONS
Expand Down Expand Up @@ -233,6 +233,7 @@ def forward(self, query_layer, key_layer,
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
np = query_layer.size(2)

# [b, np, sq, sk]
output_size = (query_layer.size(1),
Expand All @@ -253,6 +254,7 @@ def forward(self, query_layer, key_layer,
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
else:
# alibi: (batch_size * num_attention_heads, 1, max_seq_len)
matmul_input_buffer = alibi[:output_size[0]*output_size[1], :, :output_size[3]]

# Raw attention scores. [b * np, sq, sk]
Expand Down Expand Up @@ -307,7 +309,7 @@ def forward(self, query_layer, key_layer,

# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
np,
query_layer.size(0),
value_layer.size(3))

Expand Down Expand Up @@ -336,6 +338,127 @@ def forward(self, query_layer, key_layer,
return context_layer


class MultiQueryCoreAttention(CoreAttention):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
sq = query_layer.size(0)
bs = query_layer.size(1)
np = query_layer.size(2)

sk = key_layer.size(0)
# Only one head for key and values
assert key_layer.size(2) == 1 and value_layer.size(2) == 1

# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))

# [sq, b, np, hn] -> [b, np * sq, hn]
query_layer = query_layer.permute([1, 2, 0, 3]).reshape(bs, np * sq, -1)
# [sk, b, 1, hn] -> [b, hn, sk]
key_layer = key_layer.squeeze(2).permute(1, 2, 0)
# [sk, b, 1, hn] -> [sk, b * np, hn]
# key_layer = key_layer.expand(output_size[3], output_size[0], np, -1)
# key_layer = key_layer.reshape(output_size[3], output_size[0] * np, -1)

if alibi is None:
# preallocting input tensor: [b, np * sq, sk]
matmul_input_buffer = get_global_memory_buffer().get_tensor(
(bs, np * sq, sk),
query_layer.dtype, "mpu")
else:
# alibi: (batch_size * num_attention_heads, 1, max_seq_len)
# TODO: ideally, alibi would have the shape: (1, num_heads * sq, sk)
matmul_input_buffer = alibi[:bs * np, :, :sk].view(bs, np, sk)
matmul_input_buffer = matmul_input_buffer.repeat(1, sq, 1) # [b, np * sq, sk]

if alibi is None:
# Raw attention scores. [b, np * sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer, # [b, np * sq, hn]
key_layer, # [b, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
else:
if not hasattr(self, "logged_alibi"):
print("Using Alibi.")
self.logged_alibi = True

if self.apply_query_key_layer_scaling:
beta = 1.0 / self.layer_number
else:
beta = 1.0

matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer,
key_layer,
beta=beta, alpha=(1.0 / self.norm_factor))

# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(bs, np, sq, sk)

# ===========================
# Attention probs and dropout
# ===========================

# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.

if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)

# =========================
# Context layer. [sq, b, hp]
# =========================

# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]

# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
np,
query_layer.size(0),
value_layer.size(3))

# [sk, b, 1, hn] -> [b, sk, hn]
value_layer = value_layer.squeeze(2).transpose(0, 1)

# change view [b, np * sq, sk]
attention_probs = attention_probs.view(bs, np * sq, -1)

# matmul: [b, np * sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)

# change view [b, np, sq, hn]
context_layer = context_layer.view(bs, np, sq, -1)

# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)

return context_layer


class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.

Expand All @@ -353,6 +476,7 @@ def __init__(self, init_method,
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
self.attention_head_type = args.attention_head_type

projection_size = args.kv_channels * args.num_attention_heads

Expand All @@ -364,13 +488,28 @@ def __init__(self, init_method,
args.num_attention_heads, world_size)

# Strided linear layer.
if attention_type == AttnType.self_attn:
if attention_type == AttnType.self_attn and self.attention_head_type == 'multihead':
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
else:
elif attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery':
# TODO: Find a way to merge the query and key-value computations?
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
# In MultiQuery attention, keys and values are shared across heads
# Use args.kv_channels instead of projection_size
# No `.fork()` so the rng tracker is shared across tensor-parallel processes.
# with mpu.get_cuda_rng_tracker():
self.key_value = get_linear_layer(
args.hidden_size,
2 * args.kv_channels,
init_method=init_method)
elif attention_type == AttnType.cross_attn and self.attention_head_type == 'multihead':
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
Expand All @@ -383,9 +522,14 @@ def __init__(self, init_method,
2 * projection_size,
gather_output=False,
init_method=init_method)
else:
raise NotImplementedError("Multiquery attention not implemented for cross-attention.")

self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
if self.attention_head_type == 'multihead':
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
else:
self.core_attention = MultiQueryCoreAttention(self.layer_number, self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'

# Output.
Expand Down Expand Up @@ -419,15 +563,15 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.num_attention_heads_per_partition if self.attention_head_type == "multihead" else 1,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())


def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None, alibi=None):
# hidden_states: [sq, b, h]

# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
Expand All @@ -449,7 +593,7 @@ def forward(self, hidden_states, attention_mask,
# Query, Key, and Value
# =====================

if self.attention_type == AttnType.self_attn:
if self.attention_type == AttnType.self_attn and self.attention_head_type == 'multihead':
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

Expand All @@ -463,6 +607,35 @@ def forward(self, hidden_states, attention_mask,
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
elif self.attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery':
# Attention heads [sq, b, h] --> [sq, b, (2 * hn)]
mixed_kv_layer = self.key_value(hidden_states)

# [sq, b, (2 * hn)] --> [sq, b, np (expanded), 2 * hn]
# new_tensor_shape = mixed_kv_layer.size()[:-1] + \
# (self.num_attention_heads_per_partition,
# 2 * self.hidden_size_per_attention_head)
# mixed_kv_layer = mixed_kv_layer.unsqueeze(2).expand(*new_tensor_shape)

# [sq, b, (2 * hn)] --> [sq, b, 1, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(1,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

# [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn]
(key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

# Attention head [sq, b, h] --> [sq, b, np * hn]
query_layer, _ = self.query(hidden_states)
# [sq, b, np * hn] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)

# [sq, b, np, hn] -> [b, np * sq, hn]
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
Expand All @@ -489,6 +662,7 @@ def forward(self, hidden_states, attention_mask,
# Adjust key and value for inference
# ==================================


if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
Expand Down Expand Up @@ -520,7 +694,6 @@ def forward(self, hidden_states, attention_mask,
# =================
# Output. [sq, b, h]
# =================

output, bias = self.dense(context_layer)

return output, bias
Expand Down Expand Up @@ -963,6 +1136,10 @@ def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
timers = get_timers()
args = get_args()

if args.transformer_timers: timers("Transformer forward").start()

# Checks.
if inference_params:
Expand Down Expand Up @@ -1020,4 +1197,6 @@ def forward(self, hidden_states, attention_mask,
if self.post_process and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)

if args.transformer_timers: timers("Transformer forward").stop()

return hidden_states
3 changes: 2 additions & 1 deletion megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def backward(ctx, grad_output):
handle.wait()

# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
# TODO: Is the reshape preventing us from getting a speedup here?
grad_output = grad_output.reshape(grad_output.shape[0] * grad_output.shape[1],
grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2])
Expand Down
20 changes: 20 additions & 0 deletions megatron/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,19 @@ def allreduce_embedding_grads(self, args):
"""All-reduce both word and position embeddings."""
self.allreduce_word_embedding_grads(args)
self.allreduce_position_embedding_grads(args)

def allreduce_key_value_grads(self, args):
# TODO: models[0] ?
unwrapped_model = self.models[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
for layer in unwrapped_model.language_model.encoder.layers:
kv_weight = layer.self_attention.key_value.weight
if args.DDP_impl == 'local':
grad = kv_weight.main_grad
else:
grad = kv_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_tensor_model_parallel_group())


def allreduce_layernorm_grads(self, args):
Expand Down Expand Up @@ -310,6 +323,13 @@ def reduce_model_grads(self, args, timers):
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()

# All-reduce key-value grads if needed.
if args.attention_head_type == "multiquery":
timers('backward-key-value-all-reduce').start()
self.allreduce_key_value_grads(args)
timers('backward-key-value-all-reduce').stop()



class MixedPrecisionOptimizer(MegatronOptimizer):
"""Base class for both the float-16 and the distributed optimizer.
Expand Down
Loading