Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Merge pull request #443 from yuyu2172/cub-prob-map
Browse files Browse the repository at this point in the history
Add prob_map option to CUBDataset
  • Loading branch information
Hakuyume authored Oct 6, 2017
2 parents fb2b5b5 + 446b3d8 commit cc18d47
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 45 deletions.
64 changes: 39 additions & 25 deletions chainercv/datasets/cub/cub_keypoint_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ class CUBKeypointDataset(CUBDatasetBase):
:obj:`img, keypoint, kp_mask`, a tuple of an image, keypoints
and a keypoint mask that indicates visible keypoints in the image.
The data type of the three elements are :obj:`float32, float32, bool`.
If :obj:`return_mask = True`, :obj:`mask` will be returned as well,
making the returned tuple to be of length four. :obj:`mask` is a
:obj:`uint8` image which indicates the region of the image
where a bird locates.
If :obj:`return_bb = True`, a bounding box :obj:`bb` is appended to the
tuple.
If :obj:`return_prob_map = True`, a probability map :obj:`prob_map` is
appended.
keypoints are packed into a two dimensional array of shape
:math:`(K, 2)`, where :math:`K` is the number of keypoints.
Expand All @@ -42,33 +42,34 @@ class CUBKeypointDataset(CUBDatasetBase):
This information can optionally be retrieved from the dataset
by setting :obj:`return_bb = True`.
A mask image of the bird shows how likely the bird is located at a
given pixel. If the value is close to 255, more likely that a bird
locates at that pixel. The shape of this array is :math:`(1, H, W)`,
The probability map of a bird shows how likely the bird is located at each
pixel. If the value is close to 1, it is likely that the bird
locates at that pixel. The shape of this array is :math:`(H, W)`,
where :math:`H` and :math:`W` are height and width of the image
respectively.
This information can optionally be retrieved from the dataset
by setting :obj:`return_mask = True`.
by setting :obj:`return_prob_map = True`.
Args:
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
return_bb (bool): If :obj:`True`, this returns a bounding box
around a bird. The default value is :obj:`False`.
mask_dir (string): Path to the root of the mask data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
return_mask (bool): Decide whether to include mask image of the bird
in a tuple served for a query. The default value is :obj:`False`.
prob_map_dir (string): Path to the root of the probability maps.
If this is :obj:`auto`, this class will automatically download data
for you under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
return_prob_map (bool): Decide whether to include a probability map of
the bird in a tuple served for a query. The default value is
:obj:`False`.
"""

def __init__(self, data_dir='auto', return_bb=False,
mask_dir='auto', return_mask=False):
prob_map_dir='auto', return_prob_map=False):
super(CUBKeypointDataset, self).__init__(
data_dir=data_dir, return_bb=return_bb)
self.return_mask = return_mask
data_dir=data_dir, return_bb=return_bb,
prob_map_dir=prob_map_dir, return_prob_map=return_prob_map)

# load keypoint
parts_loc_file = os.path.join(self.data_dir, 'parts', 'part_locs.txt')
Expand All @@ -91,25 +92,38 @@ def __init__(self, data_dir='auto', return_bb=False,
self.kp_mask_dict[id_].append(kp_mask)

def get_example(self, i):
# this i is transformed to id for the entire dataset
"""Returns the i-th example.
Args:
i (int): The index of the example.
Returns:
tuple of an image, keypoints and a keypoint mask.
The image is in CHW format and its color channel is ordered in
RGB.
If :obj:`return_bb = True`,
a bounding box is appended to the returned value.
If :obj:`return_mask = True`,
a probability map is appended to the returned value.
"""
img = utils.read_image(
os.path.join(self.data_dir, 'images', self.paths[i]),
color=True)
keypoint = np.array(self.kp_dict[i], dtype=np.float32)
kp_mask = np.array(self.kp_mask_dict[i], dtype=np.bool)

if not self.return_mask:
if not self.return_prob_map:
if self.return_bb:
return img, keypoint, kp_mask, self.bbs[i]
else:
return img, keypoint, kp_mask

path, _ = os.path.splitext(self.paths[i])
mask = utils.read_image(
os.path.join(self.mask_dir, path + '.png'),
dtype=np.uint8,
color=False)
prob_map = utils.read_image(self.prob_map_paths[i],
dtype=np.uint8, color=False)
prob_map = prob_map.astype(np.float32) / 255 # [0, 255] -> [0, 1]
prob_map = prob_map[0] # (1, H, W) --> (H, W)
if self.return_bb:
return img, keypoint, kp_mask, self.bbs[i], mask
return img, keypoint, kp_mask, self.bbs[i], prob_map
else:
return img, keypoint, kp_mask, mask
return img, keypoint, kp_mask, prob_map
45 changes: 41 additions & 4 deletions chainercv/datasets/cub/cub_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class CUBLabelDataset(CUBDatasetBase):
:obj:`img, label`, a tuple of an image and class id.
The image is in RGB and CHW format.
The class id are between 0 and 199.
If :obj:`return_bb = True`, a bounding box :obj:`bb` is appended to the
tuple.
If :obj:`return_prob_map = True`, a probability map :obj:`prob_map` is
appended.
A bounding box is a one-dimensional array of shape :math:`(4,)`.
The elements of the bounding box corresponds to
Expand All @@ -24,18 +28,34 @@ class CUBLabelDataset(CUBDatasetBase):
This information can optionally be retrieved from the dataset
by setting :obj:`return_bb = True`.
The probability map of a bird shows how likely the bird is located at each
pixel. If the value is close to 1, it is likely that the bird
locates at that pixel. The shape of this array is :math:`(H, W)`,
where :math:`H` and :math:`W` are height and width of the image
respectively.
This information can optionally be retrieved from the dataset
by setting :obj:`return_prob_map = True`.
Args:
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
return_bb (bool): If :obj:`True`, this returns a bounding box
around a bird. The default value is :obj:`False`.
prob_map_dir (string): Path to the root of the probability maps.
If this is :obj:`auto`, this class will automatically download data
for you under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
return_prob_map (bool): Decide whether to include a probability map of
the bird in a tuple served for a query. The default value is
:obj:`False`.
"""

def __init__(self, data_dir='auto', return_bb=False):
def __init__(self, data_dir='auto', return_bb=False,
prob_map_dir='auto', return_prob_map=False):
super(CUBLabelDataset, self).__init__(
data_dir=data_dir, return_bb=return_bb)
data_dir=data_dir, return_bb=return_bb,
prob_map_dir=prob_map_dir, return_prob_map=return_prob_map)

image_class_labels_file = os.path.join(
self.data_dir, 'image_class_labels.txt')
Expand All @@ -51,13 +71,30 @@ def get_example(self, i):
Returns:
tuple of an image and its label.
The image is in CHW format and its color channel is ordered in
RGB.
If :obj:`return_bb = True`,
a bounding box is appended to the returned value.
If :obj:`return_mask = True`,
a probability map is appended to the returned value.
"""
img = utils.read_image(
os.path.join(self.data_dir, 'images', self.paths[i]),
color=True)
label = self._labels[i]

if not self.return_prob_map:
if self.return_bb:
return img, label, self.bbs[i]
else:
return img, label

prob_map = utils.read_image(self.prob_map_paths[i],
dtype=np.uint8, color=False)
prob_map = prob_map.astype(np.float32) / 255 # [0, 255] -> [0, 1]
prob_map = prob_map[0] # (1, H, W) --> (H, W)
if self.return_bb:
return img, label, self.bbs[i]
return img, label
return img, label, self.bbs[i], prob_map
else:
return img, label, prob_map
24 changes: 15 additions & 9 deletions chainercv/datasets/cub/cub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
root = 'pfnet/chainercv/cub'
url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/'\
'CUB_200_2011.tgz'
mask_url = 'http://www.vision.caltech.edu/visipedia-data/'\
prob_map_url = 'http://www.vision.caltech.edu/visipedia-data/'\
'CUB-200-2011/segmentations.tgz'


Expand All @@ -27,17 +27,17 @@ def get_cub():
return base_path


def get_cub_mask():
def get_cub_prob_map():
data_root = download.get_dataset_directory(root)
base_path = os.path.join(data_root, 'segmentations')
if os.path.exists(base_path):
# skip downloading
return base_path

download_file_path_mask = utils.cached_download(mask_url)
ext_mask = os.path.splitext(mask_url)[1]
prob_map_download_file_path = utils.cached_download(prob_map_url)
prob_map_ext = os.path.splitext(prob_map_url)[1]
utils.extractall(
download_file_path_mask, data_root, ext_mask)
prob_map_download_file_path, data_root, prob_map_ext)
return base_path


Expand All @@ -47,13 +47,14 @@ class CUBDatasetBase(chainer.dataset.DatasetMixin):
"""

def __init__(self, data_dir='auto', mask_dir='auto', return_bb=False):
def __init__(self, data_dir='auto', return_bb=False,
prob_map_dir='auto', return_prob_map=False):
if data_dir == 'auto':
data_dir = get_cub()
if mask_dir == 'auto':
mask_dir = get_cub_mask()
if prob_map_dir == 'auto':
prob_map_dir = get_cub_prob_map()
self.data_dir = data_dir
self.mask_dir = mask_dir
self.prob_map_dir = prob_map_dir

imgs_file = os.path.join(data_dir, 'images.txt')
bbs_file = os.path.join(data_dir, 'bounding_boxes.txt')
Expand All @@ -71,7 +72,12 @@ def __init__(self, data_dir='auto', mask_dir='auto', return_bb=False):
bbs[:] = bbs[:, [1, 0, 3, 2]]
self.bbs = bbs.astype(np.float32)

self.prob_map_paths = [
os.path.join(self.prob_map_dir, os.path.splitext(path)[0] + '.png')
for path in self.paths]

self.return_bb = return_bb
self.return_prob_map = return_prob_map

def __len__(self):
return len(self.paths)
Expand Down
22 changes: 15 additions & 7 deletions tests/datasets_tests/cub_tests/test_cub_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,31 @@
from chainercv.utils import assert_is_label_dataset


@testing.parameterize(
{'return_bb': True},
{'return_bb': False}
)
@testing.parameterize(*testing.product({
'return_bb': [True, False],
'return_prob_map': [True, False]
}))
class TestCUBLabelDataset(unittest.TestCase):

def setUp(self):
self.dataset = CUBLabelDataset(return_bb=self.return_bb)
self.dataset = CUBLabelDataset(
return_bb=self.return_bb, return_prob_map=self.return_prob_map)

@attr.slow
def test_cub_label_dataset(self):
assert_is_label_dataset(
self.dataset, len(cub_label_names), n_example=10)
idx = np.random.choice(np.arange(10))
if self.return_bb:
idx = np.random.choice(np.arange(10))
_, _, bb = self.dataset[idx]
bb = self.dataset[idx][2]
assert_is_bbox(bb[np.newaxis])
if self.return_prob_map:
img = self.dataset[idx][0]
prob_map = self.dataset[idx][-1]
self.assertEqual(prob_map.dtype, np.float32)
self.assertEqual(prob_map.shape, img.shape[1:])
self.assertTrue(np.min(prob_map) >= 0)
self.assertTrue(np.max(prob_map) <= 1)


testing.run_module(__name__, __file__)

0 comments on commit cc18d47

Please sign in to comment.