diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 5aafd8acc5a0c..6dfd7f4dce1fc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -161,7 +161,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { past_value == nullptr && present_k == nullptr && present_v == nullptr && - attn_probs == nullptr && // TODO: can we support it? + attn_probs == nullptr && // TODO: can we support it? l2_cache_size_ > 0) { MlasFlashAttentionThreadedArgs args; args.batch_size = batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 9e017544d7cff..e287b8e933388 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -590,15 +590,22 @@ Status UnfusedAttention( DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); - constexpr size_t element_size = sizeof(T); - const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, - sequence_length, total_sequence_length); - T* scratch2 = data.scratch + (bytes / element_size); + T* softmax_storage; + if (data.attn_probs == nullptr) { + constexpr size_t element_size = sizeof(T); + const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, + sequence_length, total_sequence_length); + T* scratch2 = data.scratch + (bytes / element_size); + softmax_storage = scratch2; + } + else { + softmax_storage = data.attn_probs; + } const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; - // Apply softmax and store result R to scratch2: BxNxSxT + // Apply softmax and store result R to softmax_storage: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask const int mask_dimension = static_cast(mask_index_dims.size()); @@ -612,7 +619,7 @@ Status UnfusedAttention( ComputeSoftmaxWithRawMask( ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension, + data.scratch, softmax_storage, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, parameters.mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index @@ -622,16 +629,16 @@ Status UnfusedAttention( ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - data.scratch, scratch2, parameters.is_unidirectional)); + data.scratch, softmax_storage, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( ComputeSoftmax( stream, total_sequence_length, sequence_length, batch_size, num_heads, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - data.scratch, scratch2, parameters.is_unidirectional)); + data.scratch, softmax_storage, parameters.is_unidirectional)); } - DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Softmax", softmax_storage, batch_size, num_heads, sequence_length, total_sequence_length); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = data.q; @@ -639,7 +646,7 @@ Status UnfusedAttention( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, &one, data.v, v_head_size, present_size_per_batch_v, - scratch2, total_sequence_length, sequence_length * total_sequence_length, + softmax_storage, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 7d111a1ee21bf..a79369153825c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -81,6 +81,7 @@ struct AttentionData { T* present = nullptr; T* present_key = nullptr; T* present_value = nullptr; + T* attn_probs = nullptr; void* fused_runner = nullptr; const void* fused_cross_attention_kernel = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index e2587d172af94..ee8ba11cc5b81 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -113,6 +113,13 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { output_shape[2] = static_cast(parameters.v_hidden_size); Tensor* output = context->Output(0, output_shape); + TensorShapeVector attn_probs_shape(4); + attn_probs_shape[0] = static_cast(parameters.batch_size); + attn_probs_shape[1] = static_cast(parameters.num_heads); + attn_probs_shape[2] = static_cast(sequence_length); + attn_probs_shape[3] = static_cast(parameters.total_sequence_length); + Tensor* attn_probs = context->Output(3, attn_probs_shape); + std::vector present_dims{ parameters.batch_size, parameters.num_heads, parameters.total_sequence_length, parameters.head_size}; TensorShape present_shape(present_dims); @@ -172,6 +179,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.past_sequence_length > 0 && nullptr == attention_bias && nullptr == key_padding_mask && + nullptr == attn_probs && // TODO: support attn_probs parameters.head_size == parameters.v_head_size && onnxruntime::lean::is_supported(device_prop, parameters.head_size, @@ -216,6 +224,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !disable_flash_attention_ && nullptr == attention_bias && nullptr == key_padding_mask && + nullptr == attn_probs && // TODO: support attn_probs parameters.head_size == parameters.v_head_size && onnxruntime::flash::is_supported(device_prop, parameters.head_size, @@ -280,7 +289,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !is_unidirectional_ && nullptr == key_padding_mask && nullptr == attention_bias && - nullptr == past_key && nullptr == present_key && + nullptr == past_key && nullptr == present_key && nullptr == attn_probs && (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && parameters.hidden_size == parameters.v_hidden_size && has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length); @@ -305,7 +314,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !is_unidirectional_ && nullptr == attention_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && - nullptr == past_key && nullptr == present_key && + nullptr == past_key && nullptr == present_key && nullptr == attn_probs && is_mask_none_or_1d_k_len && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner @@ -339,6 +348,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_memory_efficient_attention_ && is_long_sequence && + nullptr == attn_probs && // TODO: support attn_probs // Check whether the attention bias alignment is good for memory efficient attention. (attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && @@ -369,6 +379,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); + data.attn_probs = (nullptr == attn_probs) ? nullptr : reinterpret_cast(attn_probs->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc index 1555d813ea6fb..5fc87d2dab61a 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc @@ -62,7 +62,7 @@ void GetAttentionBias(std::vector& bias_data, int elements, int start_off SampleAttentionWeight(data, bias_data, elements, start_offset, step); } -void GetCrossAttentionData_HeadSize40(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize40(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 80; data.v_hidden_size = 80; data.num_heads = 2; @@ -76,18 +76,20 @@ void GetCrossAttentionData_HeadSize40(AttentionTestData& data) { LoadTensor("CrossAttentionData_HeadSize40.bias_data", data.bias_data); LoadTensor("CrossAttentionData_HeadSize40.fp32_output_data", data.fp32_output_data); LoadTensor("CrossAttentionData_HeadSize40.fp16_output_data", data.fp16_output_data); - LoadTensor("CrossAttentionData_HeadSize40.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize40.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_HeadSize40(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_HeadSize40_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize40_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize40_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d) { +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d, bool get_attention_probs) { data.hidden_size = 64; data.v_hidden_size = 64; data.num_heads = 2; @@ -115,19 +117,21 @@ void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.bias_data", data.bias_data); LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d) { +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d, bool get_attention_probs) { GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, is_mask_1d); data.bias_data.clear(); LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data) { +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 32; data.v_hidden_size = 32; data.num_heads = 1; @@ -149,18 +153,20 @@ void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.bias_data", data.bias_data); LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data) { +void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 32; data.v_hidden_size = 32; data.num_heads = 1; @@ -180,7 +186,8 @@ void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTes // Do not test fp32 data.fp32_output_data = {}; LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.fp16_output_data", data.fp16_output_data); - LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.attention_probs_data", data.attention_probs_data); } void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data) { @@ -206,7 +213,7 @@ void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTes LoadTensor("SelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV.fp16_output_data", data.fp16_output_data); } -void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 48; data.v_hidden_size = 24; data.num_heads = 3; @@ -224,18 +231,20 @@ void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) { LoadTensor("CrossAttentionData_HeadSize16_8.bias_data", data.bias_data); LoadTensor("CrossAttentionData_HeadSize16_8.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize16_8.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize16_8.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_HeadSize16_8(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize16(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize16(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 32; data.v_hidden_size = 32; data.num_heads = 2; @@ -250,18 +259,20 @@ void GetCrossAttentionData_HeadSize16(AttentionTestData& data) { LoadTensor("CrossAttentionData_HeadSize16.bias_data", data.bias_data); LoadTensor("CrossAttentionData_HeadSize16.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize16.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize16.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_HeadSize16(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_HeadSize16_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttentionData_HeadSize16_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_HeadSize16_NoBias.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize8(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize8(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -276,15 +287,17 @@ void GetCrossAttentionData_HeadSize8(AttentionTestData& data) { LoadTensor("CrossAttention_Batch1_HeadSize8.bias_data", data.bias_data); LoadTensor("CrossAttention_Batch1_HeadSize8.output", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttention_Batch1_HeadSize8.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttention_Batch1_HeadSize8.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_HeadSize8(data); data.bias_data.clear(); LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.output", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); } void GetCrossAttentionDataWithPast(AttentionTestData& data) { @@ -394,7 +407,7 @@ void GetCrossAttentionData_DiffSequenceLengths(AttentionTestData& data) { data.is_static_kv = true; } -void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data) { +void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data, bool get_attention_probs) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -419,10 +432,11 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.present_key_data", data.present_key_data); LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.present_value_data", data.present_value_data); data.is_static_kv = true; - LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.attention_probs_data", data.attention_probs_data); } -void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data) { +void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data, bool get_attention_probs) { GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data); data.bias_data.clear(); LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.fp32_output_data", data.fp32_output_data); @@ -430,7 +444,8 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestDat LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_key_data", data.present_key_data); LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_value_data", data.present_value_data); data.is_static_kv = true; - LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); + if (get_attention_probs) + LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data); } void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data) { diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.h b/onnxruntime/test/contrib_ops/attention_op_test_helper.h index 2a875c2a9abba..5535461426adf 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.h +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.h @@ -58,30 +58,30 @@ struct PackedAttentionTestData : public BaseAttentionTestData { void GetAttentionWeight(std::vector& weight_data, int elements = 64 * 3 * 64, int offset = 0, int step = 1); void GetAttentionBias(std::vector& bias_data, int elements = 3 * 64, int offset = 0, int step = 1); -void GetCrossAttentionData_HeadSize40(AttentionTestData& data); -void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data); -void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d); -void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d); -void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data); -void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data); - -void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data); +void GetCrossAttentionData_HeadSize40(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d, bool get_attention_probs = false); +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d, bool get_attention_probs = false); +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data, bool get_attention_probs = false); + +void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data, bool get_attention_probs = false); void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data); -void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data); -void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data); -void GetCrossAttentionData_HeadSize16(AttentionTestData& data); -void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data); +void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize16(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data, bool get_attention_probs = false); -void GetCrossAttentionData_HeadSize8(AttentionTestData& data); -void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data); +void GetCrossAttentionData_HeadSize8(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data, bool get_attention_probs = false); void GetCrossAttentionDataWithPast(AttentionTestData& data); void GetSelfAttentionData_WithPast_WithAttnBias_ForT5(AttentionTestData& data); void GetCrossAttentionData_DiffSequenceLengths(AttentionTestData& data); -void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data); -void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data); +void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data, bool get_attention_probs = false); +void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data, bool get_attention_probs = false); void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data); void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(AttentionTestData& data); void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(AttentionTestData& data); diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index d103ddf653139..241c18962ffc4 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -545,6 +545,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { GetCrossAttentionData_HeadSize40_NoBias(data); RunMultiHeadAttentionTests(data); + + // Test attention probs output + GetCrossAttentionData_HeadSize40(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); + + GetCrossAttentionData_HeadSize40_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) { @@ -555,6 +562,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true); RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { @@ -564,6 +578,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false); RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { @@ -573,12 +594,23 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Ma GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data); RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); + + GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { @@ -595,6 +627,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) { GetCrossAttentionData_HeadSize16_8_NoBias(data); RunMultiHeadAttentionTests(data); + + // Test attention probs output + GetCrossAttentionData_HeadSize16_8(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); + + GetCrossAttentionData_HeadSize16_8_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { @@ -604,12 +643,23 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { GetCrossAttentionData_HeadSize16_NoBias(data); RunMultiHeadAttentionTests(data); + + // Test attention probs output + GetCrossAttentionData_HeadSize16(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); + + GetCrossAttentionData_HeadSize16_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); RunMultiHeadAttentionTests(data, DISABLE_CUDA); + + // Test attention probs output + GetCrossAttentionData_HeadSize8_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // TODO (pavignol): Fix this regression @@ -648,6 +698,13 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data); RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); + + // Test attention probs output + GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); + + GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data, /* get_attention_probs */ true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) {