Skip to content

Commit

Permalink
add agedb
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Sep 10, 2023
1 parent df4f37f commit bf74d3f
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 3 deletions.
46 changes: 46 additions & 0 deletions config/usb_cv/fullysupervised/fullysupervised_agedb_122_0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
algorithm: fullysupervised
save_dir: ./saved_models/usb_cv/
save_name: fullysupervised_agedb_122_0
resume: False
load_path: ./saved_models/usb_cv//fullysupervised_agedb_122_0/latest_model.pth
overwrite: True
use_tensorboard: True
use_wandb: False
epoch: 200
num_train_iter: 204800
num_log_iter: 256
num_eval_iter: 2048
batch_size: 32
eval_batch_size: 64
num_warmup_iter: 5120
num_labels: 122
uratio: 1
ema_m: 0.0
img_size: 224
crop_ratio: 0.875
optim: AdamW
lr: 0.001
layer_decay: 0.65
momentum: 0.9
weight_decay: 0.0005
amp: False
clip: 0.0
use_cat: True
net: vit_small_patch16_224
net_from_name: False
data_dir: ./data/
dataset: agedb
train_sampler: RandomSampler
num_classes: 1
loss_type: 'l1_loss'
num_workers: 4
seed: 0
world_size: 1
rank: 0
multiprocessing_distributed: True
dist_url: tcp://127.0.0.1:10021
dist_backend: nccl
gpu: None
use_pretrain: True
pretrain_path: https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_small_patch16_224_mlp_im_1k_224.pth
find_unused_parameters: False
41 changes: 40 additions & 1 deletion semilearn/core/criterions/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.


import torch
import torch.nn as nn

from torch.nn import functional as F
Expand All @@ -27,6 +28,44 @@ def l2_loss(logits, target, reduction='mean', **kwargs):
return loss


def focal_l1_loss(logits, target, reduction='mean', activate='sigmoid', beta=0.2, gamma=1.0, **kwargs):
"""Calculate Focal L1 loss."""
target = target.type_as(logits)
loss = F.l1_loss(logits, target, reduction='none')
loss *= (torch.tanh(beta * torch.abs(logits - target))) ** gamma if activate == 'tanh' else \
(2 * torch.sigmoid(beta * torch.abs(logits - target)) - 1) ** gamma
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
return loss


def focal_l2_loss(logits, target, reduction='mean', activate='sigmoid', beta=0.2, gamma=1.0, **kwargs):
"""Calculate Focal L2 loss."""
target = target.type_as(logits)
loss = F.mse_loss(logits, target, reduction='none')
loss *= (torch.tanh(beta * torch.abs(logits - target))) ** gamma if activate == 'tanh' else \
(2 * torch.sigmoid(beta * torch.abs(logits - target)) - 1) ** gamma
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
return loss


def huber_loss(logits, target, reduction='mean', beta=1.0, **kwargs):
"""Calculate Smooth L1 loss."""
l1_loss = F.l1_loss(logits, target, reduction='none')
cond = l1_loss < beta
loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta)
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
return loss


class RegLoss(nn.Module):
"""
Wrapper for regression loss
Expand All @@ -36,7 +75,7 @@ def __init__(self,
**kwargs):
super(RegLoss, self).__init__()
self.mode = mode
self.loss_list = ["l2_loss", "l1_loss"]
self.loss_list = ["l2_loss", "l1_loss", "focal_l1_loss", "focal_l2_loss", "huber_loss"]
assert mode in self.loss_list
self.criterion = eval(self.mode)

Expand Down
5 changes: 4 additions & 1 deletion semilearn/core/utils/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_dataset(args, algorithm, dataset, num_labels, num_classes, data_dir='./d
data_dir: data folder
include_lb_to_ulb: flag of including labeled data into unlabeled data
"""
from semilearn.datasets import get_eurosat, get_medmnist, get_semi_aves, get_cifar, get_svhn, get_stl10, get_imagenet, get_imdb_wiki, get_json_dset, get_pkl_dset
from semilearn.datasets import get_agedb, get_eurosat, get_medmnist, get_semi_aves, get_cifar, get_svhn, get_stl10, get_imagenet, get_imdb_wiki, get_json_dset, get_pkl_dset

if dataset == "eurosat":
lb_dset, ulb_dset, eval_dset = get_eurosat(args, algorithm, dataset, num_labels, num_classes, data_dir=data_dir, include_lb_to_ulb=include_lb_to_ulb)
Expand All @@ -96,6 +96,9 @@ def get_dataset(args, algorithm, dataset, num_labels, num_classes, data_dir='./d
elif dataset in ["imagenet", "imagenet127"]:
lb_dset, ulb_dset, eval_dset = get_imagenet(args, algorithm, dataset, num_labels, num_classes, data_dir=data_dir, include_lb_to_ulb=include_lb_to_ulb)
test_dset = None
elif dataset == "agedb":
lb_dset, ulb_dset, eval_dset = get_agedb(args, algorithm, dataset, num_labels, data_dir=data_dir)
test_dset = None
elif dataset == "imdb_wiki":
lb_dset, ulb_dset, eval_dset = get_imdb_wiki(args, algorithm, dataset, num_labels, data_dir=data_dir)
test_dset = None
Expand Down
2 changes: 1 addition & 1 deletion semilearn/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

from semilearn.datasets.utils import split_ssl_data, get_collactor
from semilearn.datasets.cv_datasets import (get_cifar, get_eurosat, get_imagenet, get_imdb_wiki,
from semilearn.datasets.cv_datasets import (get_cifar, get_eurosat, get_imagenet, get_agedb, get_imdb_wiki,
get_medmnist, get_semi_aves, get_stl10, get_svhn, get_food101)
from semilearn.datasets.nlp_datasets import get_json_dset
from semilearn.datasets.audio_datasets import get_pkl_dset
Expand Down
1 change: 1 addition & 0 deletions semilearn/datasets/cv_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .agedb import get_agedb
from .aves import get_semi_aves
from .cifar import get_cifar
from .eurosat import get_eurosat
Expand Down
96 changes: 96 additions & 0 deletions semilearn/datasets/cv_datasets/agedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import json
import torchvision
import numpy as np
import math
import pandas as pd
from PIL import Image

from torchvision import transforms
from .datasetbase import BasicDataset
from semilearn.datasets.augmentation import RandAugment, RandomResizedCropAndInterpolation
from semilearn.datasets.utils import split_ssl_data


class AgeDBIDataset(BasicDataset):

def __init__(self,
alg,
data,
targets=None,
num_classes=None,
transform=None,
is_ulb=False,
strong_transform=None,
onehot=False,
*args,
**kwargs):
super(AgeDBIDataset, self).__init__(alg=alg, data=data, targets=targets, num_classes=num_classes,
transform=transform, is_ulb=is_ulb, strong_transform=strong_transform, onehot=onehot, *args, **kwargs)
self.data_dir = kwargs.get('data_dir', '')

def __sample__(self, idx):
img = Image.open(os.path.join(self.data_dir, self.data[idx])).convert('RGB')
label = np.asarray([self.targets[idx]]).astype('float32')
return img, label


def get_agedb(args, alg, name=None, num_labels=1000, num_classes=1, data_dir='./data', include_lb_to_ulb=True):

data_dir = os.path.join(data_dir, 'agedb')
df = pd.read_csv(os.path.join(data_dir, "agedb.csv"))
df_train, df_val, df_test = df[df['split'] == 'train'], df[df['split'] == 'val'], df[df['split'] == 'test']
train_labels, train_data = df_train['age'].tolist(), df_train['path'].tolist()
test_labels, test_data = df_test['age'].tolist(), df_test['path'].tolist()
# print(df_train['age'].shape, df_test['age'].shape) # (12208,) (2140,)

imgnet_mean = (0.485, 0.456, 0.406)
imgnet_std = (0.229, 0.224, 0.225)
img_size = args.img_size
crop_ratio = args.crop_ratio

transform_weak = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.RandomCrop(img_size, padding=16, padding_mode="reflect"),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(imgnet_mean, imgnet_std),
])

transform_strong = transforms.Compose([
transforms.Resize(int(math.floor(img_size / crop_ratio))),
RandomResizedCropAndInterpolation((img_size, img_size), scale=(0.2, 1.)),
transforms.RandomHorizontalFlip(),
RandAugment(3, 10),
transforms.ToTensor(),
transforms.Normalize(imgnet_mean, imgnet_std)
])

transform_val = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(imgnet_mean, imgnet_std),
])

lb_data, lb_targets, ulb_data, ulb_targets = split_ssl_data(args, train_data, train_labels, num_classes=1,
lb_num_labels=num_labels,
ulb_num_labels=args.ulb_num_labels,
lb_imbalance_ratio=args.lb_imb_ratio,
ulb_imbalance_ratio=args.ulb_imb_ratio,
include_lb_to_ulb=include_lb_to_ulb)

if alg == 'fullysupervised':
lb_data = train_data
lb_targets = train_labels

lb_dset = AgeDBIDataset(alg, lb_data, lb_targets, num_classes,
transform_weak, False, None, False, data_dir=data_dir)
ulb_dset = AgeDBIDataset(alg, ulb_data, ulb_targets, num_classes,
transform_weak, True, transform_strong, False, data_dir=data_dir)
eval_dset = AgeDBIDataset(alg, test_data, test_labels, num_classes,
transform_val, False, None, False, data_dir=data_dir)

return lb_dset, ulb_dset, eval_dset

0 comments on commit bf74d3f

Please sign in to comment.