-
Notifications
You must be signed in to change notification settings - Fork 9
/
adt_expam.py
254 lines (223 loc) · 11 KB
/
adt_expam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.distributions.utils import clamp_probs
from torchvision import datasets, transforms
from models.wideresnet import *
from models.resnet import *
from models.generator import define_G, get_scheduler, set_requires_grad
parser = argparse.ArgumentParser(description='PyTorch Adversarial Distributional Training')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
help='input batch size for testing (default: 128)')
parser.add_argument('--epochs', type=int, default=76, metavar='N',
help='number of epochs to train')
parser.add_argument('--weight-decay', '--wd', default=2e-4,
type=float, metavar='W')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--epsilon', default=8.0/255.0,
help='perturbation')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--model-dir', default='./model-cifar-wideResNet',
help='directory of model for saving checkpoint')
parser.add_argument('--save-freq', '-s', default=5, type=int, metavar='N',
help='save frequency')
# parameters for the generator
parser.add_argument('--net_G', type=str, default='resnet_3blocks',
help='net for G')
parser.add_argument('--opt_G', type=str, default='adam',
help='optimizer for G')
parser.add_argument('--lr_G', type=float, default=0.0002,
help='initial learning rate for adam')
parser.add_argument('--lr_policy_G', type=str, default='linear',
help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_iters_G', type=int, default=30,
help='multiply by a gamma every lr_decay_iters iterations')
parser.add_argument('--niter_G', type=int, default=100,
help='# of iter at starting learning rate')
parser.add_argument('--niter_decay_G', type=int, default=50,
help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--beta1_G', type=float, default=0.5,
help='momentum term of adam')
parser.add_argument('--ngf_G', type=int, default=256,
help='number of hidden unit in G')
parser.add_argument('--lbd', type=float, default=0.01,
help='lambda for the entropy term')
parser.add_argument('--dataset', type=str, default='cifar10',
help='dataset')
args = parser.parse_args()
print(args)
# settings
model_dir = args.model_dir
if not os.path.exists(model_dir):
os.makedirs(model_dir)
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}
# setup data loader
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
])
if args.dataset == 'cifar10':
trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)
elif args.dataset == 'cifar100':
trainset = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR100(root='../data', train=False, download=True, transform=transform_test)
elif args.dataset == 'svhn':
args.epsilon = 4.0 / 255.0
trainset = torchvision.datasets.SVHN(root='../data', split='train', download=True, transform=transform_test)
testset = torchvision.datasets.SVHN(root='../data', split='test', download=True, transform=transform_test)
else:
raise NotImplementedError
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs)
def grad_inv(grad):
return grad.neg()
def generate_pert(phi, return_entropy=False):
adv_mean = phi[:, :3, :, :]
adv_std = F.softplus(phi[:, 3:, :, :])
rand_noise = torch.randn_like(adv_std)
adv = torch.tanh(adv_mean + rand_noise * adv_std)
# omit the constants in -logp
negative_logp = (rand_noise ** 2) / 2. + (adv_std + 1e-8).log() + (1 - adv ** 2 + 1e-8).log()
entropy = negative_logp.mean() # entropy
if return_entropy:
return adv, entropy
else:
return adv
def train(args, model, device, train_loader, optimizer, epoch, G, optimizer_G):
model.train()
G.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
# calculate the two-step gradients as inputs to the generator
model.eval()
data.requires_grad_()
loss_ = F.cross_entropy(model(data), target)
grad = torch.autograd.grad(loss_, [data])[0].detach()
data.requires_grad_(False)
x_fgsm = torch.clamp(data + args.epsilon * grad.sign(), 0.0, 1.0).detach()
x_fgsm.requires_grad_()
grad_fgsm = torch.autograd.grad(F.cross_entropy(model(x_fgsm), target), [x_fgsm])[0].detach()
x_fgsm.requires_grad_(False)
phi = G(torch.cat([data, grad, grad_fgsm], 1))
model.train()
optimizer.zero_grad()
optimizer_G.zero_grad()
pert, entropy = generate_pert(phi, return_entropy=True)
x_adv = torch.clamp(data + args.epsilon * pert, 0.0, 1.0)
x_adv.register_hook(grad_inv)
loss = F.cross_entropy(model(x_adv), target)
(loss - args.lbd * entropy).backward()
optimizer.step()
optimizer_G.step()
# print progress
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def eval_train(model, device, train_loader):
model.eval()
train_loss = 0
correct = 0
with torch.no_grad():
for data, target in train_loader:
data, target = data.to(device), target.to(device)
output = model(data)
train_loss += F.cross_entropy(output, target, size_average=False).item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
train_loss /= len(train_loader.dataset)
print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
train_loss, correct, len(train_loader.dataset),
100. * correct / len(train_loader.dataset)))
training_accuracy = correct / len(train_loader.dataset)
return train_loss, training_accuracy
def eval_test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, size_average=False).item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
test_accuracy = correct / len(test_loader.dataset)
return test_loss, test_accuracy
def adjust_learning_rate(optimizer, epoch):
"""decrease the learning rate"""
lr = args.lr
if epoch >= 75:
lr = args.lr * 0.1
if epoch >= 90:
lr = args.lr * 0.01
if epoch >= 100:
lr = args.lr * 0.001
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def main():
# init model, ResNet18() can be also used here for training
model = WideResNet(depth=28, num_classes=100 if args.dataset == 'cifar100' else 10).to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
G = define_G(9, 6, args.ngf_G, args.net_G)
if args.opt_G == 'adam':
optimizer_G = torch.optim.Adam(G.parameters(), lr=args.lr_G, betas=(args.beta1_G, 0.999))
elif args.opt_G == 'sgd':
optimizer_G = torch.optim.SGD(G.parameters(), lr=args.lr_G, weight_decay=1e-4)
elif args.opt_G == 'momentum':
optimizer_G = torch.optim.SGD(G.parameters(), lr=args.lr_G, momentum=0.9, weight_decay=1e-4)
elif args.opt_G == 'rmsprop':
optimizer_G = torch.optim.RMSprop(G.parameters(), lr=args.lr_G)
scheduler_G = get_scheduler(optimizer_G, args)
print(' + Number of params of classifier: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
print(' + Number of params of generator: {}'.format(sum([p.data.nelement() for p in G.parameters()])))
for epoch in range(1, args.epochs + 1):
# adjust learning rate for SGD
adjust_learning_rate(optimizer, epoch)
# adversarial training
train(args, model, device, train_loader, optimizer, epoch, G, optimizer_G)
scheduler_G.step()
# evaluation on natural examples
print('================================================================')
eval_train(model, device, train_loader)
eval_test(model, device, test_loader)
print('================================================================')
# save checkpoint
if epoch % args.save_freq == 0 or epoch > 70:
torch.save(model.state_dict(),
os.path.join(model_dir, 'model-wideres-epoch{}.pt'.format(epoch)))
torch.save(optimizer.state_dict(),
os.path.join(model_dir, 'opt-wideres-checkpoint_epoch{}.tar'.format(epoch)))
torch.save(G.state_dict(),
os.path.join(model_dir, 'generator-epoch{}.pt'.format(epoch)))
torch.save(optimizer_G.state_dict(),
os.path.join(model_dir, 'optG_epoch{}.tar'.format(epoch)))
if __name__ == '__main__':
main()