Skip to content

Commit

Permalink
fix number of anchor ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
HiKapok committed Aug 11, 2018
1 parent 1ace507 commit 7cc976d
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 18 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ This repository contains codes of the reimplementation of [SSD: Single Shot Mult

There are already some TensorFlow based SSD reimplementation codes on GitHub, the main special features of this repo inlcude:

- state of the art performance(77.4%mAP) when training from VGG-16 pre-trained model (SSD300-VGG16).
- the model is trained using TensorFlow high level API [tf.estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator). Although TensorFlow provides many APIs, the Estimator API is highly recommended to yield scalable, high-performance models.
- state of the art performance(77.9%mAP) when training from VGG-16 pre-trained model (SSD300-VGG16).
- the model is trained using TensorFlow high level API [tf.estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator). Although TensorFlow provides many APIs, the Estimator API is highly recommended to yield scalable, high-performance models.
- all codes were writen by pure TensorFlow ops (no numpy operation) to ensure the performance and portability.
- using ssd augmentation pipeline discribed in the original paper.
- PyTorch-like model definition using high-level [tf.layers](https://www.tensorflow.org/api_docs/python/tf/layers) API for better readability ^-^.
Expand All @@ -23,7 +23,7 @@ There are already some TensorFlow based SSD reimplementation codes on GitHub, th
| |->...
|->VOC2012/
| |->Annotations/
| |->ImageSets/
| |->ImageSets/
| |->...
|->VOC2007TEST/
| |->Annotations/
Expand All @@ -38,13 +38,13 @@ There are already some TensorFlow based SSD reimplementation codes on GitHub, th
- Run the following script to start training:

```sh
python train_ssd.py
python train_ssd.py
```
- Run the following script for evaluation and get mAP:

```sh
python eval_ssd.py
python voc_eval.py
python eval_ssd.py
python voc_eval.py
```
Note: you need first modify some directory in voc_eval.py.
- Run the following script for visualization:
Expand All @@ -65,13 +65,13 @@ All the codes was tested under TensorFlow 1.6, Python 3.5, Ubuntu 16.04 with CUD

## Results (VOC07 Metric)

This implementation(SSD300-VGG16) yield **mAP 77.4%** on PASCAL VOC 2007 test dataset(the original performance described in the paper is 77.2%mAP), the details are as follows:
This implementation(SSD300-VGG16) yield **mAP 77.9%** on PASCAL VOC 2007 test dataset(the original performance described in the paper is 77.2%mAP), the details are as follows:

| sofa | bird | pottedplant | bus | diningtable | cow | bottle | horse | aeroplane | motorbike
|:-------|:-----:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|
| 78.8 | 76.3 | 53.3 | 86.2 | 77.7 | 83.0 | 52.7 | 85.5 | 82.3 | 82.2 |
| 80.1 | 75.9 | 54.1 | 85.4 | 77.8 | 85.2 | 48.7 | 85.8 | 83.6 | 82.3 |
| **sheep** | **train** | **boat** | **bicycle** | **chair** | **cat** | **tvmonitor** | **person** | **car** | **dog** |
| 77.2 | 87.3 | 69.7 | 83.3 | 59.0 | 88.2 | 74.6 | 79.6 | 84.8 | 85.1 |
| 80.3 | 85.9 | 71.6 | 83.7 | 62.6 | 89.0 | 74.8 | 79.3 | 85.6 | 86.6 |

You can download the trained model(VOC07+12 Train) from [GoogleDrive](https://drive.google.com/open?id=1sr3khWzrXZtcS5mmkQDL00y07Rj7erW5) for further research.

Expand All @@ -98,22 +98,22 @@ Here is the training logs and some detection results:
- Why: There maybe some inconsistent between different TensorFlow version.
- How: If you got this error, try change the default value of checkpoint_path to './model/vgg16.ckpt' in [train_ssd.py](https://github.com/HiKapok/SSD.TensorFlow/blob/86e3fa600d8d07122e9366ae664dea8c3c87c622/train_ssd.py#L107). For more information [issue6](https://github.com/HiKapok/SSD.TensorFlow/issues/6) and [issue9](https://github.com/HiKapok/SSD.TensorFlow/issues/9).
- Nan loss during training
- Why: This is caused by the default learning rate which is a little higher for some TensorFlow version.
- Why: This is caused by the default learning rate which is a little higher for some TensorFlow version.
- How: I don't know the details about the different behavior between different versions. There are two workarounds:
- Adding warm-up: change some codes [here](https://github.com/HiKapok/SSD.TensorFlow/blob/d9cf250df81c8af29985c03d76636b2b8b19f089/train_ssd.py#L99) to the following snippet:

```python
tf.app.flags.DEFINE_string(
'decay_boundaries', '2000, 80000, 100000',
'decay_boundaries', '1000, 80000, 100000',
'Learning rate decay boundaries by global_step (comma-separated list).')
tf.app.flags.DEFINE_string(
'lr_decay_factors', '0.1, 1, 0.1, 0.01',
'The values of learning_rate decay factor for each segment between boundaries (comma-separated list).')
```
- Lower the learning rate and run more steps until convergency.
- Lower the learning rate and run more steps until convergency.
- Why this re-implementation perform better than the reported performance
- I don't know


## ##
Apache License, Version 2.0
Apache License, Version 2.0
Binary file modified demo/demo1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified demo/demo2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified demo/demo3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion eval_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def ssd_model_fn(features, labels, mode, params):
anchor_encoder_decoder = anchor_manipulator.AnchorEncoder(positive_threshold=None, ignore_threshold=None, prior_scaling=[0.1, 0.1, 0.2, 0.2])
all_anchor_scales = [(30.,), (60.,), (112.5,), (165.,), (217.5,), (270.,)]
all_extra_scales = [(42.43,), (82.17,), (136.23,), (189.45,), (242.34,), (295.08,)]
all_anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)]
all_anchor_ratios = [(1., 2., .5), (1., 2., 3., .5, 0.3333), (1., 2., 3., .5, 0.3333), (1., 2., 3., .5, 0.3333), (1., 2., .5), (1., 2., .5)]
#all_anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)]

with tf.variable_scope(params['model_scope'], default_name=None, values=[features], reuse=tf.AUTO_REUSE):
backbone = ssd_net.VGG16Backbone(params['data_format'])
Expand Down
3 changes: 2 additions & 1 deletion simple_ssd_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def main(_):

all_anchor_scales = [(30.,), (60.,), (112.5,), (165.,), (217.5,), (270.,)]
all_extra_scales = [(42.43,), (82.17,), (136.23,), (189.45,), (242.34,), (295.08,)]
all_anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)]
all_anchor_ratios = [(1., 2., .5), (1., 2., 3., .5, 0.3333), (1., 2., 3., .5, 0.3333), (1., 2., 3., .5, 0.3333), (1., 2., .5), (1., 2., .5)]
# all_anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)]

with tf.variable_scope(FLAGS.model_scope, default_name=None, values=[features], reuse=tf.AUTO_REUSE):
backbone = ssd_net.VGG16Backbone(FLAGS.data_format)
Expand Down
6 changes: 3 additions & 3 deletions train_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@
'The minimal end learning rate used by a polynomial decay learning rate.')
# for learning rate piecewise_constant decay
tf.app.flags.DEFINE_string(
'decay_boundaries', '80000, 100000',
'decay_boundaries', '1000, 80000, 100000',
'Learning rate decay boundaries by global_step (comma-separated list).')
tf.app.flags.DEFINE_string(
'lr_decay_factors', '1, 0.1, 0.01',
'lr_decay_factors', '0.1, 1, 0.1, 0.01',
'The values of learning_rate decay factor for each segment between boundaries (comma-separated list).')
# checkpoint related configuration
tf.app.flags.DEFINE_string(
Expand Down Expand Up @@ -176,7 +176,7 @@ def input_fn():

all_anchor_scales = [(30.,), (60.,), (112.5,), (165.,), (217.5,), (270.,)]
all_extra_scales = [(42.43,), (82.17,), (136.23,), (189.45,), (242.34,), (295.08,)]
all_anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)]
all_anchor_ratios = [(1., 2., .5), (1., 2., 3., .5, 0.3333), (1., 2., 3., .5, 0.3333), (1., 2., 3., .5, 0.3333), (1., 2., .5), (1., 2., .5)]
all_layer_shapes = [(38, 38), (19, 19), (10, 10), (5, 5), (3, 3), (1, 1)]
all_layer_strides = [8, 16, 32, 64, 100, 300]
total_layers = len(all_layer_shapes)
Expand Down

0 comments on commit 7cc976d

Please sign in to comment.