From 2b8fc5529bce4d815d2e71e9c8258c29db377c87 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 16 Oct 2024 09:57:05 +0800 Subject: [PATCH] Enable RunMatMulTest all test cases support FP16 (#22440) ### Description ### Motivation and Context increase FP16 test coverage for all related EPs --- .../test/providers/cpu/math/matmul_test.cc | 88 ++++++++----------- 1 file changed, 36 insertions(+), 52 deletions(-) diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index a5cc17fe1ca78..298e870f348fc 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -38,128 +38,125 @@ template std::vector> GenerateTestCases() { std::vector> test_cases; + auto real_expected_vals = [](const std::vector& expected_vals) { + if constexpr (std::is_same_v) { + return expected_vals; + } else if constexpr (std::is_same_v) { + std::vector expected_vals_fp16(expected_vals.size()); + std::transform(expected_vals.begin(), expected_vals.end(), expected_vals_fp16.begin(), + [](int32_t num) { return MLFloat16(float(num)); }); + return expected_vals_fp16; + } else { + std::vector real_expected_vals(expected_vals.size()); + std::transform(expected_vals.begin(), expected_vals.end(), real_expected_vals.begin(), + [](int32_t num) { return static_cast(num); }); + return real_expected_vals; + } + }; + test_cases.push_back( {"test padding and broadcast A > B", {3, 1, 1, 2}, {2, 2, 2}, {3, 2, 1, 2}, - {2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}}); + real_expected_vals({2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55})}); test_cases.push_back( {"test padding and broadcast B > A", {2, 3, 2}, {3, 2, 2, 1}, {3, 2, 3, 1}, - {1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}}); + real_expected_vals({1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221})}); test_cases.push_back( {"test left 1D", {2}, {3, 2, 1}, {3, 1}, - {1, 3, 5}}); + real_expected_vals({1, 3, 5})}); test_cases.push_back( {"test right 1D", {3, 1, 2}, {2}, {3, 1}, - {1, 3, 5}}); + real_expected_vals({1, 3, 5})}); test_cases.push_back( {"test left 1D right 2D", {2}, {2, 3}, {3}, - {3, 4, 5}}); + real_expected_vals({3, 4, 5})}); test_cases.push_back( {"test scalar output", {3}, {3}, {}, - {5}}); + real_expected_vals({5})}); test_cases.push_back( {"test 2D", {3, 4}, {4, 3}, {3, 3}, - {42, 48, 54, 114, 136, 158, 186, 224, 262}}); + real_expected_vals({42, 48, 54, 114, 136, 158, 186, 224, 262})}); test_cases.push_back( {"test 2D special", {2, 2, 3}, {3, 4}, {2, 2, 4}, - {20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}}); + real_expected_vals({20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218})}); test_cases.push_back( {"test 2D special 2", {2, 2, 3}, {1, 3, 4}, {2, 2, 4}, - {20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}}); + real_expected_vals({20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218})}); test_cases.push_back( {"test 2D special 3", {2, 6}, {1, 1, 6, 1}, {1, 1, 2, 1}, - {55, 145}}); + real_expected_vals({55, 145})}); test_cases.push_back( {"test 2D empty input", {3, 4}, {4, 0}, {3, 0}, - {}}); + real_expected_vals({})}); test_cases.push_back( {"test 3D batch", {3, 1, 3}, {3, 3, 2}, {3, 1, 2}, - { + real_expected_vals({ // clang-format off 10, 13, 100, 112, 298, 319, // clang-format on - }}); + })}); test_cases.push_back( {"test 4D batch", {2, 2, 1, 3}, {2, 2, 3, 2}, {2, 2, 1, 2}, - { + real_expected_vals({ // clang-format off 10, 13, 100, 112, 298, 319, 604, 634, // clang-format on - }}); - - return test_cases; -} - -template <> -std::vector> GenerateTestCases() { - std::vector> test_cases; - - // test 2D expected_vals - std::vector expected_vals = {42, 48, 54, 114, 136, 158, 186, 224, 262}; - std::vector expected_vals_fp16(expected_vals.size()); - std::transform(expected_vals.begin(), expected_vals.end(), expected_vals_fp16.begin(), - [](int64_t num) { return MLFloat16(float(num)); }); - test_cases.push_back( - {"test 2D MLfloat16", - {3, 4}, - {4, 3}, - {3, 3}, - expected_vals_fp16}); + })}); return test_cases; } @@ -209,19 +206,12 @@ TEST(MathOpTest, MatMulFloatType) { GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; } RunMatMulTest(7, false, false); -} - -// To Test XNNPACK, Matrix B must be constant -TEST(MathOpTest, MatMulFloatType_ConstantB) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; - } + // Note. Xnnpack only supports matmul when Matrix B is constant RunMatMulTest(7, false, true); } #if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) -TEST(MathOpTest, MatMulFloat16_ConstantB) { +TEST(MathOpTest, MatMulFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { @@ -233,7 +223,9 @@ TEST(MathOpTest, MatMulFloat16_ConstantB) { if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; } - RunMatMulTest(7, false, true); + RunMatMulTest(14, false, false); + // Note. Xnnpack only supports matmul when Matrix B is constant + RunMatMulTest(14, false, true); } #endif @@ -241,14 +233,6 @@ TEST(MathOpTest, MatMulDoubleType) { RunMatMulTest(7); } -TEST(MathOpTest, MatMulFloatTypeInitializer) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; - } - RunMatMulTest(7, false, true); -} - TEST(MathOpTest, MatMulInt32Type) { RunMatMulTest(9); }