Skip to content

Commit

Permalink
Adds IRFFT Op to Signal Library (#2137)
Browse files Browse the repository at this point in the history
Inverse-RFFT as part of Signal library ops.
Testing via current FFT Op tests.

BUG=[287346710](http://b/287346710)
  • Loading branch information
suleshahid authored Jul 20, 2023
1 parent ed11500 commit 55037d2
Show file tree
Hide file tree
Showing 20 changed files with 1,310 additions and 2 deletions.
1 change: 1 addition & 0 deletions python/tflite_micro/python_ops_resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddGreaterEqual();
AddHardSwish();
AddIf();
AddIrfft();
AddL2Normalization();
AddL2Pool2D();
AddLeakyRelu();
Expand Down
2 changes: 2 additions & 0 deletions python/tflite_micro/signal/ops/fft_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def _fft_auto_scale(input_tensor, name=default_name):


rfft = _fft_wrapper(gen_fft_ops.signal_rfft, "signal_rfft")
irfft = _fft_wrapper(gen_fft_ops.signal_irfft, "signal_irfft")
fft_auto_scale = _fft_auto_scale_wrapper(gen_fft_ops.signal_fft_auto_scale,
"signal_fft_auto_scale")
tf.no_gradient("signal_rfft")
tf.no_gradient("signal_irfft")
tf.no_gradient("signal_fft_auto_scale")
55 changes: 55 additions & 0 deletions python/tflite_micro/signal/ops/fft_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,61 @@ def testFftLengthNoEven(self):
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(fft_ops.rfft(fft_input, 127))

def testIrfftTest(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
if dtype == np.float32:
# Random input in the range [-1, 1)
fft_input = np.random.random(fft_length).astype(dtype) * 2 - 1
else:
fft_input = np.random.randint(
np.iinfo(np.int16).min,
np.iinfo(np.int16).max + 1, fft_length).astype(dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
self.assertEqual(fft_output.shape[0], (fft_length / 2 + 1) * 2)
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
self.assertEqual(ifft_output.shape[0], fft_length)
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
if dtype == np.int16:
self.assertArrayNear(fft_input,
ifft_output.astype(np.int32) * fft_length, 6500)
elif dtype == np.int32:
self.assertArrayNear(fft_input,
ifft_output.astype(np.int32) * fft_length, 7875)
else:
self.assertArrayNear(fft_input, ifft_output, 5e-7)
fft_length = 2 * fft_length

def testIrfftLargeOuterDimension(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
if dtype == np.float32:
# Random input in the range [-1, 1)
fft_input = np.random.random([2, 5, fft_length
]).astype(dtype) * 2 - 1
else:
fft_input = np.random.randint(
np.iinfo(np.int16).min,
np.iinfo(np.int16).max + 1, [2, 5, fft_length]).astype(dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
self.assertEqual(fft_output.shape[-1], (fft_length / 2 + 1) * 2)
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
self.assertEqual(ifft_output.shape[-1], fft_length)
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
if dtype == np.int16:
self.assertAllClose(fft_input,
ifft_output.astype(np.int32) * fft_length,
atol=7875)
elif dtype == np.int32:
self.assertAllClose(fft_input,
ifft_output.astype(np.int32) * fft_length,
atol=7875)
else:
self.assertAllClose(fft_input, ifft_output, rtol=5e-7, atol=5e-7)
fft_length = 2 * fft_length

def testAutoScale(self):
self.SingleFftAutoScaleTest('testdata/fft_auto_scale_test1.txt')

Expand Down
3 changes: 3 additions & 0 deletions signal/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ cc_library(
"filter_bank_spectral_subtraction.cc",
"filter_bank_square_root.cc",
"framer.cc",
"irfft.cc",
"overlap_add.cc",
"rfft.cc",
"stacker.cc",
"window.cc",
],
hdrs = [
"irfft.h",
"rfft.h",
],
copts = micro_copts(),
Expand All @@ -36,6 +38,7 @@ cc_library(
"//signal/src:filter_bank_log",
"//signal/src:filter_bank_spectral_subtraction",
"//signal/src:filter_bank_square_root",
"//signal/src:irfft",
"//signal/src:overlap_add",
"//signal/src:rfft",
"//signal/src:window",
Expand Down
158 changes: 158 additions & 0 deletions signal/micro/kernels/fft_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,164 @@ TF_LITE_MICRO_TEST(RfftTestSize512Int32) {
g_gen_data_size_fft_length_512_int32, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength64Float) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const float input[] = {256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const float golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
float output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_float,
g_gen_data_size_fft_length_64_int16, output, 1e-7));
}

TF_LITE_MICRO_TEST(IrfftTestLength64Int16) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const int16_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const int16_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int16_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT16();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int16_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int16,
g_gen_data_size_fft_length_64_int16, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength64Int32) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const int32_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const int32_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int32_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength64Int32OuterDims4) {
constexpr int kOutputLen = 64;
constexpr int kOuterDim = 2;
int input_shape[] = {3, kOuterDim, kOuterDim, 66};
const int32_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {3, kOuterDim, kOuterDim, kOutputLen};
const int32_t golden[] = {
256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int32_t output[kOuterDim * kOuterDim * kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength512Float) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
float output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, tflite::kIrfftFloatLength512Input,
output_shape, tflite::kIrfftFloatLength512Golden,
*registration, g_gen_data_fft_length_512_float,
g_gen_data_size_fft_length_512_float, output, 1e-7));
}

TF_LITE_MICRO_TEST(IrfftTestLength512Int16) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
int16_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT16();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::testing::TestFFT<int16_t>(
input_shape, tflite::kIrfftInt16Length512Input,
output_shape, tflite::kIrfftInt16Length512Golden,
*registration, g_gen_data_fft_length_512_int16,
g_gen_data_size_fft_length_512_int16, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength512Int32) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
int32_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::testing::TestFFT<int32_t>(
input_shape, tflite::kIrfftInt32Length512Input,
output_shape, tflite::kIrfftInt32Length512Golden,
*registration, g_gen_data_fft_length_512_int32,
g_gen_data_size_fft_length_512_int32, output, 0));
}

TF_LITE_MICRO_TEST(FftAutoScaleTestSmall) {
constexpr int kTensorsSize = 8;
int shape[] = {1, 8};
Expand Down
Loading

0 comments on commit 55037d2

Please sign in to comment.