From 3147d51017b5744df10a52c4ede73344ae369117 Mon Sep 17 00:00:00 2001 From: amancini-N Date: Wed, 11 Dec 2024 14:18:49 +0000 Subject: [PATCH 1/2] Allow returning attention probs from MultiHeadAttention --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 2 +- .../contrib_ops/cpu/bert/attention_cpu_base.h | 16 +++- .../decoder_masked_multihead_attention.cc | 4 +- .../cpu/bert/multihead_attention.cc | 12 ++- .../cpu/quantization/attention_quant.cc | 2 +- .../core/graph/contrib_ops/bert_defs.cc | 29 ++++++ .../contrib_ops/attention_op_test_helper.cc | 15 ++++ .../contrib_ops/attention_op_test_helper.h | 1 + .../multihead_attention_op_test.cc | 43 ++++++--- .../attention/attention_test_data.txt | 88 +++++++++++++++++++ 10 files changed, 188 insertions(+), 24 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index ad14fb8258656..378ae81cee8ff 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -335,7 +335,7 @@ Status Attention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */, - output, nullptr /* present_key */, nullptr /* present_value */, + output, nullptr /* present_key */, nullptr /* present_value */, nullptr /* attn_probs */, batch_size, sequence_length, sequence_length, parameters.head_size, parameters.v_head_size, parameters.v_hidden_size, attention_bias, context); diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 87938f3728750..4124579636d84 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -29,6 +29,7 @@ class AttentionCPUBase : public AttentionBase { Tensor* output, // output tensor Tensor* present_key, // present K output tensor (if separating present KV) Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* attn_probs, // attention probabilities output tensor (optional) int batch_size, // batch size (B) int sequence_length, // sequence length of Q (S) int kv_sequence_length, // sequence length of K or V (L) @@ -102,10 +103,17 @@ class AttentionCPUBase : public AttentionBase { } // Compute the attention score. - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T); - auto attention_probs = allocator->Alloc(bytes); + void* attention_probs = nullptr; + T* attn_probs_data = nullptr; + if (attn_probs == nullptr) { + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T); + attention_probs = allocator->Alloc(bytes); + attn_probs_data = static_cast(attention_probs); + } else { + attn_probs_data = attn_probs->MutableData(); + } BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); - ComputeAttentionProbs(static_cast(attention_probs), Q, K, + ComputeAttentionProbs(attn_probs_data, Q, K, static_cast(mask_data), batch_size, sequence_length, kv_sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, present_data, @@ -117,7 +125,7 @@ class AttentionCPUBase : public AttentionBase { allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T)); BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator))); - ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), static_cast(attention_probs), + ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), attn_probs_data, V, batch_size, sequence_length, kv_sequence_length, past_sequence_length, v_head_size, v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp, past_present_share_buffer, max_sequence_length); diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc index e6f65f92e14f4..8def8c7383ef8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc @@ -189,7 +189,7 @@ Status DecoderMaskedMultiHeadAttention::Compute(OpKernelContext* context) con key->Data(), value->Data(), mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value, - batch_size, 1 /* sequence_length */, parameters.kv_sequence_length, + nullptr /* attn_probs */, batch_size, 1 /* sequence_length */, parameters.kv_sequence_length, head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk); } @@ -205,7 +205,7 @@ Status DecoderMaskedMultiHeadAttention::Compute(OpKernelContext* context) con K.GetMutable()->MutableData(), V.GetMutable()->MutableData(), mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value, - batch_size, 1 /* sequence_length */, parameters.kv_sequence_length, + nullptr /* attn_probs */, batch_size, 1 /* sequence_length */, parameters.kv_sequence_length, head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk, parameters.past_sequence_length, true /* past_present_share_buffer */); } diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index ca818f09c4b1e..5aafd8acc5a0c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -102,6 +102,13 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { output_shape[2] = static_cast(parameters.v_hidden_size); Tensor* output = context->Output(0, output_shape); + std::vector attn_probs_shape(4); + attn_probs_shape[0] = static_cast(batch_size); + attn_probs_shape[1] = static_cast(num_heads_); + attn_probs_shape[2] = static_cast(q_sequence_length); + attn_probs_shape[3] = static_cast(parameters.total_sequence_length); + Tensor* attn_probs = context->Output(3, attn_probs_shape); + constexpr int q_bias_offset = 0; const int k_bias_offset = qk_hidden_size; const int v_bias_offset = 2 * qk_hidden_size; @@ -134,7 +141,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { key->Data(), value->Data(), key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v, - batch_size, q_sequence_length, kv_sequence_length, + attn_probs, batch_size, q_sequence_length, kv_sequence_length, qk_head_size, v_head_size, v_hidden_size, attn_bias, context); } @@ -154,6 +161,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { past_value == nullptr && present_k == nullptr && present_v == nullptr && + attn_probs == nullptr && // TODO: can we support it? l2_cache_size_ > 0) { MlasFlashAttentionThreadedArgs args; args.batch_size = batch_size; @@ -214,7 +222,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { K.GetMutable()->MutableData(), V.GetMutable()->MutableData(), key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v, - batch_size, q_sequence_length, kv_sequence_length, + attn_probs, batch_size, q_sequence_length, kv_sequence_length, qk_head_size, v_head_size, v_hidden_size, attn_bias, context); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 2c897f183164f..a0af27ac87f70 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -289,7 +289,7 @@ Status QAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V return ApplyAttention(Q, K, V, mask_index, past_tensor, nullptr /* past_key */, nullptr /* past_value*/, - output, nullptr /* present_key */, nullptr /* present_value */, + output, nullptr /* present_key */, nullptr /* present_value */, nullptr /* attn_probs */, batch_size, sequence_length, sequence_length, head_size, head_size, hidden_size, nullptr /* rel_pos_bias */, context); } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index f2a2a52f8334f..e560aeadd2e2b 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -226,6 +226,30 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c } } } + + if (ctx.getNumOutputs() > 3) { // has attention_probs output + // Output 3 has shape (batch_size, num_heads, sequence_length, total_sequence_length) + if (hasInputShape(ctx, 0) && hasInputShape(ctx, past_key_index)) { + auto& query_shape = getInputShape(ctx, 0); + auto& key_shape = getInputShape(ctx, 1); + auto& key_seqlen_dim = key_shape.dim()[1]; + auto& past_seqlen_dim = getInputShape(ctx, past_key_index).dim()[2]; + if (key_seqlen_dim.has_dim_value() && past_seqlen_dim.has_dim_value()) { + auto kv_sequence_length = key_seqlen_dim.dim_value(); + auto past_sequence_length = past_seqlen_dim.dim_value(); + int64_t total_sequence_length = kv_sequence_length + past_sequence_length; + auto num_heads = getAttribute(ctx, "num_heads", 0); + + ONNX_NAMESPACE::TensorShapeProto attention_probs_shape; + *attention_probs_shape.add_dim() = query_shape.dim()[0]; + attention_probs_shape.add_dim()->set_dim_value(num_heads); + *attention_probs_shape.add_dim() = query_shape.dim()[1]; + attention_probs_shape.add_dim()->set_dim_value(total_sequence_length); + updateOutputShape(ctx, 3, attention_probs_shape); + propagateElemTypeFromInputToOutput(ctx, 0, 3); + } + } + } } // Type and shape inference for group query attention and sparse attention. @@ -1034,6 +1058,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "or present state for self attention value with shape (batch_size, num_heads, total_sequence_length, head_size)", "T", OpSchema::Optional) + .Output(3, + "attention_probs", + "Attention probabilities with shape (batch_size, num_heads, sequence_length, total_sequence_length)", + "T", + OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc index 5df521bd6381d..1555d813ea6fb 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc @@ -76,6 +76,7 @@ void GetCrossAttentionData_HeadSize40(AttentionTestData& data) { LoadTensor("CrossAttentionData_HeadSize40.bias_data", data.bias_data); LoadTensor("CrossAttentionData_HeadSize40.fp32_output_data", data.fp32_output_data); LoadTensor("CrossAttentionData_HeadSize40.fp16_output_data", data.fp16_output_data); + LoadTensor("CrossAttentionData_HeadSize40.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data) { @@ -83,6 +84,7 @@ void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data) { data.bias_data.clear(); LoadTensor("CrossAttentionData_HeadSize40_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttentionData_HeadSize40_NoBias.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d) { @@ -113,6 +115,7 @@ void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.bias_data", data.bias_data); LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d) { @@ -121,6 +124,7 @@ void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTe LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data) { @@ -145,6 +149,7 @@ void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.bias_data", data.bias_data); LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data) { @@ -152,6 +157,7 @@ void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTes data.bias_data.clear(); LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data) { @@ -174,6 +180,7 @@ void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTes // Do not test fp32 data.fp32_output_data = {}; LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.fp16_output_data", data.fp16_output_data); + LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.attention_probs_data", data.attention_probs_data); } void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data) { @@ -217,6 +224,7 @@ void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) { LoadTensor("CrossAttentionData_HeadSize16_8.bias_data", data.bias_data); LoadTensor("CrossAttentionData_HeadSize16_8.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttentionData_HeadSize16_8.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data) { @@ -224,6 +232,7 @@ void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data) { data.bias_data.clear(); LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_HeadSize16(AttentionTestData& data) { @@ -241,6 +250,7 @@ void GetCrossAttentionData_HeadSize16(AttentionTestData& data) { LoadTensor("CrossAttentionData_HeadSize16.bias_data", data.bias_data); LoadTensor("CrossAttentionData_HeadSize16.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttentionData_HeadSize16.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data) { @@ -248,6 +258,7 @@ void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data) { data.bias_data.clear(); LoadTensor("CrossAttentionData_HeadSize16_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttentionData_HeadSize16_NoBias.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_HeadSize8(AttentionTestData& data) { @@ -265,6 +276,7 @@ void GetCrossAttentionData_HeadSize8(AttentionTestData& data) { LoadTensor("CrossAttention_Batch1_HeadSize8.bias_data", data.bias_data); LoadTensor("CrossAttention_Batch1_HeadSize8.output", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttention_Batch1_HeadSize8.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data) { @@ -272,6 +284,7 @@ void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data) { data.bias_data.clear(); LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.output", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; + LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionDataWithPast(AttentionTestData& data) { @@ -406,6 +419,7 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.present_key_data", data.present_key_data); LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.present_value_data", data.present_value_data); data.is_static_kv = true; + LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data) { @@ -416,6 +430,7 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestDat LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_key_data", data.present_key_data); LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_value_data", data.present_value_data); data.is_static_kv = true; + LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); } void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data) { diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.h b/onnxruntime/test/contrib_ops/attention_op_test_helper.h index b0dbe6e7b4ac7..2a875c2a9abba 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.h +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.h @@ -38,6 +38,7 @@ struct BaseAttentionTestData { std::vector present_key_data; std::vector present_value_data; + std::vector attention_probs_data; std::vector skip_kernel_types; // skip some kernels if they do not supported this test case. }; diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 6b6799d73fb56..d103ddf653139 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -36,6 +36,7 @@ static void RunMultiHeadAttentionTest( const std::vector& past_value_data, // past_value: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& present_key_data, // present_key: [batch_size, num_heads, total_sequence_length, head_size] const std::vector& present_value_data, // present_value: [batch_size, num_heads, total_sequence_length, head_size] + const std::vector& attention_probs_data, // attention_probs: [batch_size, num_heads, sequence_length, kv_sequence_length] const std::vector& key_padding_mask_data, // key_padding_mask: see below AttentionMaskType mask_type, // 1 for [batch_size], 2 for [batch_size, kv_sequence_length] const std::vector& output_data, // output: [batch_size, sequence_length, v_hidden_size] @@ -90,6 +91,7 @@ static void RunMultiHeadAttentionTest( std::vector present_key_dims = {batch_size, num_heads, is_static_kv ? kv_sequence_length : sequence_length + kv_sequence_length, hidden_size / num_heads}; std::vector present_value_dims = present_key_dims; + std::vector attention_probs_dims = {batch_size, num_heads, sequence_length, kv_sequence_length}; std::vector query = (qkv_data.size() > 0 ? qkv_data : query_data); std::vector key; @@ -179,6 +181,12 @@ static void RunMultiHeadAttentionTest( } else { tester.AddOptionalOutputEdge(); } + + if (attention_probs_data.size()) { + tester.AddOutput("attention_probs", attention_probs_dims, ToFloat16(attention_probs_data), /*sort*/ false, rel_error, abs_error); + } else { + tester.AddOptionalOutputEdge(); + } } else { tester.AddInput("query", query_dims, query); @@ -243,6 +251,12 @@ static void RunMultiHeadAttentionTest( } else { tester.AddOptionalOutputEdge(); } + + if (attention_probs_data.size()) { + tester.AddOutput("attention_probs", attention_probs_dims, attention_probs_data, /*sort*/ false, rel_error, abs_error); + } else { + tester.AddOptionalOutputEdge(); + } } if (enable_cuda) { @@ -289,6 +303,7 @@ static void RunMultiHeadAttentionKernel( const std::vector& past_value_data, // past_value: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& present_key_data, // present_key: [batch_size, num_heads, total_sequence_length, head_size] const std::vector& present_value_data, // present_value: [batch_size, num_heads, total_sequence_length, head_size] + const std::vector& attention_probs_data, // attention_probs: [batch_size, num_heads, sequence_length, kv_sequence_length] const std::vector& key_padding_mask_data, // key_padding_mask: see below AttentionMaskType mask_type, // 1 for [batch_size], 2 for [batch_size, kv_sequence_length] const std::vector& output_data, // output: [batch_size, sequence_length, v_hidden_size] @@ -316,7 +331,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, - past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, + past_key_data, past_value_data, present_key_data, present_value_data, attention_probs_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); @@ -333,7 +348,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, - past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, + past_key_data, past_value_data, present_key_data, present_value_data, attention_probs_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); @@ -350,7 +365,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, - past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, + past_key_data, past_value_data, present_key_data, present_value_data, attention_probs_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); @@ -368,7 +383,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, - past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, + past_key_data, past_value_data, present_key_data, present_value_data, attention_probs_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); @@ -387,7 +402,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, - past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, + past_key_data, past_value_data, present_key_data, present_value_data, attention_probs_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); @@ -404,7 +419,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, - past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, + past_key_data, past_value_data, present_key_data, present_value_data, attention_probs_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, disable_rocm, disable_dml); @@ -438,7 +453,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, - data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, + data.present_value_data, data.attention_probs_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } @@ -451,7 +466,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, - data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, + data.present_value_data, data.attention_probs_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } @@ -462,7 +477,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, - data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, + data.present_value_data, data.attention_probs_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } @@ -474,7 +489,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, - data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, + data.present_value_data, data.attention_probs_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } @@ -484,7 +499,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, - data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, + data.present_value_data, data.attention_probs_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } @@ -495,7 +510,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, - data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, + data.present_value_data, data.attention_probs_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } @@ -506,7 +521,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, - data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, + data.present_value_data, data.attention_probs_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } @@ -515,7 +530,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, - data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, + data.present_value_data, data.attention_probs_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } diff --git a/onnxruntime/test/testdata/attention/attention_test_data.txt b/onnxruntime/test/testdata/attention/attention_test_data.txt index 7c60efea1f0f6..fc2d36d9752c5 100644 --- a/onnxruntime/test/testdata/attention/attention_test_data.txt +++ b/onnxruntime/test/testdata/attention/attention_test_data.txt @@ -1795,6 +1795,20 @@ name:CrossAttentionData_HeadSize40.bias_data -0.12975609,0.41347277,-0.31107110,0.17745221,-0.46015862,-0.26369864,-0.03715026,-0.42254731, -0.21274829,-0.42004544,-0.22337052,-0.26180822,-0.40042144,-0.40085569,0.17293042,0.15324622, === +name:CrossAttentionData_HeadSize40.attention_probs_data +0.00063643,0.00194771,0.99741471,0.00000007,0.00000104, +0.00069742,0.00000052,0.99480903,0.00001590,0.00447714, +0.00000958,0.97702450,0.02290714,0.00000000,0.00005884, +0.99659425,0.00206349,0.00004461,0.00118587,0.00011175, +0.99964964,0.00034993,0.00000000,0.00000038,0.00000002, +0.00000422,0.02293632,0.00684622,0.94404423,0.02616905, +0.00309461,0.00105873,0.01598126,0.97911912,0.00074622, +0.00000006,0.00000015,0.00000034,0.99615604,0.00384343, +0.53175026,0.00000895,0.00797362,0.45939830,0.00086886, +0.47805423,0.00001056,0.00479097,0.10811566,0.40902859, +0.76967221,0.00000658,0.23011664,0.00000007,0.00020461, +0.47011939,0.52926159,0.00026605,0.00028461,0.00006837 +=== name:CrossAttentionData_HeadSize40.fp32_output_data -1.5234375,2.4179688,0.95751953,-1.9316406,0.012382507,1.4960938,-1.9111328,2.0234375,3.0371094,-6.7265625,1.4042969,-1.4414062,0.094665527,3.6640625,2.359375,5.6601562,-5.3828125,1.2773438,-4.0664062,1.6591797,-2.2949219,4.28125,-0.15026855,0.16455078,1.4853516,0.15344238,1.1035156,1.2519531,4.1132812,-3.9667969,-0.036193848,-0.94482422,1.9208984,2.1347656,-0.088317871,-4.8007812,-0.78320312,-2.0410156,-0.82910156,-2.3085938, @@ -1847,6 +1861,20 @@ name:CrossAttentionData_HeadSize40.fp16_output_data 1.2402344,2.2792969,0.33398438,2.2519531,0.67041016,-0.55957031,0.20666504,1.3583984,-1.9716797,2.6074219,2.2832031,-2.0546875,-2.4335938,0.53515625,-0.15100098,1.9599609,-0.51513672,0.31030273,-0.49169922,1.4677734,2.234375,0.87451172,0.54736328,-1.8681641,-4.2265625,-0.97509766,-7.296875,-1.3486328,1.3769531,-1.8427734,3.1601562,-2.4238281,-0.82421875,-2.7324219,-0.52734375,2.2089844,0.66796875,-0.42236328,-3.03125,-0.047302246, +=== +name:CrossAttentionData_HeadSize40_NoBias.attention_probs_data +0.00048362,0.00240908,0.99710613,0.00000003,0.00000105, +0.00053021,0.00000064,0.99495512,0.00000558,0.00450848, +0.00000591,0.98134965,0.01859642,0.00000000,0.00004810, +0.99746585,0.00132920,0.00001944,0.00110644,0.00007912, +0.99977440,0.00022524,0.00000000,0.00000035,0.00000001, +0.00000461,0.01610990,0.00325279,0.96042979,0.02020300, +0.00156454,0.00089026,0.06058845,0.93540740,0.00154935, +0.00000003,0.00000013,0.00000135,0.99168313,0.00831534, +0.36340871,0.00001017,0.04086379,0.59327871,0.00243856, +0.49791878,0.00001777,0.00956724,0.22653878,0.26595742, +0.63556498,0.00000878,0.36432067,0.00000011,0.00010548, +0.35437143,0.64478034,0.00038450,0.00043159,0.00003217 === name:CrossAttentionData_HeadSize40_NoBias.fp32_output_data -1.41742659,2.15152407,1.04562795,-2.24374223,-0.04717414,1.35415494,-2.22757769,1.86239254,2.69114566,-6.72842312,1.75712097,-1.32716835,-0.29526311,3.21845007,1.92469037,5.50302410,-5.44591045,1.70150733,-3.65152955,1.24731004,-2.27949309,4.20098972,0.03186192,-0.04943787,1.44350362,0.36441594,1.57309783,0.89751047,4.46334076,-4.33237123,0.01717144,-0.59184837,1.79316127,1.77931416,0.14474387,-4.33217955,-0.56225604,-1.76532054,-0.84570312,-2.73260164, @@ -2049,6 +2077,12 @@ name:CrossAttentionData_Batch2_HeadSize32_RightSidePadding.bias_data 0.34345043,-0.02719739,-0.39574289,-0.39339882,0.23044002,-0.06155324,0.23292047,0.39775699, 0.12789404,-0.44719657,0.12020230,0.26871282,-0.10917315,-0.29244915,0.09059817,-0.19613290, === +name:CrossAttentionData_Batch2_HeadSize32_RightSidePadding.attention_probs_data +1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000, +1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000, +0.99999976,0.00000027,0.00000000,0.08087362,0.91912633,0.00000000, +0.73388237,0.26611760,0.00000000,0.39302668,0.60697329,0.00000000 +=== name:CrossAttentionData_Batch2_HeadSize32_RightSidePadding.fp32_output_data 2.42288446,1.27227366,0.74894810,1.28347683,1.39642823,-1.93045688,0.45777908,-1.26743007, 0.29003966,-3.80550122,0.80094421,0.50959778,-0.54627192,1.66060388,0.25552815,2.24310493, @@ -2086,6 +2120,12 @@ name:CrossAttentionData_Batch2_HeadSize32_RightSidePadding.fp32_output_data -1.41709673,-0.74830860,0.30404601,-0.99458563,0.22929534,-1.72507358,-0.68753922,-2.64537501, 0.58683372,0.88788664,0.54932535,1.45773280,0.96530700,-3.57728553,-0.41517627,-4.86154747, === +name:CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.attention_probs_data +1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000, +1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000, +0.99999988,0.00000017,0.00000000,0.12468453,0.87531543,0.00000000, +0.76183236,0.23816757,0.00000000,0.42891815,0.57108188,0.00000000 +=== name:CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.fp32_output_data 2.52855659,1.00436294,0.83871710,0.97005701,1.33615291,-2.07353282,0.14190522,-1.42923164, -0.05781263,-3.81081843,1.15263164,0.62601233,-0.93824124,1.21525323,-0.17992918,2.08717370, @@ -2210,6 +2250,10 @@ name:CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.bias_data 0.06171834,-0.42181283,-0.41170910,0.40969193,-0.01510030,0.07973170,-0.18156880,0.21522856, 0.03915739,-0.20913908,-0.47068381,0.35633272,-0.35124153,0.36624825,-0.05567622,-0.35343069, === +name:CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.attention_probs_data +0.00000000,0.14050138,0.85949868,0.00000000,0.98746777,0.01253215, +0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,1.00000000 +=== name:CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.fp32_output_data 0.23503941,2.87619758,0.01845241,-0.75242990,1.76869011,-0.40492195,-1.65323853,0.34011719, -2.10573196,0.13281155,0.97480160,2.74546146,-1.21957457,-0.73649400,2.52938581,1.65599120, @@ -2231,6 +2275,10 @@ name:CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.fp32_output_data -0.63150787,-2.29512286,-2.56171679,2.49406147,1.68984890,-3.61196756,-1.40276003,-1.38667703, -2.05177927,-1.23729944,-2.25812149,2.70134830,2.44814849,2.18869901,1.41840470,0.74720055, === +name:CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.attention_probs_data +0.00000000,0.12114097,0.87885910,0.00000000,0.98517215,0.01482786, +0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,1.00000000 +=== name:CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.fp32_output_data 0.38947105,2.65047002,0.05826539,-1.08475602,1.75782788,-0.59195572,-2.00590920,0.17207618, -2.52548885,0.12185203,1.33714449,2.95572400,-1.67562902,-1.20251048,2.11437178,1.54135668, @@ -2416,6 +2464,10 @@ name:CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.kv_data -0.77993584,1.37333393,-1.16019452,-0.91983509,0.20466281,1.09339333, -0.99191529,3.42685890, +=== +name:CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.attention_probs_data +0.19743514,0.27805042,0.52451444,0.00186652,0.99633980,0.00179376, +0.91027802,0.07591849,0.01380343,0.01734933,0.13505954,0.84759116 === name:CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.fp16_output_data -0.18665725,1.53655565,-1.16219902,-0.53553712,-1.76899862,-0.67172408, @@ -2684,6 +2736,10 @@ name:CrossAttentionData_HeadSize16_8.bias_data 0.34785229,0.00531715,-0.35168743,-0.11641458,0.39196932,0.44535065,0.43545735,0.15593112, 0.06171834,-0.42181283,-0.41170910,0.40969193,-0.01510030,0.07973170,-0.18156880,0.21522856, +=== +name:CrossAttentionData_HeadSize16_8.attention_probs_data +0.15171406,0.01985410,0.82843184,0.43151501,0.43381554,0.13466942,0.04487716,0.25137186,0.70375103, +0.00439296,0.94590873,0.04969835,0.99877971,0.00051274,0.00070760,0.04258159,0.33776718,0.61965120 === name:CrossAttentionData_HeadSize16_8.fp32_output_data 0.70553654,2.84393549,-0.06753168,-0.78168947,1.67733526,-0.32306066,-1.46519339,-0.24197246, @@ -2694,6 +2750,10 @@ name:CrossAttentionData_HeadSize16_8.fp32_output_data 0.59563273,0.71862715,0.57042938,1.61676264,1.43126500,2.88902473,0.78586847,1.13364232, -0.24963731,-1.69403267,-2.38265419,1.86863625,0.37573546,-2.40374231,-0.73219091,-1.54168916, +=== +name:CrossAttentionData_HeadSize16_8_NoBias.attention_probs_data +0.21398182,0.01707544,0.76894271,0.50088972,0.36535633,0.13375396,0.08453909,0.48650715,0.42895377, +0.00783394,0.92731458,0.06485146,0.99906713,0.00058053,0.00035246,0.07878280,0.30196917,0.61924797 === name:CrossAttentionData_HeadSize16_8_NoBias.fp32_output_data 0.88660234,2.46094799,0.10754689,-1.06147599,1.46027637,-0.32641891,-1.61505651,-0.62761730, @@ -2757,6 +2817,10 @@ name:CrossAttentionData_HeadSize16.bias_data 0.06171834,-0.42181283,-0.41170910,0.40969193,-0.01510030,0.07973170,-0.18156880,0.21522856, 0.03915739,-0.20913908,-0.47068381,0.35633272,-0.35124153,0.36624825,-0.05567622,-0.35343069, === +name:CrossAttentionData_HeadSize16.attention_probs_data +0.15171406,0.01985411,0.82843184,0.00300759,0.97847885,0.01851352, +0.43151510,0.43381545,0.13466942,0.08332951,0.82590133,0.09076917 +=== name:CrossAttentionData_HeadSize16.fp32_output_data 0.70553654,2.84393549,-0.06753166,-0.78168941,1.67733538,-0.32306066,-1.46519351,-0.24197248, -1.95703733,0.05333783,0.71154630,2.09348249,-1.24223638,-0.52374214,2.15032387,1.41931129, @@ -2768,6 +2832,10 @@ name:CrossAttentionData_HeadSize16.fp32_output_data -0.21351013,0.15201563,0.47484040,-0.79835385,-2.11648011,-0.13788147,-0.01865268,-0.69491959, 0.34924412,0.05382843,-1.21107709,2.20767021,-0.16749848,1.72250605,-1.32190716,0.45872629, +=== +name:CrossAttentionData_HeadSize16_NoBias.attention_probs_data +0.21398182,0.01707545,0.76894277,0.00491562,0.97517157,0.01991289, +0.50088978,0.36535639,0.13375394,0.10961151,0.78822690,0.10216155 === name:CrossAttentionData_HeadSize16_NoBias.fp32_output_data 0.88660234,2.46094799,0.10754693,-1.06147599,1.46027637,-0.32641891,-1.61505640,-0.62761730, @@ -3031,6 +3099,11 @@ name:CrossAttentionData_DiffSequenceLengths_HeadSize8.bias_data -0.10567203,0.26791072,-0.08976898,0.31341976,0.06027532,0.14307594,0.31587386,0.16180152, 0.34785229,0.00531715,-0.35168743,-0.11641458,0.39196932,0.44535065,0.43545735,0.15593112, +=== +name:CrossAttentionData_DiffSequenceLengths_HeadSize8.attention_probs_data +0.32941905,0.09115683,0.36670843,0.21271567,0.49403846,0.09097324,0.07906519,0.33592317, +0.37942606,0.03424129,0.12042803,0.46590456,0.15880427,0.27026406,0.23315622,0.33777541 + === name:CrossAttentionData_DiffSequenceLengths_HeadSize8.fp32_output_data -0.73531479,0.17652693,0.43294340,0.10832195,1.06569219,0.84791648,0.37950000,-0.19036117, @@ -3062,6 +3135,11 @@ name:CrossAttentionData_DiffSequenceLengths_HeadSize8.present_value_data 0.85667104,2.14919043,0.50618559,2.20632005,0.34294793,0.40473318,1.98550463,-0.09497684, 1.55557942,-0.98876214,-0.50508159,0.42920581,1.68902707,0.15883744,-0.46605104,-0.93880188, +=== +name:CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.attention_probs_data +0.31059057,0.09501101,0.33080846,0.26358992,0.44437748,0.09045862,0.06804446,0.39711943, +0.39367217,0.04053140,0.10096846,0.46482795,0.16198799,0.31451625,0.19218433,0.33131137 + === name:CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.fp32_output_data -0.65824336,-0.13227919,0.50652772,-0.20617434,0.96178705,0.75033671,0.07374202,-0.33853477, @@ -5051,6 +5129,11 @@ name:CrossAttention_Batch1_HeadSize8.packed_key_value 1.63199806,-1.44967401,-2.54707336,0.78083873,-0.19109090,0.59508920,0.58886564,0.81380880, 0.50881875,2.14387321,0.85787302,2.32273459,-0.04902139,-0.04061748,1.55004728,-0.25090796 +==== +name:CrossAttention_Batch1_HeadSize8.attention_probs_data +0.41842452,0.11578643,0.46578908,0.74394774,0.13699204,0.11906030, +0.71040881,0.06411081,0.22548041,0.23980425,0.40811545,0.35208032 + ==== name:CrossAttention_Batch1_HeadSize8.output -0.61998826,0.38731366,0.38371456,0.17248757,1.26609111,0.61097330,0.38864893,-0.34083632, @@ -5059,6 +5142,11 @@ name:CrossAttention_Batch1_HeadSize8.output -0.25185302,0.10573119,0.01646931,0.40613887,1.61315691,0.59776157,0.70979917,-1.10025024, 1.16315329,0.47766802,-0.03506046,1.33826876,1.36242199,0.06935713,0.58279711,-0.82380491 +==== +name:CrossAttention_Batch1_HeadSize8_NoBias.attention_probs_data +0.42176309,0.12901917,0.44921777,0.73709041,0.15004401,0.11286558, +0.73559928,0.07573527,0.18866542,0.24224728,0.47034785,0.28740484 + ==== name:CrossAttention_Batch1_HeadSize8_NoBias.output -0.51569921,0.13232709,0.43551767,-0.12155488,1.21165323,0.45272583,0.08948315,-0.53300208, From 239df8b8ed13a86730d425f47d9b81622364a945 Mon Sep 17 00:00:00 2001 From: andrea_mancini Date: Thu, 12 Dec 2024 16:44:48 +0000 Subject: [PATCH 2/2] Add CUDA implementation for attn_probs --- .../cpu/bert/multihead_attention.cc | 2 +- .../contrib_ops/cuda/bert/attention_impl.cu | 27 ++++--- .../contrib_ops/cuda/bert/attention_impl.h | 1 + .../cuda/bert/multihead_attention.cc | 15 +++- .../contrib_ops/attention_op_test_helper.cc | 75 +++++++++++-------- .../contrib_ops/attention_op_test_helper.h | 32 ++++---- .../multihead_attention_op_test.cc | 57 ++++++++++++++ 7 files changed, 150 insertions(+), 59 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 5aafd8acc5a0c..6dfd7f4dce1fc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -161,7 +161,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { past_value == nullptr && present_k == nullptr && present_v == nullptr && - attn_probs == nullptr && // TODO: can we support it? + attn_probs == nullptr && // TODO: can we support it? l2_cache_size_ > 0) { MlasFlashAttentionThreadedArgs args; args.batch_size = batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 9e017544d7cff..e287b8e933388 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -590,15 +590,22 @@ Status UnfusedAttention( DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); - constexpr size_t element_size = sizeof(T); - const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, - sequence_length, total_sequence_length); - T* scratch2 = data.scratch + (bytes / element_size); + T* softmax_storage; + if (data.attn_probs == nullptr) { + constexpr size_t element_size = sizeof(T); + const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, + sequence_length, total_sequence_length); + T* scratch2 = data.scratch + (bytes / element_size); + softmax_storage = scratch2; + } + else { + softmax_storage = data.attn_probs; + } const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; - // Apply softmax and store result R to scratch2: BxNxSxT + // Apply softmax and store result R to softmax_storage: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask const int mask_dimension = static_cast(mask_index_dims.size()); @@ -612,7 +619,7 @@ Status UnfusedAttention( ComputeSoftmaxWithRawMask( ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension, + data.scratch, softmax_storage, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, parameters.mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index @@ -622,16 +629,16 @@ Status UnfusedAttention( ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - data.scratch, scratch2, parameters.is_unidirectional)); + data.scratch, softmax_storage, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( ComputeSoftmax( stream, total_sequence_length, sequence_length, batch_size, num_heads, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - data.scratch, scratch2, parameters.is_unidirectional)); + data.scratch, softmax_storage, parameters.is_unidirectional)); } - DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Softmax", softmax_storage, batch_size, num_heads, sequence_length, total_sequence_length); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = data.q; @@ -639,7 +646,7 @@ Status UnfusedAttention( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, &one, data.v, v_head_size, present_size_per_batch_v, - scratch2, total_sequence_length, sequence_length * total_sequence_length, + softmax_storage, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 7d111a1ee21bf..a79369153825c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -81,6 +81,7 @@ struct AttentionData { T* present = nullptr; T* present_key = nullptr; T* present_value = nullptr; + T* attn_probs = nullptr; void* fused_runner = nullptr; const void* fused_cross_attention_kernel = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index e2587d172af94..ee8ba11cc5b81 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -113,6 +113,13 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { output_shape[2] = static_cast(parameters.v_hidden_size); Tensor* output = context->Output(0, output_shape); + TensorShapeVector attn_probs_shape(4); + attn_probs_shape[0] = static_cast(parameters.batch_size); + attn_probs_shape[1] = static_cast(parameters.num_heads); + attn_probs_shape[2] = static_cast(sequence_length); + attn_probs_shape[3] = static_cast(parameters.total_sequence_length); + Tensor* attn_probs = context->Output(3, attn_probs_shape); + std::vector present_dims{ parameters.batch_size, parameters.num_heads, parameters.total_sequence_length, parameters.head_size}; TensorShape present_shape(present_dims); @@ -172,6 +179,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.past_sequence_length > 0 && nullptr == attention_bias && nullptr == key_padding_mask && + nullptr == attn_probs && // TODO: support attn_probs parameters.head_size == parameters.v_head_size && onnxruntime::lean::is_supported(device_prop, parameters.head_size, @@ -216,6 +224,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !disable_flash_attention_ && nullptr == attention_bias && nullptr == key_padding_mask && + nullptr == attn_probs && // TODO: support attn_probs parameters.head_size == parameters.v_head_size && onnxruntime::flash::is_supported(device_prop, parameters.head_size, @@ -280,7 +289,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !is_unidirectional_ && nullptr == key_padding_mask && nullptr == attention_bias && - nullptr == past_key && nullptr == present_key && + nullptr == past_key && nullptr == present_key && nullptr == attn_probs && (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && parameters.hidden_size == parameters.v_hidden_size && has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length); @@ -305,7 +314,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !is_unidirectional_ && nullptr == attention_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && - nullptr == past_key && nullptr == present_key && + nullptr == past_key && nullptr == present_key && nullptr == attn_probs && is_mask_none_or_1d_k_len && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner @@ -339,6 +348,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_memory_efficient_attention_ && is_long_sequence && + nullptr == attn_probs && // TODO: support attn_probs // Check whether the attention bias alignment is good for memory efficient attention. (attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && @@ -369,6 +379,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); + data.attn_probs = (nullptr == attn_probs) ? nullptr : reinterpret_cast(attn_probs->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc index 1555d813ea6fb..5fc87d2dab61a 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc @@ -62,7 +62,7 @@ void GetAttentionBias(std::vector& bias_data, int elements, int start_off SampleAttentionWeight(data, bias_data, elements, start_offset, step); } -void GetCrossAttentionData_HeadSize40(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize40(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 80; data.v_hidden_size = 80; data.num_heads = 2; @@ -76,18 +76,20 @@ void GetCrossAttentionData_HeadSize40(AttentionTestData& data) { LoadTensor("CrossAttentionData_HeadSize40.bias_data", data.bias_data); LoadTensor("CrossAttentionData_HeadSize40.fp32_output_data", data.fp32_output_data); LoadTensor("CrossAttentionData_HeadSize40.fp16_output_data", data.fp16_output_data); - LoadTensor("CrossAttentionData_HeadSize40.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize40.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_HeadSize40(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_HeadSize40_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize40_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize40_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d) { +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d, bool get_attention_probs) { data.hidden_size = 64; data.v_hidden_size = 64; data.num_heads = 2; @@ -115,19 +117,21 @@ void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.bias_data", data.bias_data); LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d) { +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d, bool get_attention_probs) { GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, is_mask_1d); data.bias_data.clear(); LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data) { +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 32; data.v_hidden_size = 32; data.num_heads = 1; @@ -149,18 +153,20 @@ void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.bias_data", data.bias_data); LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data) { +void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 32; data.v_hidden_size = 32; data.num_heads = 1; @@ -180,7 +186,8 @@ void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTes // Do not test fp32 data.fp32_output_data = {}; LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.fp16_output_data", data.fp16_output_data); - LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.attention_probs_data", data.attention_probs_data); } void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data) { @@ -206,7 +213,7 @@ void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTes LoadTensor("SelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV.fp16_output_data", data.fp16_output_data); } -void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 48; data.v_hidden_size = 24; data.num_heads = 3; @@ -224,18 +231,20 @@ void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) { LoadTensor("CrossAttentionData_HeadSize16_8.bias_data", data.bias_data); LoadTensor("CrossAttentionData_HeadSize16_8.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize16_8.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize16_8.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_HeadSize16_8(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize16(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize16(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 32; data.v_hidden_size = 32; data.num_heads = 2; @@ -250,18 +259,20 @@ void GetCrossAttentionData_HeadSize16(AttentionTestData& data) { LoadTensor("CrossAttentionData_HeadSize16.bias_data", data.bias_data); LoadTensor("CrossAttentionData_HeadSize16.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize16.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize16.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_HeadSize16(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_HeadSize16_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize16_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize16_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize8(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize8(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -276,15 +287,17 @@ void GetCrossAttentionData_HeadSize8(AttentionTestData& data) { LoadTensor("CrossAttention_Batch1_HeadSize8.bias_data", data.bias_data); LoadTensor("CrossAttention_Batch1_HeadSize8.output", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttention_Batch1_HeadSize8.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttention_Batch1_HeadSize8.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_HeadSize8(data); data.bias_data.clear(); LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.output", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionDataWithPast(AttentionTestData& data) { @@ -394,7 +407,7 @@ void GetCrossAttentionData_DiffSequenceLengths(AttentionTestData& data) { data.is_static_kv = true; } -void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data) { +void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -419,10 +432,11 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.present_key_data", data.present_key_data); LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.present_value_data", data.present_value_data); data.is_static_kv = true; - LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.fp32_output_data", data.fp32_output_data); @@ -430,7 +444,8 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestDat LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_key_data", data.present_key_data); LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_value_data", data.present_value_data); data.is_static_kv = true; - LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); } void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data) { diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.h b/onnxruntime/test/contrib_ops/attention_op_test_helper.h index 2a875c2a9abba..5535461426adf 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.h +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.h @@ -58,30 +58,30 @@ struct PackedAttentionTestData : public BaseAttentionTestData { void GetAttentionWeight(std::vector& weight_data, int elements = 64 * 3 * 64, int offset = 0, int step = 1); void GetAttentionBias(std::vector& bias_data, int elements = 3 * 64, int offset = 0, int step = 1); -void GetCrossAttentionData_HeadSize40(AttentionTestData& data); -void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data); -void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d); -void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d); -void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data); -void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data); - -void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data); +void GetCrossAttentionData_HeadSize40(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d, bool get_attention_probs = false); +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d, bool get_attention_probs = false); +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data, bool get_attention_probs = false); + +void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data, bool get_attention_probs = false); void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data); -void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data); -void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data); -void GetCrossAttentionData_HeadSize16(AttentionTestData& data); -void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data); +void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize16(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data, bool get_attention_probs = false); -void GetCrossAttentionData_HeadSize8(AttentionTestData& data); -void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data); +void GetCrossAttentionData_HeadSize8(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data, bool get_attention_probs = false); void GetCrossAttentionDataWithPast(AttentionTestData& data); void GetSelfAttentionData_WithPast_WithAttnBias_ForT5(AttentionTestData& data); void GetCrossAttentionData_DiffSequenceLengths(AttentionTestData& data); -void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data); -void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data); +void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data, bool get_attention_probs = false); void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data); void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(AttentionTestData& data); void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(AttentionTestData& data); diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index d103ddf653139..241c18962ffc4 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -545,6 +545,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { GetCrossAttentionData_HeadSize40_NoBias(data); RunMultiHeadAttentionTests(data); + + // Test attention probs output + GetCrossAttentionData_HeadSize40(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); + + GetCrossAttentionData_HeadSize40_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) { @@ -555,6 +562,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true); RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { @@ -564,6 +578,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false); RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { @@ -573,12 +594,23 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Ma GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data); RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { @@ -595,6 +627,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) { GetCrossAttentionData_HeadSize16_8_NoBias(data); RunMultiHeadAttentionTests(data); + + // Test attention probs output + GetCrossAttentionData_HeadSize16_8(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); + + GetCrossAttentionData_HeadSize16_8_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { @@ -604,12 +643,23 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { GetCrossAttentionData_HeadSize16_NoBias(data); RunMultiHeadAttentionTests(data); + + // Test attention probs output + GetCrossAttentionData_HeadSize16(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); + + GetCrossAttentionData_HeadSize16_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); RunMultiHeadAttentionTests(data, DISABLE_CUDA); + + // Test attention probs output + GetCrossAttentionData_HeadSize8_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // TODO (pavignol): Fix this regression @@ -648,6 +698,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data); RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); + + GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) {