This is PyTorch re-implementation for Bayesian Convolutional Neural Networks.
(Chainer implementation is available: bayesian_unet)
In this project, we assume the following two scenarios, especially for medical imaging.
- Two-dimensional segmentation / regression with the 2D U-Net. (e.g., 2D x-ray, laparoscopic images, and CT slices)
- Three-dimensional segmentation / regression with the 3D U-Net. (e.g., 3D CT volumes)
This is a part of following works.
@article{hiasa2019automated,
title={Automated Muscle Segmentation from Clinical CT using Bayesian U-Net for Personalized Musculoskeletal Modeling},
author={Hiasa, Yuta and Otake, Yoshito and Takao, Masaki and Ogawa, Takeshi and Sugano, Nobuhiko and Sato, Yoshinobu},
journal={IEEE Transactions on Medical Imaging},
year={2019 (in press)},
doi={10.1109/TMI.2019.2940555},
}
@article{sakamoto2019bayesian,
title={Bayesian Segmentation of Hip and Thigh Muscles in Metal Artifact-Contaminated CT using Convolutional Neural Network-Enhanced Normalized Metal Artifact Reduction},
author={Sakamoto, Mitsuki and Hiasa, Yuta and Otake, Yoshito and Takao, Masaki and Suzuki, Yuki and Sugano, Nobuhiko and Sato, Yoshinobu},
journal={Journal of Signal Processing Systems},
year={2019 (in press)}
doi={10.1007/s11265-019-01507-z},
}
@inproceedings{hiasa2018surgical,
title={Surgical tools segmentation in laparoscopic images using convolutional neural networks with uncertainty estimation and semi-supervised learning},
author={Hiasa, Y and Otake, Y and Nakatani, S and Harada, H and Kanaji, S and Kakeji, Y and Sato, Yoshinobu},
booktitle={Proc. International Conference of Computer Assisted Radiology and Surgery},
pages={14--15},
year={2018}
}
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
- PyTorch 1.4
- Install PyTorch and dependencies from https://pytorch.org/
- Install Pytorch Trainer
git clone https://github.com/yuta-hi/pytorch-trainer
cd pytorch-trainer
python setup.py install
- For other requirements, see requirements.txt.
- Install from this repository
git clone https://github.com/yuta-hi/pytorch_bayesian_unet
cd pytorch_bayesian_unet
python setup.py install
The data set we used are medical images and it is difficult to share due to ethical issues. Thus, we prepared the following examples using synthetic or public data set.
Approximation of the function
python examples/curve_regression/train_and_test_epistemic.py
python examples/curve_regression/train_and_test_epistemic.py --test_on_test
python examples/curve_regression/train_and_test_epistemic_aleatoric.py
python examples/curve_regression/train_and_test_epistemic_aleatoric.py --test_on_test
Ten digits classification. A subset of samples was used for the training data set. In the default setting, 1,000 samples are used for training and 1,000 samples are used for validation. The distribution of predicted variance for correct and wrong predictions on the test data set (10,000 samples) were visualized.
python examples/mnist_classification/train_and_test_epistemic.py
python examples/mnist_classification/train_and_test_epistemic.py --test_on_test
Segmentation of surgical instruments from laparoscopic images. Data set is downloaded from https://endovissub-instrument.grand-challenge.org/ . Training and test data sets consist 160 and 140 images, respectively.
python examples/miccai_endovis_segmentation/preprocess.py # download the dataset and convert label format
python examples/miccai_endovis_segmentation/train_and_test_epistemic.py
python examples/miccai_endovis_segmentation/train_and_test_epistemic.py --test_on_test
Aerial-to-Map translation. This example focuses on how the adversarial training affects uncertainty behavior. This is mainly followed the previous work [P. Isola, et al.]. In this example, the generator is replaced to Bayesian U-Net for uncertainty estimates. And, spectral normalization [T. Miyato et al.] is applied to the patch discriminator for stabilizing the optimization.
cd examples/map_synthesis
python preprocess.py # download and normalize the dataset
python train_and_test_pix2pix.py --out logs/pix2pix
Note that this is under construction.
On going.
Please follow the description to define these objects.
- datasets
- data augmentor
- data normalizer
- model
- visualizer
- validator
- inferencer
- (optional) singularity image
You can define your own dataset like below. PNG, JPG, BMP and meta image format (MHD, MHA) are supported.
- [case #1] 2D images
from pytorch_bcnn.datasets import ImageDataset
data_root = './data'
patients = ['ID0', 'ID1', 'ID2'] # NOTE: 3 patients
class_list = ['background', 'liver', 'tumor']
augmentor = None # NOTE: please set if you have..
normalizer = None # NOTE: please set if you have..
dtypes = OrderedDict({
'image': np.float32,
'label': np.int64, # NOTE: if categorical label
# 'mask': np.uint8, # NOTE: please set if you have..
})
filenames = OrderedDict({
'image': '{root}/{patient}/*_image.mhd',
'label': '{root}/{patient}/*_label.mhd',
# 'mask' : '{root}/{patient}/*_mask.mhd', # NOTE: please set if you have..
})
dataset = ImageDataset(data_root, patients, classes=class_list,
dtypes=dtypes, filenames=filenames, augmentor=augmentor, normalizer=normalizer)
- [case #2] 3D volumes
from pytorch_bcnn.datasets import VolumeDataset
...
dataset = VolumeDataset(data_root, patients, classes=class_list,
dtypes=dtypes, filenames=filenames, augmentor=augmentor, normalizer=normalizer)
- [case #3] Custom dataset
from pytorch_bcnn.datasets import BaseDataset
class CustomDataset(BaseDataset):
...
raise NotImplementedError()
You can use the data augmentor based on geometric transformation, which has stochastic behavior.
from pytorch_bcnn.data.augmentor import DataAugmentor
from pytorch_bcnn.data.augmentor import Crop2D, Flip2D, Affine2D
from pytorch_bcnn.data.augmentor import Crop3D, Flip3D, Affine3D
augmentor = DataAugmentor()
augmentor.add(Crop2D(size=(300,400)))
augmentor.add(Flip2D(axis=1))
augmentor.add(Affine2D(rotation=15.,
translate=(10.,10.),
shear=0.25,
zoom=(0.8, 1.2),
keep_aspect_ratio=True,
fill_mode=('nearest', 'constant'),
cval=(0.,0.),
interp_order=(3,0)))
augmentor.summary('augment.json')
You can use the data normalizer based on intensity transformation.
from pytorch_bcnn.data.normalizer import Normalizer
from pytorch_bcnn.data.normalizer import Clip2D, Subtract2D, Divide2D, Quantize2D
from pytorch_bcnn.data.normalizer import Clip3D, Subtract3D, Divide3D, Quantize3D
normalizer = Normalizer()
normalizer.add(Clip2D((-150, 350)))
normalizer.add(Quantize2D(8))
normalizer.add(Subtract2D(0.5))
normalizer.add(Divide2D(1./255.))
normalizer.summary('norm.json')
- [case #1] Segmentation
from pytorch_bcnn.models import BayesianUNet
from pytorch_bcnn.links import Classifier
predictor = BayesianUNet(ndim=2,
in_channels=1,
out_channels=3,
nlayer=5,
nfilter=32)
lossfun = partial(softmax_cross_entropy,
normalize=False, class_weight=class_weight)
model = Classifier(predictor,
lossfun=lossfun)
- [case #2] Regression
from pytorch_bcnn.links import Regressor
from pytorch_bcnn.functions.loss import sigmoid_soft_cross_entropy
from torch.nn.functional as F
...
lossfun = F.mse_loss
# lossfun = sigmoid_soft_cross_entropy # NOTE: if you want..
model = Regressor(predictor,
lossfun=lossfun)
- [case #3] Other problems (e.g., multi-task)
from pytorch_bcnn.models import UNetBase
class MultiTaskUNet(UNetBase):
def __init__(self,
ndim,
in_channels,
foo, # TODO
bar, # TODO
nfilter=32,
nlayer=5,
conv_param=_default_conv_param,
pool_param=_default_pool_param,
upconv_param=_default_upconv_param,
norm_param=_default_norm_param,
activation_param=_default_activation_param,
dropout_param=_default_dropout_param,
residual=False,
):
super(UNet, self).__init__(
ndim,
in_channels,
nfilter,
nlayer,
conv_param,
pool_param,
upconv_param,
norm_param,
activation_param,
dropout_param,
residual,)
self._foo = foo
self._bar = bar
pass # TODO: foo, bar
def forward(self, x):
h = super().forward(x)
# TODO: foo, bar
raise NotImplementedError('foo is bar..')
- [case #1] 2D segmentation
from pytorch_bcnn.visualizer import ImageVisualizer
transforms = {
'x': lambda x: x,
'y': lambda x: np.argmax(x, axis=0),
't': lambda x: x,
}
_cmap = np.array([
[0,0,0], # NOTE: background (black)
[1,0,0], # liver (red)
[0,1,0]]) # tumor (green)
cmaps = {
'x': None,
'y': _cmap,
't': _cmap,
}
clims = {
'x': (0., 255.),
'y': None,
't': None,
}
visualizer = ImageVisualizer(transforms=transforms,
cmaps=cmaps,
clims=clims)
- [case #2] 2D regression
from pytorch_bcnn.visualizer import ImageVisualizer
import matplotlib.pyplot as plt
def alpha_blend(heatmaps, cmap='jet'):
assert heatmaps.ndim == 3
ch, w, h = heatmaps.shape
ret = np.zeros((3, w, h))
mapper = plt.get_cmap(cmap, ch)
for i in range(ch):
color = np.ones((3, w, h)) \
* np.asarray(mapper(i)[:3]).reshape(-1,1,1)
ret += (color * heatmaps[i])
return ret
transforms = {
'x': None,
'y': lambda x: alpha_blend(sigmoid(x)),
't': lambda x: alpha_blend(x),
}
clims = {
'x': (0., 255.),
'y': (0., 1.),
't': (0., 1.),
}
cmaps = None
visualizer = ImageVisualizer(transforms=transforms,
cmaps=cmaps,
clims=clims)
To visualize 3D volumes, you can pass the volume renderer to the transforms
as described above.
from pytorch_bcnn.extensions import Validator
...
valid_file = 'iter_{.updater.iteration:08}.png'
n_vis = 20 # NOTE: number of samples for visualization
trainer.extend(Validator(valid_iter, model, valid_file,
visualizer=visualizer, n_vis=n_vis,
device=device))
- [case #1] Segmentation / Classification
from pytorch_bcnn.links import MCSampler
from pytorch_bcnn.inference import Inferencer
import torch
mc_iteration = 50
model = MCSampler(predictor, # NOTE: e.g., BayesianUNet
mc_iteration=mc_iteration,
activation=partial(torch.softmax, dim=1),
reduce_mean=partial(torch.argmax, dim=1),
reduce_var=partial(torch.mean, dim=1))
infer = Inferencer(test_iter, model, device=device)
estimated_labels, predicted_variances = infer.run()
cd recipe
make all