Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bayesian layers #240

Open
neginjv opened this issue Mar 13, 2020 · 8 comments
Open

bayesian layers #240

neginjv opened this issue Mar 13, 2020 · 8 comments

Comments

@neginjv
Copy link

neginjv commented Mar 13, 2020

@dustinvtran
Hi, I have a question. Can I build bayesian layers with sequential models with lstm using edward2?
I mean building sequential model with LSTMCellReparameterization.
If yes please help me. I dont know how to build a sequential model with LSTMCellReparameterization layers.

@neginjv
Copy link
Author

neginjv commented Mar 13, 2020

@dustinvtran Please help me. I want to build a sequential model with LSTMCellReparameterization.
I really need help.

@dustinvtran
Copy link
Member

dustinvtran commented Mar 16, 2020

@dusenberrymw

There's a minimal code snippet in the paper. It should be plug-and-play with any LSTM example you currently have.

@dusenberrymw
Copy link
Member

Here's a runnable snippet using the reparameterized LSTM cell. As Dustin said, ed.layers.LSTMCellReparameterization is a drop-in replacement for tf.keras.layers.LSTMCell, provided that model.losses is used as part of the loss.

import tensorflow.compat.v2 as tf
import edward2 as ed

num_examples = 2
num_timesteps = 3
input_dim = 4
rnn_dim = 10

inputs = tf.random.normal([num_examples, num_timesteps, input_dim])
labels = tf.random.normal([num_examples])
cell = ed.layers.LSTMCellReparameterization(rnn_dim)
model = tf.keras.Sequential([
  tf.keras.layers.RNN(cell),
  tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adam(0.001)

for i in range(10):
  with tf.GradientTape() as tape:
    outputs = model(inputs)
    nll = tf.reduce_mean((labels-outputs)**2)
    kl = sum(model.losses)
    loss = nll + kl
  grads = tape.gradient(loss, model.variables)
  grads_and_vars = zip(grads, model.variables)
  optimizer.apply_gradients(grads_and_vars)
  print(f"Loss at step {i}: {loss}")
Loss at step 0: 1412.79296875
Loss at step 1: 1412.1939697265625
Loss at step 2: 1411.6484375
Loss at step 3: 1411.123779296875
Loss at step 4: 1410.535888671875
Loss at step 5: 1410.012451171875
Loss at step 6: 1409.4478759765625
Loss at step 7: 1408.9200439453125
Loss at step 8: 1408.29638671875
Loss at step 9: 1407.8123779296875

@neginjv
Copy link
Author

neginjv commented Mar 21, 2020

@dusenberrymw Hi. Thanks a lot for your help. One another question, why the value of loss function is so high in this approach?

@neginjv
Copy link
Author

neginjv commented Mar 22, 2020

@dusenberrymw Hi. I have some questions. What is the benefit of this LSTMCellReparameterization compared to ordinary lstm? Because I think the value of loss function is so high in this approach. And what is the difference between LSTMCellReparameterization and LSTMCellFlipout?
How can I calculate accuracy in every epoch?

@neginjv
Copy link
Author

neginjv commented Mar 31, 2020

@dusenberrymw Hi, could you please help me with my questions? I really need them for my master thesis and there isn't so much information about that questions anywhere else. Thanks a lot.

@YutianPangASU
Copy link

YutianPangASU commented May 7, 2020

Here's a runnable snippet using the reparameterized LSTM cell. As Dustin said, ed.layers.LSTMCellReparameterization is a drop-in replacement for tf.keras.layers.LSTMCell, provided that model.losses is used as part of the loss.

import tensorflow.compat.v2 as tf
import edward2 as ed

num_examples = 2
num_timesteps = 3
input_dim = 4
rnn_dim = 10

inputs = tf.random.normal([num_examples, num_timesteps, input_dim])
labels = tf.random.normal([num_examples])
cell = ed.layers.LSTMCellReparameterization(rnn_dim)
model = tf.keras.Sequential([
  tf.keras.layers.RNN(cell),
  tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adam(0.001)

for i in range(10):
  with tf.GradientTape() as tape:
    outputs = model(inputs)
    nll = tf.reduce_mean((labels-outputs)**2)
    kl = sum(model.losses)
    loss = nll + kl
  grads = tape.gradient(loss, model.variables)
  grads_and_vars = zip(grads, model.variables)
  optimizer.apply_gradients(grads_and_vars)
  print(f"Loss at step {i}: {loss}")
Loss at step 0: 1412.79296875
Loss at step 1: 1412.1939697265625
Loss at step 2: 1411.6484375
Loss at step 3: 1411.123779296875
Loss at step 4: 1410.535888671875
Loss at step 5: 1410.012451171875
Loss at step 6: 1409.4478759765625
Loss at step 7: 1408.9200439453125
Loss at step 8: 1408.29638671875
Loss at step 9: 1407.8123779296875

@dustinvtran
Is this running with TF2.0? I'm getting some type error in my code,

lstm_cell = ed.layers.LSTMCellReparameterization(self.lstm_hidden_dim, activation='tanh')
lstm_out = tf.keras.layers.RNN(cell=lstm_cell, return_sequences=True)(lstm_in)

{TypeError}Failed to convert object of type <class 'tuple'> to Tensor. Contents: (Dimension(3), 256). Consider casting elements to a supported type.

The error comes from the input tensor lstm_in. I'm using TF1.14, is this causing trouble?

@YutianPangASU
Copy link

Here's a runnable snippet using the reparameterized LSTM cell. As Dustin said, ed.layers.LSTMCellReparameterization is a drop-in replacement for tf.keras.layers.LSTMCell, provided that model.losses is used as part of the loss.

import tensorflow.compat.v2 as tf
import edward2 as ed

num_examples = 2
num_timesteps = 3
input_dim = 4
rnn_dim = 10

inputs = tf.random.normal([num_examples, num_timesteps, input_dim])
labels = tf.random.normal([num_examples])
cell = ed.layers.LSTMCellReparameterization(rnn_dim)
model = tf.keras.Sequential([
  tf.keras.layers.RNN(cell),
  tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adam(0.001)

for i in range(10):
  with tf.GradientTape() as tape:
    outputs = model(inputs)
    nll = tf.reduce_mean((labels-outputs)**2)
    kl = sum(model.losses)
    loss = nll + kl
  grads = tape.gradient(loss, model.variables)
  grads_and_vars = zip(grads, model.variables)
  optimizer.apply_gradients(grads_and_vars)
  print(f"Loss at step {i}: {loss}")
Loss at step 0: 1412.79296875
Loss at step 1: 1412.1939697265625
Loss at step 2: 1411.6484375
Loss at step 3: 1411.123779296875
Loss at step 4: 1410.535888671875
Loss at step 5: 1410.012451171875
Loss at step 6: 1409.4478759765625
Loss at step 7: 1408.9200439453125
Loss at step 8: 1408.29638671875
Loss at step 9: 1407.8123779296875

@dustinvtran
Is this running with TF2.0? I'm getting some type error in my code,

lstm_cell = ed.layers.LSTMCellReparameterization(self.lstm_hidden_dim, activation='tanh')
lstm_out = tf.keras.layers.RNN(cell=lstm_cell, return_sequences=True)(lstm_in)

{TypeError}Failed to convert object of type <class 'tuple'> to Tensor. Contents: (Dimension(3), 256). Consider casting elements to a supported type.

The error comes from the input tensor lstm_in. I'm using TF1.14, is this causing trouble?

@dustinvtran
I upgraded everything to TF2.0 and it works. However, I still want to ask if there is a way you could let me use TF1.14 with the lstmReparameterization layer? Like modify the source code somewhere?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants