From b4cfd2f5c405710f2d95a0d6cb7f994584158cf5 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Fri, 21 Jun 2024 17:20:23 +0800 Subject: [PATCH] Refactor codes, add test cases and comments --- .../partitioner/semantic_partitioner.py | 296 +++++++++--------- .../partitioner/semantic_partitioner_test.py | 109 +++---- 2 files changed, 202 insertions(+), 203 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/semantic_partitioner.py b/datasets/flwr_datasets/partitioner/semantic_partitioner.py index 4f64cf3ee96..e31b94e0fe9 100644 --- a/datasets/flwr_datasets/partitioner/semantic_partitioner.py +++ b/datasets/flwr_datasets/partitioner/semantic_partitioner.py @@ -16,16 +16,9 @@ import warnings -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np -import torch -from scipy.optimize import linear_sum_assignment -from sklearn.decomposition import PCA -from sklearn.mixture import GaussianMixture -from sklearn.preprocessing import StandardScaler -from torch.distributions import MultivariateNormal, kl_divergence -from torchvision import models import datasets from flwr_datasets.common.typing import NDArrayFloat @@ -48,28 +41,31 @@ class SemanticPartitioner(Partitioner): (Cited from section 4.1 in the paper) - Semantic partitioner's goal is to reverse-engineer the federated dataset-generating - process so that each client possesses semantically similar data. For example, for - the EMNIST dataset, we expect every client (writer) to (i) write in a consistent style - for each digit (intra-client intra-label similarity) and (ii) use a consistent writing - style across all digits (intra-client inter-label similarity). A simple approach might - be to cluster similar examples together and sample client data from clusters. However, - if one directly clusters the entire dataset, the resulting clusters may end up largely - correlated to labels. To disentangle the effect of label heterogeneity and semantic - heterogeneity, we propose the following algorithm to enforce intra-client intra-label - similarity and intra-client inter-label similarity in two separate stages. - - • Stage 1: For each label, we embed examples using a pretrained neural network - (extracting semantic features), and fit a Gaussian Mixture Model to cluster pretrained - embeddings into groups. Note that this results in multiple groups per label. - This stage enforces intra-client intra-label consistency. - - • Stage 2: To package the clusters from different labels into clients, we aim to compute - an optimal multi-partite matching with cost-matrix defined by KL-divergence between - the Gaussian clusters. To reduce complexity, we heuristically solve the optimal multi-partite - matching by progressively solving the optimal bipartite matching at each time for - randomly-chosen label pairs. - This stage enforces intra-client inter-label consistency. + Semantic partitioner's goal is to reverse-engineer the federated + dataset-generating process so that each client possesses semantically + similar data. For example, for the EMNIST dataset, we expect every client + (writer) to (i) write in a consistent style for each digit + (intra-client intra-label similarity) and (ii) use a consistent writing style + across all digits (intra-client inter-label similarity). A simple approach + might be to cluster similar examples together and sample client data from + clusters. However, if one directly clusters the entire dataset, the resulting + clusters may end up largely correlated to labels. To disentangle the effect + of label heterogeneity and semantic heterogeneity, we propose the following + algorithm to enforce intra-client intra-label similarity and intra-client + inter-label similarity in two separate stages. + + • Stage 1: For each label, we embed examples using a pretrained neural + network (extracting semantic features), and fit a Gaussian Mixture Model + to cluster pretrained embeddings into groups. Note that this results + in multiple groups per label. This stage enforces intra-client + intra-label consistency. + + • Stage 2: To package the clusters from different labels into clients, + we aim to compute an optimal multi-partite matching with cost-matrix + defined by KL-divergence between the Gaussian clusters. To reduce complexity, + we heuristically solve the optimal multi-partite matching by progressively + solving the optimal bipartite matching at each time for randomly-chosen + label pairs. This stage enforces intra-client inter-label consistency. Parameters ---------- @@ -80,6 +76,8 @@ class SemanticPartitioner(Partitioner): efficient_net_type: int The type of pretrained EfficientNet model. Options: [0, 1, 2, 3, 4, 5, 6, 7], corresponding to EfficientNet B0-B7 models. + batch_size: int + The batch size for EfficientNet extracting embeddings. pca_components: int The number of PCA components for dimensionality reduction. gmm_max_iter: int @@ -111,12 +109,13 @@ class SemanticPartitioner(Partitioner): >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) >>> partition = fds.load_partition(0) >>> print(partition[0]) # Print the first example - >>> {'image': , 'label': 3} + {'image': , 'label': 3} >>> partition_sizes = partition_sizes = [ >>> len(fds.load_partition(partition_id)) for partition_id in range(5) >>> ] >>> print(sorted(partition_sizes)) - >>> [3163, 5278, 5496, 6320, 9522] + [3163, 5278, 5496, 6320, 9522] """ def __init__( # pylint: disable=R0913 @@ -124,6 +123,7 @@ def __init__( # pylint: disable=R0913 num_partitions: int, partition_by: str, efficient_net_type: int = 3, + batch_size: int = 32, pca_components: int = 128, gmm_max_iter: int = 100, gmm_init_params: str = "kmeans", @@ -133,19 +133,10 @@ def __init__( # pylint: disable=R0913 ) -> None: super().__init__() # Attributes based on the constructor - _efficient_nets_dict = [ - (models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT), - (models.efficientnet_b1, models.EfficientNet_B1_Weights.DEFAULT), - (models.efficientnet_b2, models.EfficientNet_B2_Weights.DEFAULT), - (models.efficientnet_b3, models.EfficientNet_B3_Weights.DEFAULT), - (models.efficientnet_b4, models.EfficientNet_B4_Weights.DEFAULT), - (models.efficientnet_b5, models.EfficientNet_B5_Weights.DEFAULT), - (models.efficientnet_b6, models.EfficientNet_B6_Weights.DEFAULT), - (models.efficientnet_b7, models.EfficientNet_B7_Weights.DEFAULT), - ] self._num_partitions = num_partitions self._partition_by = partition_by self._efficient_net_type = efficient_net_type + self._batch_size = batch_size self._efficient_output_size = 1280 # fixed in EfficientNet class self._pca_components = pca_components self._gmm_max_iter = gmm_max_iter @@ -160,12 +151,6 @@ def __init__( # pylint: disable=R0913 self._check_variable_validation() # Utility attributes # The attributes below are determined during the first call to load_partition - self._efficient_net_backbone: Callable[[Any], models.EfficientNet] = ( - _efficient_nets_dict[self._efficient_net_type][0] - ) - self._efficient_net_pretrained_weight: models.WeightsEnum = ( - _efficient_nets_dict[self._efficient_net_type][1] - ) self._unique_classes: Optional[Union[List[int], List[str]]] = None self._partition_id_to_indices: Dict[int, List[int]] = {} self._partition_id_to_indices_determined = False @@ -186,7 +171,7 @@ def load_partition(self, partition_id: int) -> datasets.Dataset: # The partitioning is done lazily - only when the first partition is # requested. Only the first call creates the indices assignments for all the # partition indices. - self._check_data_validation_if_needed() + self._check_data_type_if_needed() self._check_num_partitions_correctness_if_needed() self._check_pca_components_validation_if_needed() self._determine_partition_id_to_indices_if_needed() @@ -199,29 +184,57 @@ def num_partitions(self) -> int: self._determine_partition_id_to_indices_if_needed() return self._num_partitions - def _subsample(self, embeddings: NDArrayFloat, num_samples: int): + def _subsample(self, embeddings: NDArrayFloat, num_samples: int) -> NDArrayFloat: if len(embeddings) < num_samples: return embeddings idx_samples = self._rng_numpy.choice( len(embeddings), num_samples, replace=False ) - return embeddings[idx_samples] + return embeddings[idx_samples] # type: ignore + # pylint: disable=C0415, R0915 def _determine_partition_id_to_indices_if_needed(self) -> None: """Create an assignment of indices to the partition indices.""" if self._partition_id_to_indices_determined: return - - efficient_net: models.EfficientNet = self._efficient_net_backbone( - weights=self._efficient_net_pretrained_weight - ) + try: + import torch + from scipy.optimize import linear_sum_assignment + from sklearn.decomposition import PCA + from sklearn.mixture import GaussianMixture + from sklearn.preprocessing import StandardScaler + from torch.distributions import MultivariateNormal, kl_divergence + from torchvision import models + except ImportError: + raise ImportError( + "SemanticPartitioner requires scikit-learn, torch, " + "torchvision, scipy, and numpy." + ) from None + efficient_nets_dict = [ + (models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT), + (models.efficientnet_b1, models.EfficientNet_B1_Weights.DEFAULT), + (models.efficientnet_b2, models.EfficientNet_B2_Weights.DEFAULT), + (models.efficientnet_b3, models.EfficientNet_B3_Weights.DEFAULT), + (models.efficientnet_b4, models.EfficientNet_B4_Weights.DEFAULT), + (models.efficientnet_b5, models.EfficientNet_B5_Weights.DEFAULT), + (models.efficientnet_b6, models.EfficientNet_B6_Weights.DEFAULT), + (models.efficientnet_b7, models.EfficientNet_B7_Weights.DEFAULT), + ] + backbone = efficient_nets_dict[self._efficient_net_type][0] + pretrained_weight = efficient_nets_dict[self._efficient_net_type][1] + efficient_net: models.EfficientNet = backbone(weights=pretrained_weight) efficient_net.classifier = torch.nn.Flatten() + device = torch.device("cpu") if self._use_cuda: if torch.cuda.is_available(): device = torch.device("cuda") else: - warnings("No detected CUDA device, the device fallbacks to CPU.") + warnings.warn( + "No detected CUDA device, the device fallbacks to CPU.", + UserWarning, + stacklevel=1, + ) efficient_net.to(device) efficient_net.eval() @@ -229,29 +242,31 @@ def _determine_partition_id_to_indices_if_needed(self) -> None: self._unique_classes = self.dataset.unique(self._partition_by) assert self._unique_classes is not None - # Change targets list data type to numpy - batch_size = 8 + # Use EfficientNet to extract embeddings images = self._preprocess_dataset_images() - embeddings = [] + embedding_list = [] with torch.no_grad(): - for i in range(0, images.shape[0], batch_size): - x = torch.tensor( - images[i : i + batch_size], dtype=torch.float, device=device + for i in range(0, images.shape[0], self._batch_size): + batch = torch.tensor( + images[i : i + self._batch_size], dtype=torch.float, device=device ) - if x.shape[1] == 1: - x = x.broadcast_to((x.shape[0], 3, *x.shape[2:])) - embeddings.append(efficient_net(x).cpu().numpy()) - embeddings = np.concatenate(embeddings) - embeddings = StandardScaler(with_std=False).fit_transform(embeddings) + if batch.shape[1] == 1: + batch = batch.broadcast_to((batch.shape[0], 3, *batch.shape[2:])) + embedding_list.append(efficient_net(batch).cpu().numpy()) + embeddings_scaled: NDArrayFloat = StandardScaler(with_std=False).fit_transform( + np.concatenate(embedding_list) + ) - if 0 < self._pca_components < embeddings.shape[1]: + if 0 < self._pca_components < embeddings_scaled.shape[1]: pca = PCA(n_components=self._pca_components, random_state=self._seed) - pca.fit(self._subsample(embeddings, 100000)) + # 100000 refers to official implementation + pca.fit(self._subsample(embeddings_scaled, 100000)) targets = np.array(self.dataset[self._partition_by], dtype=np.int64) label_cluster_means = [None for _ in self._unique_classes] label_cluster_trils = [None for _ in self._unique_classes] + # Use Gaussian Mixture Model to cluster the embeddings gmm = GaussianMixture( n_components=self._num_partitions, max_iter=self._gmm_max_iter, @@ -260,14 +275,14 @@ def _determine_partition_id_to_indices_if_needed(self) -> None: random_state=self._seed, ) - label_cluster_list = [ + label_cluster_list: List[List[List[int]]] = [ [[] for _ in range(self._num_partitions)] for _ in self._unique_classes ] for label in self._unique_classes: - idx_current_label = np.where(targets == label)[0] + # 10000 refers to official implementation embeddings_of_current_label = self._subsample( - embeddings[idx_current_label], 10000 + embedding_list[idx_current_label], 10000 ) gmm.fit(embeddings_of_current_label) @@ -275,47 +290,55 @@ def _determine_partition_id_to_indices_if_needed(self) -> None: cluster_list = gmm.predict(embeddings_of_current_label) for idx, cluster in zip(idx_current_label.tolist(), cluster_list): - label_cluster_list[label][cluster].append(idx) + label_cluster_list[label][cluster].append(idx) # type: ignore - label_cluster_means[label] = torch.tensor(gmm.means_) - label_cluster_trils[label] = torch.linalg.cholesky( + label_cluster_means[label] = torch.tensor(gmm.means_) # type: ignore + label_cluster_trils[label] = torch.linalg.cholesky( # type: ignore torch.from_numpy(gmm.covariances_) ) - cluster_assignment = [ - [None for _ in range(self._num_partitions)] for _ in self._unique_classes - ] + # Start clustering + # Format: clusters[i] indicates label i is assigned to clients in clusters[i] + clusters: List[List[int]] = [[] for _ in self._unique_classes] partitions = list(range(self._num_partitions)) unmatched_labels = list(self._unique_classes) - latest_matched_label = self._rng_numpy.choice(unmatched_labels) - cluster_assignment[latest_matched_label] = partitions + latest_matched_label = self._rng_numpy.choice(unmatched_labels) # type: ignore + clusters[latest_matched_label] = partitions unmatched_labels.remove(latest_matched_label) while unmatched_labels: - label_to_match = self._rng_numpy.choice(unmatched_labels) - - cost_matrix = ( - _pairwise_kl_div( - means_1=label_cluster_means[latest_matched_label], - trils_1=label_cluster_trils[latest_matched_label], - means_2=label_cluster_means[label_to_match], - trils_2=label_cluster_trils[label_to_match], - device=device, - ) - .cpu() - .numpy() + label_to_match = self._rng_numpy.choice(unmatched_labels) # type: ignore + + num_dist_1, num_dist_2 = ( + label_cluster_means[latest_matched_label].shape[0], + label_cluster_means[label_to_match].shape[0], ) + cost_matrix = torch.zeros((num_dist_1, num_dist_2), device=device) + + for i in range(label_cluster_means[latest_matched_label].shape[0]): + for j in range(label_cluster_means[label_to_match].shape[0]): + cost_matrix[i, j] = kl_divergence( + MultivariateNormal( + loc=label_cluster_means[latest_matched_label][i], + scale_tril=label_cluster_trils[latest_matched_label][i], + ), + MultivariateNormal( + loc=label_cluster_means[label_to_match][j], + scale_tril=label_cluster_trils[label_to_match][j], + ), + ) + cost_matrix = cost_matrix.cpu().numpy() optimal_local_assignment = linear_sum_assignment(cost_matrix) for client_id in partitions: - cluster_assignment[label_to_match][ - optimal_local_assignment[1][client_id] - ] = cluster_assignment[latest_matched_label][ - optimal_local_assignment[0][client_id] - ] + clusters[label_to_match][optimal_local_assignment[1][client_id]] = ( + clusters[latest_matched_label][ + optimal_local_assignment[0][client_id] + ] + ) unmatched_labels.remove(label_to_match) latest_matched_label = label_to_match @@ -323,9 +346,9 @@ def _determine_partition_id_to_indices_if_needed(self) -> None: partition_id_to_indices: Dict[int, List[int]] = {i: [] for i in partitions} for label in self._unique_classes: - for partition_id in partitions: - partition_id_to_indices[cluster_assignment[label][partition_id]].extend( - label_cluster_list[label][partition_id] + for i in partitions: + partition_id_to_indices[clusters[label][i]].extend( # type: ignore + label_cluster_list[label][i] # type: ignore ) # Shuffle the indices not to have the datasets with targets in sequences like @@ -337,24 +360,25 @@ def _determine_partition_id_to_indices_if_needed(self) -> None: self._partition_id_to_indices = partition_id_to_indices self._partition_id_to_indices_determined = True - def _preprocess_dataset_images(self): - images = np.array(self.dataset[self._data_column_name], dtype=np.float32) - if len(images.shape) == 3: # 1D + def _preprocess_dataset_images(self) -> NDArrayFloat: + images = np.array(self.dataset[self._data_column_name], dtype=float) + if len(images.shape) == 3: # [B, H, W] images = np.reshape( images, (images.shape[0], 1, images.shape[1], images.shape[2]) ) elif len(images.shape) == 4: # 2D - x, y, z = images.shape[1:] - if z < x and z < y: # [H, W, C] + # [H, W, C] + if images.shape[3] < min(images.shape[1], images.shape[2]): images = np.transpose(images, (0, 3, 1, 2)) - elif x < y and x < z: # [C, H, W] + # [C, H, W] + elif images.shape[1] < min(images.shape[2], images.shape[3]): pass else: raise ValueError(f"The image shape is not supported. Now: {images.shape}") return images def _check_num_partitions_correctness_if_needed(self) -> None: - """Test num_partitions when the dataset is given (in load_partition).""" + """Test whether the number of partitions is valid.""" if not self._partition_id_to_indices_determined: if self._num_partitions > self.dataset.num_rows: raise ValueError( @@ -365,19 +389,19 @@ def _check_num_partitions_correctness_if_needed(self) -> None: def _check_pca_components_validation_if_needed(self) -> None: """Test whether pca_components is in the valid range.""" if not self._partition_id_to_indices_determined: - if self._pca_components > min( self.dataset.num_rows, self._efficient_output_size ): raise ValueError( "The pca_components needs to be smaller than " - f"min(the number of samples = {self.dataset.num_rows}, efficient net output size = 1280) " + f"min(the number of samples = {self.dataset.num_rows}, " + "efficient net output size = 1280) " "in the dataset or the output size of the efficient net. " f"Now: {self._pca_components}." ) - def _check_data_validation_if_needed(self): - """Test whether dataset is image dataset""" + def _check_data_type_if_needed(self) -> None: + """Test whether data is image-like.""" if not self._partition_id_to_indices_determined: features_dict = self.dataset.features.to_dict() self._data_column_name = list(features_dict.keys())[0] @@ -385,34 +409,37 @@ def _check_data_validation_if_needed(self): data = np.array( self.dataset[self._data_column_name][0], dtype=np.float32 ) - except: - raise TypeError( - "The dataset needs to be image dataset. " - f"Now: {type(self.dataset[self._data_column_name][0])}." - ) + except ValueError: + raise ValueError( + "The data needs to be able to transform to np.ndarray. " + ) from None - if not (2 <= len(data.shape) <= 3): + if not 2 <= len(data.shape) <= 3: raise ValueError( "The image shape is not supported. " "The image shape should among {[H, W], [C, H, W], [H, W, C]}. " f"Now: {data.shape}. " ) - elif len(data.shape) == 3: - x, y, z = data.shape - if not ((x < y and x < z) or (z < x and z < y)): + if len(data.shape) == 3: + smallest_axis = min(enumerate(data.shape), key=lambda x: x[1])[0] + # smallest axis (C) should be at the first or the last place. + if smallest_axis not in [0, 2]: raise ValueError( "The 3D image shape should be [C, H, W] or [H, W, C]. " f"Now: {data.shape}. " ) - def _check_variable_validation(self): + def _check_variable_validation(self) -> None: """Test class variables validation.""" if not self._num_partitions > 0: raise ValueError("The number of partitions needs to be greater than zero.") - if not (0 <= self._efficient_net_type < 8): + if not 0 <= self._efficient_net_type <= 7: raise ValueError( - "The efficient net type needs to be in the range of 0 to 7, indicates EfficientNet-B0 ~ B7" + "The efficient net type needs to be in the range of 0 to 7, " + "indicates EfficientNet-B0 ~ B7" ) + if self._batch_size <= 0: + raise ValueError("The batch size needs to be greater than zero.") if self._gmm_init_params not in ["kmeans", "k-means++", "random"]: raise ValueError( "The gmm_init_params needs to be in [kmeans, k-means++, random]" @@ -423,34 +450,15 @@ def _check_variable_validation(self): raise ValueError("The pca components needs to be greater than zero.") -def _pairwise_kl_div( - means_1: torch.Tensor, - trils_1: torch.Tensor, - means_2: torch.Tensor, - trils_2: torch.Tensor, - device: torch.device, -): - num_dist_1, num_dist_2 = means_1.shape[0], means_2.shape[0] - pairwise_kl_matrix = torch.zeros((num_dist_1, num_dist_2), device=device) - - for i in range(means_1.shape[0]): - for j in range(means_2.shape[0]): - pairwise_kl_matrix[i, j] = kl_divergence( - MultivariateNormal(means_1[i], scale_tril=trils_1[i]), - MultivariateNormal(means_2[j], scale_tril=trils_2[j]), - ) - return pairwise_kl_matrix - - if __name__ == "__main__": # ===================== Test with custom Dataset ===================== from datasets import Dataset - data = { + dataset = { "image": [np.random.randn(28, 28) for _ in range(50)], "label": [i % 3 for i in range(50)], } - dataset = Dataset.from_dict(data) + dataset = Dataset.from_dict(dataset) partitioner = SemanticPartitioner( num_partitions=5, partition_by="label", pca_components=30 ) diff --git a/datasets/flwr_datasets/partitioner/semantic_partitioner_test.py b/datasets/flwr_datasets/partitioner/semantic_partitioner_test.py index 730fd8e7545..dcd77d1287f 100644 --- a/datasets/flwr_datasets/partitioner/semantic_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/semantic_partitioner_test.py @@ -22,18 +22,19 @@ import numpy as np from parameterized import parameterized -from torchvision import models from datasets import Dataset from flwr_datasets.partitioner.semantic_partitioner import SemanticPartitioner +# pylint: disable=R0913 def _dummy_setup( - data_shape: tuple = (28, 28, 1), + data_shape: Tuple[int, ...] = (28, 28, 1), num_partitions: int = 3, num_rows: int = 10, partition_by: str = "label", efficient_net_type: int = 0, + batch_size: int = 32, pca_components: int = 6, gmm_max_iter: int = 2, gmm_init_params: str = "random", @@ -48,6 +49,7 @@ def _dummy_setup( num_partitions=num_partitions, partition_by=partition_by, efficient_net_type=efficient_net_type, + batch_size=batch_size, pca_components=pca_components, gmm_max_iter=gmm_max_iter, gmm_init_params=gmm_init_params, @@ -59,32 +61,34 @@ def _dummy_setup( class TestSemanticPartitionerSuccess(unittest.TestCase): """Test SemanticPartitioner used with no exceptions.""" - @parameterized.expand( + # pylint: disable=R0913 + @parameterized.expand( # type: ignore [ - # data_shape, num_partitions, num_rows, partition_by, efficient_net_type, pca_components, gmm_max_iter, gmm_init_params - ((28, 28, 1), 3, 50, "label", 0, 128, 2, "kmeans"), - ((1, 28, 28), 5, 100, "label", 2, 256, 1, "random"), - ((32, 32, 3), 5, 100, "label", 7, 256, 1, "k-means++"), + ((28, 28, 1), 3, 50, "label", 0, 32, 128, 2, "kmeans"), + ((1, 28, 28), 5, 100, "label", 2, 64, 256, 1, "random"), + ((32, 32, 3), 5, 100, "label", 7, 16, 256, 1, "k-means++"), ] ) def test_valid_initialization( self, - data_shape: tuple, + data_shape: Tuple[int], num_partitions: int, num_rows: int, partition_by: str, efficient_net_type: int, + batch_size: int, pca_components: int, gmm_max_iter: int, gmm_init_params: str, ) -> None: - """Test if alpha is correct scaled based on the given num_partitions.""" + """Test whether initializaiton is successful.""" _, partitioner = _dummy_setup( data_shape=data_shape, num_partitions=num_partitions, num_rows=num_rows, partition_by=partition_by, efficient_net_type=efficient_net_type, + batch_size=batch_size, pca_components=pca_components, gmm_max_iter=gmm_max_iter, gmm_init_params=gmm_init_params, @@ -108,44 +112,28 @@ def test_valid_initialization( ), ) - @parameterized.expand([((28, 28, 1),), ((3, 32, 32),), ((28, 28),)]) - def test_data_shape(self, data_shape: tuple): + # pylint: disable=R0201 + @parameterized.expand([((28, 28, 1),), ((3, 32, 32),), ((28, 28),)]) # type: ignore + def test_data_shape(self, data_shape: Tuple[int]) -> None: """Test if data_shape is correct.""" _, partitioner = _dummy_setup(data_shape=data_shape) partitioner.load_partition(0) - @parameterized.expand([(0,), (1,), (2,), (3,)]) - def test_efficient_net_config(self, efficient_net_type: int): - """Test if efficient_net_backbone and efficient_net_pretrained_weight are correct.""" - efficient_nets_dict = [ - (models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT), - (models.efficientnet_b1, models.EfficientNet_B1_Weights.DEFAULT), - (models.efficientnet_b2, models.EfficientNet_B2_Weights.DEFAULT), - (models.efficientnet_b3, models.EfficientNet_B3_Weights.DEFAULT), - (models.efficientnet_b4, models.EfficientNet_B4_Weights.DEFAULT), - (models.efficientnet_b5, models.EfficientNet_B5_Weights.DEFAULT), - (models.efficientnet_b6, models.EfficientNet_B6_Weights.DEFAULT), - (models.efficientnet_b7, models.EfficientNet_B7_Weights.DEFAULT), - ] + @parameterized.expand([(0,), (3,), (7,)]) # type: ignore + def test_efficient_net(self, efficient_net_type: int) -> None: + """Test if efficient_net_type is correct.""" _, partitioner = _dummy_setup(efficient_net_type=efficient_net_type) - self.assertEqual( - ( - partitioner._efficient_net_backbone, - partitioner._efficient_net_pretrained_weight, - ), - ( - efficient_nets_dict[efficient_net_type][0], - efficient_nets_dict[efficient_net_type][1], - ), - ) + partitioner.load_partition(0) - @parameterized.expand([(64,), (96,), (32,)]) + @parameterized.expand([(64,), (96,), (32,)]) # type: ignore def test_pca_components(self, pca_components: int) -> None: - """Test if pca_components is correct scaled based on the given num_partitions.""" + """Test if pca_components is correct.""" _, partitioner = _dummy_setup(num_rows=100, pca_components=pca_components) self.assertEqual(partitioner._pca_components, pca_components) - @parameterized.expand([(1, "random"), (2, "kmeans"), (2, "k-means++")]) + @parameterized.expand( # type: ignore + [(1, "random"), (2, "kmeans"), (2, "k-means++")] + ) def test_gaussian_mixture_model( self, gmm_max_iter: int, gmm_init_params: str ) -> None: @@ -160,7 +148,6 @@ def test_gaussian_mixture_model( def test_determine_partition_id_to_indices(self) -> None: """Test the determine_nod_id_to_indices matches the flag after the call.""" - _, partitioner = _dummy_setup() partitioner._determine_partition_id_to_indices_if_needed() self.assertTrue( @@ -172,12 +159,12 @@ def test_determine_partition_id_to_indices(self) -> None: class TestSemanticPartitionerFailure(unittest.TestCase): """Test SemanticPartitioner failures (exceptions) by incorrect usage.""" - def test_invalid_dataset_type(self): + def test_invalid_dataset_type(self) -> None: """Test if raises when the dataset is not an image dataset.""" alphabets = list(string.ascii_uppercase) data = { "letters": [alphabets[i % len(alphabets)] for i in range(300)], - "label": [i for i in range(300)], + "label": list(range(300)), } dataset = Dataset.from_dict(data) partitioner = SemanticPartitioner(num_partitions=3, partition_by="label") @@ -185,44 +172,48 @@ def test_invalid_dataset_type(self): with self.assertRaises(TypeError): partitioner.load_partition(0) - @parameterized.expand([((28, 1, 28),), ((3, 3, 32, 32),), ((28,),)]) - def test_invalid_data_shape(self, data_shape: tuple): + @parameterized.expand([(0,), (-1,)]) # type: ignore + def test_invalid_batch_size(self, batch_size: int) -> None: + """Test if raises when the batch_size is not a positive integer.""" + with self.assertRaises(ValueError): + _, partitioner = _dummy_setup(batch_size=batch_size) + partitioner.load_partition(0) + + @parameterized.expand([((28, 1, 28),), ((3, 3, 32, 32),), ((28,),)]) # type: ignore + def test_invalid_data_shape(self, data_shape: Tuple[int]) -> None: """Test if raises when the data_shape is not a tuple of length 2.""" with self.assertRaises(ValueError): _, partitioner = _dummy_setup(data_shape=data_shape) partitioner.load_partition(0) @parameterized.expand([(-2,), (-1,), (3,), (4,)]) # type: ignore - def test_load_invalid_partition_index(self, partition_id): + def test_load_invalid_partition_index(self, partition_id: int) -> None: """Test if raises when the load_partition is above the num_partitions.""" _, partitioner = _dummy_setup(num_partitions=3) with self.assertRaises(KeyError): partitioner.load_partition(partition_id) - @parameterized.expand([(-1,), (2.5,), (9,), (8), (7.0,)]) - def test_invalid_efficient_net_type(self, efficient_net_type): - """ - Test if efficient_net_type is not an integer or not in range [0, 7].""" + @parameterized.expand([(-1,), (2.5,), (9,), (8), (7.0,)]) # type: ignore + def test_invalid_efficient_net_type(self, efficient_net_type: int) -> None: + """Test if efficient_net_type is not an integer or not in range [0, 7].""" with self.assertRaises((ValueError, TypeError)): - SemanticPartitioner(efficient_net_type=efficient_net_type) + _dummy_setup(efficient_net_type=efficient_net_type) - @parameterized.expand( # type: ignore - [(0,), (-1,), (11,), (100,)] - ) # num_partitions, - def test_invalid_num_partitions(self, num_partitions): + @parameterized.expand([(0,), (-1,), (11,), (100,)]) # type: ignore + def test_invalid_num_partitions(self, num_partitions: int) -> None: """Test if 0 is invalid num_partitions.""" with self.assertRaises(ValueError): _, partitioner = _dummy_setup(num_partitions=num_partitions, num_rows=10) partitioner.load_partition(0) - @parameterized.expand([(0,), (-1,), (2.0,), (11,)]) - def test_invalid_pca_components(self, pca_components): + @parameterized.expand([(0,), (-1,), (2.0,), (11,)]) # type: ignore + def test_invalid_pca_components(self, pca_components: int) -> None: """Test if pca_components is not a positive integer.""" with self.assertRaises((ValueError, TypeError)): _, partitioner = _dummy_setup(pca_components=pca_components) partitioner.load_partition(0) - @parameterized.expand( + @parameterized.expand( # type: ignore [ (0, "random"), (-1, "keams"), @@ -232,10 +223,10 @@ def test_invalid_pca_components(self, pca_components): (10, "kmeans++"), ] ) - def test_invalid_gaussian_mixture_config(self, gmm_max_iter, gmm_init_params): - """ - Test if gmm_max_iter is not a positive integer or gmm_init_params is not one of the allowed values. - """ + def test_invalid_gaussian_mixture_config( + self, gmm_max_iter: int, gmm_init_params: str + ) -> None: + """Test if gmm_max_iter and gmm_init_params are not valid.""" with self.assertRaises(ValueError): _, partitioner = _dummy_setup( gmm_max_iter=gmm_max_iter, gmm_init_params=gmm_init_params