Skip to content

Commit

Permalink
ENH: Refactor the dataset parameters tuples
Browse files Browse the repository at this point in the history
Refactor the dataset parameters tuples:
- The dataset tuples do not provide an effective fetcher, but only the
parameters needed to make the fetcher. Thus, the `fetcher_` prefix is
removed.
- Rename the method that provides the fetcher parameters accordingly.
- The fetcher names are all built in the same way, and thus this is put
into a method for the sake of best coding practices. Also, prefer naming
the fetchers directly using the Dataset enum values.
  • Loading branch information
jhlegarreta committed Mar 3, 2023
1 parent e488075 commit 41d5931
Showing 1 changed file with 48 additions and 43 deletions.
91 changes: 48 additions & 43 deletions tractolearn/tractoio/dataset_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,7 @@ def fetcher():
return fetcher


fetch_bundle_label_config = (
"fetch_bundle_label_config",
bundle_label_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["rbx_atlas_v10.json"],
["rbx_atlas_v10.json"],
Expand All @@ -341,8 +340,7 @@ def fetcher():
False,
)

fetch_contrastive_ae_weights = (
"fetch_contrastive_ae_weights",
contrastive_ae_weights = (
TRACTOLEARN_DATASETS_URL + "7562790/files/",
["best_model_contrastive_tractoinferno_hcp.pt"],
["best_model_contrastive_tractoinferno_hcp.pt"],
Expand All @@ -353,8 +351,7 @@ def fetcher():
False,
)

fetch_mni2009cnonlinsymm_anat = (
"fetch_mni2009cnonlinsymm_anat",
mni2009cnonlinsymm_anat = (
TRACTOLEARN_DATASETS_URL + "7562790/files/",
["mni_masked.nii.gz"],
["mni_masked.nii.gz"],
Expand All @@ -365,8 +362,7 @@ def fetcher():
False,
)

fetch_generative_loa_cone_config = (
"fetch_generative_loa_cone_config",
generative_loa_cone_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["degree.json"],
["degree.json"],
Expand All @@ -377,8 +373,7 @@ def fetcher():
False,
)

fetch_generative_seed_streamline_ratio_config = (
"fetch_generative_seed_streamline_ratio_config",
generative_seed_streamline_ratio_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["ratio.json"],
["ratio.json"],
Expand All @@ -389,8 +384,7 @@ def fetcher():
False,
)

fetch_generative_streamline_max_count_config = (
"fetch_generative_streamline_max_count_config",
generative_streamline_max_count_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["max_total_sampling.json"],
["max_total_sampling.json"],
Expand All @@ -401,8 +395,7 @@ def fetcher():
False,
)

fetch_generative_streamline_req_count_config = (
"fetch_generative_streamline_req_count_config",
generative_streamline_req_count_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["number_rejection_sampling.json"],
["number_rejection_sampling.json"],
Expand All @@ -413,8 +406,7 @@ def fetcher():
False,
)

fetch_generative_wm_tisue_criterion_config = (
"fetch_generative_wm_tisue_criterion_config",
generative_wm_tisue_criterion_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["white_matter_mask.json"],
["white_matter_mask.json"],
Expand All @@ -425,8 +417,7 @@ def fetcher():
False,
)

fetch_recobundlesx_atlas = (
"fetch_recobundlesx_atlas",
recobundlesx_atlas = (
TRACTOLEARN_DATASETS_URL + "7562635/files/",
["atlas.zip"],
["atlas.zip"],
Expand All @@ -437,8 +428,7 @@ def fetcher():
True,
)

fetch_recobundlesx_config = (
"fetch_recobundlesx_config",
recobundlesx_config = (
TRACTOLEARN_DATASETS_URL + "7562635/files/",
["config.zip"],
["config.zip"],
Expand All @@ -449,8 +439,7 @@ def fetcher():
True,
)

fetch_tractoinferno_hcp_contrastive_threshold_config = (
"fetch_tractoinferno_hcp_contrastive_threshold_config",
tractoinferno_hcp_contrastive_threshold_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["thresholds_contrastive_tractoinferno_hcp.json"],
["thresholds_contrastive_tractoinferno_hcp.json"],
Expand All @@ -461,8 +450,7 @@ def fetcher():
False,
)

fetch_tractoinferno_hcp_ref_tractography = (
"fetch_tractoinferno_hcp_ref_tractography",
tractoinferno_hcp_ref_tractography = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["data_tractoinferno_hcp_qbx.hdf5"],
["data_tractoinferno_hcp_qbx.hdf5"],
Expand All @@ -474,43 +462,59 @@ def fetcher():
)


def get_fetcher_method(name):
"""Provide the fetcher method corresponding to the method name.
def _get_fetcher_data(name):
"""Provide the fetcher method parameters corresponding to the method name.
Returns
-------
callable
Fetcher method.
Tuple
Fetcher method parameters.
"""

if name == Dataset.BUNDLE_LABEL_CONFIG.name:
return fetch_bundle_label_config
return bundle_label_config
elif name == Dataset.CONTRASTIVE_AUTOENCODER_WEIGHTS.name:
return fetch_contrastive_ae_weights
return contrastive_ae_weights
elif name == Dataset.MNI2009CNONLINSYMM_ANAT.name:
return fetch_mni2009cnonlinsymm_anat
return mni2009cnonlinsymm_anat
elif name == Dataset.GENERATIVE_LOA_CONE_CONFIG.name:
return fetch_generative_loa_cone_config
return generative_loa_cone_config
elif name == Dataset.GENERATIVE_SEED_STRML_RATIO_CONFIG.name:
return fetch_generative_seed_streamline_ratio_config
return generative_seed_streamline_ratio_config
elif name == Dataset.GENERATIVE_STRML_MAX_COUNT_CONFIG.name:
return fetch_generative_streamline_max_count_config
return generative_streamline_max_count_config
elif name == Dataset.GENERATIVE_STRML_RQ_COUNT_CONFIG.name:
return fetch_generative_streamline_req_count_config
return generative_streamline_req_count_config
elif name == Dataset.GENERATIVE_WM_TISSUE_CRITERION_CONFIG.name:
return fetch_generative_wm_tisue_criterion_config
return generative_wm_tisue_criterion_config
elif name == Dataset.RECOBUNDLESX_ATLAS.name:
return fetch_recobundlesx_atlas
return recobundlesx_atlas
elif name == Dataset.RECOBUNDLESX_CONFIG.name:
return fetch_recobundlesx_config
return recobundlesx_config
elif name == Dataset.TRACTOINFERNO_HCP_CONTRASTIVE_THR_CONFIG.name:
return fetch_tractoinferno_hcp_contrastive_threshold_config
return tractoinferno_hcp_contrastive_threshold_config
elif name == Dataset.TRACTOINFERNO_HCP_REF_TRACTOGRAPHY.name:
return fetch_tractoinferno_hcp_ref_tractography
return tractoinferno_hcp_ref_tractography
else:
raise DatasetError(_unknown_dataset_msg(name))


def _compose_fetcher_name(name):
"""Compose a name for the fetcher given the dataset name.
Parameters
----------
name : string
Dataset name.
Returns
-------
string
Fetcher name for dataset.
"""

return "fetcher_" + Dataset[name].value


def provide_dataset_description():
"""Provide the description of the available datasets.
Expand All @@ -525,7 +529,7 @@ def provide_dataset_description():
descr = list()

for elem in list(Dataset):
params = get_fetcher_method(elem.name)
params = _get_fetcher_data(elem.name)
descr.append(
elem.value
+ ": "
Expand Down Expand Up @@ -556,8 +560,9 @@ def retrieve_dataset(name, path):

logger.info(f"\nDataset: {name}")

params = get_fetcher_method(name)
files, folder = _make_fetcher(path, *params)()
params = _get_fetcher_data(name)
fetcher_name = _compose_fetcher_name(name)
files, folder = _make_fetcher(path, fetcher_name, *params)()

file_basename = list(files.keys())[0]

Expand Down

0 comments on commit 41d5931

Please sign in to comment.