Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Keras 3 example for "Audio track separation" #2003

Merged

Conversation

johacks
Copy link
Contributor

@johacks johacks commented Dec 10, 2024

Hi, I saw "Audio track separation" in the call for contributions, so I implemented an example that separates the vocal track from songs in the MUSDB18 dataset.

Some notes:

  • The code has been tested on all frameworks.
  • I have also uploaded the notebook running for 3 epochs on all frameworks with A100 GPU:
    • JAX: 405s on average per epoch.
    • Torch: 615s on average per epoch.
    • Tensorflow: 261s on average per epoch.
  • I can confirm the convert script runs ok and the keras page generates correctly.

Please let me know if there is anything wrong with the provided script I may have missed.

Thanks!

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! This is excellent work, I enjoyed reading through it 👍

import numpy as np
import soundfile as sf
from IPython import display
from keras import callbacks, layers, ops, optimizers, saving, utils
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only keep layers, ops, callbacks -- they are used many times. However optimizers, saving, utils are only used 1-2x so you can just do e.g. keras.optimizers.Adam.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up needing many more uses of saving in refactor, so I kept that import, removed the rest



@saving.register_keras_serializable()
class TDF(layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can there be a more explicit name?



@saving.register_keras_serializable()
class TFC(layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here?

else:
model = tfc_tdf_net(keras.Input(sample_batch_x.shape[1:]), name="tfc_tdf_net")

model.summary()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be useful to plot the model, or is the result too busy? e.g. keras.utils.plot_model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how it would look that after a refactoring i've made grouping the TFC_TDF blocks.
It's a little long, but actually simple in structure. Regardless, It can be further refactored to group into decoder and encoder blocks, but perhaps it wont be as informative.

model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok -- your call on whether to include it.

@johacks
Copy link
Contributor Author

johacks commented Dec 15, 2024

Hi @fchollet ,

Thanks for the feedback!

I did a small refactor to further group TFC_TDF and Downsample/Upsample blocks into custom layers so they could be better visualized in a plot. I also renamed symbols to further avoid abbreviations and updated some docstrings to reflect the changes.

The code is currently working on all frameworks after these changes. Also the convert script still runs correctly.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update -- the changes are looking good! Please add the generated files.

@johacks
Copy link
Contributor Author

johacks commented Dec 16, 2024

Hi again,

I have just pushed the generated files. I also updated layer TimeFrequencyTransformBlock to use a flat list instead of list of tuples, because the latter was resulting in weights not being properly tracked if saving the model.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - Thank you for the excellent contribution!

@fchollet fchollet merged commit 0ddb810 into keras-team:master Dec 16, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants