From c490f57e1263a646c8a2c020f64b7c94e0259034 Mon Sep 17 00:00:00 2001 From: Yusuke Niitani Date: Fri, 15 Mar 2019 19:57:53 +0900 Subject: [PATCH] merge to fpn --- chainercv/links/__init__.py | 2 + chainercv/links/model/fpn/__init__.py | 2 + chainercv/links/model/fpn/faster_rcnn.py | 46 +++- .../links/model/fpn/faster_rcnn_fpn_resnet.py | 76 +++++- .../model/{mask_rcnn => fpn}/keypoint_head.py | 4 +- chainercv/links/model/fpn/keypoint_utils.py | 52 ++++ chainercv/links/model/fpn/mask_utils.py | 47 ---- chainercv/links/model/mask_rcnn/__init__.py | 11 - chainercv/links/model/mask_rcnn/mask_rcnn.py | 253 ------------------ .../model/mask_rcnn/mask_rcnn_fpn_resnet.py | 137 ---------- examples/fpn/demo.py | 30 ++- examples/mask_rcnn/demo.py | 75 ------ 12 files changed, 201 insertions(+), 534 deletions(-) rename chainercv/links/model/{mask_rcnn => fpn}/keypoint_head.py (98%) create mode 100644 chainercv/links/model/fpn/keypoint_utils.py delete mode 100644 chainercv/links/model/mask_rcnn/__init__.py delete mode 100644 chainercv/links/model/mask_rcnn/mask_rcnn.py delete mode 100644 chainercv/links/model/mask_rcnn/mask_rcnn_fpn_resnet.py delete mode 100644 examples/mask_rcnn/demo.py diff --git a/chainercv/links/__init__.py b/chainercv/links/__init__.py index 72b4d32106..aa91f30b77 100644 --- a/chainercv/links/__init__.py +++ b/chainercv/links/__init__.py @@ -11,6 +11,8 @@ from chainercv.links.model.faster_rcnn.faster_rcnn_vgg import FasterRCNNVGG16 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet101 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet50 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet101 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet50 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet101 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet50 # NOQA from chainercv.links.model.resnet import ResNet101 # NOQA diff --git a/chainercv/links/model/fpn/__init__.py b/chainercv/links/model/fpn/__init__.py index 7f2f16d62e..d55ac5471c 100644 --- a/chainercv/links/model/fpn/__init__.py +++ b/chainercv/links/model/fpn/__init__.py @@ -1,6 +1,8 @@ from chainercv.links.model.fpn.faster_rcnn import FasterRCNN # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet101 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet50 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet101 # NOQA +from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet50 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet101 # NOQA from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet50 # NOQA from chainercv.links.model.fpn.fpn import FPN # NOQA diff --git a/chainercv/links/model/fpn/faster_rcnn.py b/chainercv/links/model/fpn/faster_rcnn.py index 68b4506233..56c11ba7fb 100644 --- a/chainercv/links/model/fpn/faster_rcnn.py +++ b/chainercv/links/model/fpn/faster_rcnn.py @@ -50,10 +50,11 @@ class FasterRCNN(chainer.Chain): """ stride = 32 - _accepted_return_values = ('rois', 'bboxes', 'labels', 'scores', 'masks') + _accepted_return_values = ('rois', 'bboxes', 'labels', 'scores', + 'masks', 'points', 'point_scores') def __init__(self, extractor, rpn, bbox_head, - mask_head, return_values, + mask_head, keypoint_head, return_values, min_size=800, max_size=1333): for value_name in return_values: if value_name not in self._accepted_return_values: @@ -64,8 +65,10 @@ def __init__(self, extractor, rpn, bbox_head, self._store_rpn_outputs = 'rois' in self._return_values self._run_bbox = any([key in self._return_values - for key in ['bboxes', 'labels', 'scores', 'masks']]) + for key in ['bboxes', 'labels', 'scores', + 'masks', 'points', 'point_scores']]) self._run_mask = 'masks' in self._return_values + self._run_keypoint = 'points' in self._return_values super(FasterRCNN, self).__init__() with self.init_scope(): @@ -75,6 +78,8 @@ def __init__(self, extractor, rpn, bbox_head, self.bbox_head = bbox_head if self._run_mask: self.mask_head = mask_head + if self._run_keypoint: + self.keypoint_head = keypoint_head self.min_size = min_size self.max_size = max_size @@ -174,10 +179,9 @@ def predict(self, imgs): scores_cpu = [cuda.to_cpu(score) for score in scores] output.update({'bboxes': bboxes_cpu, 'labels': labels_cpu, 'scores': scores_cpu}) - - if self._run_mask: rescaled_bboxes = [bbox * scale - for scale, bbox in zip(scales, bboxes)] + for scale, bbox in zip(scales, bboxes)] + if self._run_mask: # Change bboxes to RoI and RoI indices format mask_rois_before_reordering, mask_roi_indices_before_reordering =\ _list_to_flat(rescaled_bboxes) @@ -200,6 +204,36 @@ def predict(self, imgs): # Currently MaskHead only supports numpy inputs masks_cpu = self.mask_head.decode(segms, bboxes_cpu, labels_cpu, sizes) output.update({'masks': masks_cpu}) + + if self._run_keypoint: + (point_rois_before_reordering, + point_roi_indices_before_reordering) = _list_to_flat( + rescaled_bboxes) + point_rois, point_roi_indices, order =\ + self.keypoint_head.distribute( + point_rois_before_reordering, + point_roi_indices_before_reordering) + with chainer.using_config( + 'train', False), chainer.no_backprop_mode(): + point_maps = self.keypoint_head( + hs, point_rois, point_roi_indices).data + point_maps = point_maps[order] + point_maps = _flat_to_list( + point_maps, point_roi_indices_before_reordering, len(imgs)) + point_maps = [point_map if point_map is not None else + self.xp.zeros( + (0, self.keypoint_head.n_point, + self.keypoint_head.point_map_size, + self.keypoint_head.point_map_size), + dtype=np.float32) + for point_map in point_maps] + point_maps = [ + chainer.backends.cuda.to_cpu(point_map) + for point_map in point_maps] + points_cpu, point_scores_cpu = self.keypoint_head.decode( + point_maps, bboxes_cpu) + output.update( + {'points': points_cpu, 'point_scores': point_scores_cpu}) return tuple([output[key] for key in self._return_values]) def prepare(self, imgs): diff --git a/chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py b/chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py index debadb10ea..99e87c342b 100644 --- a/chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py +++ b/chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py @@ -7,6 +7,7 @@ from chainercv.links.model.fpn.faster_rcnn import FasterRCNN from chainercv.links.model.fpn.fpn import FPN from chainercv.links.model.fpn.bbox_head import BboxHead +from chainercv.links.model.fpn.keypoint_head import KeypointHead from chainercv.links.model.fpn.mask_head import MaskHead from chainercv.links.model.fpn.rpn import RPN from chainercv.links.model.resnet import ResNet101 @@ -45,6 +46,7 @@ class FasterRCNNFPNResNet(FasterRCNN): """ def __init__(self, n_fg_class=None, pretrained_model=None, + n_point=None, return_values=['bboxes', 'labels', 'scores'], min_size=800, max_size=1333): param, path = utils.prepare_pretrained_model( @@ -63,6 +65,7 @@ def __init__(self, n_fg_class=None, pretrained_model=None, rpn=RPN(extractor.scales), bbox_head=BboxHead(param['n_fg_class'] + 1, extractor.scales), mask_head=MaskHead(param['n_fg_class'] + 1, extractor.scales), + keypoint_head=KeypointHead(n_point, extractor.scales), return_values=return_values, min_size=min_size, max_size=max_size ) @@ -72,7 +75,7 @@ def __init__(self, n_fg_class=None, pretrained_model=None, self.extractor.base, self._base(pretrained_model='imagenet', arch='he')) elif path: - chainer.serializers.load_npz(path, self) + chainer.serializers.load_npz(path, self, strict=False) class MaskRCNNFPNResNet(FasterRCNNFPNResNet): @@ -91,7 +94,30 @@ class MaskRCNNFPNResNet(FasterRCNNFPNResNet): def __init__(self, n_fg_class=None, pretrained_model=None, min_size=800, max_size=1333): super(MaskRCNNFPNResNet, self).__init__( - n_fg_class, pretrained_model, ['masks', 'labels', 'scores'], + n_fg_class, pretrained_model, None, + ['masks', 'labels', 'scores'], + min_size, max_size) + + +class KeypointRCNNFPNResNet(FasterRCNNFPNResNet): + """Feature Pyramid Networks with ResNet-50. + + This is a model of Feature Pyramid Networks [#]_. + This model uses :class:`~chainercv.links.ResNet50` as + its base feature extractor. + + .. [#] Tsung-Yi Lin et al. + Feature Pyramid Networks for Object Detection. CVPR 2017 + + + """ + + def __init__(self, n_fg_class=None, pretrained_model=None, + n_point=None, + min_size=800, max_size=1333): + super(KeypointRCNNFPNResNet, self).__init__( + n_fg_class, pretrained_model, n_point, + ['points', 'labels', 'scores', 'point_scores', 'bboxes'], min_size, max_size) @@ -189,6 +215,52 @@ class MaskRCNNFPNResNet101(MaskRCNNFPNResNet): } +class KeypointRCNNFPNResNet50(KeypointRCNNFPNResNet): + """Feature Pyramid Networks with ResNet-50. + + This is a model of Feature Pyramid Networks [#]_. + This model uses :class:`~chainercv.links.ResNet50` as + its base feature extractor. + + .. [#] Tsung-Yi Lin et al. + Feature Pyramid Networks for Object Detection. CVPR 2017 + + + """ + + _base = ResNet50 + _models = { + 'coco': { + 'param': {'n_fg_class': 80}, + 'url': 'https://chainercv-models.preferred.jp/' + 'faster_rcnn_fpn_resnet50_mask_coco_trained_2019_03_15.npz', + 'cv2': True + }, + } + + +class KeypointRCNNFPNResNet101(KeypointRCNNFPNResNet): + """Feature Pyramid Networks with ResNet-50. + + This is a model of Feature Pyramid Networks [#]_. + This model uses :class:`~chainercv.links.ResNet50` as + its base feature extractor. + + .. [#] Tsung-Yi Lin et al. + Feature Pyramid Networks for Object Detection. CVPR 2017 + + + """ + + _base = ResNet50 + _models = { + 'coco': { + 'param': {'n_fg_class': 80}, + 'url': '', + 'cv2': True + }, + } + def _copyparams(dst, src): if isinstance(dst, chainer.Chain): diff --git a/chainercv/links/model/mask_rcnn/keypoint_head.py b/chainercv/links/model/fpn/keypoint_head.py similarity index 98% rename from chainercv/links/model/mask_rcnn/keypoint_head.py rename to chainercv/links/model/fpn/keypoint_head.py index f53a44a102..c0dd00679d 100644 --- a/chainercv/links/model/mask_rcnn/keypoint_head.py +++ b/chainercv/links/model/fpn/keypoint_head.py @@ -15,8 +15,8 @@ from chainercv.transforms.image.resize import resize from chainercv.utils.bbox.bbox_iou import bbox_iou -from chainercv.links.model.mask_rcnn.misc import point_to_roi_points -from chainercv.links.model.mask_rcnn.misc import within_bbox +from chainercv.links.model.fpn.keypoint_utils import point_to_roi_points +from chainercv.links.model.fpn.keypoint_utils import within_bbox # make a bilinear interpolation kernel diff --git a/chainercv/links/model/fpn/keypoint_utils.py b/chainercv/links/model/fpn/keypoint_utils.py new file mode 100644 index 0000000000..adc5070528 --- /dev/null +++ b/chainercv/links/model/fpn/keypoint_utils.py @@ -0,0 +1,52 @@ +from __future__ import division + +import numpy as np + +import chainer + + +def point_to_roi_points( + point, visible, bbox, point_map_size): + xp = chainer.backends.cuda.get_array_module(point) + + R, K, _ = point.shape + + roi_point = xp.zeros((len(bbox), K, 2)) + roi_visible = xp.zeros((len(bbox), K), dtype=np.bool) + + offset_y = bbox[:, 0] + offset_x = bbox[:, 1] + scale_y = point_map_size / (bbox[:, 2] - bbox[:, 0]) + scale_x = point_map_size / (bbox[:, 3] - bbox[:, 1]) + + for k in range(K): + y_boundary_index = xp.where(point[:, k, 0] == bbox[:, 2])[0] + x_boundary_index = xp.where(point[:, k, 1] == bbox[:, 3])[0] + + ys = (point[:, k, 0] - offset_y) * scale_y + ys = xp.floor(ys) + if len(y_boundary_index) > 0: + ys[y_boundary_index] = point_map_size - 1 + xs = (point[:, k, 1] - offset_x) * scale_x + xs = xp.floor(xs) + if len(x_boundary_index) > 0: + xs[x_boundary_index] = point_map_size - 1 + + valid = xp.logical_and( + xp.logical_and( + xp.logical_and(ys >= 0, xs >= 0), + xp.logical_and(ys < point_map_size, xs < point_map_size)), + visible[:, k]) + + roi_point[:, k, 0] = ys + roi_point[:, k, 1] = xs + roi_visible[:, k] = valid + return roi_point, roi_visible + + +def within_bbox(point, bbox): + y_within = (point[:, :, 0] >= bbox[:, 0][:, None]) & ( + point[:, :, 0] <= bbox[:, 2][:, None]) + x_within = (point[:, :, 1] >= bbox[:, 1][:, None]) & ( + point[:, :, 1] <= bbox[:, 3][:, None]) + return y_within & x_within diff --git a/chainercv/links/model/fpn/mask_utils.py b/chainercv/links/model/fpn/mask_utils.py index c8cba87076..5c28e20232 100644 --- a/chainercv/links/model/fpn/mask_utils.py +++ b/chainercv/links/model/fpn/mask_utils.py @@ -155,50 +155,3 @@ def _expand_boxes(bbox, scale): expanded_bbox[:, 3] = x_c + w_half return expanded_bbox - - -def point_to_roi_points( - point, visible, bbox, point_map_size): - xp = chainer.backends.cuda.get_array_module(point) - - R, K, _ = point.shape - - roi_point = xp.zeros((len(bbox), K, 2)) - roi_visible = xp.zeros((len(bbox), K), dtype=np.bool) - - offset_y = bbox[:, 0] - offset_x = bbox[:, 1] - scale_y = point_map_size / (bbox[:, 2] - bbox[:, 0]) - scale_x = point_map_size / (bbox[:, 3] - bbox[:, 1]) - - for k in range(K): - y_boundary_index = xp.where(point[:, k, 0] == bbox[:, 2])[0] - x_boundary_index = xp.where(point[:, k, 1] == bbox[:, 3])[0] - - ys = (point[:, k, 0] - offset_y) * scale_y - ys = xp.floor(ys) - if len(y_boundary_index) > 0: - ys[y_boundary_index] = point_map_size - 1 - xs = (point[:, k, 1] - offset_x) * scale_x - xs = xp.floor(xs) - if len(x_boundary_index) > 0: - xs[x_boundary_index] = point_map_size - 1 - - valid = xp.logical_and( - xp.logical_and( - xp.logical_and(ys >= 0, xs >= 0), - xp.logical_and(ys < point_map_size, xs < point_map_size)), - visible[:, k]) - - roi_point[:, k, 0] = ys - roi_point[:, k, 1] = xs - roi_visible[:, k] = valid - return roi_point, roi_visible - - -def within_bbox(point, bbox): - y_within = (point[:, :, 0] >= bbox[:, 0][:, None]) & ( - point[:, :, 0] <= bbox[:, 2][:, None]) - x_within = (point[:, :, 1] >= bbox[:, 1][:, None]) & ( - point[:, :, 1] <= bbox[:, 3][:, None]) - return y_within & x_within diff --git a/chainercv/links/model/mask_rcnn/__init__.py b/chainercv/links/model/mask_rcnn/__init__.py deleted file mode 100644 index 3391efe1f9..0000000000 --- a/chainercv/links/model/mask_rcnn/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from chainercv.links.model.mask_rcnn.keypoint_head import KeypointHead # NOQA -from chainercv.links.model.mask_rcnn.keypoint_head import keypoint_loss_post # NOQA -from chainercv.links.model.mask_rcnn.keypoint_head import keypoint_loss_pre # NOQA -from chainercv.links.model.mask_rcnn.mask_head import mask_loss_post # NOQA -from chainercv.links.model.mask_rcnn.mask_head import mask_loss_pre # NOQA -from chainercv.links.model.mask_rcnn.mask_head import MaskHead # NOQA -from chainercv.links.model.mask_rcnn.mask_rcnn import MaskRCNN # NOQA -from chainercv.links.model.mask_rcnn.mask_rcnn_fpn_resnet import MaskRCNNFPNResNet101 # NOQA -from chainercv.links.model.mask_rcnn.mask_rcnn_fpn_resnet import MaskRCNNFPNResNet50 # NOQA -from chainercv.links.model.mask_rcnn.misc import mask_to_segm # NOQA -from chainercv.links.model.mask_rcnn.misc import segm_to_mask # NOQA diff --git a/chainercv/links/model/mask_rcnn/mask_rcnn.py b/chainercv/links/model/mask_rcnn/mask_rcnn.py deleted file mode 100644 index 8bb88f9789..0000000000 --- a/chainercv/links/model/mask_rcnn/mask_rcnn.py +++ /dev/null @@ -1,253 +0,0 @@ -from __future__ import division - -import numpy as np - -import chainer -from chainer.backends import cuda -import chainer.functions as F - -from chainercv.links.model.mask_rcnn.misc import scale_img - - -class MaskRCNN(chainer.Chain): - - """Base class of Mask R-CNN. - - This is a base class of Mask R-CNN [#]_. - - .. [#] Kaiming He et al. Mask R-CNN. ICCV 2017 - - Args: - extractor (Link): A link that extracts feature maps. - This link must have :obj:`scales`, :obj:`mean` and - :meth:`__call__`. - rpn (Link): A link that has the same interface as - :class:`~chainercv.links.model.fpn.RPN`. - Please refer to the documentation found there. - head (Link): A link that has the same interface as - :class:`~chainercv.links.model.fpn.Head`. - Please refer to the documentation found there. - mask_head (Link): A link that has the same interface as - :class:`~chainercv.links.model.mask_rcnn.MaskRCNN`. - Please refer to the documentation found there. - - Parameters: - nms_thresh (float): The threshold value - for :func:`~chainercv.utils.non_maximum_suppression`. - The default value is :obj:`0.5`. - This value can be changed directly or by using :meth:`use_preset`. - score_thresh (float): The threshold value for confidence score. - If a bounding box whose confidence score is lower than this value, - the bounding box will be suppressed. - The default value is :obj:`0.7`. - This value can be changed directly or by using :meth:`use_preset`. - - """ - - min_size = 800 - max_size = 1333 - stride = 32 - - def __init__(self, extractor, rpn, head, mask_head, - keypoint_head, mode='mask'): - super(MaskRCNN, self).__init__() - with self.init_scope(): - self.extractor = extractor - self.rpn = rpn - self.head = head - if mode == 'mask': - self.mask_head = mask_head - elif mode =='keypoint': - self.keypoint_head = keypoint_head - self.mode = mode - - self.use_preset('visualize') - - def use_preset(self, preset): - """Use the given preset during prediction. - - This method changes values of :obj:`nms_thresh` and - :obj:`score_thresh`. These values are a threshold value - used for non maximum suppression and a threshold value - to discard low confidence proposals in :meth:`predict`, - respectively. - - If the attributes need to be changed to something - other than the values provided in the presets, please modify - them by directly accessing the public attributes. - - Args: - preset ({'visualize', 'evaluate'}): A string to determine the - preset to use. - """ - - if preset == 'visualize': - self.nms_thresh = 0.5 - self.score_thresh = 0.7 - elif preset == 'evaluate': - self.nms_thresh = 0.5 - self.score_thresh = 0.05 - else: - raise ValueError('preset must be visualize or evaluate') - - def __call__(self, x): - assert(not chainer.config.train) - hs = self.extractor(x) - rpn_locs, rpn_confs = self.rpn(hs) - anchors = self.rpn.anchors(h.shape[2:] for h in hs) - rois, roi_indices = self.rpn.decode( - rpn_locs, rpn_confs, anchors, x.shape) - rois, roi_indices = self.head.distribute(rois, roi_indices) - return hs, rois, roi_indices - - def predict(self, imgs): - """Segment object instances from images. - - This method predicts instance-aware object regions for each image. - - Args: - imgs (iterable of numpy.ndarray): Arrays holding images of shape - :math:`(B, C, H, W)`. All images are in CHW and RGB format - and the range of their value is :math:`[0, 255]`. - - Returns: - tuple of lists: - This method returns a tuple of three lists, - :obj:`(masks, labels, scores)`. - - * **masks**: A list of boolean arrays of shape :math:`(R, H, W)`, \ - where :math:`R` is the number of masks in a image. \ - Each pixel holds value if it is inside the object inside or not. - * **labels** : A list of integer arrays of shape :math:`(R,)`. \ - Each value indicates the class of the masks. \ - Values are in range :math:`[0, L - 1]`, where :math:`L` is the \ - number of the foreground classes. - * **scores** : A list of float arrays of shape :math:`(R,)`. \ - Each value indicates how confident the prediction is. - - """ - - sizes = [img.shape[1:] for img in imgs] - x, scales = self.prepare(imgs) - - with chainer.using_config('train', False), chainer.no_backprop_mode(): - hs, rois, roi_indices = self(x) - head_locs, head_confs = self.head(hs, rois, roi_indices) - bboxes, labels, scores = self.head.decode( - rois, roi_indices, head_locs, head_confs, - scales, sizes, self.nms_thresh, self.score_thresh) - - rescaled_bboxes = [bbox * scale for scale, bbox in zip(scales, bboxes)] - if self.mode == 'mask': - # Change bboxes to RoI and RoI indices format - mask_rois_before_reordering, mask_roi_indices_before_reordering =\ - _list_to_flat(rescaled_bboxes) - mask_rois, mask_roi_indices, order = self.mask_head.distribute( - mask_rois_before_reordering, mask_roi_indices_before_reordering) - with chainer.using_config('train', False), chainer.no_backprop_mode(): - segms = F.sigmoid( - self.mask_head(hs, mask_rois, mask_roi_indices)).data - # Put the order of proposals back to the one used by bbox head. - segms = segms[order] - segms = _flat_to_list( - segms, mask_roi_indices_before_reordering, len(imgs)) - segms = [segm if segm is not None else - self.xp.zeros( - (0, self.mask_head.segm_size, self.mask_head.segm_size), - dtype=np.float32) - for segm in segms] - - segms = [chainer.backends.cuda.to_cpu(segm) for segm in segms] - bboxes = [chainer.backends.cuda.to_cpu(bbox / scale) - for bbox, scale in zip(rescaled_bboxes, scales)] - labels = [chainer.backends.cuda.to_cpu(label) for label in labels] - # Currently MaskHead only supports numpy inputs - masks = self.mask_head.decode(segms, bboxes, labels, sizes) - scores = [cuda.to_cpu(score) for score in scores] - return masks, labels, scores - elif self.mode == 'keypoint': - (point_rois_before_reordering, - point_roi_indices_before_reordering) = _list_to_flat( - rescaled_bboxes) - point_rois, point_roi_indices, order =\ - self.keypoint_head.distribute( - point_rois_before_reordering, - point_roi_indices_before_reordering) - with chainer.using_config('train', False), chainer.no_backprop_mode(): - point_maps = self.keypoint_head( - hs, point_rois, point_roi_indices).data - point_maps = point_maps[order] - point_maps = _flat_to_list( - point_maps, point_roi_indices_before_reordering, len(imgs)) - point_maps = [point_map if point_map is not None else - self.xp.zeros( - (0, self.keypoint_head.n_point, - self.keypoint_head.point_map_size, - self.keypoint_head.point_map_size), - dtype=np.float32) - for point_map in point_maps] - point_maps = [ - chainer.backends.cuda.to_cpu(point_map) - for point_map in point_maps] - bboxes = [chainer.cuda.to_cpu(bbox / scale) - for bbox, scale in zip(rescaled_bboxes, scales)] - points, point_scores = self.keypoint_head.decode( - point_maps, bboxes) - labels = [cuda.to_cpu(label) for label in labels] - scores = [cuda.to_cpu(score) for score in scores] - return points, labels, scores, point_scores, bboxes - - def prepare(self, imgs): - """Preprocess images. - - Args: - imgs (iterable of numpy.ndarray): Arrays holding images. - All images are in CHW and RGB format - and the range of their value is :math:`[0, 255]`. - - Returns: - Two arrays: preprocessed images and \ - scales that were caluclated in prepocessing. - - """ - scales = [] - resized_imgs = [] - for img in imgs: - img, scale = scale_img( - img, self.min_size, self.max_size) - img -= self.extractor.mean - scales.append(scale) - resized_imgs.append(img) - pad_size = np.array( - [im.shape[1:] for im in resized_imgs]).max(axis=0) - pad_size = ( - np.ceil(pad_size / self.stride) * self.stride).astype(int) - x = np.zeros( - (len(imgs), 3, pad_size[0], pad_size[1]), dtype=np.float32) - for i, im in enumerate(resized_imgs): - _, H, W = im.shape - x[i, :, :H, :W] = im - x = self.xp.array(x) - - return x, scales - - -def _list_to_flat(array_list): - xp = chainer.backends.cuda.get_array_module(array_list[0]) - - indices = xp.concatenate( - [i * xp.ones((len(array),), dtype=np.int32) for - i, array in enumerate(array_list)], axis=0) - flat = xp.concatenate(array_list, axis=0) - return flat, indices - - -def _flat_to_list(flat, indices, B): - array_list = [] - for i in range(B): - array = flat[indices == i] - if len(array) > 0: - array_list.append(array) - else: - array_list.append(None) - return array_list diff --git a/chainercv/links/model/mask_rcnn/mask_rcnn_fpn_resnet.py b/chainercv/links/model/mask_rcnn/mask_rcnn_fpn_resnet.py deleted file mode 100644 index 3048ce80cf..0000000000 --- a/chainercv/links/model/mask_rcnn/mask_rcnn_fpn_resnet.py +++ /dev/null @@ -1,137 +0,0 @@ -from __future__ import division - -import chainer -import chainer.functions as F - -from chainercv.links.model.fpn import FPN -from chainercv.links.model.fpn import Head -from chainercv.links.model.fpn import RPN -from chainercv.links.model.mask_rcnn.keypoint_head import KeypointHead -from chainercv.links.model.mask_rcnn.mask_head import MaskHead -from chainercv.links.model.mask_rcnn.mask_rcnn import MaskRCNN -from chainercv.links.model.resnet import ResNet101 -from chainercv.links.model.resnet import ResNet50 -from chainercv import utils - -from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import _copyparams - - -class MaskRCNNFPNResNet(MaskRCNN): - - """Base class for Mask R-CNN with ResNet backbone. - - A subclass of this class should have :obj:`_base` and :obj:`_models`. - """ - - def __init__(self, n_fg_class=None, pretrained_model=None, - n_point=17, mode='mask'): - param, path = utils.prepare_pretrained_model( - {'n_fg_class': n_fg_class}, pretrained_model, self._models) - - base = self._base(n_class=1, arch='he') - base.pick = ('res2', 'res3', 'res4', 'res5') - base.pool1 = lambda x: F.max_pooling_2d( - x, 3, stride=2, pad=1, cover_all=False) - base.remove_unused() - extractor = FPN( - base, len(base.pick), (1 / 4, 1 / 8, 1 / 16, 1 / 32, 1 / 64)) - - n_class = param['n_fg_class'] + 1 - super(MaskRCNNFPNResNet, self).__init__( - extractor=extractor, - rpn=RPN(extractor.scales), - head=Head(n_class, extractor.scales), - mask_head=MaskHead(n_class, extractor.scales), - keypoint_head=KeypointHead(n_point, extractor.scales), - mode=mode, - ) - if path == 'imagenet': - _copyparams( - self.extractor.base, - self._base(pretrained_model='imagenet', arch='he')) - elif path: - chainer.serializers.load_npz(path, self) - - -class MaskRCNNFPNResNet50(MaskRCNNFPNResNet): - - """Mask R-CNN with ResNet-50. - - This is a model of Mask R-CNN [#]_. - This model uses :class:`~chainercv.links.ResNet50` as - its base feature extractor. - - .. [#] Kaiming He et al. Mask R-CNN. ICCV 2017 - - Args: - n_fg_class (int): The number of classes excluding the background. - pretrained_model (string): The weight file to be loaded. - This can take :obj:`'coco'`, `filepath` or :obj:`None`. - The default value is :obj:`None`. - - * :obj:`'coco'`: Load weights trained on train split of \ - MS COCO 2017. \ - The weight file is downloaded and cached automatically. \ - :obj:`n_fg_class` must be :obj:`80` or :obj:`None`. - * :obj:`'imagenet'`: Load weights of ResNet-50 trained on \ - ImageNet. \ - The weight file is downloaded and cached automatically. \ - This option initializes weights partially and the rests are \ - initialized randomly. In this case, :obj:`n_fg_class` \ - can be set to any number. - * `filepath`: A path of npz file. In this case, :obj:`n_fg_class` \ - must be specified properly. - * :obj:`None`: Do not load weights. - - """ - - _base = ResNet50 - _models = { - 'coco': { - 'param': {'n_fg_class': 80}, - 'url': None, - 'cv2': True - }, - } - - -class MaskRCNNFPNResNet101(MaskRCNNFPNResNet): - - """Mask R-CNN with ResNet-101. - - This is a model of Mask R-CNN [#]_. - This model uses :class:`~chainercv.links.ResNet101` as - its base feature extractor. - - .. [#] Kaiming He et al. Mask R-CNN. ICCV 2017 - - Args: - n_fg_class (int): The number of classes excluding the background. - pretrained_model (string): The weight file to be loaded. - This can take :obj:`'coco'`, `filepath` or :obj:`None`. - The default value is :obj:`None`. - - * :obj:`'coco'`: Load weights trained on train split of \ - MS COCO 2017. \ - The weight file is downloaded and cached automatically. \ - :obj:`n_fg_class` must be :obj:`80` or :obj:`None`. - * :obj:`'imagenet'`: Load weights of ResNet-101 trained on \ - ImageNet. \ - The weight file is downloaded and cached automatically. \ - This option initializes weights partially and the rests are \ - initialized randomly. In this case, :obj:`n_fg_class` \ - can be set to any number. - * `filepath`: A path of npz file. In this case, :obj:`n_fg_class` \ - must be specified properly. - * :obj:`None`: Do not load weights. - - """ - - _base = ResNet101 - _models = { - 'coco': { - 'param': {'n_fg_class': 80}, - 'url': None, - 'cv2': True - }, - } diff --git a/examples/fpn/demo.py b/examples/fpn/demo.py index 0d615cacfb..b11a844eb6 100644 --- a/examples/fpn/demo.py +++ b/examples/fpn/demo.py @@ -5,13 +5,17 @@ from chainercv.datasets import coco_bbox_label_names from chainercv.datasets import coco_instance_segmentation_label_names +from chainercv.datasets import coco_keypoint_names from chainercv.links import FasterRCNNFPNResNet101 from chainercv.links import FasterRCNNFPNResNet50 +from chainercv.links import KeypointRCNNFPNResNet101 +from chainercv.links import KeypointRCNNFPNResNet50 from chainercv.links import MaskRCNNFPNResNet101 from chainercv.links import MaskRCNNFPNResNet50 from chainercv import utils from chainercv.visualizations import vis_bbox from chainercv.visualizations import vis_instance_segmentation +from chainercv.visualizations import vis_keypoint_coco def main(): @@ -19,7 +23,8 @@ def main(): parser.add_argument( '--model', choices=('faster_rcnn_fpn_resnet50', 'faster_rcnn_fpn_resnet101', - 'mask_rcnn_fpn_resnet50', 'mask_rcnn_fpn_resnet101'), + 'mask_rcnn_fpn_resnet50', 'mask_rcnn_fpn_resnet101', + 'keypoint_rcnn_fpn_resnet50', 'keypoint_rcnn_fpn_resnet101'), default='faster_rcnn_fpn_resnet50') parser.add_argument('--gpu', type=int, default=-1) parser.add_argument('--pretrained-model', default='coco') @@ -46,6 +51,18 @@ def main(): model = MaskRCNNFPNResNet101( n_fg_class=len(coco_instance_segmentation_label_names), pretrained_model=args.pretrained_model) + elif args.model == 'keypoint_rcnn_fpn_resnet50': + mode = 'keypoint' + model = KeypointRCNNFPNResNet50( + n_fg_class=1, + pretrained_model=args.pretrained_model, + n_point=len(coco_keypoint_names[0])) + elif args.model == 'keypoint_rcnn_fpn_resnet101': + mode = 'keypoint' + model = KeypointRCNNFPNResNet101( + n_fg_class=1, + pretrained_model=args.pretrained_model, + n_point=len(coco_keypoint_names[0])) if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() @@ -69,6 +86,17 @@ def main(): vis_instance_segmentation( img, mask, label, score, label_names=coco_instance_segmentation_label_names) + elif mode == 'keypoint': + points, labels, scores, point_scores, bboxes = model.predict([img]) + point = points[0] + label = labels[0] + score = scores[0] + point_score = point_scores[0] + bbox = bboxes[0] + ax = vis_keypoint_coco( + img, point, None, point_score) + vis_bbox(None, bbox, label, score=score, + label_names=coco_bbox_label_names, ax=ax) plt.show() diff --git a/examples/mask_rcnn/demo.py b/examples/mask_rcnn/demo.py deleted file mode 100644 index 81659c862b..0000000000 --- a/examples/mask_rcnn/demo.py +++ /dev/null @@ -1,75 +0,0 @@ -import argparse -import matplotlib.pyplot as plt - -import chainer - -import chainercv -from chainercv.datasets import coco_instance_segmentation_label_names -from chainercv import utils - -from chainercv.links import MaskRCNNFPNResNet101 -from chainercv.links import MaskRCNNFPNResNet50 - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--model', - choices=('mask_rcnn_fpn_resnet50', 'mask_rcnn_fpn_resnet101'), - default='mask_rcnn_fpn_resnet50' - ) - parser.add_argument('--gpu', type=int, default=-1) - parser.add_argument('--pretrained-model', default='coco') - parser.add_argument( - '--mode', - choices=('mask', 'keypoint'), - default='mask') - parser.add_argument('image') - args = parser.parse_args() - - if args.mode == 'mask': - n_fg_class = len(coco_instance_segmentation_label_names) - elif args.mode == 'keypoint': - n_fg_class = 1 - if args.model == 'mask_rcnn_fpn_resnet50': - model = MaskRCNNFPNResNet50( - n_fg_class=n_fg_class, - pretrained_model=args.pretrained_model, - mode=args.mode - ) - elif args.model == 'mask_rcnn_fpn_resnet101': - model = MaskRCNNFPNResNet101( - n_fg_class=n_fg_class, - pretrained_model=args.pretrained_model, - mode=args.mode - ) - - if args.gpu >= 0: - chainer.cuda.get_device_from_id(args.gpu).use() - model.to_gpu() - - img = utils.read_image(args.image) - if args.mode == 'mask': - masks, labels, scores = model.predict([img]) - mask = masks[0] - label = labels[0] - score = scores[0] - chainercv.visualizations.vis_instance_segmentation( - img, mask, label, score, - label_names=coco_instance_segmentation_label_names) - plt.show() - elif args.mode == 'keypoint': - points, labels, scores, point_scores, bboxes = model.predict([img]) - point = points[0] - label = labels[0] - score = scores[0] - point_score = point_scores[0] - bbox = bboxes[0] - ax = chainercv.visualizations.vis_keypoint_coco( - img, point, None, point_score) - chainercv.visualizations.vis_bbox(None, bbox, score=score, ax=ax) - plt.show() - - -if __name__ == '__main__': - main()