From bf74d3f910a098e696e652f824a2deeda72a73d2 Mon Sep 17 00:00:00 2001 From: Lupin1998 <1070535169@qq.com> Date: Sun, 10 Sep 2023 22:11:14 +0100 Subject: [PATCH] add agedb --- .../fullysupervised_agedb_122_0.yaml | 46 +++++++++ semilearn/core/criterions/regression.py | 41 +++++++- semilearn/core/utils/build.py | 5 +- semilearn/datasets/__init__.py | 2 +- semilearn/datasets/cv_datasets/__init__.py | 1 + semilearn/datasets/cv_datasets/agedb.py | 96 +++++++++++++++++++ 6 files changed, 188 insertions(+), 3 deletions(-) create mode 100644 config/usb_cv/fullysupervised/fullysupervised_agedb_122_0.yaml create mode 100644 semilearn/datasets/cv_datasets/agedb.py diff --git a/config/usb_cv/fullysupervised/fullysupervised_agedb_122_0.yaml b/config/usb_cv/fullysupervised/fullysupervised_agedb_122_0.yaml new file mode 100644 index 0000000..f9d4842 --- /dev/null +++ b/config/usb_cv/fullysupervised/fullysupervised_agedb_122_0.yaml @@ -0,0 +1,46 @@ +algorithm: fullysupervised +save_dir: ./saved_models/usb_cv/ +save_name: fullysupervised_agedb_122_0 +resume: False +load_path: ./saved_models/usb_cv//fullysupervised_agedb_122_0/latest_model.pth +overwrite: True +use_tensorboard: True +use_wandb: False +epoch: 200 +num_train_iter: 204800 +num_log_iter: 256 +num_eval_iter: 2048 +batch_size: 32 +eval_batch_size: 64 +num_warmup_iter: 5120 +num_labels: 122 +uratio: 1 +ema_m: 0.0 +img_size: 224 +crop_ratio: 0.875 +optim: AdamW +lr: 0.001 +layer_decay: 0.65 +momentum: 0.9 +weight_decay: 0.0005 +amp: False +clip: 0.0 +use_cat: True +net: vit_small_patch16_224 +net_from_name: False +data_dir: ./data/ +dataset: agedb +train_sampler: RandomSampler +num_classes: 1 +loss_type: 'l1_loss' +num_workers: 4 +seed: 0 +world_size: 1 +rank: 0 +multiprocessing_distributed: True +dist_url: tcp://127.0.0.1:10021 +dist_backend: nccl +gpu: None +use_pretrain: True +pretrain_path: https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_small_patch16_224_mlp_im_1k_224.pth +find_unused_parameters: False diff --git a/semilearn/core/criterions/regression.py b/semilearn/core/criterions/regression.py index c9c338c..66eda72 100644 --- a/semilearn/core/criterions/regression.py +++ b/semilearn/core/criterions/regression.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. +import torch import torch.nn as nn from torch.nn import functional as F @@ -27,6 +28,44 @@ def l2_loss(logits, target, reduction='mean', **kwargs): return loss +def focal_l1_loss(logits, target, reduction='mean', activate='sigmoid', beta=0.2, gamma=1.0, **kwargs): + """Calculate Focal L1 loss.""" + target = target.type_as(logits) + loss = F.l1_loss(logits, target, reduction='none') + loss *= (torch.tanh(beta * torch.abs(logits - target))) ** gamma if activate == 'tanh' else \ + (2 * torch.sigmoid(beta * torch.abs(logits - target)) - 1) ** gamma + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + return loss + + +def focal_l2_loss(logits, target, reduction='mean', activate='sigmoid', beta=0.2, gamma=1.0, **kwargs): + """Calculate Focal L2 loss.""" + target = target.type_as(logits) + loss = F.mse_loss(logits, target, reduction='none') + loss *= (torch.tanh(beta * torch.abs(logits - target))) ** gamma if activate == 'tanh' else \ + (2 * torch.sigmoid(beta * torch.abs(logits - target)) - 1) ** gamma + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + return loss + + +def huber_loss(logits, target, reduction='mean', beta=1.0, **kwargs): + """Calculate Smooth L1 loss.""" + l1_loss = F.l1_loss(logits, target, reduction='none') + cond = l1_loss < beta + loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta) + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + return loss + + class RegLoss(nn.Module): """ Wrapper for regression loss @@ -36,7 +75,7 @@ def __init__(self, **kwargs): super(RegLoss, self).__init__() self.mode = mode - self.loss_list = ["l2_loss", "l1_loss"] + self.loss_list = ["l2_loss", "l1_loss", "focal_l1_loss", "focal_l2_loss", "huber_loss"] assert mode in self.loss_list self.criterion = eval(self.mode) diff --git a/semilearn/core/utils/build.py b/semilearn/core/utils/build.py index 3a0e500..3b0e7ab 100644 --- a/semilearn/core/utils/build.py +++ b/semilearn/core/utils/build.py @@ -70,7 +70,7 @@ def get_dataset(args, algorithm, dataset, num_labels, num_classes, data_dir='./d data_dir: data folder include_lb_to_ulb: flag of including labeled data into unlabeled data """ - from semilearn.datasets import get_eurosat, get_medmnist, get_semi_aves, get_cifar, get_svhn, get_stl10, get_imagenet, get_imdb_wiki, get_json_dset, get_pkl_dset + from semilearn.datasets import get_agedb, get_eurosat, get_medmnist, get_semi_aves, get_cifar, get_svhn, get_stl10, get_imagenet, get_imdb_wiki, get_json_dset, get_pkl_dset if dataset == "eurosat": lb_dset, ulb_dset, eval_dset = get_eurosat(args, algorithm, dataset, num_labels, num_classes, data_dir=data_dir, include_lb_to_ulb=include_lb_to_ulb) @@ -96,6 +96,9 @@ def get_dataset(args, algorithm, dataset, num_labels, num_classes, data_dir='./d elif dataset in ["imagenet", "imagenet127"]: lb_dset, ulb_dset, eval_dset = get_imagenet(args, algorithm, dataset, num_labels, num_classes, data_dir=data_dir, include_lb_to_ulb=include_lb_to_ulb) test_dset = None + elif dataset == "agedb": + lb_dset, ulb_dset, eval_dset = get_agedb(args, algorithm, dataset, num_labels, data_dir=data_dir) + test_dset = None elif dataset == "imdb_wiki": lb_dset, ulb_dset, eval_dset = get_imdb_wiki(args, algorithm, dataset, num_labels, data_dir=data_dir) test_dset = None diff --git a/semilearn/datasets/__init__.py b/semilearn/datasets/__init__.py index ea1d26d..25dc740 100644 --- a/semilearn/datasets/__init__.py +++ b/semilearn/datasets/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from semilearn.datasets.utils import split_ssl_data, get_collactor -from semilearn.datasets.cv_datasets import (get_cifar, get_eurosat, get_imagenet, get_imdb_wiki, +from semilearn.datasets.cv_datasets import (get_cifar, get_eurosat, get_imagenet, get_agedb, get_imdb_wiki, get_medmnist, get_semi_aves, get_stl10, get_svhn, get_food101) from semilearn.datasets.nlp_datasets import get_json_dset from semilearn.datasets.audio_datasets import get_pkl_dset diff --git a/semilearn/datasets/cv_datasets/__init__.py b/semilearn/datasets/cv_datasets/__init__.py index e55aa85..8f84c6f 100644 --- a/semilearn/datasets/cv_datasets/__init__.py +++ b/semilearn/datasets/cv_datasets/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from .agedb import get_agedb from .aves import get_semi_aves from .cifar import get_cifar from .eurosat import get_eurosat diff --git a/semilearn/datasets/cv_datasets/agedb.py b/semilearn/datasets/cv_datasets/agedb.py new file mode 100644 index 0000000..b8df7ff --- /dev/null +++ b/semilearn/datasets/cv_datasets/agedb.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import json +import torchvision +import numpy as np +import math +import pandas as pd +from PIL import Image + +from torchvision import transforms +from .datasetbase import BasicDataset +from semilearn.datasets.augmentation import RandAugment, RandomResizedCropAndInterpolation +from semilearn.datasets.utils import split_ssl_data + + +class AgeDBIDataset(BasicDataset): + + def __init__(self, + alg, + data, + targets=None, + num_classes=None, + transform=None, + is_ulb=False, + strong_transform=None, + onehot=False, + *args, + **kwargs): + super(AgeDBIDataset, self).__init__(alg=alg, data=data, targets=targets, num_classes=num_classes, + transform=transform, is_ulb=is_ulb, strong_transform=strong_transform, onehot=onehot, *args, **kwargs) + self.data_dir = kwargs.get('data_dir', '') + + def __sample__(self, idx): + img = Image.open(os.path.join(self.data_dir, self.data[idx])).convert('RGB') + label = np.asarray([self.targets[idx]]).astype('float32') + return img, label + + +def get_agedb(args, alg, name=None, num_labels=1000, num_classes=1, data_dir='./data', include_lb_to_ulb=True): + + data_dir = os.path.join(data_dir, 'agedb') + df = pd.read_csv(os.path.join(data_dir, "agedb.csv")) + df_train, df_val, df_test = df[df['split'] == 'train'], df[df['split'] == 'val'], df[df['split'] == 'test'] + train_labels, train_data = df_train['age'].tolist(), df_train['path'].tolist() + test_labels, test_data = df_test['age'].tolist(), df_test['path'].tolist() + # print(df_train['age'].shape, df_test['age'].shape) # (12208,) (2140,) + + imgnet_mean = (0.485, 0.456, 0.406) + imgnet_std = (0.229, 0.224, 0.225) + img_size = args.img_size + crop_ratio = args.crop_ratio + + transform_weak = transforms.Compose([ + transforms.Resize((img_size, img_size)), + transforms.RandomCrop(img_size, padding=16, padding_mode="reflect"), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(imgnet_mean, imgnet_std), + ]) + + transform_strong = transforms.Compose([ + transforms.Resize(int(math.floor(img_size / crop_ratio))), + RandomResizedCropAndInterpolation((img_size, img_size), scale=(0.2, 1.)), + transforms.RandomHorizontalFlip(), + RandAugment(3, 10), + transforms.ToTensor(), + transforms.Normalize(imgnet_mean, imgnet_std) + ]) + + transform_val = transforms.Compose([ + transforms.Resize((img_size, img_size)), + transforms.ToTensor(), + transforms.Normalize(imgnet_mean, imgnet_std), + ]) + + lb_data, lb_targets, ulb_data, ulb_targets = split_ssl_data(args, train_data, train_labels, num_classes=1, + lb_num_labels=num_labels, + ulb_num_labels=args.ulb_num_labels, + lb_imbalance_ratio=args.lb_imb_ratio, + ulb_imbalance_ratio=args.ulb_imb_ratio, + include_lb_to_ulb=include_lb_to_ulb) + + if alg == 'fullysupervised': + lb_data = train_data + lb_targets = train_labels + + lb_dset = AgeDBIDataset(alg, lb_data, lb_targets, num_classes, + transform_weak, False, None, False, data_dir=data_dir) + ulb_dset = AgeDBIDataset(alg, ulb_data, ulb_targets, num_classes, + transform_weak, True, transform_strong, False, data_dir=data_dir) + eval_dset = AgeDBIDataset(alg, test_data, test_labels, num_classes, + transform_val, False, None, False, data_dir=data_dir) + + return lb_dset, ulb_dset, eval_dset