Skip to content

Commit

Permalink
update readme to describe how to plot metrics with matplotlib + fixed…
Browse files Browse the repository at this point in the history
… bug in coco2pascal.yaml + added the ablation on t_pi
  • Loading branch information
mboudiaf committed Mar 29, 2021
1 parent 7175525 commit 5900a52
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 15 deletions.
23 changes: 17 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation

In this repo, we provide the code for our paper : "Few-Shot Segmentation Without Meta-Learning: A Good Transductive Inference Is All You Need?", available at https://arxiv.org/abs/2012.06166:
Update 03/21: Paper accepted at CVPR 2021 !

Code for the paper : "Few-Shot Segmentation Without Meta-Learning: A Good Transductive Inference Is All You Need?", freely available at https://arxiv.org/abs/2012.06166:

<img src="figures/intro_image.png" width="800" height="400"/>

Expand Down Expand Up @@ -45,7 +47,7 @@ data
| ├── JPEGImages
| └── SegmentationClassAug
```
**Pascal** : The JPEG images can be found in the PascalVOC 2012 toolkit to be downloaded at [PascalVOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) and [SegmentationClassAug](https://mycuhk-my.sharepoint.com/personal/1155122171_link_cuhk_edu_hk/_layouts/15/onedrive.aspx?id=%2Fpersonal%2F1155122171%5Flink%5Fcuhk%5Fedu%5Fhk%2FDocuments%2FTPAMI%20Submission%2FPFENet%5Fcheckpoints%2Fgt%5Fvoc%2Ezip&parent=%2Fpersonal%2F1155122171%5Flink%5Fcuhk%5Fedu%5Fhk%2FDocuments%2FTPAMI%20Submission%2FPFENet%5Fcheckpoints&originalPath=aHR0cHM6Ly9teWN1aGstbXkuc2hhcmVwb2ludC5jb20vOnU6L2cvcGVyc29uYWwvMTE1NTEyMjE3MV9saW5rX2N1aGtfZWR1X2hrL0VSZ3lTb05ZYjdoQnF2REJFOHo0cVZzQmg2dTNLaVdOQllEWUJNZWcxemdFS0E_cnRpbWU9ZTVBTWNtdTgyRWc) (pre-processed ground-truth masks).
**Pascal** : The JPEG images can be found in the PascalVOC 2012 toolkit to be downloaded at [PascalVOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) and [SegmentationClassAug](https://mycuhk-my.sharepoint.com/personal/1155122171_link_cuhk_edu_hk/_layouts/15/onedrive.aspx?id=%2Fpersonal%2F1155122171%5Flink%5Fcuhk%5Fedu%5Fhk%2FDocuments%2FTPAMI%20Submission%2FPFENet%5Fcheckpoints%2Fgt%5Fvoc%2Ezip&parent=%2Fpersonal%2F1155122171%5Flink%5Fcuhk%5Fedu%5Fhk%2FDocuments%2FTPAMI%20Submission%2FPFENet%5Fcheckpoints&originalPath=aHR0cHM6Ly9teWN1aGstbXkuc2hhcmVwb2ludC5jb20vOnU6L2cvcGVyc29uYWwvMTE1NTEyMjE3MV9saW5rX2N1aGtfZWR1X2hrL0VSZ3lTb05ZYjdoQnF2REJFOHo0cVZzQmg2dTNLaVdOQllEWUJNZWcxemdFS0E_cnRpbWU9ZTVBTWNtdTgyRWc) (pre-processed ground-truth masks).

**Coco** : Coco 2014 train, validation images and annotations can be downloaded at [Coco](https://cocodataset.org/#download). Once this is done, you will have to generate the subfolders coco/train and coco/val (ground truth masks). Both folders can be generated by executing the python script data/coco/create_masks.py (note that the script uses the package pycocotools that can be found at https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools):

Expand Down Expand Up @@ -101,7 +103,7 @@ This script will test successively on all folds of the current dataset. Below ar
Results :
|(1 shot/5 shot)| Arch | Fold-0 | Fold-1 | Fold-2 | Fold-3 | Mean |
| --- | --- | --- | --- | --- | --- | --- |
| RePRI | Resnet-50 | 59.8 / 64.6 | 68.3 / 71.4 | 62.1 / 71.1 | 48.5 / 59.3 | 59.7 / 66.6 |
| RePRI | Resnet-50 | 60.2 / 64.5 | 67.0 / 70.8 | 61.7 / 71.7 | 47.5 / 60.3 | 59.1 / 66.8 |
| Oracle-RePRI | Resnet-50 | 72.4 / 75.1 | 78.0 / 80.8 | 77.1 / 81.4 | 65.8 / 74.4 | 73.3 / 77.9 |
| RePRI | Resnet-101 | 59.6 / 66.2 | 68.3 / 71.4 | 62.2 / 67.0 | 47.2 / 57.7 | 59.4 / 65.6 |
| Oracle-RePRI | Resnet-101 | 73.9 / 76.8 | 79.7 / 81.7 | 76.1 / 79.5 | 65.1 / 74.5 | 73.7 / 78.1 |
Expand All @@ -117,7 +119,7 @@ bash scripts/test.sh pascal 5 [0] 50 # 5-shot
Results :
|(1 shot/5 shot)| Arch | Fold-0 | Fold-1 | Fold-2 | Fold-3 | Mean |
| --- | --- | --- | --- | --- | --- | --- |
| RePRI | Resnet-50 | 32.0 / 39.3 | 38.7 / 45.4 | 32.7 / 39.7 | 33.1 / 41.8 | 34.1/41.6 |
| RePRI | Resnet-50 | 31.2 / 38.5 | 38.1 / 46.2 | 33.3 / 40.0 | 33.0 / 43.6 | 34.0/42.1 |
| Oracle-RePRI | Resnet-50 | 49.3 / 51.5 | 51.4 / 60.8 | 38.2 / 54.7 | 41.6 / 55.2 | 45.1 / 55.5 |

Command :
Expand All @@ -136,7 +138,7 @@ Results :

|(1 shot/5 shot)| Arch | Fold-0 | Fold-1 | Fold-2 | Fold-3 | Mean |
| --- | --- | --- | --- | --- | --- | --- |
| RePRI | Resnet-50 | 52.8 / 57.7 | 64.0 / 66.1 | 64.1 / 67.6 | 71.5 / 73.1 | 63.1 / 66.2 |
| RePRI | Resnet-50 | 52.2 / 56.5 | 64.3 / 68.2 | 64.8 / 70.0 | 71.6 / 76.2 | 63.2 / 67.7 |
| Oracle-RePRI | Resnet-50 | 69.6 / 73.5 | 71.7 / 74.9 | 77.6 / 82.2 | 86.2 / 88.1 | 76.2 / 79.7 |


Expand All @@ -150,6 +152,9 @@ bash scripts/test.sh coco2pascal 5 [0] 50 # 5-shot

## Monitoring metrics

This code offers two options to visualize/plot metrics during training

### Live monitoring with visdom
For both training and testing, you can monitor metrics using visdom_logger (https://github.com/luizgh/visdom_logger). To install this package, simply clone the repo and install it with pip:
```
git clone https://github.com/luizgh/visdom_logger.git
Expand All @@ -159,9 +164,15 @@ For both training and testing, you can monitor metrics using visdom_logger (http
```
python -m visdom.server -port 8098
```

Finally, add the line visdom_port 8098 in the options in scripts/train.sh or scripts/test.sh, and metrics will be displayed at this port. You can monitor them through your navigator.

### Good old fashioned matplotlib

Alternatively, this code also saves important metrics (training loss, accuracy and validation loss and accuracy) as training progresses in the form of numpy files (.npy). Then, you can plot these metrics with:
```python
bash scripts/plot_training.sh model_ckpt
```

## Contact

For further questions or details, please post an issue or directly reach out to Malik Boudiaf (malik.boudiaf.1@etsmtl.net)
Expand Down
1 change: 1 addition & 0 deletions config_files/coco2pascal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ MODEL:
bottleneck_dim: 512

EVALUATION:
episodic_val: True
shot: 1
random_shot: False
episodic: True
Expand Down
39 changes: 39 additions & 0 deletions scripts/ablation_t_pi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
DATA=$1
SHOT=$2
GPU=$3
LAYERS=$4

SPLITS="0 1 2 3"
PIs="1 6 9 12 15 20 25 30 30 35 40 45 50"
if [ $SHOT == 1 ]
then
bsz_val="500"
elif [ $SHOT == 5 ]
then
bsz_val="100"
elif [ $SHOT == 10 ]
then
bsz_val="50"
fi

for PI in $PIs
do
for SPLIT in $SPLITS
do
dirname="results/test/arch=resnet-${LAYERS}/data=${DATA}/shot=${SHOT}/split=${SPLIT}"
mkdir -p -- "$dirname"
python3 -m src.test --config config_files/${DATA}.yaml \
--opts train_split ${SPLIT} \
batch_size_val ${bsz_val} \
shot ${SHOT} \
layers ${LAYERS} \
FB_param_update "[${PI}]" \
temperature 20.0 \
adapt_iter 50 \
cls_lr 0.025 \
gpus ${GPU} \
test_num 1000 \
n_runs 3 \
| tee ${dirname}/log_${PI}.txt
done
done
3 changes: 3 additions & 0 deletions scripts/plot_ablation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
folder="results/"

python3 -m src.plot --ablation_plot --folder $folder --figsize 20 9
2 changes: 1 addition & 1 deletion scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fi

for SPLIT in $SPLITS
do
dirname="results/test/resnet-${LAYERS}/${DATA}/shot_${SHOT}/split_${SPLIT}"
dirname="results/test/arch=resnet-${LAYERS}/data=${DATA}/shot=shot_${SHOT}/split=split_${SPLIT}"
mkdir -p -- "$dirname"
python3 -m src.test --config config_files/${DATA}.yaml \
--opts train_split ${SPLIT} \
Expand Down
76 changes: 68 additions & 8 deletions src/plot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from pathlib import Path
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from itertools import cycle
from collections import defaultdict
import argparse
matplotlib.use("Agg")
plt.style.use('ggplot')

colors = ["c", "r", "g", "b", "m", 'y', 'k', 'chartreuse', 'coral', 'gold', 'lavender',
colors = ["g", "m", 'y', 'k', 'chartreuse', 'coral', 'gold', 'lavender',
'silver', 'tan', 'teal', 'wheat', 'orchid', 'orange', 'tomato']

styles = ['--', '-.', ':', '-']
Expand All @@ -16,14 +16,15 @@ def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Plot training metrics')
parser.add_argument('--folder', type=str, help='Folder to search')
parser.add_argument('--fontsize', type=int, default=11)
parser.add_argument('--figsize', type=list, default=[10, 10])
parser.add_argument('--figsize', type=int, nargs="+", default=[10, 10])
parser.add_argument('--ablation_plot', action='store_true')

args = parser.parse_args()
return args


def make_plot(args: argparse.Namespace,
filename: str) -> None:
def make_training_plot(args: argparse.Namespace,
filename: str) -> None:
plt.rc('font', size=args.fontsize)

fig = plt.Figure(args.figsize)
Expand All @@ -50,7 +51,66 @@ def make_plot(args: argparse.Namespace,
fig.savefig(p.joinpath('{}.png'.format(filename.split('.')[0])))


def nested_dd():
return defaultdict(nested_dd)


def make_ablation_plot(args: argparse.Namespace):
p = Path(args.folder)
all_files = p.glob(f'**/*.txt')

sota = {'pascal': {1: 0.608, 5: 0.620},
'coco': {1: 0.358, 5: 0.390}}
res_dic = nested_dd()
for file in all_files:
shot = eval([part.split('=')[1] for part in file.parts if 'shot' in part][0])
data = [part.split('=')[1] for part in file.parts if 'data' in part][0]
split = eval([part.split('=')[1] for part in file.parts if 'split' in part][0])
tpi = eval(file.stem.split('_')[1])

res = process_logfile(file)
res_dic[data][shot][tpi][split] = np.mean(res)

plt.rc('font', size=30)
plt.rc('font')
fig, axes = plt.subplots(1, 2, figsize=args.figsize)
ax = fig.gca()
for i, data in enumerate(['pascal', 'coco']):
ax = axes[i]
for style, color, shot in zip(cycle(styles), cycle(colors), res_dic[data]):
tpis = np.array(list(res_dic[data][shot].keys()))
mean = np.array([np.mean([res_dic[data][shot][tpi][split] for split in res_dic[data][shot][tpi]]) \
for tpi in res_dic[data][shot]])

# x = np.linspace(0, n_epochs - 1, (n_epochs * iter_per_epoch))
# y = np.reshape(array, (n_epochs * iter_per_epoch))
sort_index = np.argsort(tpis)
ax.plot(tpis[sort_index], mean[sort_index],
label=f'RePRI ({shot} shot)', color=color, linestyle=style, linewidth=3)
ax.plot(tpis[sort_index], sota[data][shot] * np.ones(len(tpis)),
label=f'SOTA ({shot} shot)', linewidth=3.5, color=color, linestyle=':')
if data == 'pascal':
ax.set_ylabel('Average mIoU (over 4 folds)', size=30)
ax.set_xlabel('$t_\pi$', size=40)
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, bbox_to_anchor=(0.78, 1.17), ncol=2, shadow=True)
fig.subplots_adjust(wspace=0.13)
fig.tight_layout()
fig.savefig(p.joinpath(f"ablation.pdf"), bbox_inches='tight')


def process_logfile(path: str):
with open(path, 'r') as f:
res_lines = [line for line in f.readlines() if line.split('-')[0] == 'mIoU']
res_lines = [eval(line.split(' ')[-1][:-2]) for line in res_lines]
res = np.array(res_lines)
return res


if __name__ == "__main__":
args = parse_args()
for filename in ['val_mIou.npy', 'val_loss.npy', 'train_mIou.npy', 'train_loss.npy']:
make_plot(args=args, filename=filename)
if args.ablation_plot:
make_ablation_plot(args)
else:
for filename in ['val_mIou.npy', 'val_loss.npy', 'train_mIou.npy', 'train_loss.npy']:
make_training_plot(args=args, filename=filename)

0 comments on commit 5900a52

Please sign in to comment.