-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug
No convergence using bayesian LSTMCell
#474
Comments
@dusenberrymw has quite a bit of experience with Bayesian LSTMs and may be able to help. A codebase with it is https://github.com/Google-Health/records-research/tree/master/model-uncertainty. For me, the weirdest thing is the loss curve, where it apparently takes quite a few epochs before it goes <100. The other RNNs start at <1. Some ideas:
|
Many thanks for your ideas @dustinvtran! Your are right that they are not the same unit-wise. As seen in Experiment No. 3 in the first line of the code-snippet, there the
This leads me to two questions regarding edward2 though:
I feel like this could improve the usability quite a bit and make it less complicated to switch from deterministic keras models. |
Thanks for following up.
Oops, you're correct. That's the right constant to scale by! To answer your questions:
ed.layers.LSTMCellFlipout(
units=512,
kernel_regularizer=ed.regularizers.NormalKLDivergence(scale_factor=1./dataset_size),
recurrent_regularizer=ed.regularizers.NormalKLDivergence(scale_factor=1./dataset_size),
) Here's an example doing that for a Wide Resnet CIFAR baseline. We thought about including the dataset size as a necessary argument to the layer. But ultimately, it seemed to complicate the Keras abstraction as that's a special setting of how regularizers work more broadly.
|
Hi!
ed.layers.LSTMCellFlipout(
units=512,
kernel_regularizer=ed.regularizers.NormalKLDivergence(scale_factor=1./dataset_size),
recurrent_regularizer=ed.regularizers.NormalKLDivergence(scale_factor=1./dataset_size),
) to ed.layers.LSTMCellFlipout(
units=512,
scaling_factor=1./dataset_size
) I think it would also be closer to the way it's implemented in
|
Got it. Regarding 1: It's preferable to keep the regularizer semantics because that flexibility is quite often needed for tweaking BNN layers (e.g., adding L2 regularization on top, or even swapping the KL penalty to maximize for entropy or an alternative divergence). |
I've been trying to use bayesian LSTM layers in my research and always experience the same issue that the models loss is converging but the accuracy is staying somewhere around 0.5 for a binary classification task.
To make sure it is not actually a problem within my data I set up a number of experiments using EmbeddedReberGrammar and classifying whether or not a string is a valid ERG or not. This is a fairly simple task for RNN and especially for LSTM but also used as benchmark in
Long Short-Term Memory (Hochreiter, Schmidhuber, 1997)
.I set up 4 experiments, which are all running with the same data:
Experiment No 1: Simple RNN Cell
The code to build the model is simply:
As seen below the model converges slowly but clearly and learns to classify valid strings with an accuracy around 80 %.
Experiment No 2: Standard LSTM layer
This model is converging not only faster but also with higher accuracy around 90 %. This is expected behaviour.
Experiment No 3: Bayesian Dense layers as output layers
In this case I am using a DenseFlipout layer from
tensorflow-probability
before the output layer.This model is still working fine, the training behaviour is very similiar to Experiment No 2.
Experiment No 4: LSTMCellFlipout by
edward2
without the DenseFlipout layerThis is where the problems start: as you can see in the training metrics below, the loss is still converging but the accuracy is not.
To be clear here: this is a very simple model obviously but I have actually tried different things of which none showed any improvement in the behaviour of the model. This is a list of things I tried:
tf.keras.layers.RNN(ed.layers.LSTMCellFlipout(n))
LSTMCellReparameterization
class instead of the flipout versionDo you have any examples using the bayesian LSTM cells where this behaviour does not occur? Do you have any idea where this behaviour could come from?
The text was updated successfully, but these errors were encountered: