Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
merge to fpn
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed Mar 15, 2019
1 parent 0840f0c commit 462726f
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 541 deletions.
2 changes: 2 additions & 0 deletions chainercv/links/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from chainercv.links.model.faster_rcnn.faster_rcnn_vgg import FasterRCNNVGG16 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet101 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet50 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet101 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet50 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet101 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet50 # NOQA
from chainercv.links.model.resnet import ResNet101 # NOQA
Expand Down
2 changes: 2 additions & 0 deletions chainercv/links/model/fpn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from chainercv.links.model.fpn.faster_rcnn import FasterRCNN # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet101 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet50 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet101 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import KeypointRCNNFPNResNet50 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet101 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import MaskRCNNFPNResNet50 # NOQA
from chainercv.links.model.fpn.fpn import FPN # NOQA
Expand Down
46 changes: 40 additions & 6 deletions chainercv/links/model/fpn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ class FasterRCNN(chainer.Chain):
"""

stride = 32
_accepted_return_values = ('rois', 'bboxes', 'labels', 'scores', 'masks')
_accepted_return_values = ('rois', 'bboxes', 'labels', 'scores',
'masks', 'points', 'point_scores')

def __init__(self, extractor, rpn, bbox_head,
mask_head, return_values,
mask_head, keypoint_head, return_values,
min_size=800, max_size=1333):
for value_name in return_values:
if value_name not in self._accepted_return_values:
Expand All @@ -64,8 +65,10 @@ def __init__(self, extractor, rpn, bbox_head,

self._store_rpn_outputs = 'rois' in self._return_values
self._run_bbox = any([key in self._return_values
for key in ['bboxes', 'labels', 'scores', 'masks']])
for key in ['bboxes', 'labels', 'scores',
'masks', 'points', 'point_scores']])
self._run_mask = 'masks' in self._return_values
self._run_keypoint = 'points' in self._return_values
super(FasterRCNN, self).__init__()

with self.init_scope():
Expand All @@ -75,6 +78,8 @@ def __init__(self, extractor, rpn, bbox_head,
self.bbox_head = bbox_head
if self._run_mask:
self.mask_head = mask_head
if self._run_keypoint:
self.keypoint_head = keypoint_head

self.min_size = min_size
self.max_size = max_size
Expand Down Expand Up @@ -174,10 +179,9 @@ def predict(self, imgs):
scores_cpu = [cuda.to_cpu(score) for score in scores]
output.update({'bboxes': bboxes_cpu, 'labels': labels_cpu,
'scores': scores_cpu})

if self._run_mask:
rescaled_bboxes = [bbox * scale
for scale, bbox in zip(scales, bboxes)]
for scale, bbox in zip(scales, bboxes)]
if self._run_mask:
# Change bboxes to RoI and RoI indices format
mask_rois_before_reordering, mask_roi_indices_before_reordering =\
_list_to_flat(rescaled_bboxes)
Expand All @@ -200,6 +204,36 @@ def predict(self, imgs):
# Currently MaskHead only supports numpy inputs
masks_cpu = self.mask_head.decode(segms, bboxes_cpu, labels_cpu, sizes)
output.update({'masks': masks_cpu})

if self._run_keypoint:
(point_rois_before_reordering,
point_roi_indices_before_reordering) = _list_to_flat(
rescaled_bboxes)
point_rois, point_roi_indices, order =\
self.keypoint_head.distribute(
point_rois_before_reordering,
point_roi_indices_before_reordering)
with chainer.using_config(
'train', False), chainer.no_backprop_mode():
point_maps = self.keypoint_head(
hs, point_rois, point_roi_indices).data
point_maps = point_maps[order]
point_maps = _flat_to_list(
point_maps, point_roi_indices_before_reordering, len(imgs))
point_maps = [point_map if point_map is not None else
self.xp.zeros(
(0, self.keypoint_head.n_point,
self.keypoint_head.point_map_size,
self.keypoint_head.point_map_size),
dtype=np.float32)
for point_map in point_maps]
point_maps = [
chainer.backends.cuda.to_cpu(point_map)
for point_map in point_maps]
points_cpu, point_scores_cpu = self.keypoint_head.decode(
point_maps, bboxes_cpu)
output.update(
{'points': points_cpu, 'point_scores': point_scores_cpu})
return tuple([output[key] for key in self._return_values])

def prepare(self, imgs):
Expand Down
83 changes: 80 additions & 3 deletions chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from chainercv.links.model.fpn.faster_rcnn import FasterRCNN
from chainercv.links.model.fpn.fpn import FPN
from chainercv.links.model.fpn.bbox_head import BboxHead
from chainercv.links.model.fpn.keypoint_head import KeypointHead
from chainercv.links.model.fpn.mask_head import MaskHead
from chainercv.links.model.fpn.rpn import RPN
from chainercv.links.model.resnet import ResNet101
Expand Down Expand Up @@ -45,10 +46,12 @@ class FasterRCNNFPNResNet(FasterRCNN):
"""

def __init__(self, n_fg_class=None, pretrained_model=None,
n_point=None,
return_values=['bboxes', 'labels', 'scores'],
min_size=800, max_size=1333):
param, path = utils.prepare_pretrained_model(
{'n_fg_class': n_fg_class}, pretrained_model, self._models)
{'n_fg_class': n_fg_class, 'n_point': n_point},
pretrained_model, self._models)

base = self._base(n_class=1, arch='he')
base.pick = ('res2', 'res3', 'res4', 'res5')
Expand All @@ -58,11 +61,16 @@ def __init__(self, n_fg_class=None, pretrained_model=None,
extractor = FPN(
base, len(base.pick), (1 / 4, 1 / 8, 1 / 16, 1 / 32, 1 / 64))

if param['n_point'] is not None:
keypoint_head = KeypointHead(param['n_point'], extractor.scales)
else:
keypoint_head = None
super(FasterRCNNFPNResNet, self).__init__(
extractor=extractor,
rpn=RPN(extractor.scales),
bbox_head=BboxHead(param['n_fg_class'] + 1, extractor.scales),
mask_head=MaskHead(param['n_fg_class'] + 1, extractor.scales),
keypoint_head=keypoint_head,
return_values=return_values,
min_size=min_size, max_size=max_size
)
Expand All @@ -72,7 +80,7 @@ def __init__(self, n_fg_class=None, pretrained_model=None,
self.extractor.base,
self._base(pretrained_model='imagenet', arch='he'))
elif path:
chainer.serializers.load_npz(path, self)
chainer.serializers.load_npz(path, self, strict=False)


class MaskRCNNFPNResNet(FasterRCNNFPNResNet):
Expand All @@ -91,7 +99,30 @@ class MaskRCNNFPNResNet(FasterRCNNFPNResNet):
def __init__(self, n_fg_class=None, pretrained_model=None,
min_size=800, max_size=1333):
super(MaskRCNNFPNResNet, self).__init__(
n_fg_class, pretrained_model, ['masks', 'labels', 'scores'],
n_fg_class, pretrained_model, None,
['masks', 'labels', 'scores'],
min_size, max_size)


class KeypointRCNNFPNResNet(FasterRCNNFPNResNet):
"""Feature Pyramid Networks with ResNet-50.
This is a model of Feature Pyramid Networks [#]_.
This model uses :class:`~chainercv.links.ResNet50` as
its base feature extractor.
.. [#] Tsung-Yi Lin et al.
Feature Pyramid Networks for Object Detection. CVPR 2017
"""

def __init__(self, n_fg_class=None, pretrained_model=None,
n_point=None,
min_size=800, max_size=1333):
super(KeypointRCNNFPNResNet, self).__init__(
n_fg_class, pretrained_model, n_point,
['points', 'labels', 'scores', 'point_scores', 'bboxes'],
min_size, max_size)


Expand Down Expand Up @@ -189,6 +220,52 @@ class MaskRCNNFPNResNet101(MaskRCNNFPNResNet):
}


class KeypointRCNNFPNResNet50(KeypointRCNNFPNResNet):
"""Feature Pyramid Networks with ResNet-50.
This is a model of Feature Pyramid Networks [#]_.
This model uses :class:`~chainercv.links.ResNet50` as
its base feature extractor.
.. [#] Tsung-Yi Lin et al.
Feature Pyramid Networks for Object Detection. CVPR 2017
"""

_base = ResNet50
_models = {
'coco': {
'param': {'n_fg_class': 1, 'n_point': 17},
'url': 'https://chainercv-models.preferred.jp/'
'faster_rcnn_fpn_resnet50_keypoint_coco_converted_2019_03_15.npz',
'cv2': True
},
}


class KeypointRCNNFPNResNet101(KeypointRCNNFPNResNet):
"""Feature Pyramid Networks with ResNet-50.
This is a model of Feature Pyramid Networks [#]_.
This model uses :class:`~chainercv.links.ResNet50` as
its base feature extractor.
.. [#] Tsung-Yi Lin et al.
Feature Pyramid Networks for Object Detection. CVPR 2017
"""

_base = ResNet50
_models = {
'coco': {
'param': {'n_fg_class': 80},
'url': '',
'cv2': True
},
}


def _copyparams(dst, src):
if isinstance(dst, chainer.Chain):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from chainercv.transforms.image.resize import resize
from chainercv.utils.bbox.bbox_iou import bbox_iou

from chainercv.links.model.mask_rcnn.misc import point_to_roi_points
from chainercv.links.model.mask_rcnn.misc import within_bbox
from chainercv.links.model.fpn.keypoint_utils import point_to_roi_points
from chainercv.links.model.fpn.keypoint_utils import within_bbox


# make a bilinear interpolation kernel
Expand Down
52 changes: 52 additions & 0 deletions chainercv/links/model/fpn/keypoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import division

import numpy as np

import chainer


def point_to_roi_points(
point, visible, bbox, point_map_size):
xp = chainer.backends.cuda.get_array_module(point)

R, K, _ = point.shape

roi_point = xp.zeros((len(bbox), K, 2))
roi_visible = xp.zeros((len(bbox), K), dtype=np.bool)

offset_y = bbox[:, 0]
offset_x = bbox[:, 1]
scale_y = point_map_size / (bbox[:, 2] - bbox[:, 0])
scale_x = point_map_size / (bbox[:, 3] - bbox[:, 1])

for k in range(K):
y_boundary_index = xp.where(point[:, k, 0] == bbox[:, 2])[0]
x_boundary_index = xp.where(point[:, k, 1] == bbox[:, 3])[0]

ys = (point[:, k, 0] - offset_y) * scale_y
ys = xp.floor(ys)
if len(y_boundary_index) > 0:
ys[y_boundary_index] = point_map_size - 1
xs = (point[:, k, 1] - offset_x) * scale_x
xs = xp.floor(xs)
if len(x_boundary_index) > 0:
xs[x_boundary_index] = point_map_size - 1

valid = xp.logical_and(
xp.logical_and(
xp.logical_and(ys >= 0, xs >= 0),
xp.logical_and(ys < point_map_size, xs < point_map_size)),
visible[:, k])

roi_point[:, k, 0] = ys
roi_point[:, k, 1] = xs
roi_visible[:, k] = valid
return roi_point, roi_visible


def within_bbox(point, bbox):
y_within = (point[:, :, 0] >= bbox[:, 0][:, None]) & (
point[:, :, 0] <= bbox[:, 2][:, None])
x_within = (point[:, :, 1] >= bbox[:, 1][:, None]) & (
point[:, :, 1] <= bbox[:, 3][:, None])
return y_within & x_within
47 changes: 0 additions & 47 deletions chainercv/links/model/fpn/mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,50 +155,3 @@ def _expand_boxes(bbox, scale):
expanded_bbox[:, 3] = x_c + w_half

return expanded_bbox


def point_to_roi_points(
point, visible, bbox, point_map_size):
xp = chainer.backends.cuda.get_array_module(point)

R, K, _ = point.shape

roi_point = xp.zeros((len(bbox), K, 2))
roi_visible = xp.zeros((len(bbox), K), dtype=np.bool)

offset_y = bbox[:, 0]
offset_x = bbox[:, 1]
scale_y = point_map_size / (bbox[:, 2] - bbox[:, 0])
scale_x = point_map_size / (bbox[:, 3] - bbox[:, 1])

for k in range(K):
y_boundary_index = xp.where(point[:, k, 0] == bbox[:, 2])[0]
x_boundary_index = xp.where(point[:, k, 1] == bbox[:, 3])[0]

ys = (point[:, k, 0] - offset_y) * scale_y
ys = xp.floor(ys)
if len(y_boundary_index) > 0:
ys[y_boundary_index] = point_map_size - 1
xs = (point[:, k, 1] - offset_x) * scale_x
xs = xp.floor(xs)
if len(x_boundary_index) > 0:
xs[x_boundary_index] = point_map_size - 1

valid = xp.logical_and(
xp.logical_and(
xp.logical_and(ys >= 0, xs >= 0),
xp.logical_and(ys < point_map_size, xs < point_map_size)),
visible[:, k])

roi_point[:, k, 0] = ys
roi_point[:, k, 1] = xs
roi_visible[:, k] = valid
return roi_point, roi_visible


def within_bbox(point, bbox):
y_within = (point[:, :, 0] >= bbox[:, 0][:, None]) & (
point[:, :, 0] <= bbox[:, 2][:, None])
x_within = (point[:, :, 1] >= bbox[:, 1][:, None]) & (
point[:, :, 1] <= bbox[:, 3][:, None])
return y_within & x_within
11 changes: 0 additions & 11 deletions chainercv/links/model/mask_rcnn/__init__.py

This file was deleted.

Loading

0 comments on commit 462726f

Please sign in to comment.