Skip to content

Commit

Permalink
Merge pull request #237 from Mattk70/TFJS-optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
kahst authored Jan 22, 2024
2 parents 65dd3e6 + a97a486 commit 9894557
Showing 1 changed file with 47 additions and 44 deletions.
91 changes: 47 additions & 44 deletions checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/static/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,52 +40,55 @@ class MelSpecLayerSimple extends tf.layers.Layer {
}

// Define the layer's forward pass
call(input) {
call(inputs) {
return tf.tidy(() => {

// inputs is a tensor representing the input data
input = input[0].squeeze()

// Normalize values between -1 and 1
input = tf.sub(input, tf.min(input, -1, true));
input = tf.div(input, tf.max(input, -1, true).add(0.000001));
input = tf.sub(input, 0.5);
input = tf.mul(input, 2.0);

// Perform STFT
let spec = tf.signal.stft(
input,
this.frameLength,
this.frameStep,
this.frameLength,
tf.signal.hannWindow,
);

// Cast from complex to float
spec = tf.cast(spec, 'float32');

// Apply mel filter bank
spec = tf.matMul(spec, this.melFilterbank);

// Convert to power spectrogram
spec = spec.pow(2.0);

// Apply nonlinearity
spec = spec.pow(tf.div(1.0, tf.add(1.0, tf.exp(this.magScale.read()))));

// Flip the spectrogram
spec = tf.reverse(spec, -1);

// Swap axes to fit input shape
spec = tf.transpose(spec)

// Adding the channel dimension
spec = spec.expandDims(-1);

// Adding batch dimension
spec = spec.expandDims(0);

return spec;
inputs = inputs[0];
// Split 'inputs' along batch dimension into array of tensors with length == batch size
const inputList = tf.split(inputs, inputs.shape[0])
// Perform STFT on each tensor in the array
const specBatch = inputList.map(input =>{
input = input.squeeze();
// Normalize values between -1 and 1
input = tf.sub(input, tf.min(input, -1, true));
input = tf.div(input, tf.max(input, -1, true).add(0.000001));
input = tf.sub(input, 0.5);
input = tf.mul(input, 2.0);

// Perform STFT
let spec = tf.signal.stft(
input,
this.frameLength,
this.frameStep,
this.frameLength,
tf.signal.hannWindow,
);

// Cast from complex to float
spec = tf.cast(spec, 'float32');

// Apply mel filter bank
spec = tf.matMul(spec, this.melFilterbank);

// Convert to power spectrogram
spec = spec.pow(2.0);

// Apply nonlinearity
spec = spec.pow(tf.div(1.0, tf.add(1.0, tf.exp(this.magScale.read()))));

// Flip the spectrogram
spec = tf.reverse(spec, -1);

// Swap axes to fit input shape
spec = tf.transpose(spec)

// Adding the channel dimension
spec = spec.expandDims(-1);

return spec;
})
// Convert tensor array into batch tensor
return tf.stack(specBatch)
});
}

Expand Down

0 comments on commit 9894557

Please sign in to comment.