Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrete Configuration Complex #65

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions configs/datasets/simple_configuration_graphs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: graph
data_type: toy_dataset
data_name: simple_configuration_graphs
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
num_features: 1
num_classes: 2
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
transform_type: 'lifting'
transform_name: "DiscreteConfigurationComplexLifting"
k: 2
feature_aggregation: "concat"
preserve_edge_attr: True
feature_lifting: ProjectionSum
5 changes: 5 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_simple_configuration_graphs,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -108,6 +109,10 @@ def load(self) -> torch_geometric.data.Dataset:
data = load_manual_graph()
dataset = CustomDataset([data], self.data_dir)

elif self.parameters.data_name in ["simple_configuration_graphs"]:
data = load_simple_configuration_graphs()
dataset = CustomDataset([*data], self.data_dir)

else:
raise NotImplementedError(
f"Dataset {self.parameters.data_name} not implemented"
Expand Down
55 changes: 47 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down Expand Up @@ -334,6 +334,45 @@ def load_manual_graph():
)


def load_simple_configuration_graphs():
"""Generate small graphs to illustrate the discrete configuration complex."""

# Y shaped graph
y_graph = nx.Graph()
y_graph.add_edges_from([(0, 1), (0, 2), (0, 3)])
y_data = torch_geometric.data.Data(
x=torch.tensor([0, 1, 2, 3]).unsqueeze(1).float(),
y=torch.tensor([0]),
edge_index=torch.Tensor(list(y_graph.edges())).T.long(),
num_nodes=4,
edge_attr=torch.Tensor([-1, -2, -3]).unsqueeze(1).float(),
)

# X shaped graph
x_graph = nx.Graph()
x_graph.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
x_data = torch_geometric.data.Data(
x=torch.tensor([0, 1, 2, 3, 4]).unsqueeze(1).float(),
y=torch.tensor([0]),
edge_index=torch.Tensor(list(x_graph.edges())).T.long(),
num_nodes=4,
edge_attr=torch.Tensor([-1, -2, -3, -4]).unsqueeze(1).float(),
)

# g shaped graph
g_graph = nx.Graph()
g_graph.add_edges_from([(0, 1), (1, 2), (2, 0), (2, 3)])
g_data = torch_geometric.data.Data(
x=torch.tensor([0, 1, 2, 3]).unsqueeze(1).float(),
y=torch.tensor([1]),
edge_index=torch.Tensor(list(g_graph.edges())).T.long(),
num_nodes=4,
edge_attr=torch.Tensor([-1, -2, -3, -4]).unsqueeze(1).float(),
)

return x_data, y_data, g_data


def get_Planetoid_pyg(cfg):
r"""Loads Planetoid graph datasets from torch_geometric.

Expand Down
4 changes: 4 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
)
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
from modules.transforms.liftings.graph2cell.cycle_lifting import CellCycleLifting
from modules.transforms.liftings.graph2cell.discrete_configuration_complex_lifting import (
DiscreteConfigurationComplexLifting,
)
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
Expand All @@ -31,6 +34,7 @@
"OneHotDegreeFeatures": OneHotDegreeFeatures,
"NodeFeaturesToFloat": NodeFeaturesToFloat,
"KeepOnlyConnectedComponent": KeepOnlyConnectedComponent,
"DiscreteConfigurationComplexLifting": DiscreteConfigurationComplexLifting,
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from itertools import permutations
from typing import ClassVar

import networkx as nx
import torch
import torch_geometric
from toponetx.classes import CellComplex

from modules.transforms.liftings.graph2cell.base import Graph2CellLifting
from modules.utils.utils import edge_cycle_to_vertex_cycle

Vertex = int
Edge = tuple[Vertex, Vertex]
ConfigurationTuple = tuple[Vertex | Edge]


class DiscreteConfigurationComplexLifting(Graph2CellLifting):
r"""Lifts graphs to cell complexes by generating the k-th *discrete configuration complex* $D_k(G)$ of the graph. This is a cube complex, which is similar to a simplicial complex except each n-dimensional cell is homeomorphic to a n-dimensional cube rather than an n-dimensional simplex.

The discrete configuration complex of order k consists of all sets of k unique edges or vertices of $G$, with the additional constraint that if an edge e is in a cell, then neither of the endpoints of e are in the cell. For examples of different graphs and their configuration complexes, see the tutorial.

Note that since TopoNetx only supports cell complexes of dimension 2, if you generate a configuration complex of order k > 2 this will only produce the 2-skeleton.

Parameters
----------
k: int,
The order of the configuration complex, i.e. the number of 'agents' in a single configuration.
preserve_edge_attr : bool, optional
Whether to preserve edge attributes. Default is True.
feature_aggregation: str, optional
For a k-agent configuration, the method by which the features are aggregated. Can be "mean", "sum", or "concat". Default is "concat".
**kwargs : optional
Additional arguments for the class.
"""

def __init__(
self,
k: int,
preserve_edge_attr: bool = True,
feature_aggregation="concat",
**kwargs,
):
self.k = k
self.complex_dim = 2
if feature_aggregation not in ["mean", "sum", "concat"]:
raise ValueError(
"feature_aggregation must be one of 'mean', 'sum', 'concat'"
)
self.feature_aggregation = feature_aggregation
super().__init__(preserve_edge_attr=preserve_edge_attr, **kwargs)

def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data:
r"""Applies the full lifting (topology + features) to the input data.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
torch_geometric.data.Data
The lifted data.
"""
# Unlike the base class, we do not pass the initial data to the final data
# This is because the configuration complex has a completely different 1-skeleton from the original graph
lifted_topology = self.lift_topology(data)
lifted_topology = self.feature_lifting(lifted_topology)
return torch_geometric.data.Data(y=data.y, **lifted_topology)

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Generates the cubical complex of discrete graph configurations.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
dict
The lifted topology.
"""
G = self._generate_graph_from_data(data)
if G.is_directed():
raise ValueError("Directed Graphs are not supported.")

Configuration = generate_configuration_class(
G, self.feature_aggregation, self.contains_edge_attr
)

# The vertices of the configuration complex are just tuples of k vertices
for dim_0_configuration_tuple in permutations(G, self.k):
configuration = Configuration(dim_0_configuration_tuple)
configuration.generate_upwards_neighbors()

cells = {i: [] for i in range(self.k + 1)}
for conf in Configuration.instances.values():
features = conf.features()
attrs = {"features": features} if features is not None else {}
cell = (conf.contents, attrs)
cells[conf.dim].append(cell)

# TopoNetX only supports cells of dimension <= 2
cc = CellComplex()
for node, attrs in cells[0]:
cc.add_node(node, **attrs)
for edge, attrs in cells[1]:
cc.add_edge(edge[0], edge[1], **attrs)
for cell, attrs in cells[2]:
cell_vertices = edge_cycle_to_vertex_cycle(cell)
cc.add_cell(cell_vertices, rank=2, **attrs)

return self._get_lifted_topology(cc, G)


def generate_configuration_class(
graph: nx.Graph, feature_aggregation: str, edge_features: bool
):
"""Class factory for the Configuration class."""

class Configuration:
"""Represents a single legal configuration of k agents on a graph G. A legal configuration is a tuple of k edges and vertices of G where all the vertices and endpoints are **distinct** i.e. no two edges sharing an endpoint can simultaneously be in the configuration, and adjacent (edge, vertex) pair can be contained in the configuration. Each configuration corresponds to a cell, and the number of edges in the configuration is the dimension.

Parameters
----------
k : int, optional.
The order of the configuration complex, or the number of 'points' in the configuration.
graph: nx.Graph.
The graph on which the configurations are defined.
"""

instances: ClassVar[dict[ConfigurationTuple, "Configuration"]] = {}

def __new__(cls, configuration_tuple: ConfigurationTuple):
# Ensure that a configuration tuple corresponds to a *unique* configuration object
key = configuration_tuple
if key not in cls.instances:
cls.instances[key] = super().__new__(cls)

return cls.instances[key]

def __init__(self, configuration_tuple: ConfigurationTuple) -> None:
# If this object was already initialized earlier, maintain current state
if hasattr(self, "initialized"):
return

self.initialized = True
self.configuration_tuple = configuration_tuple
self.neighborhood = set()
self.dim = 0
for agent in configuration_tuple:
if isinstance(agent, Vertex):
self.neighborhood.add(agent)
else:
self.neighborhood.update(set(agent))
self.dim += 1

if self.dim == 0:
self.contents = configuration_tuple
else:
self.contents = []

self._upwards_neighbors_generated = False

def features(self):
"""Generate the features for the configuration by combining the edge and vertex features."""
features = []
for agent in self.configuration_tuple:
if isinstance(agent, Vertex):
features.append(graph.nodes[agent]["features"])
elif edge_features:
features.append(graph.edges[agent]["features"])

if not features:
return None

if feature_aggregation == "mean":
try:
return torch.stack(features, dim=0).mean(dim=0)
except Exception as e:
raise ValueError(
"Failed to mean feature tensors. This may be because edge features and vertex features have different shapes. If this is the case, use feature_aggregation='concat', or disable edge features."
) from e
elif feature_aggregation == "sum":
try:
return torch.stack(features, dim=0).sum(dim=0)
except Exception as e:
raise ValueError(
"Failed to sum feature tensors. This may be because edge features and vertex features have different shapes. If this is the case, use feature_aggregation='concat', or disable edge features."
) from e
elif feature_aggregation == "concat":
return torch.concatenate(features, dim=-1)
else:
raise ValueError(
f"Unrecognized feature_aggregation: {feature_aggregation}"
)

def generate_upwards_neighbors(self):
"""For the configuration self of dimension d, generate the configurations of dimension d+1 containing it."""
if self._upwards_neighbors_generated:
return
self._upwards_neighbors_generated = True
for i, agent in enumerate(self.configuration_tuple):
if isinstance(agent, Vertex):
for neighbor in graph[agent]:
self._generate_single_neighbor(i, agent, neighbor)

def _generate_single_neighbor(
self, index: int, vertex_agent: int, neighbor: int
):
"""Generate a configuration containing self by moving an agent from a vertex onto an edge."""
# If adding the edge (vertex_agent, neighbor) would produce an illegal configuration, ignore it
if neighbor in self.neighborhood:
return

# We always orient edges (min -> max) to maintain uniqueness of configuration tuples
new_edge = (min(vertex_agent, neighbor), max(vertex_agent, neighbor))

# Remove the vertex at index and replace it with new edge
new_configuration_tuple = (
*self.configuration_tuple[:index],
new_edge,
*self.configuration_tuple[index + 1 :],
)
new_configuration = Configuration(new_configuration_tuple)
new_configuration.contents.append(self.contents)
new_configuration.generate_upwards_neighbors()

return Configuration
5 changes: 2 additions & 3 deletions modules/transforms/liftings/lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,9 @@ def _generate_graph_from_data(self, data: torch_geometric.data.Data) -> nx.Graph
if self.preserve_edge_attr and self._data_has_edge_attr(data):
# In case edge features are given, assign features to every edge
edge_index, edge_attr = (
data.edge_index,
data.edge_attr
(data.edge_index, data.edge_attr)
if is_undirected(data.edge_index, data.edge_attr)
else to_undirected(data.edge_index, data.edge_attr),
else to_undirected(data.edge_index, data.edge_attr)
)
edges = [
(i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1))
Expand Down
Loading
Loading