Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow Variable Aggregation testing in Keras 3.7.0 #20601

Open
dryglicki opened this issue Dec 5, 2024 · 4 comments
Open

Tensorflow Variable Aggregation testing in Keras 3.7.0 #20601

dryglicki opened this issue Dec 5, 2024 · 4 comments
Assignees

Comments

@dryglicki
Copy link
Contributor

dryglicki commented Dec 5, 2024

Forking this from: #20568

Specifically tagging @james77777778.

tl; dr: It's SUM.


The recommendation was:

Since there is no reproducible script for debugging, this is a random guess:
Before 28d39c0, the aggregation behavior might have been broken due to incorrect propagation of the aggregation attr to the variables.
Essentially, the training would be an aggregation=None setting (the default value for tf.Variable), which is likely incorrect.

@das-apratim could you first try training the model without using tf.distribute.MirroredStrategy() to check if any NaNs occur?

If the training runs without issues, try adding back tf.distribute.MirroredStrategy() and modifying _map_aggregation in keras/src/backend/tensorflow/core.py as follows:

    mapping = {
        "none": tf.VariableAggregation.NONE,
        "sum": tf.VariableAggregation.NONE,
        "mean": tf.VariableAggregation.NONE,
        "only_first_replica": tf.VariableAggregation.NONE,
    }

This adjustment reflects the behavior in Keras 3.6. See if the training runs well with this change.

If it does, incrementally restore the original mapping to identify which key is causing the issue. Here's a general guideline:

"sum" is associated with metrics.
"mean" is associated with model/optimizer weights.
"only_first_replica" is associated with counters.

I can make the following report for you:

NONE, NONE, NONE,  NONE -- okay
NONE,  SUM, NONE,  NONE -- fail, metrics nan first. Regression metrics go inf, "counting" metrics go nan; 123/341
NONE, NONE, MEAN,  NONE -- okay
NONE, NONE, NONE, FIRST -- okay
NONE, NONE, MEAN, FIRST -- okay

If it matters at all, this was using a Logical split of an RTX A6000 Ada for testing.

@dryglicki
Copy link
Contributor Author

Follow-up comment. Here's the code for my own MAE metric. If I'm doing something wrong, please let me know.

class MAE(K.metrics.Metric):
    def __init__(self,
                 name = 'mae',
                 min_val = 163.,
                 max_val = 333.,
                 remap = True,
                 invert_remap = True,
                 **kwargs):
        super().__init__(name = name, **kwargs)
        self.min_val = min_val
        self.max_val = max_val
        self.remap = remap
        self.mae_tracker = K.metrics.Mean(name = f'{name}_tracker')
        self.invert_remap = invert_remap

    def _remap(self, inputs):
        if self.invert_remap:
            return self.max_val - (inputs * (self.max_val - self.min_val))
        else:
            return self.min_val + (inputs * (self.max_val - self.min_val))

    def update_state(self, y_true, y_pred, sample_weight = None):
        if self.remap:
            y_true = self._remap(y_true)
            y_pred = self._remap(y_pred)
        mae = kops.mean(kops.abs(y_true - y_pred), axis=[1,2,3,4])
#       mae = kops.mean(kops.abs(y_true - y_pred))
        self.mae_tracker.update_state(mae)

    def result(self):
        return self.mae_tracker.result()

    def reset_state(self):
        self.mae_tracker.reset_state()

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'name' : self.name,
            'min_val' : self.min_val,
            'max_val' : self.max_val,
            'remap'  : self.remap,
            'invert_remap' : self.invert_remap
            })
        return config

@james77777778
Copy link
Contributor

First of all, I'm not very familiar with tf.distribute, and it's difficult for me to debug without a simple code snippet.

It seems that the current codebase is missing tf.VariableSynchronization, as implemented in tf-keras:
https://github.com/keras-team/tf-keras/blob/master/tf_keras/metrics/base_metric.py#L347

I can try adding it, but I have no idea if it will resolve the issue.
@fchollet any chance you could provide some inputs?

@dryglicki
Copy link
Contributor Author

Hi, James.

I understand your frustration. I am frustrated, too. To be totally honest, I am pretty sure my "waiting for Keras 3" is going to get me fired or something, since I had a fully functioning prototype precipitation Nowcasting GAN in Keras 2/TF 2.15 that now I can't train properly because of the broken legacy Tensorflow support and because I insisted on jumping feet first into Keras 3. My mistake. Furthermore, it looks like Keras doesn't support sub-classed model instances with multiple component models as part of your JAX interface, rendering Keras totally useless for my use case. My team will now be back-porting our system to tf-keras before moving to pytorch entirely.

It is unfortunate that as you all are trying to fix and modify things, you don't have a stock multi-GPU testing environment that isn't a Google colab where you can test these obviously back-breaking changes for those of us who are trying to migrate from Keras 2 to Keras 3. Parallel training for Tensorflow is utterly broken in Keras 3, regardless of whether or not it's MirroredStrategy for the issues mentioned in this and the other threads and MultiWorkerMirroredStrategy appears to be not supported at all.

It isn't difficult to see that Google is going to sunset Tensorflow. That much is clear. JAX is the future for Google. But you all had to have understood that the majority of your user base were Tensorflow users, right?

I suppose my request here is that you make it clear that if you have certain use cases from K2 that include but are not limited to:

  • Subclassed models with multiple components such as GANs
  • Legacy code that is heavily dependent upon using Tensorflow in a parallelized environment

that one shouldn't use Keras 3 for the time being.

@fchollet
Copy link
Collaborator

fchollet commented Dec 6, 2024

It seems that the current codebase is missing tf.VariableSynchronization, as implemented in tf-keras:

I thought it was only relevant for MultiWorkerMirroredStrategy, which we don't support in Keras 3. But I am not a tf.distribute expert and I don't know for sure.

The fact that the value needs to be different for TPU only, in a way that isn't handled automatically by TF, is weird.

I think we can add it -- when we create a tf.Variable, we can do:

if tf.distribute.has_strategy():
    strategy = tf.distribute.get_strategy()
    if is_tpu_strategy(strategy):
        synchronization = tf.VariableSynchronization.ON_WRITE
else:
    synchronization = tf.VariableSynchronization.AUTO

then pass synchronization to the TF variable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants