Skip to content

Commit

Permalink
feat(datasets) Add label distribution visualization (#3451)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <javier@flower.ai>
  • Loading branch information
adam-narozniak and jafermarq authored Jun 6, 2024
1 parent 72244a8 commit 097b803
Show file tree
Hide file tree
Showing 12 changed files with 1,039 additions and 0 deletions.
3 changes: 3 additions & 0 deletions datasets/flwr_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

from flwr_datasets import partitioner, preprocessor
from flwr_datasets import utils as utils
from flwr_datasets import visualization
from flwr_datasets.common.version import package_version as _package_version
from flwr_datasets.federated_dataset import FederatedDataset

__all__ = [
"FederatedDataset",
"partitioner",
"metrics",
"visualization",
"preprocessor",
"utils",
]
Expand Down
23 changes: 23 additions & 0 deletions datasets/flwr_datasets/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Metrics package."""


from flwr_datasets.metrics.utils import compute_counts, compute_frequency

__all__ = [
"compute_counts",
"compute_frequency",
]
78 changes: 78 additions & 0 deletions datasets/flwr_datasets/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for metrics computation."""


from typing import List, Union

import pandas as pd


def compute_counts(
labels: Union[List[int], List[str]], unique_labels: Union[List[int], List[str]]
) -> pd.Series:
"""Compute the count of labels when taking into account all possible labels.
Also known as absolute frequency.
Parameters
----------
labels: Union[List[int], List[str]]
The labels from the datasets.
unique_labels: Union[List[int], List[str]]
The reference all unique label. Needed to avoid missing any label, instead
having the value equal to zero for them.
Returns
-------
label_counts: pd.Series
The pd.Series with label as indices and counts as values.
"""
if len(unique_labels) != len(set(unique_labels)):
raise ValueError("unique_labels must contain unique elements only.")
labels_series = pd.Series(labels)
label_counts = labels_series.value_counts()
label_counts_with_zeros = pd.Series(index=unique_labels, data=0)
label_counts_with_zeros = label_counts_with_zeros.add(
label_counts, fill_value=0
).astype(int)
return label_counts_with_zeros


def compute_frequency(
labels: Union[List[int], List[str]], unique_labels: Union[List[int], List[str]]
) -> pd.Series:
"""Compute the distribution of labels when taking into account all possible labels.
Also known as relative frequency.
Parameters
----------
labels: Union[List[int], List[str]]
The labels from the datasets.
unique_labels: Union[List[int], List[str]]
The reference all unique label. Needed to avoid missing any label, instead
having the value equal to zero for them.
Returns
-------
The pd.Series with label as indices and probabilities as values.
"""
counts = compute_counts(labels, unique_labels)
if len(labels) == 0:
counts = counts.astype(float)
return counts
counts = counts.divide(len(labels))
return counts
89 changes: 89 additions & 0 deletions datasets/flwr_datasets/metrics/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metrics utils."""
# pylint: disable=no-self-use


import unittest

import pandas as pd
from parameterized import parameterized

from flwr_datasets.metrics.utils import compute_counts, compute_frequency


class TestMetricsUtils(unittest.TestCase):
"""Test metrics utils."""

@parameterized.expand( # type: ignore
[
([1, 2, 2, 3], [1, 2, 3, 4], pd.Series([1, 2, 1, 0], index=[1, 2, 3, 4])),
([], [1, 2, 3], pd.Series([0, 0, 0], index=[1, 2, 3])),
([1, 1, 2], [1, 2, 3, 4], pd.Series([2, 1, 0, 0], index=[1, 2, 3, 4])),
]
)
def test_compute_counts(self, labels, unique_labels, expected) -> None:
"""Test if the counts are computed correctly."""
result = compute_counts(labels, unique_labels)
pd.testing.assert_series_equal(result, expected)

@parameterized.expand( # type: ignore
[
(
[1, 1, 2, 2, 2, 3],
[1, 2, 3, 4],
pd.Series([0.3333, 0.5, 0.1667, 0.0], index=[1, 2, 3, 4]),
),
([], [1, 2, 3], pd.Series([0.0, 0.0, 0.0], index=[1, 2, 3])),
(
["a", "b", "b", "c"],
["a", "b", "c", "d"],
pd.Series([0.25, 0.50, 0.25, 0.0], index=["a", "b", "c", "d"]),
),
]
)
def test_compute_distribution(self, labels, unique_labels, expected) -> None:
"""Test if the distributions are computed correctly."""
result = compute_frequency(labels, unique_labels)
pd.testing.assert_series_equal(result, expected, atol=0.001)

@parameterized.expand( # type: ignore
[
(["a", "b", "b", "c"], ["a", "b", "c"]),
([1, 2, 2, 3, 3, 3, 4], [1, 2, 3, 4]),
]
)
def test_distribution_sum_to_one(self, labels, unique_labels) -> None:
"""Test if distributions sum up to one."""
result = compute_frequency(labels, unique_labels)
self.assertAlmostEqual(result.sum(), 1.0)

def test_compute_counts_non_unique_labels(self) -> None:
"""Test if not having the unique labels raises ValueError."""
labels = [1, 2, 3]
unique_labels = [1, 2, 2, 3]
with self.assertRaises(ValueError):
compute_counts(labels, unique_labels)

def test_compute_distribution_non_unique_labels(self) -> None:
"""Test if not having the unique labels raises ValueError."""
labels = [1, 1, 2, 3]
unique_labels = [1, 1, 2, 3]
with self.assertRaises(ValueError):
compute_frequency(labels, unique_labels)


if __name__ == "__main__":
unittest.main()
24 changes: 24 additions & 0 deletions datasets/flwr_datasets/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visualization package."""


from .comparison_label_distribution import plot_comparison_label_distribution
from .label_distribution import plot_label_distributions

__all__ = [
"plot_label_distributions",
"plot_comparison_label_distribution",
]
143 changes: 143 additions & 0 deletions datasets/flwr_datasets/visualization/bar_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Label distribution bar plotting."""


from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import pandas as pd
from matplotlib import colors as mcolors
from matplotlib import pyplot as plt
from matplotlib.axes import Axes


# pylint: disable=too-many-arguments,too-many-locals,too-many-branches
def _plot_bar(
dataframe: pd.DataFrame,
axis: Optional[Axes],
figsize: Optional[Tuple[float, float]],
title: str,
colormap: Optional[Union[str, mcolors.Colormap]],
partition_id_axis: str,
size_unit: str,
legend: bool,
legend_title: Optional[str],
plot_kwargs: Optional[Dict[str, Any]],
legend_kwargs: Optional[Dict[str, Any]],
) -> Axes:

if axis is None:
if figsize is None:
figsize = _initialize_figsize(
partition_id_axis=partition_id_axis, num_partitions=dataframe.shape[0]
)
_, axis = plt.subplots(figsize=figsize)

# Handle plot_kwargs
if plot_kwargs is None:
plot_kwargs = {}

kind = "bar" if partition_id_axis == "x" else "barh"
if "kind" not in plot_kwargs:
plot_kwargs["kind"] = kind

# Handle non-optional parameters
plot_kwargs["title"] = title

# Handle optional parameters
if colormap is not None:
plot_kwargs["colormap"] = colormap
elif "colormap" not in plot_kwargs:
plot_kwargs["colormap"] = "RdYlGn"

if "xlabel" not in plot_kwargs and "ylabel" not in plot_kwargs:
xlabel, ylabel = _initialize_xy_labels(
size_unit=size_unit, partition_id_axis=partition_id_axis
)
plot_kwargs["xlabel"] = xlabel
plot_kwargs["ylabel"] = ylabel

# Make the x ticks readable (they appear 90 degrees rotated by default)
if "rot" not in plot_kwargs:
plot_kwargs["rot"] = 0

# Handle hard-coded parameters
# Legend is handled separately (via axes.legend call not in the plot())
if "legend" not in plot_kwargs:
plot_kwargs["legend"] = False

# Make the bar plot stacked
if "stacked" not in plot_kwargs:
plot_kwargs["stacked"] = True

axis = dataframe.plot(
ax=axis,
**plot_kwargs,
)

if legend:
if legend_kwargs is None:
legend_kwargs = {}

if legend_title is not None:
legend_kwargs["title"] = legend_title
elif "title" not in legend_kwargs:
legend_kwargs["title"] = "Labels"

if "loc" not in legend_kwargs:
legend_kwargs["loc"] = "outside center right"

if "bbox_to_anchor" not in legend_kwargs:
max_len_label_str = max([len(str(column)) for column in dataframe.columns])
shift = min(0.05 + max_len_label_str / 100, 0.15)
legend_kwargs["bbox_to_anchor"] = (1.0 + shift, 0.5)

handles, legend_labels = axis.get_legend_handles_labels()
_ = axis.figure.legend(
handles=handles[::-1], labels=legend_labels[::-1], **legend_kwargs
)

# Heuristic to make the partition id on xticks non-overlapping
if partition_id_axis == "x":
xticklabels = axis.get_xticklabels()
if len(xticklabels) > 20:
# Make every other xtick label not visible
for i, label in enumerate(xticklabels):
if i % 2 == 1:
label.set_visible(False)
return axis


def _initialize_figsize(
partition_id_axis: str,
num_partitions: int,
) -> Tuple[float, float]:
figsize = (0.0, 0.0)
if partition_id_axis == "x":
figsize = (6.4, 4.8)
elif partition_id_axis == "y":
figsize = (6.4, np.sqrt(num_partitions))
return figsize


def _initialize_xy_labels(size_unit: str, partition_id_axis: str) -> Tuple[str, str]:
xlabel = "Partition ID"
ylabel = "Count" if size_unit == "absolute" else "Percent %"

if partition_id_axis == "y":
xlabel, ylabel = ylabel, xlabel

return xlabel, ylabel
Loading

0 comments on commit 097b803

Please sign in to comment.