-
Notifications
You must be signed in to change notification settings - Fork 6
/
xray_dataloader.py
216 lines (170 loc) · 7.99 KB
/
xray_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
################################################################################
# CSE 190: Programming Assignment 3
# Fall 2018
# Code author: Jenny Hamer
#
#
# Description:
# This code defines a custom PyTorch Dataset object suited for the
# NIH ChestX-ray14 dataset of 14 common thorax diseases. This dataset contains
# 112,120 images (frontal-view X-rays) from 30,805 unique patients. Each image
# may be labeled with a single disease or multiple (multi-label). The nominative
# labels are mapped to an integer between 0-13, which is later converted into
# an n-hot binary encoded label.
#
#
# Dataset citation:
# X. Wang, Y. Peng , L. Lu Hospital-scale Chest X-ray Database and Benchmarks on
# Weakly-Supervised Classification and Localization of Common Thorax Diseases.
# Department of Radiology and Imaging Sciences, September 2017.
# https://arxiv.org/pdf/1705.02315.pdf
################################################################################
# PyTorch imports
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.data.sampler import SubsetRandomSampler
# Other libraries for data manipulation and visualization
import os
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Uncomment for Python2
# from __future__ import print_function
class ChestXrayDataset(Dataset):
"""Custom Dataset class for the Chest X-Ray Dataset.
The expected dataset is stored in the "/datasets/ChestXray-NIHCC/" on ieng6
"""
def __init__(self, transform=transforms.ToTensor(), color='L'):
"""
Args:
-----
- transform: A torchvision.transforms object -
transformations to apply to each image
(Can be "transforms.Compose([transforms])")
- color: Specifies image-color format to convert to
(default is L: 8-bit pixels, black and white)
Attributes:
-----------
- image_dir: The absolute filepath to the dataset on ieng6
- image_info: A Pandas DataFrame of the dataset metadata
- image_filenames: An array of indices corresponding to the images
- labels: An array of labels corresponding to the each sample
- classes: A dictionary mapping each disease name to an int between [0, 13]
"""
self.transform = transform
self.color = color
self.image_dir = "/datasets/ChestXray-NIHCC/images/"
self.image_info = pd.read_csv("/datasets/ChestXray-NIHCC/Data_Entry_2017.csv")
self.image_filenames = self.image_info["Image Index"]
self.labels = self.image_info["Finding Labels"]
self.classes = {0: "Atelectasis", 1: "Cardiomegaly", 2: "Effusion",
3: "Infiltration", 4: "Mass", 5: "Nodule", 6: "Pneumonia",
7: "Pneumothorax", 8: "Consolidation", 9: "Edema",
10: "Emphysema", 11: "Fibrosis",
12: "Pleural_Thickening", 13: "Hernia"}
def __len__(self):
# Return the total number of data samples
return len(self.image_filenames)
def __getitem__(self, ind):
"""Returns the image and its label at the index 'ind'
(after applying transformations to the image, if specified).
Params:
-------
- ind: (int) The index of the image to get
Returns:
--------
- A tuple (image, label)
"""
# Compose the path to the image file from the image_dir + image_name
image_path = os.path.join(self.image_dir, self.image_filenames.ix[ind])
# Load the image
image = Image.open(image_path).convert(mode=str(self.color))
# If a transform is specified, apply it
if self.transform is not None:
image = self.transform(image)
# Verify that image is in Tensor format
if type(image) is not torch.Tensor:
image = transform.ToTensor(image)
# Convert multi-class label into binary encoding
label = self.convert_label(self.labels[ind], self.classes)
# Return the image and its label
return (image, label)
def convert_label(self, label, classes):
"""Convert the numerical label to n-hot encoding.
Params:
-------
- label: a string of conditions corresponding to an image's class
Returns:
--------
- binary_label: (Tensor) a binary encoding of the multi-class label
"""
binary_label = torch.zeros(len(classes))
for key, value in classes.items():
if value in label:
binary_label[key] = 1.0
return binary_label
def create_split_loaders(batch_size, seed, transform=transforms.ToTensor(),
p_val=0.1, p_test=0.2, shuffle=True,
show_sample=False, extras={}):
""" Creates the DataLoader objects for the training, validation, and test sets.
Params:
-------
- batch_size: (int) mini-batch size to load at a time
- seed: (int) Seed for random generator (use for testing/reproducibility)
- transform: A torchvision.transforms object - transformations to apply to each image
(Can be "transforms.Compose([transforms])")
- p_val: (float) Percent (as decimal) of dataset to use for validation
- p_test: (float) Percent (as decimal) of the dataset to split for testing
- shuffle: (bool) Indicate whether to shuffle the dataset before splitting
- show_sample: (bool) Plot a mini-example as a grid of the dataset
- extras: (dict)
If CUDA/GPU computing is supported, contains:
- num_workers: (int) Number of subprocesses to use while loading the dataset
- pin_memory: (bool) For use with CUDA - copy tensors into pinned memory
(set to True if using a GPU)
Otherwise, extras is an empty dict.
Returns:
--------
- train_loader: (DataLoader) The iterator for the training set
- val_loader: (DataLoader) The iterator for the validation set
- test_loader: (DataLoader) The iterator for the test set
"""
# Get create a ChestXrayDataset object
dataset = ChestXrayDataset(transform)
# Dimensions and indices of training set
dataset_size = len(dataset)
all_indices = list(range(dataset_size))
# Shuffle dataset before dividing into training & test sets
if shuffle:
np.random.seed(seed)
np.random.shuffle(all_indices)
# Create the validation split from the full dataset
val_split = int(np.floor(p_val * dataset_size))
train_ind, val_ind = all_indices[val_split :], all_indices[: val_split]
# Separate a test split from the training dataset
test_split = int(np.floor(p_test * len(train_ind)))
train_ind, test_ind = train_ind[test_split :], train_ind[: test_split]
# Use the SubsetRandomSampler as the iterator for each subset
sample_train = SubsetRandomSampler(train_ind)
sample_test = SubsetRandomSampler(test_ind)
sample_val = SubsetRandomSampler(val_ind)
num_workers = 0
pin_memory = False
# If CUDA is available
if extras:
num_workers = extras["num_workers"]
pin_memory = extras["pin_memory"]
# Define the training, test, & validation DataLoaders
train_loader = DataLoader(dataset, batch_size=batch_size,
sampler=sample_train, num_workers=num_workers,
pin_memory=pin_memory)
test_loader = DataLoader(dataset, batch_size=batch_size,
sampler=sample_test, num_workers=num_workers,
pin_memory=pin_memory)
val_loader = DataLoader(dataset, batch_size=batch_size,
sampler=sample_val, num_workers=num_workers,
pin_memory=pin_memory)
# Return the training, validation, test DataLoader objects
return (train_loader, val_loader, test_loader)