Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularity Maximization Lifting (Graph to Hypergraph) #49

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transform_type: 'lifting'
transform_name: "ModularityMaximizationLifting"
num_communities: 2
k_neighbors: 3
feature_lifting: ProjectionSum
4 changes: 4 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
from modules.transforms.liftings.graph2hypergraph.modularity_maximization_lifting import (
ModularityMaximizationLifting,
)
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
"HypergraphKNNLifting": HypergraphKNNLifting,
"ModularityMaximizationLifting": ModularityMaximizationLifting,
# Graph -> Simplicial Complex
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import torch
import torch_geometric

from modules.transforms.liftings.graph2hypergraph.base import Graph2HypergraphLifting


class ModularityMaximizationLifting(Graph2HypergraphLifting):
r"""Lifts graphs to hypergraph domain using modularity maximization and community detection.

This method creates hyperedges based on the community structure of the graph and
k-nearest neighbors within each community.

Parameters
----------
num_communities : int, optional
The number of communities to detect. Default is 2.
k_neighbors : int, optional
The number of nearest neighbors to consider within each community. Default is 3.
**kwargs : optional
Additional arguments for the base class.
"""

def __init__(self, num_communities=2, k_neighbors=3, **kwargs):
super().__init__(**kwargs)
self.num_communities = num_communities
self.k_neighbors = k_neighbors

def modularity_matrix(self, data):
r"""Compute the modularity matrix B of the graph.

B_ij = A_ij - (k_i * k_j) / (2m)

Parameters
----------
data : torch_geometric.data.Data
The input graph data.

Returns
-------
torch.Tensor
The modularity matrix B.
"""
a = torch.zeros((data.num_nodes, data.num_nodes))
a[data.edge_index[0], data.edge_index[1]] = 1
k = a.sum(dim=1)
m = data.edge_index.size(1) / 2
return a - torch.outer(k, k) / (2 * m)

def kmeans(self, x, n_clusters, n_iterations=100):
r"""Perform k-means clustering on the input data.

Note: This implementation uses random initialization, so results may vary
between runs even for the same input data.

Parameters
----------
x : torch.Tensor
The input data to cluster.
n_clusters : int
The number of clusters to form.
n_iterations : int, optional
The maximum number of iterations. Default is 100.

Returns
-------
torch.Tensor
The cluster assignments for each input point.

Warning
-------
Due to random initialization of centroids, the resulting hyperedges
may differ each time the code is run, even with the same input.
"""
# Initialize cluster centers randomly
centroids = x[torch.randperm(x.shape[0])[:n_clusters]]
cluster_assignments = torch.zeros(x.shape[0], dtype=torch.long)
for _ in range(n_iterations):
# Assign points to the nearest centroid
distances = torch.cdist(x, centroids)
cluster_assignments = torch.argmin(distances, dim=1)

# Update centroids
new_centroids = torch.stack(
[x[cluster_assignments == k].mean(dim=0) for k in range(n_clusters)]
)

if torch.allclose(centroids, new_centroids):
break

centroids = new_centroids

return cluster_assignments

def detect_communities(self, b):
r"""Detect communities using spectral clustering on the modularity matrix.

Parameters
----------
b : torch.Tensor
The modularity matrix.

Returns
-------
torch.Tensor
The community assignments for each node.
"""
eigvals, eigvecs = torch.linalg.eigh(b)
leading_eigvecs = eigvecs[
:, torch.argsort(eigvals, descending=True)[: self.num_communities]
]

# Use implemented k-means clustering on the leading eigenvectors
return self.kmeans(leading_eigvecs, self.num_communities)

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Lift the graph topology to a hypergraph based on community structure and k-nearest neighbors.

Parameters
----------
data : torch_geometric.data.Data
The input graph data.

Returns
-------
dict
A dictionary containing the incidence matrix of the hypergraph, number of hyperedges,
and the original node features.
"""
b = self.modularity_matrix(data)
community_assignments = self.detect_communities(b)

num_nodes = data.x.shape[0]
num_hyperedges = num_nodes
incidence_matrix = torch.zeros(num_nodes, num_nodes)

for i in range(num_nodes):
# Find nodes in the same community
same_community = (
(community_assignments == community_assignments[i]).nonzero().view(-1)
)

# Calculate distances to nodes in the same community
distances = torch.norm(
data.x[i].unsqueeze(0) - data.x[same_community], dim=1
)

# Select k nearest neighbors within the community
k = min(self.k_neighbors, len(same_community))
_, nearest_indices = torch.topk(distances, k, largest=False)
nearest_neighbors = same_community[nearest_indices]

# Create a hyperedge
incidence_matrix[i, nearest_neighbors] = 1
incidence_matrix[i, i] = 1 # Include the node itself

incidence_matrix = incidence_matrix.to_sparse_coo()

return {
"incidence_hyperedges": incidence_matrix,
"num_hyperedges": num_hyperedges,
"x_0": data.x,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
import torch

from modules.data.utils.utils import load_manual_graph
from modules.transforms.liftings.graph2hypergraph.modularity_maximization_lifting import (
ModularityMaximizationLifting,
)


class TestModularityMaximizationLifting:
"""Test the ModularityMaximizationLifting class."""

def setup_method(self):
# Load the graph
self.data = load_manual_graph()

# Initialize the ModularityMaximizationLifting class
self.lifting = ModularityMaximizationLifting(num_communities=2, k_neighbors=3)

def test_kmeans(self):
# Set a random seed for reproducibility
torch.manual_seed(42)

# Test the kmeans method
x = torch.tensor(
[
[1.0, 1.0],
[2.0, 1.0],
[3.0, 2.0],
[4.0, 2.0],
[5.0, 3.0],
[6.0, 4.0],
[7.0, 4.0],
[8.0, 4.0],
]
)
n_clusters = 2
n_iterations = 100
kmeans_clusters = self.lifting.kmeans(x, n_clusters, n_iterations)

expected_clusters = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0])

assert (
kmeans_clusters == expected_clusters
).all(), "Something is wrong with kmeans."

def test_modularity_matrix(self):
# Test the lift_topology method
data_modularity_matrix = self.lifting.modularity_matrix(self.data.clone())

expected_modularity_matrix = torch.tensor(
[
[-1.2308, 0.3846, -0.2308, -0.3077, 1.0000, -0.6154, 0.0000, 1.0000],
[-0.6154, -0.3077, 0.3846, -0.1538, 1.0000, -0.3077, 0.0000, 0.0000],
[-1.2308, -0.6154, -1.2308, 0.6923, 1.0000, 0.3846, 0.0000, 1.0000],
[-0.3077, -0.1538, -0.3077, -0.0769, 0.0000, -0.1538, 1.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[-0.6154, -0.3077, -0.6154, -0.1538, 0.0000, -0.3077, 1.0000, 1.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
]
)

# Round the modularity matrix to 4 decimal places for comparison
number_of_digits = 4
data_modularity_matrix_rounded = (
data_modularity_matrix * 10**number_of_digits
).round() / (10**number_of_digits)

assert (
expected_modularity_matrix == data_modularity_matrix_rounded
).all(), "Something is wrong with modularity matrix."

def test_detect_communities(self):
# Set a random seed for reproducibility
torch.manual_seed(42)

# Run the modularity matrix which is tested above
b = self.lifting.modularity_matrix(self.data.clone())

# Test the detect_communities method
detected_communities = self.lifting.detect_communities(b)

expected_communities = torch.tensor([0, 0, 0, 1, 0, 0, 0, 0])

assert (
detected_communities == expected_communities
).all(), "Something is wrong with detect communities."

def test_lift_topology(self):
# Set a random seed for reproducibility
torch.manual_seed(42)

# Test the lift_topology method
lifted_data = self.lifting.lift_topology(self.data.clone())

expected_n_hyperedges = self.data.num_nodes

expected_incidence_1 = torch.tensor(
[
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
]
)

assert (
expected_incidence_1 == lifted_data["incidence_hyperedges"].to_dense()
).all(), "Something is wrong with incidence_hyperedges."
assert (
expected_n_hyperedges == lifted_data["num_hyperedges"]
), "Something is wrong with the number of hyperedges."


if __name__ == "__main__":
pytest.main([__file__])
Loading
Loading