-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.load(pruned_model) #26
Comments
@wuzhiyang2016 I've got the same problem. I added the ModifiedVGG16Model class in my test script, with no success. Do you still have the same issue? |
@jacobgil Any idea? |
The code in the repo saves the entire model with pickling, instead of the state dict, which is actually a bad practice. state_dict = model.state_dict()
save(state_dict, 'model.chkpt') Then you can load the model like this: model = ModifiedVGG16Model()
checkpoint = torch.load(checkpoint_path, \
map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint)
model.eval() |
like the author, when we load the saved model , ModifiedVGG16Model() should be defined |
@jacobgil Hey! There is a problem with size mismatch. For example: Is there way to solve it ? |
@jacobgil : Since the pruned model will be having different in and out filters in each layer. Is there any way to load the pruned model state dictionary with original model class? |
@ms-krajesh Probably not a solution but I just rewrote model architecture with altered layer numbers. |
when i use torch.load() to load pruned model , error happened: AttributeError: 'module' object has no attribute 'ModifiedVGG16Model', anyone meet this problem?
The text was updated successfully, but these errors were encountered: