-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
57 lines (45 loc) · 1.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import print_function
from matplotlib import pyplot as plt
import sys
import torch
import numpy as np
from args import args
from utils.trainer import Trainer
from utils.merge_iterator import Merge_Iterator
from utils.model_utils import model_selector
from utils.model_utils import set_seed, _init_weight
def get_dataloaders(args):
print(f'Data config: \n\t Dataset: {args.dataset}, num_clients: {args.num_clients}, disjoint classes: {args.disjoint_classes}, imbalanced:{args.imbalanced}')
if args.dataset=='MNIST':
from datasets.mnist import get_datasets
elif args.dataset=='Fashion_MNIST':
from datasets.fashionmnist import get_datasets
elif args.dataset=='CIFAR10':
from datasets.cifar10 import get_datasets
elif args.dataset=='CIFAR100':
from datasets.cifar100 import get_datasets
else:
sys.exit()
return get_datasets(args)
def main():
# Training settings
args.device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
print(f'Using Device {args.device}')
print(args)
set_seed(args.seed)
weight_dir = f'{args.base_dir}iwa_weights/'
if args.baseline:
train_loader1, test_dataset = get_dataloaders(args)
model = model_selector(args)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f'Model parameters: {params}')
save_path = f'{args.model}{args.dataset}_baseline.pt'
trainer = Trainer(args, [train_loader1, test_dataset], model, args.device, save_path,)
trainer.fit(log_output=True)
else:
train_loader_list, test_loader,train_weight_list = get_dataloaders(args)
merge_iterator = Merge_Iterator(args, train_loader_list, test_loader,train_weight_list, args.device, weight_dir)
merge_iterator.run()
if __name__ == '__main__':
main()