-
Notifications
You must be signed in to change notification settings - Fork 14
Reaches loose spatial relationships when shuffle=True during training in PyTorch #202
Comments
I don't see anything wrong with that, and I think it's your only good option if you're wanting to shuffle otherwise you're messing up temporal or spatial learning. Also, I think disaggregating that first dimension makes for a more readable data shape |
Sounds good, I'll put this on my to-do list. Unfortunately I think this might require some upstream/downstream changes in the data prep and predict/evaluate workflows too. Before too much effort goes into it, it might be worth having a conversation about how/if we want to make the workflows more agnostic to input/output shapes (as is it's mostly geared towards the 3 dimensional inputs/outputs currently used in the RGCN). If you have any thoughts, let me know! |
I'm not sure what you mean by this. From my understanding, the input shape is I'm also not sure that moving to a 4-dimensional array would work. We have that at one point in the prep steps and then we move to a 3d array: river-dl/river_dl/preproc_utils.py Lines 517 to 523 in 4f1500a
We move to a 3d array b/c that is the shape that the TF LSTM model is expecting: I made this decision with TF in mind, but it looks like Pytorch's LSTM also is asking for a 3D array (or 2D array .... not sure how that works??) https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html |
@jsadler2, all true/good points! A couple of clarifications, and thanks for the detailed response, I think this is actually a more nuanced conversation then it appears on the surface.
Totally, a 3d array makes sense for a basic LSTM, but with the addition of the graph convolution layer what the model considers a sequence moves from the reach scale to the network scale meaning we're functionally adding a dimension.
That's how I understand it as well. When I said "input shape" what I meant was batch shape. What I'm proposing is that we leave the input shape as
You're right, I don't think it would work for the LSTM, and I think we'd have to tweak the RGCN a little bit. I'm not totally sure what that would look like, because the LSTM portion does expect a 3d array, but it'd be nice to be able to shuffle our batches while maintaining their spatial relationships (basically the order that we feed the sequences into the LSTM within a given batch). Does that help at all? |
That's a good point that I kinda overlooked. That's my impression too that an RGCN will always need "one batch" of all the (same-ordered) streams for a given date to cooperate with the adjacency matrix. My original thought from this thread was that from the |
Yes. Very helpful. Thanks, @SimonTopp.
Yes. I think that's exactly what we want. Good, succinct summary. |
I think in TF we can achieve this by converting our numpy arrays into
For the torch implementation, I think it will be a little more verbose, but it looks like this might do the trick: |
btw, I did some experiments in a notebook here: https://colab.research.google.com/drive/1w260ctpEvRoPvPLKFNg-u2eGCOtXvJne?usp=sharing |
Nice, thanks Jeff! At first glance that looks like what we're looking for! And +1 to the |
When training the RGCN in PyTorch, if
shuffle=True
then the reaches get mixed up during training and no longer maintain the relationships in the adjacency matrix.shuffle
is false by default, but it's an easy thing to overlook. @jsadler2, @jdiaz4302, @jds485 not sure if any of you are using River-dl RGCN workflows, but wanted to give you a heads up if so.Not sure what the best way to safeguard against this is. Right now the RGCN treats each reach time series as a training instance. I think it's more accurate to think of an entire sequence for the entire network as a training instance, so when you shuffle them you're shuffling the order to model sees sequences for the entire network. Basically going from the input shape of
[n reaches, sequence length, n features]
to[batch size, n reaches, sequence length, n features]
. Does that make sense? Any hot takes?The text was updated successfully, but these errors were encountered: