Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal Strict Lifting (Hypergraph to Combinatorial) #47

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions configs/datasets/manual_hypergraph_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: hypergraph
data_type: toy_dataset
data_name: manual
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
num_features: 1
num_classes: 2
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: node
5 changes: 5 additions & 0 deletions configs/models/combinatorial/hmc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
in_channels: null # This will be set by the dataset
hidden_channels: 32
out_channels: null # This will be set by the dataset
n_layers: 2
negative_slope: 0.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transform_type: "lifting"
transform_name: "UniversalStrictLifting"
10 changes: 10 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_manual_hypergraph,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -203,4 +204,13 @@ def load(
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
# Manual hypergraph
if self.parameters.data_name in ["manual"]:
root_folder = rootutils.find_root()
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])
self.data_dir = os.path.join(root_data_dir, self.parameters["data_name"])

data = load_manual_hypergraph()
return CustomDataset([data], self.parameters.data_dir)

return load_hypergraph_pickle_dataset(self.parameters)
130 changes: 114 additions & 16 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,71 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity


def get_combinatorial_complex_connectivity(complex, max_rank=None):
r"""Gets the connectivity matrices for the combinatorial complex.

Parameters
----------
complex : topnetx.CombinatorialComplex
Combinatorial complex.
max_rank : int
Maximum rank of the complex.

Returns
-------
dict
Dictionary containing the connectivity matrices.
"""
if max_rank is None:
max_rank = complex.dim
practical_shape = list(
np.pad(list(complex.shape), (0, max_rank + 1 - len(complex.shape)))
)

connectivity = {}

for rank_idx in range(max_rank + 1):
if rank_idx > 0:
try:
connectivity[f"incidence_{rank_idx}"] = from_sparse(
complex.incidence_matrix(rank=rank_idx - 1, to_rank=rank_idx)
)
except ValueError:
connectivity[
f"incidence_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)

try:
connectivity[f"adjacency_{rank_idx}"] = from_sparse(
complex.adjacency_matrix(rank=rank_idx, via_rank=rank_idx + 1)
)
except ValueError:
connectivity[f"adjacency_{rank_idx}"] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)

connectivity["shape"] = practical_shape

return connectivity


def generate_zero_sparse_connectivity(m, n):
r"""Generates a zero sparse connectivity matrix.

Expand Down Expand Up @@ -218,17 +268,13 @@ def load_hypergraph_pickle_dataset(cfg):

print(f"number of hyperedges: {len(hypergraph)}")

edge_idx = 0 # num_nodes
node_list = []
edge_list = []
for he in hypergraph:
cur_he = hypergraph[he]
cur_size = len(cur_he)

node_list += list(cur_he)
edge_list += [edge_idx] * cur_size

edge_idx += 1
for edge_idx, cur_he in enumerate(hypergraph.values()):
cur_size = len(cur_he)
node_list.extend(cur_he)
edge_list.extend([edge_idx] * cur_size)

# check that every node is in some hyperedge
if len(np.unique(node_list)) != num_nodes:
Expand Down Expand Up @@ -334,6 +380,58 @@ def load_manual_graph():
)


def load_manual_hypergraph():
"""Create a manual hypergraph for testing purposes."""
# Define the vertices (just 8 vertices)
vertices = [i for i in range(8)]
y = [0, 1, 1, 1, 0, 0, 0, 0]
# Define the hyperedges
hyperedges = [
[0, 1, 2, 3],
[4, 5, 6, 7],
[0, 1, 2],
[0, 1, 3],
[0, 2, 3],
[1, 2, 3],
[0, 1],
[0, 2],
[0, 3],
[1, 2],
[1, 3],
[2, 3],
[3, 4],
[4, 5],
[4, 7],
[5, 6],
[6, 7],
]

# Generate feature from 0 to 7
x = torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000]).unsqueeze(1).float()
labels = torch.tensor(y, dtype=torch.long)

node_list = []
edge_list = []

for edge_idx, he in enumerate(hyperedges):
cur_size = len(he)
node_list += he
edge_list += [edge_idx] * cur_size

edge_index = np.array([node_list, edge_list], dtype=int)
edge_index = torch.LongTensor(edge_index)

incidence_hyperedges = torch.sparse_coo_tensor(
edge_index,
values=torch.ones(edge_index.shape[1]),
size=(len(vertices), len(hyperedges)),
)

return Data(
x=x, edge_index=edge_index, y=labels, incidence_hyperedges=incidence_hyperedges
)


def get_Planetoid_pyg(cfg):
r"""Loads Planetoid graph datasets from torch_geometric.

Expand Down
78 changes: 78 additions & 0 deletions modules/models/combinatorial/hmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
from topomodelx.nn.combinatorial.hmc import HMC


class HMCModel(torch.nn.Module):
r"""A simple HMC model that runs over combinatorial complex data.
Note that some parameters are defined by the considered dataset.

Parameters
----------
model_config : Dict | DictConfig
Model configuration.
dataset_config : Dict | DictConfig
Dataset configuration.
"""

def __init__(self, model_config, dataset_config):
in_channels = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]
n_layers = model_config["n_layers"]
negative_slope = model_config["negative_slope"]

super().__init__()

in_channels_layer = [in_channels, in_channels, in_channels]
int_channels_layer = [hidden_channels, hidden_channels, hidden_channels]
out_channels_layer = [hidden_channels, hidden_channels, hidden_channels]

channels_per_layer = [
[in_channels_layer, int_channels_layer, out_channels_layer]
]

for _ in range(1, n_layers):
in_channels_layer = [hidden_channels, hidden_channels, hidden_channels]
int_channels_layer = [hidden_channels, hidden_channels, hidden_channels]
out_channels_layer = [hidden_channels, hidden_channels, hidden_channels]

channels_per_layer.append(
[in_channels_layer, int_channels_layer, out_channels_layer]
)

self.base_model = HMC(
channels_per_layer=channels_per_layer, negative_slope=negative_slope
)
self.linear = torch.nn.Linear(hidden_channels, out_channels)

def forward(self, data):
r"""Forward pass of the model.

Parameters
----------
data : torch_geometric.data.Data
Input data.

Returns
-------
torch.Tensor
Output tensor.
"""
x = self.base_model(
data.x_0,
data.x_1,
data.x_2,
data.adjacency_0,
data.adjacency_1,
data.adjacency_2,
data.incidence_1,
data.incidence_2,
)[1]

x = self.linear(x)

return torch.sigmoid(x)
5 changes: 5 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.hypergraph2combinatorial.universal_strict_lifting import (
UniversalStrictLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
Expand All @@ -23,6 +26,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Hypergraph -> Combinatorial Complex
"UniversalStrictLifting": UniversalStrictLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
52 changes: 52 additions & 0 deletions modules/transforms/liftings/hypergraph2combinatorial/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
from toponetx import CombinatorialComplex

from modules.data.utils.utils import get_combinatorial_complex_connectivity
from modules.transforms.liftings.lifting import HypergraphLifting


class Hypergraph2CombinatorialLifting(HypergraphLifting):
r"""Abstract class for lifting hypergraphs to combinatorial complexes.

Parameters
----------
**kwargs : optional
Additional arguments for the class.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "hypergraph2combinatorial"

def _get_lifted_topology(self, combinatorial_complex: CombinatorialComplex) -> dict:
r"""Returns the lifted topology.

Parameters
----------
combinatorial_complex : CombinatorialComplex
The combinatorial complex.

Returns
-------
dict
The lifted topology.
"""
lifted_topology = get_combinatorial_complex_connectivity(combinatorial_complex)

# Feature liftings

features = combinatorial_complex.get_cell_attributes("features")

for i in range(combinatorial_complex.dim + 1):
x = [
feat
for cell, feat in features
if combinatorial_complex.cells.get_rank(cell) == i
]
if x:
lifted_topology[f"x_{i}"] = torch.stack(x)
else:
num_cells = len(combinatorial_complex.skeleton(i))
lifted_topology[f"x_{i}"] = torch.zeros(num_cells, 1)

return lifted_topology
Loading
Loading