Unconstrained Channel Pruning · Paper
UPSCALE: Unconstrained Channel Pruning @ ICML 2023
Alvin Wan, Hanxiang Hao, Kaushik Patnaik, Yueyang Xu, Omer Hadad, David Güera, Zhile Ren, Qi Shan
By removing constraints from existing pruners, we improve ImageNet accuracy for post-training pruned models by 2.1 points on average - benefiting DenseNet (+16.9), EfficientNetV2 (+7.9), and ResNet (+6.2). Furthermore, for these unconstrained pruned models, UPSCALE improves inference speeds by up to 2x over a baseline export.
Install our package.
pip install apple-upscale
Mask and prune channels, using the default magnitude pruner.
import torch, torchvision
from upscale import MaskingManager, PruningManager
x = torch.rand((1, 3, 224, 224), device='cuda')
model = torchvision.models.get_model('resnet18', pretrained=True).cuda() # get any pytorch model
MaskingManager(model).importance().mask()
PruningManager(model).compute([x]).prune()
We provide a number of pruning heuristics out of the box:
You can pass the desired heuristic into the UpscaleManager.mask
method call. You can also configure the pruning ratio in UpscaleManager.mask
. A value of 0.25
means 25% of channels are set to zero.
from upscale.importance import LAMP
MaskingManager(model).importance(LAMP()).mask(amount=0.25)
You can also zero out channels using any method you see fit.
model.conv0.weight[:, 24] = 0
Then, run our export.
PruningManager(model).compute([x]).prune()
You may want direct access to network segments to build a heavily-customized pruning algorithm.
for segment in MaskingManager(model).segments():
# prune each segment in the network independently
for layer in segment.layers:
# layers in the segment
NOTE: See src/upscale/pruning/README.md for more details on how the core export algorithm code is organized.
Clone and setup.
git clone git@github.com:apple/ml-upscale.git
cd upscale
pip install -e .
Run tests.
py.test src tests --doctest-modules
Follow the development installation instructions to have the paper code under paper/
available.
To run the baseline unconstrained export, pass baseline=True
to PruningManager.prune
.
PruningManager(model).compute([x]).prune(baseline=True)
To reproduce the paper results, run
python paper/main.py resnet18
Plug in any model in the torchvision.models
namespace.
usage: main.py [-h] [--side {input,output} [{input,output} ...]]
[--method {constrained,unconstrained} [{constrained,unconstrained} ...]]
[--amount AMOUNT [AMOUNT ...]] [--epochs EPOCHS]
[--heuristic {l1,l2,lamp,fpgm,hrank}] [--global] [--out OUT]
[--force] [--latency] [--clean]
model
positional arguments:
model model to prune
options:
-h, --help show this help message and exit
--side {input,output} [{input,output} ...]
prune which "side" -- producers, or consumers
--method {constrained,unconstrained} [{constrained,unconstrained} ...]
how to handle multiple branches
--amount AMOUNT [AMOUNT ...]
amounts to prune by. .6 means 60 percent pruned
--epochs EPOCHS number of epochs to train for
--heuristic {l1,l2,lamp,fpgm,hrank}
pruning heuristic
--global apply heuristic globally
--out OUT directory to write results.csv to
--force force latency rerun
--latency measure latency locally
--clean clean the dataframe
If you find this useful for your research, please consider citing
@inproceedings{wan2023upscale,
title={UPSCALE: Unconstrained Channel Pruning},
author={Alvin Wan and Hanxiang Hao and Kaushik Patnaik and Yueyang Xu and Omer Hadad and David Guera and Zhile Ren and Qi Shan},
booktitle={ICML},
year={2023}
}