Skip to content

Commit

Permalink
add support for local model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasStruppek committed Jun 22, 2022
1 parent d80aef7 commit 3e8a920
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 9 deletions.
65 changes: 65 additions & 0 deletions configs/attacking/default_attacking_local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
---
target_model:
architecture: densenet169 # architecture of target model
num_classes: 530 # number of output neurons
weights: trained_models/facescrub/densenet169_facescrub.pt # link to weight file
evaluation_model:
architecture: inception-v3 # architecture of evaluation model
num_classes: 530 # number of output neurons
weights: trained_models/facescrub/inception_v3_facescrub.pt # link to weight file

stylegan_model: stylegan2-ada-pytorch/ffhq.pkl # Path to StyleGAN2 weight.
seed: 42 # Seed used for splitting the datasets and initialize the attack.
dataset: facescrub # Target dataset, select one of [facescrub, celeba_identities, stanford_dogs_cropped, stanford_dogs_uncropped].

candidates:
num_candidates: 200 # Number of latent vectors to optimize for each target.
candidate_search:
search_space_size: 2000 # Set of randomly sampled latent vector, from which the candidates are selected.
center_crop: 800 # Crop generated images.
resize: 224 # Resize generated images (after cropping).
horizontal_flip: true # Flip the generated images horizontally in 50% of the cases.
batch_size: 25 # Batch size during the sampling process (single GPU).
truncation_psi: 0.5 # Truncation psi for StyleGAN.
truncation_cutoff: 8 # Truncation cutoff for StyleGAN.

attack:
batch_size: 25 # Batch size per GPU.
num_epochs: 50 # Number of optimization iterations per batch.
targets: 0 # Specify the targeted classes, either a single class index, a list of indices, or all.
discriminator_loss_weight: 0.0 # Add discriminator weight.
single_w: true # Optimize a single 512-vector. Otherwise, a distinct vector for each AdaIn operation is optimized.
clip: false # Clip generated images in range [-1, 1].
transformations: # Transformations applied during the optimization.
CenterCrop:
size: 800
Resize:
size: 224
RandomResizedCrop:
size: [224, 224]
scale: [0.9, 1.0]
ratio: [1.0, 1.0]

optimizer: # Optimizer used for optimization. All optimizers from torch.optim are possible.
Adam:
lr: 0.005
weight_decay: 0
betas: [0.1, 0.1]

lr_scheduler: # Option to provide a learning rate scheduler from torch.optim.
MultiStepLR:
milestones: [30, 40]
gamma: 0.1

final_selection:
samples_per_target: 50 # Number of samples to select from the set of optimized latent vectors.
approach: transforms # Currently only transforms is available as an option.
iterations: 100 # Number of iterations random transformations are applied.


wandb: # Options for WandB logging.
enable_logging: false # Activate logging.
wandb_init_args: # WandB init arguments.
project: model_inversion_attacks
save_code: true
name: resnest101_facescrub
44 changes: 35 additions & 9 deletions utils/attack_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,56 @@
import torch.optim as optim
import torchvision.transforms as T
import yaml
from attacks.initial_selection import find_initial_w
from matplotlib.pyplot import fill
import wandb
from models.classifier import Classifier

from attacks.initial_selection import find_initial_w
import wandb
from utils.wandb import load_model


class AttackConfigParser:

def __init__(self, config_file):
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
self._config = config

def create_target_model(self):
model = load_model(self._config['wandb_target_run'])
if 'wandb_target_run' in self._config:
model = load_model(self._config['wandb_target_run'])
elif 'target_model' in self._config:
config = self._config['target_model']
model = Classifier(num_classes=config['num_classes'],
architecture=config['architecture'])
model.load_state_dict(torch.load(config['weights']))
else:
raise RuntimeError('No target model stated in the config file.')

model.eval()
self.model = model
return model

def get_target_dataset(self):
api = wandb.Api(timeout=60)
run = api.run(self._config['wandb_target_run'])
return run.config['Dataset'].strip().lower()
try:
api = wandb.Api(timeout=60)
run = api.run(self._config['wandb_target_run'])
return run.config['Dataset'].strip().lower()
except:
return self._config['dataset']

def create_evaluation_model(self):
evaluation_model = load_model(self._config['wandb_evaluation_run'])
if 'wandb_evaluation_run' in self._config:
evaluation_model = load_model(self._config['wandb_evaluation_run'])
elif 'evaluation_model' in self._config:
config = self._config['evaluation_model']
evaluation_model = Classifier(num_classes=config['num_classes'],
architecture=config['architecture'])
evaluation_model.load_state_dict(torch.load(config['weights']))
else:
raise RuntimeError(
'No evaluation model stated in the config file.')

evaluation_model.eval()
self.evaluation_model = evaluation_model
return evaluation_model
Expand Down Expand Up @@ -214,15 +238,17 @@ def fid_evaluation(self):
def attack_center_crop(self):
if 'transformations' in self._config['attack']:
if 'CenterCrop' in self._config['attack']['transformations']:
return self._config['attack']['transformations']['CenterCrop']['size']
return self._config['attack']['transformations']['CenterCrop'][
'size']
else:
return None

@property
def attack_resize(self):
if 'transformations' in self._config['attack']:
if 'Resize' in self._config['attack']['transformations']:
return self._config['attack']['transformations']['Resize']['size']
return self._config['attack']['transformations']['Resize'][
'size']
else:
return None

Expand Down

0 comments on commit 3e8a920

Please sign in to comment.