Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we prune pre-trained model like VGG16 etc... using this optimization library #40

Closed
tejalal opened this issue Jul 25, 2019 · 13 comments
Assignees
Labels
technique:pruning Regarding tfmot.sparsity.keras APIs and docs

Comments

@tejalal
Copy link

tejalal commented Jul 25, 2019

I tried to create a model like:

`def Vgg16():
    vgg16 = VGG16(include_top=False, 
                                           weights='imagenet',
                                           input_shape = (32, 32, 3))
    top_model = Sequential()
    top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
    top_model.add(Dense(512, activation='relu'))
    top_model.add(Dropout(0.5))
    top_model.add(Dense(256, activation='relu'))
    top_model.add(Dropout(0.5))
    top_model.add(Dense(10, activation='sigmoid'))
    model = Model(vgg16.input,top_model(vgg16.output))
    return model`

and when I call

`new_pruning_params = {
      'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.5,
                                                   final_sparsity=0.80,
                                                   begin_step=0,
                                                   end_step=end_step,
                                                   frequency=100)
}

**pruned_model = sparsity.prune_low_magnitude(loaded_model, **new_pruning_params)`**

it generates error as:
Please initialize Prune with a supported layer. Layers should either be a PrunableLayer instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.sequential.Sequential'>

@s36srini
Copy link

s36srini commented Jul 25, 2019

Firstly,
include_top=False means that you are changing the input, so you'll firstly want to do this:
model = Model(top_model.input, vgg16.output), this will combine the input of top_model along with the sequential layers and connect it to the input of vgg16 (without the input layer), and have the output remain the same as vgg16.

Secondly, by pruning the whole model, you don't get to specify which layers you want to prune; it is only necessary to prune layers that have a high number of trainable parameters. In my code, I only prune pointwise convolutional layers as they contain 76% of the model's parameters.
Here's my code for reference:

mobileNet = tf.keras.applications.MobileNet(weights=None) # Not ImageNet 2012 trained weights

end_step = np.ceil(1.0 * NUM_TRAIN_SAMPLES / FLAGS.batch_size).astype(np.int32) * EPOCHS

pruning_schedule = sparsity.PolynomialDecay(
                        initial_sparsity=0.0, final_sparsity=0.5,
                        begin_step=0, end_step=end_step, frequency=100)

#layer.input_shape[-1]
pruned_model = tf.keras.Sequential()
for layer in mobileNet.layers:
    if(re.match(r"conv_pw_\d+$", layer.name)):
         pruned_model.add(sparsity.prune_low_magnitude(
            layer,
            pruning_schedule,
            block_size=(1,1)
         ))
    else:
        pruned_model.add(layer)
        
opt = tf.train.AdamOptimizer()
pruned_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

Lastly, I forgot to mention, you will want to use keras from tensorflow.python;
from tensorflow.python import keras -> this is different than import keras; and using keras without this import will lead to headaches.

@s36srini
Copy link

I also forgot to mention, you want to initialize the sequential model as tf.keras.Sequential() not keras.Sequential()

@Cospel
Copy link

Cospel commented Oct 6, 2019

I'm using tf=2.0.0 library and get same error:

ValueError: Please initialize `Prune` with a supported layer. Layers should either be a `PrunableLayer` instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.training.Model'>

My code looks like this:

        model = tf.keras.Sequential(
            [
                tf.keras.applications.MobileNetV2(weights="imagenet", input_shape=(224, 224, 3), include_top=False),
                tf.keras.layers.GlobalAveragePooling2D(),
                tf.keras.layers.Dense(256, activation="relu", name="descriptor"),
                tf.keras.layers.Dense(2, activation="softmax", name="probs"),
            ]
        )

        model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
                initial_sparsity=0.0, final_sparsity=0.5, begin_step=3, end_step=5
            ))

@alanchiao
Copy link

alanchiao commented Nov 11, 2019

In general, yes you can.

There are some caveats (e.g. lack of subclassed model support / nesting of models within models like in both examples (tejalal@ and Cospel@). Created #155 in light of this for making subclassed support better.

@Cospel
Copy link

Cospel commented Nov 12, 2019

Thank you @alanchiao. Most of the models nowadays are models that are subclassed or nested. It will be very useful if we could prune them.

@alanchiao
Copy link

@nutsiepully, @raziel for visibility

@raziel
Copy link

raziel commented Nov 13, 2019

We understand the need. The caveat is that going subclass then basically diminishes the usability of Keras abstractions we are using. Our suggestion, for now, would be to abstract some of the subclass logic into keras layers and then apply the pruning in the same manner as we currently do for the built in layers.

@nutsiepully wdyt? Do we have an example to point folks to?

@alanchiao
Copy link

Closing this issue since #155 was created. Will update this thread once #155 is fixed and we'll have almost complete coverage at that point.

@nutsiepully
Copy link
Contributor

nutsiepully commented Jan 21, 2020

Sorry, I seem to have missed this issue.

For now as @raziel suggested, the best approach is to apply pruning on a per-layer basis. You can choose the layers most important to you and just prune them. For parts of your model that are purely custom, you can use the PrunableLayer abstraction to control them.

@alanchiao alanchiao self-assigned this Jan 28, 2020
@alanchiao alanchiao added the technique:pruning Regarding tfmot.sparsity.keras APIs and docs label Feb 6, 2020
@jiayiliu
Copy link

jiayiliu commented Feb 6, 2020

mobileNet = tf.keras.applications.MobileNet(weights=None) # Not ImageNet 2012 trained weights

end_step = np.ceil(1.0 * NUM_TRAIN_SAMPLES / FLAGS.batch_size).astype(np.int32) * EPOCHS

pruning_schedule = sparsity.PolynomialDecay(
initial_sparsity=0.0, final_sparsity=0.5,
begin_step=0, end_step=end_step, frequency=100)

#layer.input_shape[-1]
pruned_model = tf.keras.Sequential()
for layer in mobileNet.layers:
if(re.match(r"conv_pw_\d+$", layer.name)):
pruned_model.add(sparsity.prune_low_magnitude(
layer,
pruning_schedule,
block_size=(1,1)
))
else:
pruned_model.add(layer)

Thank you @s36srini for sharing the code. It works well for MobileNet, but it fails for MobileNetV2. Because we cannot model.add() easily as A merge layer should be called on a list of inputs..

@sushruta
Copy link

sushruta commented Nov 6, 2020

hello, so what's the correct way of getting past the above error -

A merge layer should be called on a list of inputs.

when we use model.add(...)

I get the same error when I try to define the layers I need for pruning efficientnet-B6.


What I do is the following -

  • I load Efficientnet-B6 with weights as Imagenet
  • I freeze the first 150 layers as non-touchable
  • From 151st layer onwards, I set layer.trainable as True and also check if they are one of expand_conv or project_conv and if they are, I set them as targets for pruning in the exact same way as described by @s36srini

It gives me an error pointing at model.add(...) step.

I peeked at the code of Efficientnet and it looks like it's using Functional API. Could maxing Sequential API with Functional API result in errors like these?

@NonlinearNimesh
Copy link

Hi, I have a trained frozen model, is it possible to prune it, Any references will be a great help

Thanks.

@gnhearx
Copy link

gnhearx commented Mar 17, 2022

Hi everyone :)
I have a similar issue with pruning nested models, even if I apply the pruning wrappers per layer inside all the nested Functional API models, they don't prune.

Is this expected behaviour at all for nested models? Because I would think that if any layer in a model has that wrapper, then it will be pruned when the pruning callback is called in the training phase. Unfortunately, this does not happen. Instead everything not nested (that have pruning wrappers) do prune, and anything inside a nested model does not.

I can also confirm that if I create a model with no nested models at all, then everything I set to prune does in fact prune the way it should.

Side note:
My nested model is a pretrained VGG16 from keras and I apply pruning wrappers to each layer within the nested model.

If anyone perhaps have a solution to this or workaround that would seriously be very helpful, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
technique:pruning Regarding tfmot.sparsity.keras APIs and docs
Projects
None yet
Development

No branches or pull requests

10 participants