-
Notifications
You must be signed in to change notification settings - Fork 27
/
data_provider_factory.py
50 lines (38 loc) · 1.31 KB
/
data_provider_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# ---------------------------------------------------------------
# CleanNet implementation based on https://arxiv.org/abs/1711.07131.
# "CleanNet: Transfer Learning for Scalable Image Classifier Training with Label Noise"
# Kuang-Huei Lee, Xiaodong He, Lei Zhang, Linjun Yang
#
# Writen by Kuang-Huei Lee, 2018
# Licensed under the MSR-LA Full Rights License [see license.txt]
# ---------------------------------------------------------------
"""Data provider factory"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import math
import os
import sys
import numpy as np
import tensorflow as tf
import data_provider_trainval
import data_provider_inference
datasets_map = {
'trainval': data_provider_trainval,
'inference': data_provider_inference,
}
def get_data_batcher(name, mode, opt):
"""Given a dataset name and data provider mode returns a Dataset.
Args:
name: String, the name of the dataset.
mode: train/val/unverified/inference
opt: Options
Returns:
a data batcher
Raises:
ValueError: If the dataset `name` is unknown.
"""
if name not in datasets_map:
raise ValueError('Name of dataset unknown %s' % name)
return datasets_map[name].get_data_batcher(mode, opt)