From 0c1ccff5869ff0678653531ee46634626d102763 Mon Sep 17 00:00:00 2001 From: Tyler Yep Date: Sun, 6 Nov 2022 17:27:02 -0800 Subject: [PATCH] Fix torchvision deprecation warnings in test cases (#191) --- tests/torchinfo_xl_test.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/torchinfo_xl_test.py b/tests/torchinfo_xl_test.py index 566b04d..4c2173b 100644 --- a/tests/torchinfo_xl_test.py +++ b/tests/torchinfo_xl_test.py @@ -39,13 +39,17 @@ def test_eval_order_doesnt_matter() -> None: input_size = (1, 3, 224, 224) input_tensor = torch.ones(input_size).to(device) - model1 = torchvision.models.resnet18(pretrained=True) + model1 = torchvision.models.resnet18( + weights=torchvision.models.ResNet18_Weights.DEFAULT + ) model1.eval() summary(model1, input_size=input_size) with torch.inference_mode(): # type: ignore[no-untyped-call] output1 = model1(input_tensor) - model2 = torchvision.models.resnet18(pretrained=True) + model2 = torchvision.models.resnet18( + weights=torchvision.models.ResNet18_Weights.DEFAULT + ) summary(model2, input_size=input_size) model2.eval() with torch.inference_mode(): # type: ignore[no-untyped-call] @@ -121,8 +125,10 @@ def test_tmva_net_column_totals() -> None: def test_google() -> None: - summary(torchvision.models.googlenet(), (1, 3, 112, 112), depth=7) + google_net = torchvision.models.googlenet(init_weights=False) + + summary(google_net, (1, 3, 112, 112), depth=7) # Check googlenet in training mode since InceptionAux layers are used in # forward-prop in train mode but not in eval mode. - summary(torchvision.models.googlenet(), (1, 3, 112, 112), depth=7, mode="train") + summary(google_net, (1, 3, 112, 112), depth=7, mode="train")