Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Merge pull request #399 from yuyu2172/cub-label-test
Browse files Browse the repository at this point in the history
Add return_bb option to CUBDatasets and add a test
  • Loading branch information
Hakuyume authored Oct 5, 2017
2 parents 2bd9ca5 + ca519f5 commit 82f8ef8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 36 deletions.
38 changes: 22 additions & 16 deletions chainercv/datasets/cub/cub_keypoint_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,39 @@ class CUBKeypointDataset(CUBDatasetBase):
A keypoint mask array indicates whether a keypoint is visible in the
image or not. This is a boolean array of shape :math:`(K,)`.
A bounding box is a one-dimensional array of shape :math:`(4,)`.
The elements of the bounding box corresponds to
:obj:`(y_min, x_min, y_max, x_max)`, where the four attributes are
coordinates of the top left and the bottom right vertices.
This information can optionally be retrieved from the dataset
by setting :obj:`return_bb = True`.
A mask image of the bird shows how likely the bird is located at a
given pixel. If the value is close to 255, more likely that a bird
locates at that pixel. The shape of this array is :math:`(1, H, W)`,
where :math:`H` and :math:`W` are height and width of the image
respectively.
This information can optionally be retrieved from the dataset
by setting :obj:`return_mask = True`.
Args:
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
crop_bbox (bool): If true, this class returns an image cropped
by the bounding box of the bird inside it.
return_bb (bool): If :obj:`True`, this returns a bounding box
around a bird. The default value is :obj:`False`.
mask_dir (string): Path to the root of the mask data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
return_mask (bool): Decide whether to include mask image of the bird
in a tuple served for a query.
in a tuple served for a query. The default value is :obj:`False`.
"""

def __init__(self, data_dir='auto', crop_bbox=True,
def __init__(self, data_dir='auto', return_bb=False,
mask_dir='auto', return_mask=False):
super(CUBKeypointDataset, self).__init__(
data_dir=data_dir, crop_bbox=crop_bbox)
data_dir=data_dir, return_bb=return_bb)
self.return_mask = return_mask

# load keypoint
Expand Down Expand Up @@ -89,21 +98,18 @@ def get_example(self, i):
keypoint = np.array(self.kp_dict[i], dtype=np.float32)
kp_mask = np.array(self.kp_mask_dict[i], dtype=np.bool)

if self.crop_bbox:
# (y_min, x_min, y_max, x_max)
bbox = self.bboxes[i].astype(np.int32)
img = img[:, bbox[0]: bbox[2], bbox[1]: bbox[3]]
keypoint[:, :2] = keypoint[:, :2] - bbox[:2]

if not self.return_mask:
return img, keypoint, kp_mask
if self.return_bb:
return img, keypoint, kp_mask, self.bbs[i]
else:
return img, keypoint, kp_mask

path, _ = os.path.splitext(self.paths[i])
mask = utils.read_image(
os.path.join(self.mask_dir, path + '.png'),
dtype=np.uint8,
color=False)
if self.crop_bbox:
mask = mask[:, bbox[0]: bbox[2], bbox[1]: bbox[3]]

return img, keypoint, kp_mask, mask
if self.return_bb:
return img, keypoint, kp_mask, self.bbs[i], mask
else:
return img, keypoint, kp_mask, mask
28 changes: 16 additions & 12 deletions chainercv/datasets/cub/cub_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,31 @@ class CUBLabelDataset(CUBDatasetBase):
The image is in RGB and CHW format.
The class id are between 0 and 199.
There are 200 labels of birds in total.
A bounding box is a one-dimensional array of shape :math:`(4,)`.
The elements of the bounding box corresponds to
:obj:`(y_min, x_min, y_max, x_max)`, where the four attributes are
coordinates of the top left and the bottom right vertices.
This information can optionally be retrieved from the dataset
by setting :obj:`return_bb = True`.
Args:
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
crop_bbox (bool): If true, this class returns an image cropped
by the bounding box of the bird inside it.
return_bb (bool): If :obj:`True`, this returns a bounding box
around a bird. The default value is :obj:`False`.
"""

def __init__(self, data_dir='auto', crop_bbox=True):
def __init__(self, data_dir='auto', return_bb=False):
super(CUBLabelDataset, self).__init__(
data_dir=data_dir, crop_bbox=crop_bbox)
data_dir=data_dir, return_bb=return_bb)

image_class_labels_file = os.path.join(
self.data_dir, 'image_class_labels.txt')
self._data_labels = [int(d_label.split()[1]) - 1 for
d_label in open(image_class_labels_file)]
labels = [int(d_label.split()[1]) - 1 for
d_label in open(image_class_labels_file)]
self._labels = np.array(labels, dtype=np.int32)

def get_example(self, i):
"""Returns the i-th example.
Expand All @@ -50,10 +56,8 @@ def get_example(self, i):
img = utils.read_image(
os.path.join(self.data_dir, 'images', self.paths[i]),
color=True)
label = self._labels[i]

if self.crop_bbox:
# (y_min, x_min, y_max, x_max)
bbox = self.bboxes[i].astype(np.int32)
img = img[:, bbox[0]: bbox[2], bbox[1]: bbox[3]]
label = self._data_labels[i]
if self.return_bb:
return img, label, self.bbs[i]
return img, label
16 changes: 8 additions & 8 deletions chainercv/datasets/cub/cub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class CUBDatasetBase(chainer.dataset.DatasetMixin):
"""

def __init__(self, data_dir='auto', mask_dir='auto', crop_bbox=True):
def __init__(self, data_dir='auto', mask_dir='auto', return_bb=False):
if data_dir == 'auto':
data_dir = get_cub()
if mask_dir == 'auto':
Expand All @@ -56,22 +56,22 @@ def __init__(self, data_dir='auto', mask_dir='auto', crop_bbox=True):
self.mask_dir = mask_dir

imgs_file = os.path.join(data_dir, 'images.txt')
bboxes_file = os.path.join(data_dir, 'bounding_boxes.txt')
bbs_file = os.path.join(data_dir, 'bounding_boxes.txt')

self.paths = [
line.strip().split()[1] for line in open(imgs_file)]

# (x_min, y_min, width, height)
bboxes = np.array([
bbs = np.array([
tuple(map(float, line.split()[1:5]))
for line in open(bboxes_file)])
for line in open(bbs_file)])
# (x_min, y_min, width, height) -> (x_min, y_min, x_max, y_max)
bboxes[:, 2:] += bboxes[:, :2]
bbs[:, 2:] += bbs[:, :2]
# (x_min, y_min, width, height) -> (y_min, x_min, y_max, x_max)
bboxes[:] = bboxes[:, [1, 0, 3, 2]]
self.bboxes = bboxes.astype(np.float32)
bbs[:] = bbs[:, [1, 0, 3, 2]]
self.bbs = bbs.astype(np.float32)

self.crop_bbox = crop_bbox
self.return_bb = return_bb

def __len__(self):
return len(self.paths)
Expand Down
33 changes: 33 additions & 0 deletions tests/datasets_tests/cub_tests/test_cub_label_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest

import numpy as np

from chainer import testing
from chainer.testing import attr

from chainercv.datasets import cub_label_names
from chainercv.datasets import CUBLabelDataset
from chainercv.utils import assert_is_bbox
from chainercv.utils import assert_is_label_dataset


@testing.parameterize(
{'return_bb': True},
{'return_bb': False}
)
class TestCUBLabelDataset(unittest.TestCase):

def setUp(self):
self.dataset = CUBLabelDataset(return_bb=self.return_bb)

@attr.slow
def test_cub_label_dataset(self):
assert_is_label_dataset(
self.dataset, len(cub_label_names), n_example=10)
if self.return_bb:
idx = np.random.choice(np.arange(10))
_, _, bb = self.dataset[idx]
assert_is_bbox(bb[np.newaxis])


testing.run_module(__name__, __file__)

0 comments on commit 82f8ef8

Please sign in to comment.