Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Adding State Updating and # of Tasks to LSTM and RGCN models #104

Merged
merged 34 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
94d0454
adding DA capabilities to LSTM
May 25, 2021
b9f5393
updating loss functions and adding tasks to RGCN
May 26, 2021
a1ff04f
updating parameter definitions
May 26, 2021
b7e9046
updating dropout argument names
Jun 2, 2021
494b931
cleaning LSTM
jzwart Jun 2, 2021
b3f8df9
Being explicit about the number of tasks
jzwart Jun 2, 2021
34e9983
cleaning up init states call
jzwart Jun 2, 2021
10ffdbd
Updating rgcn for clarity
Jun 2, 2021
c97f179
Merge branch 'adding_da_to_models' of https://github.com/USGS-R/river…
Jun 2, 2021
cf95324
updating lambda_aux argument
Jun 2, 2021
088d4a7
making stateful based on return_state argument
Jun 2, 2021
4101879
separating output, rgcn layers
jsadler2 Jun 2, 2021
6fc687d
making a more generic multitask loss function
jsadler2 Jun 2, 2021
7c36a4f
adding SingletaskLSTMModel
jsadler2 Jun 2, 2021
c582d3d
Convert to MultitaskLSTM; update GRU classes
jsadler2 Jun 2, 2021
ab4cff1
explicit about number of tasks for y_data_components
jzwart Jun 3, 2021
3129bd5
convenience fxn weighted_masked_rmse
jsadler2 Jun 4, 2021
c8c140a
lambda_aux -> lambdas in rnns
jsadler2 Jun 4, 2021
3e54465
num_tasks, lambdas in train functions
jsadler2 Jun 4, 2021
e939460
Merge branch 'adding_da_to_models' of github.com:USGS-R/river-dl into…
jsadler2 Jun 4, 2021
17bfab8
[#106] taking train_step out in rnns
jsadler2 Jun 4, 2021
2c7580b
[#106] provide loss_func to train func; compiles rnns
jsadler2 Jun 4, 2021
3799509
[#98] multitask nse, kge functions; rm weights
jsadler2 Jun 4, 2021
a4d54e2
[#106] match train cli with train.py fxn
jsadler2 Jun 4, 2021
0f568ae
take out unneeded check on h_/c_init in RGCN
jsadler2 Jun 4, 2021
6cd7b52
[#98] don't pass weights to `fit` call
jsadler2 Jun 4, 2021
9bc45ca
add `num_tasks` to predict fxns
jsadler2 Jun 4, 2021
63bb515
Snakefile updates for lambdas, num_tasks, loss_func
jsadler2 Jun 4, 2021
8c80003
RGCN `states` attribute; just final states
jsadler2 Jun 4, 2021
3eca6f2
typo in predict
jsadler2 Jun 4, 2021
4527d67
Black formatting and docstring corrections
jsadler2 Jun 4, 2021
8d061f8
attr for rnns
jsadler2 Jun 4, 2021
1c1641a
"outputs" -> "variables" in `num_task` docstring
jsadler2 Jun 7, 2021
8362981
can provide h_/c_init to initalize rnn
jsadler2 Jun 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 103 additions & 14 deletions river_dl/RGCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,36 @@


class RGCN(layers.Layer):
def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None):
def __init__(
self,
hidden_size,
A,
tasks=1,
dropout=0, # I propose changing this to 'recurrent_dropout' and adding another option for 'dropout' since these will map to the options for the tf LSTM layers https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTMCell ; and also https://arxiv.org/pdf/1512.05287.pdf
jzwart marked this conversation as resolved.
Show resolved Hide resolved
jzwart marked this conversation as resolved.
Show resolved Hide resolved
flow_in_temp=False,
rand_seed=None,
return_state=False
):
"""

:param hidden_size: [int] the number of hidden units
:param A: [numpy array] adjacency matrix
:param tasks: [int] number of prediction tasks to perform - currently supports either 1 or 2 prediction tasks
jsadler2 marked this conversation as resolved.
Show resolved Hide resolved
:param dropout: [float] value between 0 and 1 for the probability of a reccurent element to be zero
:param flow_in_temp: [bool] whether the flow predictions should feed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter name suggests it'd be hard to generalize to other variables. Prod mostly to @jsadler2 to think about whether/how this parameter name and/or functionality should be adjusted to accommodate, say, DO predictions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be moved to the model level, not the layer level because we will be separating out the output layer from the RGCN layer.

into the temp predictions
:param rand_seed: [int] the random seed for initialization
:param return_state: [bool] whether the hidden (h) and cell (c) states of LSTM should be returned
"""
super().__init__()
self.hidden_size = hidden_size
self.A = A.astype("float32")
self.tasks = tasks
self.flow_in_temp = flow_in_temp
self.return_state = return_state

# set up the layer
self.lstm = tf.keras.layers.LSTMCell(hidden_size)
self.lstm = tf.keras.layers.LSTMCell(hidden_size, recurrent_dropout=dropout)

### set up the weights ###
w_initializer = tf.random_normal_initializer(
Expand Down Expand Up @@ -88,6 +102,7 @@ def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None):
shape=[hidden_size], initializer="zeros", name="b_c"
)

# will be doing two task predictions if flow_in_temp == True
if self.flow_in_temp:
# was W2
self.W_out_flow = self.add_weight(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this variable be W_out_task0 instead of specifying flow? and then W_out_temp could be W_out_task1? Or the reverse, if that's how it goes - is temp task 1 or 2 in your conventions, @jsadler2 ?

Expand All @@ -108,25 +123,45 @@ def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None):
shape=[1], initializer="zeros", name="b_out"
)
else:
# was W2
self.W_out = self.add_weight(
shape=[hidden_size, 2], initializer=w_initializer, name="W_out"
)
# was b2
self.b_out = self.add_weight(
shape=[2], initializer="zeros", name="b_out"
)
if self.tasks == 2:
# was W2
self.W_out = self.add_weight(
shape=[hidden_size, 2], initializer=w_initializer, name="W_out"
)
# was b2
self.b_out = self.add_weight(
shape=[2], initializer="zeros", name="b_out"
)
else:
# was W2
self.W_out = self.add_weight(
shape=[hidden_size, 1], initializer=w_initializer, name="W_out"
)
# was b2
self.b_out = self.add_weight(
shape=[1], initializer="zeros", name="b_out"
)
jzwart marked this conversation as resolved.
Show resolved Hide resolved

@tf.function
def call(self, inputs, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need to add a docstring to this function so that we are clear on what the arguments are.

h_list = []
c_list = []
graph_size = self.A.shape[0]
hidden_state_prev, cell_state_prev = (
tf.zeros([graph_size, self.hidden_size]),
tf.zeros([graph_size, self.hidden_size]),
)
out = []
n_steps = inputs.shape[1]
h_update = tf.cast(kwargs['h_init'], tf.float32)
jzwart marked this conversation as resolved.
Show resolved Hide resolved
c_update = tf.cast(kwargs['c_init'], tf.float32)
if self.return_state:
# set the initial h & c states to the supplied h and c states if using DA
hidden_state_prev = h_update
cell_state_prev = c_update
jzwart marked this conversation as resolved.
Show resolved Hide resolved
for t in range(n_steps):
seq, state = self.lstm(inputs[:, t, :], states=[h_update, c_update])
h, c = state # are these used anywhere?
jzwart marked this conversation as resolved.
Show resolved Hide resolved
h_graph = tf.nn.tanh(
tf.matmul(
self.A,
Expand Down Expand Up @@ -176,23 +211,77 @@ def call(self, inputs, **kwargs):

hidden_state_prev = h_update
cell_state_prev = c_update

h_list.append(h_update)
c_list.append(c_update)
jsadler2 marked this conversation as resolved.
Show resolved Hide resolved

h_list = tf.stack(h_list)
c_list = tf.stack(c_list)
h_list = tf.transpose(h_list, [1, 0, 2])
c_list = tf.transpose(c_list, [1, 0, 2])
out = tf.stack(out)
out = tf.transpose(out, [1, 0, 2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment explaining how lines 218-223 reshape h_list, c_list, and out (i.e., from what initial shapes to what final shapes)?

return out

if self.return_state:
return out, h_list, c_list
else:
return out


class RGCNModel(tf.keras.Model):
def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None):
def __init__(
self,
hidden_size,
A,
tasks=1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again I think num_tasks would be slightly clearer

dropout=0, # I propose changing this to 'recurrent_dropout' and adding another option for 'dropout' since these will map to the options for the tf LSTM layers https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTMCell ; and also https://arxiv.org/pdf/1512.05287.pdf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah do it

flow_in_temp=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See renaming comment above. This could become something like task0_in_task1 (or the reverse?)

rand_seed=None,
return_state=False
):
"""
:param hidden_size: [int] the number of hidden units
:param A: [numpy array] adjacency matrix
:param tasks: [int] number of prediction tasks to perform - currently supports either 1 or 2 prediction tasks
:param dropout: [float] value between 0 and 1 for the probability of a reccurent element to be zero
:param flow_in_temp: [bool] whether the flow predictions should feed
into the temp predictions
:param rand_seed: [int] the random seed for initialization
:param return_state: [bool] whether the hidden (h) and cell (c) states of LSTM should be returned
"""
super().__init__()
self.rgcn_layer = RGCN(hidden_size, A, flow_in_temp, rand_seed)
self.return_state = return_state
self.hidden_size = hidden_size
self.tasks = tasks
self.dropout = dropout
self.rnn_layer = tf.keras.layers.LSTM(
hidden_size,
return_sequences=True,
stateful=True,
return_state=return_state,
recurrent_dropout=dropout)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is it that this wasn't being called before? And where does self.rnn_layer get used?


self.rgcn_layer = RGCN(
hidden_size,
A,
tasks,
dropout,
flow_in_temp,
rand_seed,
return_state)

self.h_gr = None
self.c_gr = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider setting up a states property rather than h_gr and c_gr properties, to be more similar to the LSTM states setup

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about this idea, @jzwart? If we do this, you'd access it like

model = RGCNModel(bla, bla, bla)
h, c = model.states # would just be the final states

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that would be good


def call(self, inputs, **kwargs):
output = self.rgcn_layer(inputs)
batch_size = inputs.shape[0]
h_init = kwargs.get('h_init', tf.zeros([batch_size, self.hidden_size]))
c_init = kwargs.get('c_init', tf.zeros([batch_size, self.hidden_size]))
if self.return_state:
output, h_gr, c_gr = self.rgcn_layer(inputs, h_init=h_init, c_init=c_init)
self.h_gr = h_gr
self.c_gr = c_gr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work to make these three lines a one-liner? If so, I think in a way that'd be clearer (b/c we don't have to wonder where else h_gr might get used).

Suggested change
output, h_gr, c_gr = self.rgcn_layer(inputs, h_init=h_init, c_init=c_init)
self.h_gr = h_gr
self.c_gr = c_gr
output, self.h_gr, self.c_gr = self.rgcn_layer(inputs, h_init=h_init, c_init=c_init)

else:
output = self.rgcn_layer(inputs, h_init=h_init, c_init=c_init)

return output
40 changes: 24 additions & 16 deletions river_dl/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,23 @@ def samplewise_nnse_loss(y_true, y_pred):
return 1 - nnse_val


def nnse_masked_one_var(data, y_pred, var_idx):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx)
def nnse_masked_one_var(data, y_pred, var_idx, tasks):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks)
return nnse_loss(y_true, y_pred)


def nnse_one_var_samplewise(data, y_pred, var_idx):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx)
def nnse_one_var_samplewise(data, y_pred, var_idx, tasks):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks)
return samplewise_nnse_loss(y_true, y_pred)


def y_data_components(data, y_pred, var_idx):
weights = data[:, :, -2:]
y_true = data[:, :, :-2]
def y_data_components(data, y_pred, var_idx, tasks):
if tasks == 2:
weights = data[:, :, -2:]
y_true = data[:, :, :-2]
else:
weights = data[:, :, -1:]
y_true = data[:, :, :-1]
jzwart marked this conversation as resolved.
Show resolved Hide resolved

# ensure y_pred, weights, and y_true are all tensors the same data type
y_true = tf.convert_to_tensor(y_true)
Expand All @@ -99,23 +103,27 @@ def y_data_components(data, y_pred, var_idx):
return y_true, y_pred, weights


def rmse_masked_one_var(data, y_pred, var_idx):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx)
def rmse_masked_one_var(data, y_pred, var_idx, tasks):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks)
return rmse(y_true, y_pred)


def weighted_masked_rmse(lamb=0.5):
def weighted_masked_rmse(lamb=0.5, tasks=1):
"""
calculate a weighted, masked rmse.
:param lamb: [float] (short for lambda). The factor that the auxiliary loss
will be multiplied by before added to the main loss.
:param tasks: [int] number of prediction tasks to perform - currently supports either 1 or 2 prediction tasks
"""

def rmse_masked_combined(data, y_pred):
rmse_main = rmse_masked_one_var(data, y_pred, 0)
rmse_aux = rmse_masked_one_var(data, y_pred, 1)
rmse_loss = rmse_main + lamb * rmse_aux
return rmse_loss
rmse_main = rmse_masked_one_var(data, y_pred, 0, tasks)
if tasks == 2:
rmse_aux = rmse_masked_one_var(data, y_pred, 1, tasks)
rmse_loss = rmse_main + lamb * rmse_aux
jzwart marked this conversation as resolved.
Show resolved Hide resolved
return rmse_loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this change. I suggest renaming rmse_main to rmse_1task and renaming rmse_loss to rmse_2tasks.

else:
return rmse_main

return rmse_masked_combined

Expand Down Expand Up @@ -181,8 +189,8 @@ def kge_norm_loss(y_true, y_pred):
return 1 - norm_kge(y_true, y_pred)


def kge_loss_one_var(data, y_pred, var_idx):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx)
def kge_loss_one_var(data, y_pred, var_idx, tasks):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks)
return kge_loss(y_true, y_pred)


Expand Down
Loading